use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::{ClusterError, MigrationCheckpointError, Result};
use nodedb_types::Hlc;
pub type MigrationId = Uuid;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
PartialOrd,
Ord,
Serialize,
Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
pub enum MigrationPhaseTag {
AddLearner,
CatchUp,
PromoteLearner,
LeadershipTransfer,
Cutover,
Complete,
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
pub enum MigrationCheckpointPayload {
AddLearner {
vshard_id: u32,
source_node: u64,
target_node: u64,
source_group: u64,
write_pause_budget_us: u64,
started_at_hlc: Hlc,
},
CatchUp {
vshard_id: u32,
learner_log_index_at_add: u64,
},
PromoteLearner {
vshard_id: u32,
target_node: u64,
source_group: u64,
},
LeadershipTransfer {
vshard_id: u32,
target_is_voter: bool,
new_leader_node_id: u64,
source_group: u64,
},
Cutover {
vshard_id: u32,
new_leader_node_id: u64,
source_group: u64,
},
Complete {
vshard_id: u32,
actual_pause_us: u64,
ghost_stub_installed: bool,
},
}
impl MigrationCheckpointPayload {
pub fn phase_tag(&self) -> MigrationPhaseTag {
match self {
Self::AddLearner { .. } => MigrationPhaseTag::AddLearner,
Self::CatchUp { .. } => MigrationPhaseTag::CatchUp,
Self::PromoteLearner { .. } => MigrationPhaseTag::PromoteLearner,
Self::LeadershipTransfer { .. } => MigrationPhaseTag::LeadershipTransfer,
Self::Cutover { .. } => MigrationPhaseTag::Cutover,
Self::Complete { .. } => MigrationPhaseTag::Complete,
}
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
zerompk::to_msgpack_vec(self).map_err(|e| {
ClusterError::MigrationCheckpoint(MigrationCheckpointError::Codec {
detail: format!("payload encode: {e}"),
})
})
}
pub fn crc32c(&self) -> Result<u32> {
let bytes = self.to_bytes()?;
Ok(crc32c::crc32c(&bytes))
}
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
pub struct PersistedMigrationCheckpoint {
pub migration_id: String, pub attempt: u32,
pub payload: MigrationCheckpointPayload,
pub crc32c: u32,
pub ts_ms: u64,
}
impl PersistedMigrationCheckpoint {
pub fn migration_uuid(&self) -> Option<MigrationId> {
self.migration_id.parse().ok()
}
}
pub struct MigrationStateTable {
mem: HashMap<String, PersistedMigrationCheckpoint>,
db: Arc<redb::Database>,
}
impl MigrationStateTable {
pub const TABLE: redb::TableDefinition<'static, &'static str, &'static [u8]> =
redb::TableDefinition::new("_cluster.migration_state");
pub fn new(db: Arc<redb::Database>) -> Self {
Self {
mem: HashMap::new(),
db,
}
}
pub fn load_all(&mut self) -> Result<()> {
let txn = self.db.begin_read().map_err(|e| ClusterError::Storage {
detail: format!("migration_state begin_read: {e}"),
})?;
let table = txn
.open_table(Self::TABLE)
.map_err(|e| ClusterError::Storage {
detail: format!("migration_state open_table: {e}"),
})?;
let range = table.range::<&str>(..).map_err(|e| ClusterError::Storage {
detail: format!("migration_state range: {e}"),
})?;
for entry in range {
let (key, value) = entry.map_err(|e| ClusterError::Storage {
detail: format!("migration_state iter: {e}"),
})?;
let key_str = key.value().to_owned();
match zerompk::from_msgpack::<PersistedMigrationCheckpoint>(value.value()) {
Ok(row) => {
self.mem.insert(key_str, row);
}
Err(e) => {
tracing::warn!(key = %key_str, error = %e, "migration_state: corrupt row skipped");
}
}
}
Ok(())
}
pub fn upsert(&mut self, row: PersistedMigrationCheckpoint) -> Result<()> {
let key = row.migration_id.clone();
if let Some(existing) = self.mem.get(&key)
&& existing.payload.phase_tag() == row.payload.phase_tag()
&& existing.attempt == row.attempt
{
return Ok(());
}
let bytes = zerompk::to_msgpack_vec(&row).map_err(|e| ClusterError::Codec {
detail: format!("migration_state encode: {e}"),
})?;
let txn = self.db.begin_write().map_err(|e| ClusterError::Storage {
detail: format!("migration_state begin_write: {e}"),
})?;
{
let mut table = txn
.open_table(Self::TABLE)
.map_err(|e| ClusterError::Storage {
detail: format!("migration_state open_table: {e}"),
})?;
table
.insert(key.as_str(), bytes.as_slice())
.map_err(|e| ClusterError::Storage {
detail: format!("migration_state insert: {e}"),
})?;
}
txn.commit().map_err(|e| ClusterError::Storage {
detail: format!("migration_state commit: {e}"),
})?;
self.mem.insert(key, row);
Ok(())
}
pub fn remove(&mut self, migration_id: &MigrationId) -> Result<()> {
let key = migration_id.hyphenated().to_string();
self.mem.remove(&key);
let txn = self.db.begin_write().map_err(|e| ClusterError::Storage {
detail: format!("migration_state begin_write: {e}"),
})?;
{
let mut table = txn
.open_table(Self::TABLE)
.map_err(|e| ClusterError::Storage {
detail: format!("migration_state open_table: {e}"),
})?;
let _ = table
.remove(key.as_str())
.map_err(|e| ClusterError::Storage {
detail: format!("migration_state remove: {e}"),
})?;
}
txn.commit().map_err(|e| ClusterError::Storage {
detail: format!("migration_state commit: {e}"),
})?;
Ok(())
}
pub fn all_checkpoints(&self) -> Vec<PersistedMigrationCheckpoint> {
self.mem.values().cloned().collect()
}
pub fn get(&self, migration_id: &MigrationId) -> Option<&PersistedMigrationCheckpoint> {
self.mem.get(&migration_id.hyphenated().to_string())
}
}
pub type SharedMigrationStateTable = Arc<Mutex<MigrationStateTable>>;
pub fn new_shared(db: Arc<redb::Database>) -> SharedMigrationStateTable {
Arc::new(Mutex::new(MigrationStateTable::new(db)))
}
#[cfg(test)]
mod tests {
use super::*;
fn temp_db() -> Arc<redb::Database> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.redb");
let db = redb::Database::create(&path).unwrap();
let txn = db.begin_write().unwrap();
{
let _ = txn.open_table(MigrationStateTable::TABLE).unwrap();
}
txn.commit().unwrap();
std::mem::forget(dir);
Arc::new(db)
}
fn make_row(
id: MigrationId,
phase: MigrationCheckpointPayload,
attempt: u32,
) -> PersistedMigrationCheckpoint {
let crc = phase.crc32c().unwrap();
PersistedMigrationCheckpoint {
migration_id: id.hyphenated().to_string(),
attempt,
payload: phase,
crc32c: crc,
ts_ms: 0,
}
}
#[test]
fn upsert_and_load_roundtrip() {
let db = temp_db();
let mut table = MigrationStateTable::new(Arc::clone(&db));
let id = Uuid::new_v4();
let payload = MigrationCheckpointPayload::AddLearner {
vshard_id: 5,
source_node: 1,
target_node: 2,
source_group: 0,
write_pause_budget_us: 500_000,
started_at_hlc: Hlc::default(),
};
let row = make_row(id, payload.clone(), 0);
table.upsert(row).unwrap();
let mut table2 = MigrationStateTable::new(Arc::clone(&db));
table2.load_all().unwrap();
let loaded = table2.get(&id).unwrap();
assert_eq!(loaded.payload, payload);
assert_eq!(loaded.attempt, 0);
}
#[test]
fn idempotent_upsert_same_phase_attempt() {
let db = temp_db();
let mut table = MigrationStateTable::new(Arc::clone(&db));
let id = Uuid::new_v4();
let payload = MigrationCheckpointPayload::CatchUp {
vshard_id: 3,
learner_log_index_at_add: 10,
};
let row = make_row(id, payload, 0);
table.upsert(row.clone()).unwrap();
table.upsert(row).unwrap(); assert_eq!(table.all_checkpoints().len(), 1);
}
#[test]
fn remove_deletes_from_redb() {
let db = temp_db();
let mut table = MigrationStateTable::new(Arc::clone(&db));
let id = Uuid::new_v4();
let payload = MigrationCheckpointPayload::Complete {
vshard_id: 7,
actual_pause_us: 100,
ghost_stub_installed: true,
};
table.upsert(make_row(id, payload, 1)).unwrap();
table.remove(&id).unwrap();
let mut table2 = MigrationStateTable::new(Arc::clone(&db));
table2.load_all().unwrap();
assert!(table2.get(&id).is_none());
}
#[test]
fn payload_crc32c_detects_corruption() {
let payload = MigrationCheckpointPayload::CatchUp {
vshard_id: 9,
learner_log_index_at_add: 42,
};
let mut bytes = payload.to_bytes().unwrap();
bytes[0] ^= 0xFF;
let original_crc = payload.crc32c().unwrap();
let corrupted_crc = crc32c::crc32c(&bytes);
assert_ne!(original_crc, corrupted_crc);
}
#[test]
fn zerompk_payload_roundtrip() {
let payload = MigrationCheckpointPayload::LeadershipTransfer {
vshard_id: 11,
target_is_voter: true,
new_leader_node_id: 7,
source_group: 2,
};
let bytes = zerompk::to_msgpack_vec(&payload).unwrap();
let decoded: MigrationCheckpointPayload = zerompk::from_msgpack(&bytes).unwrap();
assert_eq!(payload, decoded);
}
}