1use crate::VectorError;
17use std::collections::HashMap;
18use std::time::Instant;
19
20#[derive(Debug, Clone, PartialEq)]
26pub enum ReplicaState {
27 Primary,
29 Replica,
31 CatchingUp {
33 progress: f64,
35 },
36 Failed,
38}
39
40impl ReplicaState {
41 pub fn is_healthy(&self) -> bool {
43 matches!(self, ReplicaState::Primary | ReplicaState::Replica)
44 }
45
46 pub fn is_primary(&self) -> bool {
48 matches!(self, ReplicaState::Primary)
49 }
50}
51
52#[derive(Debug, Clone)]
58pub struct ShardReplica {
59 pub shard_id: u64,
61 pub replica_id: String,
63 pub node_id: String,
65 pub state: ReplicaState,
67 pub last_sync: Instant,
69 pub vector_count: usize,
71}
72
73impl ShardReplica {
74 pub fn new(
76 shard_id: u64,
77 replica_id: impl Into<String>,
78 node_id: impl Into<String>,
79 state: ReplicaState,
80 vector_count: usize,
81 ) -> Self {
82 Self {
83 shard_id,
84 replica_id: replica_id.into(),
85 node_id: node_id.into(),
86 state,
87 last_sync: Instant::now(),
88 vector_count,
89 }
90 }
91
92 pub fn touch(&mut self) {
94 self.last_sync = Instant::now();
95 }
96}
97
98#[derive(Debug, Clone)]
104pub struct ReplicationStatus {
105 pub total_shards: usize,
107 pub under_replicated: usize,
109 pub over_replicated: usize,
111 pub failed_replicas: usize,
113 pub healthy: bool,
116}
117
118pub struct ReplicaManager {
124 shards: HashMap<u64, Vec<ShardReplica>>,
126 replication_factor: usize,
128}
129
130impl ReplicaManager {
131 pub fn new(replication_factor: usize) -> Self {
139 let factor = replication_factor.max(1);
140 Self {
141 shards: HashMap::new(),
142 replication_factor: factor,
143 }
144 }
145
146 pub fn register_replica(&mut self, replica: ShardReplica) -> Result<(), VectorError> {
158 let shard_id = replica.shard_id;
159 let replica_id = replica.replica_id.clone();
160 let is_primary = replica.state.is_primary();
161
162 let entry = self.shards.entry(shard_id).or_default();
163
164 if entry.iter().any(|r| r.replica_id == replica_id) {
166 return Err(VectorError::InvalidData(format!(
167 "Replica '{}' for shard {} is already registered",
168 replica_id, shard_id
169 )));
170 }
171
172 if is_primary && entry.iter().any(|r| r.state.is_primary()) {
174 return Err(VectorError::InvalidData(format!(
175 "Shard {} already has a primary; cannot register another",
176 shard_id
177 )));
178 }
179
180 entry.push(replica);
181 Ok(())
182 }
183
184 pub fn unregister_replica(&mut self, shard_id: u64, replica_id: &str) -> bool {
188 let Some(replicas) = self.shards.get_mut(&shard_id) else {
189 return false;
190 };
191 let before = replicas.len();
192 replicas.retain(|r| r.replica_id != replica_id);
193 replicas.len() < before
194 }
195
196 pub fn promote_to_primary(
209 &mut self,
210 shard_id: u64,
211 replica_id: &str,
212 ) -> Result<(), VectorError> {
213 let replicas = self
214 .shards
215 .get_mut(&shard_id)
216 .ok_or_else(|| VectorError::InvalidData(format!("Shard {} not found", shard_id)))?;
217
218 let target_exists = replicas.iter().any(|r| r.replica_id == replica_id);
220 if !target_exists {
221 return Err(VectorError::InvalidData(format!(
222 "Replica '{}' not found in shard {}",
223 replica_id, shard_id
224 )));
225 }
226
227 let target_failed = replicas
228 .iter()
229 .find(|r| r.replica_id == replica_id)
230 .map(|r| matches!(r.state, ReplicaState::Failed))
231 .unwrap_or(false);
232 if target_failed {
233 return Err(VectorError::InvalidData(format!(
234 "Cannot promote failed replica '{}' in shard {}",
235 replica_id, shard_id
236 )));
237 }
238
239 for r in replicas.iter_mut() {
241 if r.replica_id != replica_id && matches!(r.state, ReplicaState::Primary) {
242 r.state = ReplicaState::Replica;
243 }
244 }
245
246 for r in replicas.iter_mut() {
248 if r.replica_id == replica_id {
249 r.state = ReplicaState::Primary;
250 r.touch();
251 }
252 }
253
254 Ok(())
255 }
256
257 pub fn mark_failed(&mut self, shard_id: u64, replica_id: &str) {
264 if let Some(replicas) = self.shards.get_mut(&shard_id) {
265 for r in replicas.iter_mut() {
266 if r.replica_id == replica_id {
267 r.state = ReplicaState::Failed;
268 }
269 }
270 }
271 }
272
273 pub fn auto_failover(&mut self, shard_id: u64) -> Result<String, VectorError> {
282 let best_id = {
283 let replicas = self
284 .shards
285 .get(&shard_id)
286 .ok_or_else(|| VectorError::InvalidData(format!("Shard {} not found", shard_id)))?;
287
288 replicas
289 .iter()
290 .filter(|r| r.state.is_healthy() && !r.state.is_primary())
291 .max_by_key(|r| r.vector_count)
292 .map(|r| r.replica_id.clone())
293 .ok_or_else(|| {
294 VectorError::InvalidData(format!(
295 "No healthy replica available to promote for shard {}",
296 shard_id
297 ))
298 })?
299 };
300
301 self.promote_to_primary(shard_id, &best_id)?;
302 Ok(best_id)
303 }
304
305 pub fn update_sync_progress(&mut self, shard_id: u64, replica_id: &str, progress: f64) {
309 let Some(replicas) = self.shards.get_mut(&shard_id) else {
310 return;
311 };
312 for r in replicas.iter_mut() {
313 if r.replica_id == replica_id {
314 if progress >= 1.0 {
315 r.state = ReplicaState::Replica;
316 } else {
317 r.state = ReplicaState::CatchingUp { progress };
318 }
319 r.touch();
320 }
321 }
322 }
323
324 pub fn get_primary(&self, shard_id: u64) -> Option<&ShardReplica> {
330 self.shards
331 .get(&shard_id)?
332 .iter()
333 .find(|r| r.state.is_primary())
334 }
335
336 pub fn get_replicas(&self, shard_id: u64) -> Vec<&ShardReplica> {
338 self.shards
339 .get(&shard_id)
340 .map(|v| v.iter().collect())
341 .unwrap_or_default()
342 }
343
344 pub fn get_healthy_replicas(&self, shard_id: u64) -> Vec<&ShardReplica> {
346 self.shards
347 .get(&shard_id)
348 .map(|v| v.iter().filter(|r| r.state.is_healthy()).collect())
349 .unwrap_or_default()
350 }
351
352 pub fn shard_ids(&self) -> Vec<u64> {
354 self.shards.keys().cloned().collect()
355 }
356
357 pub fn replication_factor(&self) -> usize {
359 self.replication_factor
360 }
361
362 pub fn needs_rebalancing(&self) -> bool {
369 self.shards.values().any(|replicas| {
370 let healthy = replicas.iter().filter(|r| r.state.is_healthy()).count();
371 healthy != self.replication_factor
372 })
373 }
374
375 pub fn replication_status(&self) -> ReplicationStatus {
377 let total_shards = self.shards.len();
378 let mut under_replicated = 0usize;
379 let mut over_replicated = 0usize;
380 let mut failed_replicas = 0usize;
381
382 for replicas in self.shards.values() {
383 let healthy = replicas.iter().filter(|r| r.state.is_healthy()).count();
384 let failed = replicas
385 .iter()
386 .filter(|r| matches!(r.state, ReplicaState::Failed))
387 .count();
388
389 failed_replicas += failed;
390
391 match healthy.cmp(&self.replication_factor) {
392 std::cmp::Ordering::Less => under_replicated += 1,
393 std::cmp::Ordering::Greater => over_replicated += 1,
394 std::cmp::Ordering::Equal => {}
395 }
396 }
397
398 let healthy = under_replicated == 0 && over_replicated == 0 && failed_replicas == 0;
399
400 ReplicationStatus {
401 total_shards,
402 under_replicated,
403 over_replicated,
404 failed_replicas,
405 healthy,
406 }
407 }
408}
409
410#[cfg(test)]
415mod tests {
416 use super::*;
417
418 fn primary(shard: u64, rid: &str, node: &str) -> ShardReplica {
419 ShardReplica::new(shard, rid, node, ReplicaState::Primary, 1000)
420 }
421
422 fn replica(shard: u64, rid: &str, node: &str) -> ShardReplica {
423 ShardReplica::new(shard, rid, node, ReplicaState::Replica, 1000)
424 }
425
426 fn catching_up(shard: u64, rid: &str, node: &str, progress: f64) -> ShardReplica {
427 ShardReplica::new(shard, rid, node, ReplicaState::CatchingUp { progress }, 500)
428 }
429
430 #[test]
433 fn test_register_primary_and_replicas() {
434 let mut mgr = ReplicaManager::new(3);
435
436 mgr.register_replica(primary(1, "r0", "node-a"))
437 .expect("primary");
438 mgr.register_replica(replica(1, "r1", "node-b"))
439 .expect("replica 1");
440 mgr.register_replica(replica(1, "r2", "node-c"))
441 .expect("replica 2");
442
443 assert_eq!(mgr.get_replicas(1).len(), 3);
444 assert!(mgr.get_primary(1).is_some());
445 }
446
447 #[test]
448 fn test_duplicate_primary_rejected() {
449 let mut mgr = ReplicaManager::new(2);
450
451 mgr.register_replica(primary(1, "r0", "node-a"))
452 .expect("first primary");
453 let err = mgr.register_replica(primary(1, "r1", "node-b"));
454 assert!(err.is_err(), "duplicate primary must be rejected");
455 }
456
457 #[test]
458 fn test_duplicate_replica_id_rejected() {
459 let mut mgr = ReplicaManager::new(2);
460
461 mgr.register_replica(primary(1, "r0", "node-a"))
462 .expect("first");
463 let err = mgr.register_replica(replica(1, "r0", "node-b"));
464 assert!(err.is_err(), "duplicate replica_id must be rejected");
465 }
466
467 #[test]
470 fn test_promote_to_primary() {
471 let mut mgr = ReplicaManager::new(2);
472 mgr.register_replica(primary(1, "r0", "node-a"))
473 .expect("ok");
474 mgr.register_replica(replica(1, "r1", "node-b"))
475 .expect("ok");
476
477 mgr.promote_to_primary(1, "r1").expect("promote failed");
478
479 let new_primary = mgr.get_primary(1).expect("primary should exist");
480 assert_eq!(new_primary.replica_id, "r1");
481
482 let replicas = mgr.get_replicas(1);
484 let old = replicas
485 .iter()
486 .find(|r| r.replica_id == "r0")
487 .expect("r0 should still exist");
488 assert!(matches!(old.state, ReplicaState::Replica));
489 }
490
491 #[test]
492 fn test_promote_failed_replica_rejected() {
493 let mut mgr = ReplicaManager::new(2);
494 mgr.register_replica(primary(1, "r0", "node-a"))
495 .expect("ok");
496 mgr.register_replica(replica(1, "r1", "node-b"))
497 .expect("ok");
498
499 mgr.mark_failed(1, "r1");
500 let err = mgr.promote_to_primary(1, "r1");
501 assert!(err.is_err(), "promoting a failed replica must fail");
502 }
503
504 #[test]
505 fn test_promote_nonexistent_replica_rejected() {
506 let mut mgr = ReplicaManager::new(1);
507 mgr.register_replica(primary(1, "r0", "node-a"))
508 .expect("ok");
509
510 let err = mgr.promote_to_primary(1, "ghost");
511 assert!(err.is_err());
512 }
513
514 #[test]
517 fn test_mark_failed() {
518 let mut mgr = ReplicaManager::new(2);
519 mgr.register_replica(primary(1, "r0", "node-a"))
520 .expect("ok");
521 mgr.register_replica(replica(1, "r1", "node-b"))
522 .expect("ok");
523
524 mgr.mark_failed(1, "r1");
525
526 let replicas = mgr.get_replicas(1);
527 let r1 = replicas
528 .iter()
529 .find(|r| r.replica_id == "r1")
530 .expect("r1 exists");
531 assert!(matches!(r1.state, ReplicaState::Failed));
532 }
533
534 #[test]
535 fn test_mark_failed_noop_unknown() {
536 let mut mgr = ReplicaManager::new(1);
538 mgr.mark_failed(99, "ghost"); }
540
541 #[test]
544 fn test_auto_failover_selects_best_replica() {
545 let mut mgr = ReplicaManager::new(3);
546 mgr.register_replica(primary(1, "r0", "node-a"))
547 .expect("ok");
548
549 let mut r1 = replica(1, "r1", "node-b");
551 r1.vector_count = 2000;
552 let mut r2 = replica(1, "r2", "node-c");
553 r2.vector_count = 1500;
554
555 mgr.register_replica(r1).expect("ok");
556 mgr.register_replica(r2).expect("ok");
557
558 mgr.mark_failed(1, "r0");
559
560 let promoted = mgr.auto_failover(1).expect("auto_failover failed");
561 assert_eq!(promoted, "r1");
563 }
564
565 #[test]
566 fn test_auto_failover_fails_when_no_healthy_replica() {
567 let mut mgr = ReplicaManager::new(2);
568 mgr.register_replica(primary(1, "r0", "node-a"))
569 .expect("ok");
570 mgr.register_replica(replica(1, "r1", "node-b"))
571 .expect("ok");
572
573 mgr.mark_failed(1, "r0");
574 mgr.mark_failed(1, "r1");
575
576 let err = mgr.auto_failover(1);
577 assert!(err.is_err(), "no healthy replica → should fail");
578 }
579
580 #[test]
583 fn test_sync_progress_promotes_when_complete() {
584 let mut mgr = ReplicaManager::new(2);
585 mgr.register_replica(primary(1, "r0", "node-a"))
586 .expect("ok");
587 mgr.register_replica(catching_up(1, "r1", "node-b", 0.3))
588 .expect("ok");
589
590 mgr.update_sync_progress(1, "r1", 1.0);
591
592 let replicas = mgr.get_replicas(1);
593 let r1 = replicas.iter().find(|r| r.replica_id == "r1").expect("r1");
594 assert!(matches!(r1.state, ReplicaState::Replica));
595 }
596
597 #[test]
598 fn test_sync_progress_partial() {
599 let mut mgr = ReplicaManager::new(2);
600 mgr.register_replica(primary(1, "r0", "node-a"))
601 .expect("ok");
602 mgr.register_replica(catching_up(1, "r1", "node-b", 0.1))
603 .expect("ok");
604
605 mgr.update_sync_progress(1, "r1", 0.7);
606
607 let replicas = mgr.get_replicas(1);
608 let r1 = replicas.iter().find(|r| r.replica_id == "r1").expect("r1");
609 if let ReplicaState::CatchingUp { progress } = r1.state {
610 assert!((progress - 0.7).abs() < 1e-10);
611 } else {
612 panic!("Expected CatchingUp state");
613 }
614 }
615
616 #[test]
619 fn test_needs_rebalancing_false_when_healthy() {
620 let mut mgr = ReplicaManager::new(2);
621 mgr.register_replica(primary(1, "r0", "node-a"))
622 .expect("ok");
623 mgr.register_replica(replica(1, "r1", "node-b"))
624 .expect("ok");
625
626 assert!(!mgr.needs_rebalancing());
627 }
628
629 #[test]
630 fn test_needs_rebalancing_true_when_under_replicated() {
631 let mut mgr = ReplicaManager::new(3);
632 mgr.register_replica(primary(1, "r0", "node-a"))
633 .expect("ok");
634 assert!(mgr.needs_rebalancing());
637 }
638
639 #[test]
640 fn test_needs_rebalancing_true_when_over_replicated() {
641 let mut mgr = ReplicaManager::new(1);
642 mgr.register_replica(primary(1, "r0", "node-a"))
643 .expect("ok");
644 mgr.register_replica(replica(1, "r1", "node-b"))
645 .expect("ok");
646
647 assert!(mgr.needs_rebalancing());
649 }
650
651 #[test]
652 fn test_replication_status_healthy() {
653 let mut mgr = ReplicaManager::new(2);
654 mgr.register_replica(primary(1, "r0", "node-a"))
655 .expect("ok");
656 mgr.register_replica(replica(1, "r1", "node-b"))
657 .expect("ok");
658
659 let status = mgr.replication_status();
660 assert_eq!(status.total_shards, 1);
661 assert_eq!(status.under_replicated, 0);
662 assert_eq!(status.over_replicated, 0);
663 assert_eq!(status.failed_replicas, 0);
664 assert!(status.healthy);
665 }
666
667 #[test]
668 fn test_replication_status_with_failures() {
669 let mut mgr = ReplicaManager::new(3);
670 mgr.register_replica(primary(1, "r0", "node-a"))
671 .expect("ok");
672 mgr.register_replica(replica(1, "r1", "node-b"))
673 .expect("ok");
674 mgr.register_replica(replica(1, "r2", "node-c"))
675 .expect("ok");
676 mgr.mark_failed(1, "r2");
677
678 let status = mgr.replication_status();
679 assert!(!status.healthy);
680 assert_eq!(status.failed_replicas, 1);
681 assert_eq!(status.under_replicated, 1); }
683
684 #[test]
685 fn test_replication_status_multiple_shards() {
686 let mut mgr = ReplicaManager::new(2);
687
688 mgr.register_replica(primary(1, "r0", "node-a"))
690 .expect("ok");
691 mgr.register_replica(replica(1, "r1", "node-b"))
692 .expect("ok");
693
694 mgr.register_replica(primary(2, "r0", "node-c"))
696 .expect("ok");
697
698 let status = mgr.replication_status();
699 assert_eq!(status.total_shards, 2);
700 assert_eq!(status.under_replicated, 1);
701 assert!(!status.healthy);
702 }
703
704 #[test]
705 fn test_unregister_replica() {
706 let mut mgr = ReplicaManager::new(2);
707 mgr.register_replica(primary(1, "r0", "node-a"))
708 .expect("ok");
709 mgr.register_replica(replica(1, "r1", "node-b"))
710 .expect("ok");
711
712 let removed = mgr.unregister_replica(1, "r1");
713 assert!(removed);
714 assert_eq!(mgr.get_replicas(1).len(), 1);
715 }
716
717 #[test]
718 fn test_get_healthy_replicas() {
719 let mut mgr = ReplicaManager::new(3);
720 mgr.register_replica(primary(1, "r0", "node-a"))
721 .expect("ok");
722 mgr.register_replica(replica(1, "r1", "node-b"))
723 .expect("ok");
724 mgr.register_replica(replica(1, "r2", "node-c"))
725 .expect("ok");
726 mgr.mark_failed(1, "r2");
727
728 let healthy = mgr.get_healthy_replicas(1);
729 assert_eq!(healthy.len(), 2);
730 assert!(healthy.iter().all(|r| r.state.is_healthy()));
731 }
732}