use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use osproxy_core::{Epoch, PartitionId};
use osproxy_spi::{MigrationPhase, Placement, PlacementAt};
use crate::migration::{MigrationError, PartitionState, Phase, WriteAdmission};
fn migration_phase(state: &PartitionState) -> MigrationPhase {
match state {
PartitionState::Active(_) => MigrationPhase::Settled,
PartitionState::Migrating {
phase: Phase::Draining,
..
} => MigrationPhase::Draining,
PartitionState::Migrating {
phase: Phase::Cutover,
..
} => MigrationPhase::Cutover,
}
}
#[derive(Clone, Debug)]
struct Entry {
state: PartitionState,
epoch: Epoch,
}
#[derive(Debug)]
pub struct PlacementTable {
entries: RwLock<HashMap<PartitionId, Entry>>,
generation: AtomicU64,
}
impl PlacementTable {
#[must_use]
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
generation: AtomicU64::new(0),
}
}
pub fn set(&self, partition: PartitionId, placement: Placement) -> Epoch {
let epoch = self.next_epoch();
self.write_lock().insert(
partition,
Entry::new(PartitionState::Active(placement), epoch),
);
epoch
}
pub fn begin_migration(
&self,
partition: &PartitionId,
to: Placement,
) -> Result<Epoch, MigrationError> {
self.transition(partition, |state| match state {
PartitionState::Active(from) => Ok(PartitionState::Migrating {
from,
to,
phase: Phase::Draining,
}),
PartitionState::Migrating { .. } => Err(MigrationError::AlreadyMigrating),
})
}
pub fn enter_cutover(&self, partition: &PartitionId) -> Result<Epoch, MigrationError> {
self.transition(partition, |state| match state {
PartitionState::Migrating {
from,
to,
phase: Phase::Draining,
} => Ok(PartitionState::Migrating {
from,
to,
phase: Phase::Cutover,
}),
PartitionState::Migrating { .. } => Err(MigrationError::NotDraining),
PartitionState::Active(_) => Err(MigrationError::NotMigrating),
})
}
pub fn complete_migration(&self, partition: &PartitionId) -> Result<Epoch, MigrationError> {
self.transition(partition, |state| match state {
PartitionState::Migrating {
to,
phase: Phase::Cutover,
..
} => Ok(PartitionState::Active(to)),
PartitionState::Migrating { .. } => Err(MigrationError::NotCutover),
PartitionState::Active(_) => Err(MigrationError::NotMigrating),
})
}
pub fn abort_migration(&self, partition: &PartitionId) -> Result<Epoch, MigrationError> {
self.transition(partition, |state| match state {
PartitionState::Migrating { from, .. } => Ok(PartitionState::Active(from)),
PartitionState::Active(_) => Err(MigrationError::NotMigrating),
})
}
#[must_use]
pub fn state(&self, partition: &PartitionId) -> Option<(PartitionState, Epoch)> {
self.read_lock()
.get(partition)
.map(|e| (e.state.clone(), e.epoch))
}
#[must_use]
pub fn get(&self, partition: &PartitionId) -> Option<PlacementAt> {
self.read_lock().get(partition).map(|e| {
PlacementAt::new(e.state.read_placement().clone(), e.epoch)
.with_phase(migration_phase(&e.state))
})
}
#[must_use]
pub fn admit_write(&self, partition: &PartitionId, epoch: Epoch) -> WriteAdmission {
let admit = self
.read_lock()
.get(partition)
.is_some_and(|e| e.state.write_placement().is_some() && e.epoch == epoch);
if admit {
WriteAdmission::Admit
} else {
WriteAdmission::Reject
}
}
#[must_use]
pub fn current_epoch(&self) -> Epoch {
Epoch::new(self.generation.load(Ordering::SeqCst))
}
fn next_epoch(&self) -> Epoch {
Epoch::new(self.generation.fetch_add(1, Ordering::SeqCst) + 1)
}
fn transition(
&self,
partition: &PartitionId,
f: impl FnOnce(PartitionState) -> Result<PartitionState, MigrationError>,
) -> Result<Epoch, MigrationError> {
let mut entries = self.write_lock();
let current = entries
.get(partition)
.ok_or(MigrationError::UnknownPartition)?;
let next = f(current.state.clone())?;
let epoch = self.next_epoch();
entries.insert(partition.clone(), Entry::new(next, epoch));
Ok(epoch)
}
fn read_lock(&self) -> std::sync::RwLockReadGuard<'_, HashMap<PartitionId, Entry>> {
self.entries
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn write_lock(&self) -> std::sync::RwLockWriteGuard<'_, HashMap<PartitionId, Entry>> {
self.entries
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
impl Entry {
fn new(state: PartitionState, epoch: Epoch) -> Self {
Self { state, epoch }
}
}
impl Default for PlacementTable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use osproxy_core::{ClusterId, IndexName};
fn shared(cluster: &str, index: &str) -> Placement {
Placement::SharedIndex {
cluster: ClusterId::from(cluster),
index: IndexName::from(index),
inject: Vec::new(),
}
}
#[test]
fn missing_partition_resolves_to_none() {
let table = PlacementTable::new();
assert!(table.get(&PartitionId::from("absent")).is_none());
assert_eq!(table.current_epoch(), Epoch::ZERO);
}
#[test]
fn set_assigns_monotonic_epochs() {
let table = PlacementTable::new();
let e1 = table.set(PartitionId::from("a"), shared("c", "i"));
let e2 = table.set(PartitionId::from("b"), shared("c", "i"));
assert_eq!(e1, Epoch::new(1));
assert_eq!(e2, Epoch::new(2));
assert!(e2 > e1);
assert_eq!(table.current_epoch(), e2);
}
#[test]
fn migration_replaces_placement_and_advances_epoch() {
let table = PlacementTable::new();
let p = PartitionId::from("t");
table.set(p.clone(), shared("old", "i"));
let before = table.get(&p).unwrap();
assert_eq!(before.placement.cluster().as_str(), "old");
let migrated = table.set(p.clone(), shared("new", "i"));
let after = table.get(&p).unwrap();
assert_eq!(after.placement.cluster().as_str(), "new");
assert_eq!(after.epoch, migrated);
assert!(after.epoch > before.epoch);
}
}