use crate::network::{NetworkService, RpcMessage};
use crate::raft::OxirsNodeId;
use crate::shard::ShardId;
use crate::storage::StorageBackend;
use anyhow::Result;
use oxirs_core::model::Triple;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, RwLock, Semaphore};
use tokio::time::Instant;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MigrationStrategy {
CopyThenSwitch,
LiveMigration,
HotMigration,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum MigrationPhase {
Planned,
InitialCopy,
IncrementalSync,
Cutover,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationOperation {
pub migration_id: String,
pub shard_id: ShardId,
pub source_nodes: HashSet<OxirsNodeId>,
pub target_nodes: HashSet<OxirsNodeId>,
pub strategy: MigrationStrategy,
pub phase: MigrationPhase,
pub progress: f64,
pub estimated_size: u64,
pub migrated_size: u64,
pub stats: MigrationStats,
pub created_at: u64,
pub started_at: Option<u64>,
pub completed_at: Option<u64>,
pub error_message: Option<String>,
}
impl MigrationOperation {
pub fn new(
shard_id: ShardId,
source_nodes: HashSet<OxirsNodeId>,
target_nodes: HashSet<OxirsNodeId>,
strategy: MigrationStrategy,
estimated_size: u64,
) -> Self {
Self {
migration_id: uuid::Uuid::new_v4().to_string(),
shard_id,
source_nodes,
target_nodes,
strategy,
phase: MigrationPhase::Planned,
progress: 0.0,
estimated_size,
migrated_size: 0,
stats: MigrationStats::default(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
started_at: None,
completed_at: None,
error_message: None,
}
}
pub fn update_progress(&mut self, migrated: u64) {
self.migrated_size = migrated;
if self.estimated_size > 0 {
self.progress = (migrated as f64 / self.estimated_size as f64 * 100.0).min(100.0);
}
}
pub fn advance_phase(&mut self, new_phase: MigrationPhase) {
if matches!(new_phase, MigrationPhase::InitialCopy) && self.started_at.is_none() {
self.started_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
);
}
self.phase = new_phase.clone();
if matches!(
new_phase,
MigrationPhase::Completed | MigrationPhase::Failed | MigrationPhase::Cancelled
) {
self.completed_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
);
}
}
pub fn set_error(&mut self, error: String) {
self.phase = MigrationPhase::Failed;
self.error_message = Some(error);
self.advance_phase(MigrationPhase::Failed);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationStats {
pub triples_migrated: u64,
pub total_triples: u64,
pub throughput: f64,
pub network_bandwidth: f64,
pub avg_latency_ms: f64,
pub retries: u32,
pub last_updated: u64,
}
impl Default for MigrationStats {
fn default() -> Self {
Self {
triples_migrated: 0,
total_triples: 0,
throughput: 0.0,
network_bandwidth: 0.0,
avg_latency_ms: 0.0,
retries: 0,
last_updated: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MigrationBatch {
pub batch_id: String,
pub migration_id: String,
pub sequence: u64,
pub triples: Vec<Triple>,
pub checksum: u32,
pub created_at: u64,
}
impl MigrationBatch {
pub fn new(migration_id: String, sequence: u64, triples: Vec<Triple>) -> Self {
let serialized = oxicode::serde::encode_to_vec(&triples, oxicode::config::standard())
.unwrap_or_default();
let checksum = crc32fast::hash(&serialized);
Self {
batch_id: uuid::Uuid::new_v4().to_string(),
migration_id,
sequence,
triples,
checksum,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
}
}
pub fn verify_integrity(&self) -> bool {
let serialized = oxicode::serde::encode_to_vec(&self.triples, oxicode::config::standard())
.unwrap_or_default();
let computed_checksum = crc32fast::hash(&serialized);
computed_checksum == self.checksum
}
}
#[derive(Debug, Clone)]
pub struct MigrationConfig {
pub batch_size: usize,
pub max_concurrent_batches: usize,
pub max_retries: u32,
pub operation_timeout: Duration,
pub progress_interval: Duration,
pub verify_consistency: bool,
pub enable_rollback: bool,
}
impl Default for MigrationConfig {
fn default() -> Self {
Self {
batch_size: 10000,
max_concurrent_batches: 5,
max_retries: 3,
operation_timeout: Duration::from_secs(30),
progress_interval: Duration::from_secs(5),
verify_consistency: true,
enable_rollback: true,
}
}
}
pub struct ShardMigrationManager {
node_id: OxirsNodeId,
storage: Arc<dyn StorageBackend>,
network: Arc<NetworkService>,
active_migrations: Arc<RwLock<HashMap<String, MigrationOperation>>>,
config: MigrationConfig,
migration_semaphore: Arc<Semaphore>,
#[allow(dead_code)]
shutdown_tx: Option<mpsc::Sender<()>>,
}
impl ShardMigrationManager {
pub fn new(
node_id: OxirsNodeId,
storage: Arc<dyn StorageBackend>,
network: Arc<NetworkService>,
config: MigrationConfig,
) -> Self {
let max_concurrent = config.max_concurrent_batches;
Self {
node_id,
storage,
network,
active_migrations: Arc::new(RwLock::new(HashMap::new())),
config,
migration_semaphore: Arc::new(Semaphore::new(max_concurrent)),
shutdown_tx: None,
}
}
pub async fn start_migration(
&mut self,
shard_id: ShardId,
source_nodes: HashSet<OxirsNodeId>,
target_nodes: HashSet<OxirsNodeId>,
strategy: MigrationStrategy,
) -> Result<String> {
let estimated_size = self.estimate_migration_size(shard_id).await?;
let migration = MigrationOperation::new(
shard_id,
source_nodes,
target_nodes,
strategy,
estimated_size,
);
let migration_id = migration.migration_id.clone();
self.validate_migration(&migration).await?;
{
let mut migrations = self.active_migrations.write().await;
migrations.insert(migration_id.clone(), migration);
}
let manager = self.clone();
let migration_id_clone = migration_id.clone();
tokio::spawn(async move {
if let Err(e) = manager.execute_migration(&migration_id_clone).await {
error!("Migration {} failed: {}", migration_id_clone, e);
let mut migrations = manager.active_migrations.write().await;
if let Some(migration) = migrations.get_mut(&migration_id_clone) {
migration.set_error(e.to_string());
}
}
});
info!("Started migration {} for shard {}", migration_id, shard_id);
Ok(migration_id)
}
async fn execute_migration(&self, migration_id: &str) -> Result<()> {
info!("Executing migration {}", migration_id);
let (shard_id, strategy) = {
let migrations = self.active_migrations.read().await;
let migration = migrations
.get(migration_id)
.ok_or_else(|| anyhow::anyhow!("Migration {} not found", migration_id))?;
(migration.shard_id, migration.strategy.clone())
};
match strategy {
MigrationStrategy::CopyThenSwitch => {
self.execute_copy_then_switch_migration(migration_id, shard_id)
.await
}
MigrationStrategy::LiveMigration => {
self.execute_live_migration(migration_id, shard_id).await
}
MigrationStrategy::HotMigration => {
self.execute_hot_migration(migration_id, shard_id).await
}
}
}
async fn execute_copy_then_switch_migration(
&self,
migration_id: &str,
shard_id: ShardId,
) -> Result<()> {
self.update_migration_phase(migration_id, MigrationPhase::InitialCopy)
.await?;
self.copy_shard_data(migration_id, shard_id).await?;
self.update_migration_phase(migration_id, MigrationPhase::Cutover)
.await?;
self.switch_shard_ownership(migration_id, shard_id).await?;
self.update_migration_phase(migration_id, MigrationPhase::Completed)
.await?;
Ok(())
}
async fn execute_live_migration(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
self.update_migration_phase(migration_id, MigrationPhase::InitialCopy)
.await?;
self.copy_shard_data(migration_id, shard_id).await?;
self.update_migration_phase(migration_id, MigrationPhase::IncrementalSync)
.await?;
self.sync_incremental_changes(migration_id, shard_id)
.await?;
self.update_migration_phase(migration_id, MigrationPhase::Cutover)
.await?;
self.switch_shard_ownership(migration_id, shard_id).await?;
self.update_migration_phase(migration_id, MigrationPhase::Completed)
.await?;
Ok(())
}
async fn execute_hot_migration(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
self.update_migration_phase(migration_id, MigrationPhase::InitialCopy)
.await?;
self.start_dual_writes(migration_id, shard_id).await?;
self.copy_shard_data(migration_id, shard_id).await?;
self.update_migration_phase(migration_id, MigrationPhase::Cutover)
.await?;
self.switch_shard_ownership(migration_id, shard_id).await?;
self.stop_dual_writes(migration_id, shard_id).await?;
self.update_migration_phase(migration_id, MigrationPhase::Completed)
.await?;
Ok(())
}
async fn copy_shard_data(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
info!(
"Copying shard {} data for migration {}",
shard_id, migration_id
);
let triples = self.storage.export_shard(shard_id).await?;
let total_triples = triples.len() as u64;
{
let mut migrations = self.active_migrations.write().await;
if let Some(migration) = migrations.get_mut(migration_id) {
migration.stats.total_triples = total_triples;
}
}
let batch_size = self.config.batch_size;
let mut processed = 0u64;
for (sequence, chunk) in triples.chunks(batch_size).enumerate() {
let batch =
MigrationBatch::new(migration_id.to_string(), sequence as u64, chunk.to_vec());
self.transfer_batch(migration_id, &batch).await?;
processed += chunk.len() as u64;
self.update_migration_progress(migration_id, processed)
.await?;
tokio::time::sleep(Duration::from_millis(10)).await;
}
info!(
"Completed copying {} triples for migration {}",
processed, migration_id
);
Ok(())
}
async fn transfer_batch(&self, migration_id: &str, batch: &MigrationBatch) -> Result<()> {
let target_nodes = {
let migrations = self.active_migrations.read().await;
let migration = migrations
.get(migration_id)
.ok_or_else(|| anyhow::anyhow!("Migration {} not found", migration_id))?;
migration.target_nodes.clone()
};
let _permit = self.migration_semaphore.acquire().await?;
for &target_node in &target_nodes {
let message = RpcMessage::MigrationBatch {
migration_id: migration_id.to_string(),
batch: batch.clone(),
};
let mut attempts = 0;
while attempts < self.config.max_retries {
match self
.network
.send_message(target_node, message.clone())
.await
{
Ok(_) => break,
Err(e) => {
attempts += 1;
warn!(
"Failed to send batch {} to node {} (attempt {}): {}",
batch.batch_id, target_node, attempts, e
);
if attempts >= self.config.max_retries {
return Err(anyhow::anyhow!(
"Failed to send batch after {} attempts: {}",
self.config.max_retries,
e
));
}
tokio::time::sleep(Duration::from_secs(1 << (attempts - 1))).await;
}
}
}
}
debug!(
"Successfully transferred batch {} for migration {}",
batch.batch_id, migration_id
);
Ok(())
}
async fn sync_incremental_changes(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
info!("Syncing incremental changes for migration {}", migration_id);
let sync_duration = Duration::from_secs(5);
let start_time = Instant::now();
while start_time.elapsed() < sync_duration {
tokio::time::sleep(Duration::from_millis(100)).await;
debug!("Checking for incremental changes for shard {}", shard_id);
}
info!("Completed incremental sync for migration {}", migration_id);
Ok(())
}
async fn switch_shard_ownership(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
info!(
"Switching ownership for shard {} in migration {}",
shard_id, migration_id
);
tokio::time::sleep(Duration::from_millis(500)).await;
info!("Completed ownership switch for migration {}", migration_id);
Ok(())
}
async fn start_dual_writes(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
info!(
"Starting dual writes for shard {} in migration {}",
shard_id, migration_id
);
Ok(())
}
async fn stop_dual_writes(&self, migration_id: &str, shard_id: ShardId) -> Result<()> {
info!(
"Stopping dual writes for shard {} in migration {}",
shard_id, migration_id
);
Ok(())
}
async fn update_migration_phase(
&self,
migration_id: &str,
phase: MigrationPhase,
) -> Result<()> {
let mut migrations = self.active_migrations.write().await;
if let Some(migration) = migrations.get_mut(migration_id) {
migration.advance_phase(phase);
info!(
"Migration {} advanced to phase {:?}",
migration_id, migration.phase
);
}
Ok(())
}
async fn update_migration_progress(&self, migration_id: &str, migrated: u64) -> Result<()> {
let mut migrations = self.active_migrations.write().await;
if let Some(migration) = migrations.get_mut(migration_id) {
let old_migrated = migration.stats.triples_migrated;
migration.stats.triples_migrated = migrated;
migration.update_progress(migrated);
let elapsed = migration
.started_at
.map(|start| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs()
.saturating_sub(start)
})
.unwrap_or(1);
if elapsed > 0 {
migration.stats.throughput = migrated as f64 / elapsed as f64;
}
if migrated > old_migrated {
debug!(
"Migration {} progress: {:.1}% ({} / {} triples)",
migration_id, migration.progress, migrated, migration.stats.total_triples
);
}
}
Ok(())
}
async fn estimate_migration_size(&self, shard_id: ShardId) -> Result<u64> {
let triple_count = self.storage.get_shard_triple_count(shard_id).await? as u64;
let size_bytes = self.storage.get_shard_size(shard_id).await?;
Ok(triple_count.max(size_bytes))
}
async fn validate_migration(&self, migration: &MigrationOperation) -> Result<()> {
if migration.source_nodes.is_empty() {
return Err(anyhow::anyhow!("No source nodes specified"));
}
if migration.target_nodes.is_empty() {
return Err(anyhow::anyhow!("No target nodes specified"));
}
if migration
.source_nodes
.intersection(&migration.target_nodes)
.count()
> 0
{
return Err(anyhow::anyhow!("Source and target nodes must be disjoint"));
}
Ok(())
}
pub async fn get_migration_status(&self, migration_id: &str) -> Option<MigrationOperation> {
let migrations = self.active_migrations.read().await;
migrations.get(migration_id).cloned()
}
pub async fn list_active_migrations(&self) -> Vec<MigrationOperation> {
let migrations = self.active_migrations.read().await;
migrations.values().cloned().collect()
}
pub async fn cancel_migration(&self, migration_id: &str) -> Result<()> {
let mut migrations = self.active_migrations.write().await;
if let Some(migration) = migrations.get_mut(migration_id) {
if matches!(
migration.phase,
MigrationPhase::Completed | MigrationPhase::Failed
) {
return Err(anyhow::anyhow!(
"Cannot cancel completed or failed migration"
));
}
migration.advance_phase(MigrationPhase::Cancelled);
info!("Cancelled migration {}", migration_id);
}
Ok(())
}
pub async fn cleanup_completed_migrations(&self, retention_hours: u64) -> Result<usize> {
let cutoff_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs()
.saturating_sub(retention_hours * 3600);
let mut migrations = self.active_migrations.write().await;
let initial_count = migrations.len();
migrations.retain(|_id, migration| {
!matches!(
migration.phase,
MigrationPhase::Completed | MigrationPhase::Failed | MigrationPhase::Cancelled
) || migration.completed_at.unwrap_or(u64::MAX) > cutoff_time
});
let cleaned_count = initial_count - migrations.len();
if cleaned_count > 0 {
info!("Cleaned up {} completed migrations", cleaned_count);
}
Ok(cleaned_count)
}
}
impl Clone for ShardMigrationManager {
fn clone(&self) -> Self {
Self {
node_id: self.node_id,
storage: Arc::clone(&self.storage),
network: Arc::clone(&self.network),
active_migrations: Arc::clone(&self.active_migrations),
config: self.config.clone(),
migration_semaphore: Arc::clone(&self.migration_semaphore),
shutdown_tx: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_migration() -> MigrationOperation {
let source_nodes = [1, 2].iter().cloned().collect();
let target_nodes = [3, 4].iter().cloned().collect();
MigrationOperation::new(
1,
source_nodes,
target_nodes,
MigrationStrategy::LiveMigration,
1000,
)
}
#[tokio::test]
async fn test_migration_operation_creation() {
let migration = create_test_migration();
assert_eq!(migration.shard_id, 1);
assert_eq!(migration.phase, MigrationPhase::Planned);
assert_eq!(migration.progress, 0.0);
assert!(migration.started_at.is_none());
}
#[tokio::test]
async fn test_migration_progress_update() {
let mut migration = create_test_migration();
migration.update_progress(500);
assert_eq!(migration.migrated_size, 500);
assert_eq!(migration.progress, 50.0);
}
#[tokio::test]
async fn test_migration_phase_advancement() {
let mut migration = create_test_migration();
migration.advance_phase(MigrationPhase::InitialCopy);
assert_eq!(migration.phase, MigrationPhase::InitialCopy);
assert!(migration.started_at.is_some());
migration.advance_phase(MigrationPhase::Completed);
assert_eq!(migration.phase, MigrationPhase::Completed);
assert!(migration.completed_at.is_some());
}
#[test]
fn test_migration_batch_integrity() {
use oxirs_core::model::{NamedNode, Triple as CoreTriple};
let triples = vec![CoreTriple::new(
NamedNode::new("http://example.org/s").unwrap(),
NamedNode::new("http://example.org/p").unwrap(),
NamedNode::new("http://example.org/o").unwrap(),
)];
let batch = MigrationBatch::new("test-migration".to_string(), 1, triples);
assert!(batch.verify_integrity());
}
#[test]
fn test_migration_config_defaults() {
let config = MigrationConfig::default();
assert_eq!(config.batch_size, 10000);
assert_eq!(config.max_concurrent_batches, 5);
assert_eq!(config.max_retries, 3);
assert!(config.verify_consistency);
assert!(config.enable_rollback);
}
}