use std::rc::Rc;
use indexmap::IndexSet;
use super::super::types::*;
use super::RescaleMessage;
use super::{collect::CollectRouter, MessageRouter};
use crate::types::{Key, WorkerId};
#[derive(Debug)]
pub(crate) struct InterrogateRouter<K> {
pub(super) version: Version,
pub(super) old_worker_set: IndexSet<WorkerId>,
pub(super) new_worker_set: IndexSet<WorkerId>,
interrogate_msg: Interrogate<K>,
trigger: RescaleMessage,
}
impl<K> InterrogateRouter<K>
where
K: Key,
{
pub(super) fn new(
version: Version,
old_worker_set: IndexSet<WorkerId>,
trigger: RescaleMessage,
partitioner: WorkerPartitioner<K>,
) -> (Self, Interrogate<K>)
where
K: Key,
{
let new_worker_set = trigger.get_new_workers().clone();
let old_worker_set_clone = old_worker_set.clone();
let new_worker_set_clone = new_worker_set.clone();
let key_needs_to_be_moved = Rc::new(move |key: &K| {
let original_target = partitioner(key, &old_worker_set_clone);
let new_target = partitioner(key, &new_worker_set_clone);
original_target != new_target
});
let interrogate_msg = Interrogate::new(key_needs_to_be_moved);
let new_state = InterrogateRouter {
version,
old_worker_set,
new_worker_set,
interrogate_msg: interrogate_msg.clone(),
trigger,
};
(new_state, interrogate_msg)
}
pub(super) fn route_message(
&mut self,
key: &K,
partitioner: WorkerPartitioner<K>,
this_worker: WorkerId,
) -> WorkerId {
let old_target = partitioner(key, &self.old_worker_set);
let new_target = partitioner(key, &self.new_worker_set);
match (old_target == this_worker, new_target == this_worker) {
(true, false) => {
self.interrogate_msg.add_keys(&[key.clone()]);
this_worker
}
(true, true) => this_worker,
(false, _) => old_target,
}
}
pub(crate) fn lifecycle<V, T>(self) -> MessageRouter<K, V, T> {
match self.interrogate_msg.try_unwrap() {
Ok(whitelist) => {
let router = CollectRouter::new(
whitelist,
self.old_worker_set,
self.new_worker_set,
self.trigger,
);
MessageRouter::Collect(router)
}
Err(e) => {
let router = Self {
interrogate_msg: e,
..self
};
MessageRouter::Interrogate(router)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::keyed::partitioners::index_select;
#[test]
fn create_new_worker_set() {
let trigger = RescaleMessage::new(IndexSet::from([0, 1, 2, 3]), 1);
let (router, _) = InterrogateRouter::new(0, IndexSet::from([0]), trigger, index_select);
assert_eq!(router.new_worker_set, IndexSet::from([0, 1, 2, 3]));
let trigger = RescaleMessage::new(IndexSet::from([0, 1]), 2);
let (router, _) =
InterrogateRouter::new(0, IndexSet::from([0, 1, 2, 3]), trigger, index_select);
assert_eq!(router.new_worker_set, IndexSet::from([0, 1]));
}
#[test]
fn creates_interrogate() {
let trigger = RescaleMessage::new(IndexSet::from([0, 1]), 1);
let (mut router, mut interrogate) =
InterrogateRouter::new(0, IndexSet::from([0]), trigger, index_select);
interrogate.add_keys(&[0, 1, 2, 3, 4]);
router.route_message(&5, index_select, 0);
drop(interrogate);
let collect: MessageRouter<u64, i32, i32> = router.lifecycle();
match collect {
MessageRouter::Collect(c) => {
assert_eq!(c.whitelist, IndexSet::from([1, 3, 5]))
}
_ => panic!(),
}
}
#[test]
fn noop_if_interrogate_is_running() {
let trigger = RescaleMessage::new(IndexSet::from([1]), 1);
let (router, interrogate) =
InterrogateRouter::new(0, IndexSet::from([0]), trigger, index_select);
let router: MessageRouter<u64, i32, i32> = router.lifecycle();
let router = match router {
MessageRouter::Interrogate(x) => x,
_ => panic!(),
};
drop(interrogate);
let collect: MessageRouter<u64, i32, i32> = router.lifecycle();
assert!(matches!(collect, MessageRouter::Collect(_)));
}
#[test]
fn handle_data_rule_1_1() {
let trigger = RescaleMessage::new(IndexSet::from([1]), 1);
let (mut router, interrogate) =
InterrogateRouter::new(0, IndexSet::from([0]), trigger, index_select);
let target = router.route_message(&43, index_select, 0);
assert_eq!(target, 0);
drop(interrogate);
let collect: MessageRouter<u64, i32, i32> = router.lifecycle();
match collect {
MessageRouter::Collect(c) => {
assert!(c.whitelist.contains(&43))
}
_ => panic!(),
}
}
#[test]
fn handle_data_rule_1_2() {
let trigger = RescaleMessage::new(IndexSet::from([1]), 1);
let (mut router, _interrogate) =
InterrogateRouter::new(0, IndexSet::from([0]), trigger, index_select);
let target = router.route_message(&44, index_select, 0);
assert_eq!(target, 0);
}
#[test]
fn handle_data_rule_2() {
let trigger = RescaleMessage::new(IndexSet::from([1]), 1);
let (mut router, _interrogate) =
InterrogateRouter::new(0, IndexSet::from([0, 1]), trigger, index_select);
let target = router.route_message(&11, index_select, 0);
assert_eq!(target, 1)
}
}