use {
super::{
KeyspaceError,
KeyspaceResult,
interval::Interval,
node::KeyspaceNode,
sharding::Shards,
},
std::{collections::HashMap, fmt, ops::Deref},
};
pub struct MigrationPlan<N: KeyspaceNode> {
intervals: HashMap<N::Id, Vec<Interval<N>>>,
version: u64,
}
impl<N: KeyspaceNode> Deref for MigrationPlan<N> {
type Target = HashMap<N::Id, Vec<Interval<N>>>;
fn deref(&self) -> &Self::Target {
&self.intervals
}
}
impl<N> fmt::Debug for MigrationPlan<N>
where
N: KeyspaceNode,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MigrationPlan")
.field("intervals", &self.intervals)
.finish_non_exhaustive()
}
}
impl<N: KeyspaceNode> MigrationPlan<N> {
pub(crate) fn new<const RF: usize>(
version: u64,
old_shards: &Shards<N, RF>,
new_shards: &Shards<N, RF>,
) -> KeyspaceResult<Self> {
let mut intervals = HashMap::new();
if old_shards.len() != new_shards.len() {
return Err(KeyspaceError::ShardCountMismatch);
}
for (old_shard, new_shard) in old_shards.iter().zip(new_shards.iter()) {
let old_replica_set = old_shard.replica_set();
let new_replica_set = new_shard.replica_set();
if old_replica_set == new_replica_set {
continue;
}
assert_eq!(old_shard.key_range(), new_shard.key_range());
let key_range = old_shard.key_range();
for target_node in new_replica_set.iter().cloned() {
if old_replica_set.contains(&target_node) {
continue;
}
intervals
.entry(target_node.id().clone())
.or_insert_with(Vec::new)
.push(Interval::new(key_range, old_replica_set.iter().cloned()));
}
}
Ok(Self { version, intervals })
}
pub fn version(&self) -> u64 {
self.version
}
pub fn pull_intervals(&self, node_id: &N::Id) -> impl Iterator<Item = &Interval<N>> {
self.intervals
.get(node_id)
.into_iter()
.flat_map(|intervals| intervals.iter())
}
}