1use std::sync::{Arc, Mutex, RwLock};
14use std::time::Duration;
15
16use tracing::{debug, info};
17
18use crate::conf_change::{ConfChange, ConfChangeType};
19use crate::error::{ClusterError, Result};
20use crate::ghost::{GhostStub, GhostTable};
21use crate::migration::{MigrationPhase, MigrationState};
22use crate::multi_raft::MultiRaft;
23use crate::routing::RoutingTable;
24use crate::topology::ClusterTopology;
25use crate::transport::NexarTransport;
26
27#[derive(Debug, Clone)]
29pub struct MigrationRequest {
30 pub vshard_id: u16,
31 pub source_node: u64,
32 pub target_node: u64,
33 pub write_pause_budget_us: u64,
35}
36
37impl Default for MigrationRequest {
38 fn default() -> Self {
39 Self {
40 vshard_id: 0,
41 source_node: 0,
42 target_node: 0,
43 write_pause_budget_us: 500_000, }
45 }
46}
47
48#[derive(Debug)]
50pub struct MigrationResult {
51 pub vshard_id: u16,
52 pub source_node: u64,
53 pub target_node: u64,
54 pub phase: MigrationPhase,
55 pub elapsed: Option<Duration>,
56}
57
58pub struct MigrationExecutor {
63 multi_raft: Arc<Mutex<MultiRaft>>,
64 routing: Arc<RwLock<RoutingTable>>,
65 topology: Arc<RwLock<ClusterTopology>>,
66 transport: Arc<NexarTransport>,
67 ghost_table: Arc<Mutex<GhostTable>>,
68}
69
70impl MigrationExecutor {
71 pub fn new(
72 multi_raft: Arc<Mutex<MultiRaft>>,
73 routing: Arc<RwLock<RoutingTable>>,
74 topology: Arc<RwLock<ClusterTopology>>,
75 transport: Arc<NexarTransport>,
76 ) -> Self {
77 Self {
78 multi_raft,
79 routing,
80 topology,
81 transport,
82 ghost_table: Arc::new(Mutex::new(GhostTable::new())),
83 }
84 }
85
86 pub fn ghost_table(&self) -> &Arc<Mutex<GhostTable>> {
88 &self.ghost_table
89 }
90
91 pub async fn execute(&self, req: MigrationRequest) -> Result<MigrationResult> {
95 let source_group = {
97 let routing = self.routing.read().unwrap_or_else(|p| p.into_inner());
98 routing.group_for_vshard(req.vshard_id)?
99 };
100
101 let mut state = MigrationState::new(
102 req.vshard_id,
103 source_group,
104 source_group, req.source_node,
106 req.target_node,
107 req.write_pause_budget_us,
108 );
109
110 info!(
111 vshard = req.vshard_id,
112 source = req.source_node,
113 target = req.target_node,
114 group = source_group,
115 "starting vShard migration"
116 );
117
118 self.phase1_base_copy(&mut state, source_group, &req)
121 .await?;
122
123 self.phase2_wal_catchup(&mut state, source_group, &req)
126 .await?;
127
128 self.phase3_cutover(&mut state, source_group, &req).await?;
131
132 let elapsed = state.elapsed();
133 let phase = state.phase().clone();
134
135 info!(
136 vshard = req.vshard_id,
137 source = req.source_node,
138 target = req.target_node,
139 elapsed_ms = elapsed.map(|d| d.as_millis() as u64).unwrap_or(0),
140 "vShard migration completed"
141 );
142
143 Ok(MigrationResult {
144 vshard_id: req.vshard_id,
145 source_node: req.source_node,
146 target_node: req.target_node,
147 phase,
148 elapsed,
149 })
150 }
151
152 async fn phase1_base_copy(
158 &self,
159 state: &mut MigrationState,
160 group_id: u64,
161 req: &MigrationRequest,
162 ) -> Result<()> {
163 let committed = {
165 let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
166 let statuses = mr.group_statuses();
167 statuses
168 .iter()
169 .find(|s| s.group_id == group_id)
170 .map(|s| s.commit_index)
171 .unwrap_or(0)
172 };
173 state.start_base_copy(committed);
174
175 info!(
176 vshard = req.vshard_id,
177 group = group_id,
178 target = req.target_node,
179 entries = committed,
180 "phase 1: adding target to raft group"
181 );
182
183 let change = ConfChange {
185 change_type: ConfChangeType::AddNode,
186 node_id: req.target_node,
187 };
188
189 {
190 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
191 mr.propose_conf_change(group_id, &change)?;
192 }
193
194 if let Some(node_info) = {
196 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
197 topo.get_node(req.target_node).map(|n| n.addr.clone())
198 } && let Ok(addr) = node_info.parse()
199 {
200 self.transport.register_peer(req.target_node, addr);
201 }
202
203 state.update_base_copy(committed);
207
208 debug!(
209 vshard = req.vshard_id,
210 "phase 1 complete: target added to raft group"
211 );
212
213 Ok(())
214 }
215
216 async fn phase2_wal_catchup(
218 &self,
219 state: &mut MigrationState,
220 group_id: u64,
221 req: &MigrationRequest,
222 ) -> Result<()> {
223 let leader_commit = {
224 let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
225 let statuses = mr.group_statuses();
226 statuses
227 .iter()
228 .find(|s| s.group_id == group_id)
229 .map(|s| s.commit_index)
230 .unwrap_or(0)
231 };
232
233 state.start_wal_catchup(leader_commit, leader_commit);
234
235 info!(
236 vshard = req.vshard_id,
237 leader_commit, "phase 2: monitoring replication lag"
238 );
239
240 let initial_stable_id = self.transport.peer_connection_stable_id(req.target_node);
244
245 let initial_target_addr = {
247 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
248 topo.get_node(req.target_node).map(|n| n.addr.clone())
249 };
250
251 let poll_interval = Duration::from_millis(100);
255 let timeout = Duration::from_secs(60);
256 let deadline = std::time::Instant::now() + timeout;
257
258 loop {
259 tokio::time::sleep(poll_interval).await;
260
261 if let Some(initial_id) = initial_stable_id {
264 match self.transport.peer_connection_stable_id(req.target_node) {
265 Some(current_id) if current_id != initial_id => {
266 let reason = format!(
267 "peer identity changed mid-migration: connection stable_id {} -> {} for node {}",
268 initial_id, current_id, req.target_node
269 );
270 state.fail(reason.clone());
271 return Err(ClusterError::Transport { detail: reason });
272 }
273 None => {
274 let reason = format!(
276 "connection to target node {} lost during migration",
277 req.target_node
278 );
279 state.fail(reason.clone());
280 return Err(ClusterError::Transport { detail: reason });
281 }
282 _ => {}
283 }
284 }
285
286 {
288 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
289 let current_addr = topo.get_node(req.target_node).map(|n| n.addr.clone());
290 if current_addr != initial_target_addr {
291 let reason = format!(
292 "target node {} address changed during migration: {:?} -> {:?}",
293 req.target_node, initial_target_addr, current_addr
294 );
295 state.fail(reason.clone());
296 return Err(ClusterError::Transport { detail: reason });
297 }
298 }
299
300 let (leader_commit, target_match) = {
301 let mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
302 let statuses = mr.group_statuses();
303 let commit = statuses
304 .iter()
305 .find(|s| s.group_id == group_id)
306 .map(|s| s.commit_index)
307 .unwrap_or(0);
308 let target_match = mr.match_index_for(group_id, req.target_node).unwrap_or(0);
310 (commit, target_match)
311 };
312
313 state.update_wal_catchup(leader_commit, target_match);
314
315 if state.is_catchup_ready() {
316 debug!(
317 vshard = req.vshard_id,
318 leader_commit, target_match, "phase 2 complete: target caught up"
319 );
320 return Ok(());
321 }
322
323 if std::time::Instant::now() >= deadline {
324 let reason = format!(
325 "WAL catch-up timed out after {}s (leader={leader_commit}, target={target_match})",
326 timeout.as_secs()
327 );
328 state.fail(reason.clone());
329 return Err(ClusterError::Transport { detail: reason });
330 }
331 }
332 }
333
334 async fn phase3_cutover(
336 &self,
337 state: &mut MigrationState,
338 group_id: u64,
339 req: &MigrationRequest,
340 ) -> Result<()> {
341 let estimated_pause_us = 10_000; state.start_cutover(estimated_pause_us).map_err(|e| {
345 state.fail(format!("cutover rejected: {e}"));
346 e
347 })?;
348
349 let cutover_start = std::time::Instant::now();
350
351 info!(
352 vshard = req.vshard_id,
353 estimated_pause_us, "phase 3: atomic cut-over"
354 );
355
356 let routing_change = ConfChange {
360 change_type: ConfChangeType::AddNode,
361 node_id: req.target_node,
362 };
363 {
364 let mut mr = self.multi_raft.lock().unwrap_or_else(|p| p.into_inner());
365 mr.propose_conf_change(group_id, &routing_change)?;
366 }
367
368 {
371 let mut routing = self.routing.write().unwrap_or_else(|p| p.into_inner());
372 routing.reassign_vshard(req.vshard_id, group_id);
373 }
374
375 {
379 let mut ghosts = self.ghost_table.lock().unwrap_or_else(|p| p.into_inner());
380 ghosts.insert(GhostStub {
381 node_id: format!("vshard-{}", req.vshard_id),
382 target_shard: req.vshard_id,
383 refcount: 1,
384 created_at_ms: std::time::SystemTime::now()
385 .duration_since(std::time::UNIX_EPOCH)
386 .unwrap_or_default()
387 .as_millis() as u64,
388 });
389 }
390 debug!(
391 vshard = req.vshard_id,
392 target = req.target_node,
393 "ghost stub registered for transparent forwarding"
394 );
395
396 let actual_pause_us = cutover_start.elapsed().as_micros() as u64;
397 state.complete(actual_pause_us);
398
399 debug!(
400 vshard = req.vshard_id,
401 actual_pause_us, "phase 3 complete: routing updated via raft"
402 );
403
404 Ok(())
405 }
406}
407
408pub struct MigrationTracker {
410 active: Mutex<Vec<MigrationState>>,
411}
412
413impl MigrationTracker {
414 pub fn new() -> Self {
415 Self {
416 active: Mutex::new(Vec::new()),
417 }
418 }
419
420 pub fn add(&self, state: MigrationState) {
421 let mut active = self.active.lock().unwrap_or_else(|p| p.into_inner());
422 active.push(state);
423 }
424
425 pub fn active_count(&self) -> usize {
426 let active = self.active.lock().unwrap_or_else(|p| p.into_inner());
427 active.iter().filter(|s| s.is_active()).count()
428 }
429
430 pub fn snapshot(&self) -> Vec<MigrationSnapshot> {
432 let active = self.active.lock().unwrap_or_else(|p| p.into_inner());
433 active
434 .iter()
435 .map(|s| MigrationSnapshot {
436 vshard_id: s.vshard_id(),
437 phase: format!("{:?}", s.phase()),
438 elapsed_ms: s.elapsed().map(|d| d.as_millis() as u64).unwrap_or(0),
439 is_active: s.is_active(),
440 })
441 .collect()
442 }
443
444 pub fn gc(&self, max_age: Duration) {
446 let mut active = self.active.lock().unwrap_or_else(|p| p.into_inner());
447 active.retain(|s| s.is_active() || s.elapsed().map(|d| d < max_age).unwrap_or(true));
448 }
449}
450
451impl Default for MigrationTracker {
452 fn default() -> Self {
453 Self::new()
454 }
455}
456
457#[derive(Debug, Clone)]
459pub struct MigrationSnapshot {
460 pub vshard_id: u16,
461 pub phase: String,
462 pub elapsed_ms: u64,
463 pub is_active: bool,
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use crate::routing::RoutingTable;
470 use crate::topology::ClusterTopology;
471
472 #[test]
473 fn migration_tracker_lifecycle() {
474 let tracker = MigrationTracker::new();
475 assert_eq!(tracker.active_count(), 0);
476
477 let mut state = MigrationState::new(0, 0, 1, 1, 2, 500_000);
478 state.start_base_copy(100);
479 tracker.add(state);
480
481 assert_eq!(tracker.active_count(), 1);
482 assert_eq!(tracker.snapshot().len(), 1);
483 assert!(tracker.snapshot()[0].is_active);
484 }
485
486 #[tokio::test]
487 async fn migration_executor_phase1() {
488 let dir = tempfile::tempdir().unwrap();
490 let rt = RoutingTable::uniform(1, &[1], 1);
491 let mut mr = crate::multi_raft::MultiRaft::new(1, rt.clone(), dir.path().to_path_buf());
492 mr.add_group(0, vec![]).unwrap();
493
494 use std::time::Instant;
496 for node in mr.groups_mut().values_mut() {
497 node.election_deadline_override(Instant::now() - Duration::from_millis(1));
498 }
499 let _ = mr.tick();
501 for (gid, ready) in mr.tick().groups {
503 if let Some(last) = ready.committed_entries.last() {
504 mr.advance_applied(gid, last.index).unwrap();
505 }
506 }
507
508 let multi_raft = Arc::new(Mutex::new(mr));
509 let routing = Arc::new(RwLock::new(rt));
510 let topology = Arc::new(RwLock::new(ClusterTopology::new()));
511 let transport = Arc::new(NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap());
512
513 let executor = MigrationExecutor::new(multi_raft.clone(), routing, topology, transport);
514
515 let mut state = MigrationState::new(0, 0, 0, 1, 2, 500_000);
516
517 let req = MigrationRequest {
518 vshard_id: 0,
519 source_node: 1,
520 target_node: 2,
521 write_pause_budget_us: 500_000,
522 };
523
524 executor
526 .phase1_base_copy(&mut state, 0, &req)
527 .await
528 .unwrap();
529
530 }
533
534 #[test]
535 fn migration_request_default() {
536 let req = MigrationRequest::default();
537 assert_eq!(req.write_pause_budget_us, 500_000);
538 }
539}