use crate::VectorError;
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, PartialEq)]
pub enum ReplicaState {
Primary,
Replica,
CatchingUp {
progress: f64,
},
Failed,
}
impl ReplicaState {
pub fn is_healthy(&self) -> bool {
matches!(self, ReplicaState::Primary | ReplicaState::Replica)
}
pub fn is_primary(&self) -> bool {
matches!(self, ReplicaState::Primary)
}
}
#[derive(Debug, Clone)]
pub struct ShardReplica {
pub shard_id: u64,
pub replica_id: String,
pub node_id: String,
pub state: ReplicaState,
pub last_sync: Instant,
pub vector_count: usize,
}
impl ShardReplica {
pub fn new(
shard_id: u64,
replica_id: impl Into<String>,
node_id: impl Into<String>,
state: ReplicaState,
vector_count: usize,
) -> Self {
Self {
shard_id,
replica_id: replica_id.into(),
node_id: node_id.into(),
state,
last_sync: Instant::now(),
vector_count,
}
}
pub fn touch(&mut self) {
self.last_sync = Instant::now();
}
}
#[derive(Debug, Clone)]
pub struct ReplicationStatus {
pub total_shards: usize,
pub under_replicated: usize,
pub over_replicated: usize,
pub failed_replicas: usize,
pub healthy: bool,
}
pub struct ReplicaManager {
shards: HashMap<u64, Vec<ShardReplica>>,
replication_factor: usize,
}
impl ReplicaManager {
pub fn new(replication_factor: usize) -> Self {
let factor = replication_factor.max(1);
Self {
shards: HashMap::new(),
replication_factor: factor,
}
}
pub fn register_replica(&mut self, replica: ShardReplica) -> Result<(), VectorError> {
let shard_id = replica.shard_id;
let replica_id = replica.replica_id.clone();
let is_primary = replica.state.is_primary();
let entry = self.shards.entry(shard_id).or_default();
if entry.iter().any(|r| r.replica_id == replica_id) {
return Err(VectorError::InvalidData(format!(
"Replica '{}' for shard {} is already registered",
replica_id, shard_id
)));
}
if is_primary && entry.iter().any(|r| r.state.is_primary()) {
return Err(VectorError::InvalidData(format!(
"Shard {} already has a primary; cannot register another",
shard_id
)));
}
entry.push(replica);
Ok(())
}
pub fn unregister_replica(&mut self, shard_id: u64, replica_id: &str) -> bool {
let Some(replicas) = self.shards.get_mut(&shard_id) else {
return false;
};
let before = replicas.len();
replicas.retain(|r| r.replica_id != replica_id);
replicas.len() < before
}
pub fn promote_to_primary(
&mut self,
shard_id: u64,
replica_id: &str,
) -> Result<(), VectorError> {
let replicas = self
.shards
.get_mut(&shard_id)
.ok_or_else(|| VectorError::InvalidData(format!("Shard {} not found", shard_id)))?;
let target_exists = replicas.iter().any(|r| r.replica_id == replica_id);
if !target_exists {
return Err(VectorError::InvalidData(format!(
"Replica '{}' not found in shard {}",
replica_id, shard_id
)));
}
let target_failed = replicas
.iter()
.find(|r| r.replica_id == replica_id)
.map(|r| matches!(r.state, ReplicaState::Failed))
.unwrap_or(false);
if target_failed {
return Err(VectorError::InvalidData(format!(
"Cannot promote failed replica '{}' in shard {}",
replica_id, shard_id
)));
}
for r in replicas.iter_mut() {
if r.replica_id != replica_id && matches!(r.state, ReplicaState::Primary) {
r.state = ReplicaState::Replica;
}
}
for r in replicas.iter_mut() {
if r.replica_id == replica_id {
r.state = ReplicaState::Primary;
r.touch();
}
}
Ok(())
}
pub fn mark_failed(&mut self, shard_id: u64, replica_id: &str) {
if let Some(replicas) = self.shards.get_mut(&shard_id) {
for r in replicas.iter_mut() {
if r.replica_id == replica_id {
r.state = ReplicaState::Failed;
}
}
}
}
pub fn auto_failover(&mut self, shard_id: u64) -> Result<String, VectorError> {
let best_id = {
let replicas = self
.shards
.get(&shard_id)
.ok_or_else(|| VectorError::InvalidData(format!("Shard {} not found", shard_id)))?;
replicas
.iter()
.filter(|r| r.state.is_healthy() && !r.state.is_primary())
.max_by_key(|r| r.vector_count)
.map(|r| r.replica_id.clone())
.ok_or_else(|| {
VectorError::InvalidData(format!(
"No healthy replica available to promote for shard {}",
shard_id
))
})?
};
self.promote_to_primary(shard_id, &best_id)?;
Ok(best_id)
}
pub fn update_sync_progress(&mut self, shard_id: u64, replica_id: &str, progress: f64) {
let Some(replicas) = self.shards.get_mut(&shard_id) else {
return;
};
for r in replicas.iter_mut() {
if r.replica_id == replica_id {
if progress >= 1.0 {
r.state = ReplicaState::Replica;
} else {
r.state = ReplicaState::CatchingUp { progress };
}
r.touch();
}
}
}
pub fn get_primary(&self, shard_id: u64) -> Option<&ShardReplica> {
self.shards
.get(&shard_id)?
.iter()
.find(|r| r.state.is_primary())
}
pub fn get_replicas(&self, shard_id: u64) -> Vec<&ShardReplica> {
self.shards
.get(&shard_id)
.map(|v| v.iter().collect())
.unwrap_or_default()
}
pub fn get_healthy_replicas(&self, shard_id: u64) -> Vec<&ShardReplica> {
self.shards
.get(&shard_id)
.map(|v| v.iter().filter(|r| r.state.is_healthy()).collect())
.unwrap_or_default()
}
pub fn shard_ids(&self) -> Vec<u64> {
self.shards.keys().cloned().collect()
}
pub fn replication_factor(&self) -> usize {
self.replication_factor
}
pub fn needs_rebalancing(&self) -> bool {
self.shards.values().any(|replicas| {
let healthy = replicas.iter().filter(|r| r.state.is_healthy()).count();
healthy != self.replication_factor
})
}
pub fn replication_status(&self) -> ReplicationStatus {
let total_shards = self.shards.len();
let mut under_replicated = 0usize;
let mut over_replicated = 0usize;
let mut failed_replicas = 0usize;
for replicas in self.shards.values() {
let healthy = replicas.iter().filter(|r| r.state.is_healthy()).count();
let failed = replicas
.iter()
.filter(|r| matches!(r.state, ReplicaState::Failed))
.count();
failed_replicas += failed;
match healthy.cmp(&self.replication_factor) {
std::cmp::Ordering::Less => under_replicated += 1,
std::cmp::Ordering::Greater => over_replicated += 1,
std::cmp::Ordering::Equal => {}
}
}
let healthy = under_replicated == 0 && over_replicated == 0 && failed_replicas == 0;
ReplicationStatus {
total_shards,
under_replicated,
over_replicated,
failed_replicas,
healthy,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn primary(shard: u64, rid: &str, node: &str) -> ShardReplica {
ShardReplica::new(shard, rid, node, ReplicaState::Primary, 1000)
}
fn replica(shard: u64, rid: &str, node: &str) -> ShardReplica {
ShardReplica::new(shard, rid, node, ReplicaState::Replica, 1000)
}
fn catching_up(shard: u64, rid: &str, node: &str, progress: f64) -> ShardReplica {
ShardReplica::new(shard, rid, node, ReplicaState::CatchingUp { progress }, 500)
}
#[test]
fn test_register_primary_and_replicas() {
let mut mgr = ReplicaManager::new(3);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("primary");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("replica 1");
mgr.register_replica(replica(1, "r2", "node-c"))
.expect("replica 2");
assert_eq!(mgr.get_replicas(1).len(), 3);
assert!(mgr.get_primary(1).is_some());
}
#[test]
fn test_duplicate_primary_rejected() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("first primary");
let err = mgr.register_replica(primary(1, "r1", "node-b"));
assert!(err.is_err(), "duplicate primary must be rejected");
}
#[test]
fn test_duplicate_replica_id_rejected() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("first");
let err = mgr.register_replica(replica(1, "r0", "node-b"));
assert!(err.is_err(), "duplicate replica_id must be rejected");
}
#[test]
fn test_promote_to_primary() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.promote_to_primary(1, "r1").expect("promote failed");
let new_primary = mgr.get_primary(1).expect("primary should exist");
assert_eq!(new_primary.replica_id, "r1");
let replicas = mgr.get_replicas(1);
let old = replicas
.iter()
.find(|r| r.replica_id == "r0")
.expect("r0 should still exist");
assert!(matches!(old.state, ReplicaState::Replica));
}
#[test]
fn test_promote_failed_replica_rejected() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.mark_failed(1, "r1");
let err = mgr.promote_to_primary(1, "r1");
assert!(err.is_err(), "promoting a failed replica must fail");
}
#[test]
fn test_promote_nonexistent_replica_rejected() {
let mut mgr = ReplicaManager::new(1);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
let err = mgr.promote_to_primary(1, "ghost");
assert!(err.is_err());
}
#[test]
fn test_mark_failed() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.mark_failed(1, "r1");
let replicas = mgr.get_replicas(1);
let r1 = replicas
.iter()
.find(|r| r.replica_id == "r1")
.expect("r1 exists");
assert!(matches!(r1.state, ReplicaState::Failed));
}
#[test]
fn test_mark_failed_noop_unknown() {
let mut mgr = ReplicaManager::new(1);
mgr.mark_failed(99, "ghost"); }
#[test]
fn test_auto_failover_selects_best_replica() {
let mut mgr = ReplicaManager::new(3);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
let mut r1 = replica(1, "r1", "node-b");
r1.vector_count = 2000;
let mut r2 = replica(1, "r2", "node-c");
r2.vector_count = 1500;
mgr.register_replica(r1).expect("ok");
mgr.register_replica(r2).expect("ok");
mgr.mark_failed(1, "r0");
let promoted = mgr.auto_failover(1).expect("auto_failover failed");
assert_eq!(promoted, "r1");
}
#[test]
fn test_auto_failover_fails_when_no_healthy_replica() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.mark_failed(1, "r0");
mgr.mark_failed(1, "r1");
let err = mgr.auto_failover(1);
assert!(err.is_err(), "no healthy replica → should fail");
}
#[test]
fn test_sync_progress_promotes_when_complete() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(catching_up(1, "r1", "node-b", 0.3))
.expect("ok");
mgr.update_sync_progress(1, "r1", 1.0);
let replicas = mgr.get_replicas(1);
let r1 = replicas.iter().find(|r| r.replica_id == "r1").expect("r1");
assert!(matches!(r1.state, ReplicaState::Replica));
}
#[test]
fn test_sync_progress_partial() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(catching_up(1, "r1", "node-b", 0.1))
.expect("ok");
mgr.update_sync_progress(1, "r1", 0.7);
let replicas = mgr.get_replicas(1);
let r1 = replicas.iter().find(|r| r.replica_id == "r1").expect("r1");
if let ReplicaState::CatchingUp { progress } = r1.state {
assert!((progress - 0.7).abs() < 1e-10);
} else {
panic!("Expected CatchingUp state");
}
}
#[test]
fn test_needs_rebalancing_false_when_healthy() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
assert!(!mgr.needs_rebalancing());
}
#[test]
fn test_needs_rebalancing_true_when_under_replicated() {
let mut mgr = ReplicaManager::new(3);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
assert!(mgr.needs_rebalancing());
}
#[test]
fn test_needs_rebalancing_true_when_over_replicated() {
let mut mgr = ReplicaManager::new(1);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
assert!(mgr.needs_rebalancing());
}
#[test]
fn test_replication_status_healthy() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
let status = mgr.replication_status();
assert_eq!(status.total_shards, 1);
assert_eq!(status.under_replicated, 0);
assert_eq!(status.over_replicated, 0);
assert_eq!(status.failed_replicas, 0);
assert!(status.healthy);
}
#[test]
fn test_replication_status_with_failures() {
let mut mgr = ReplicaManager::new(3);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.register_replica(replica(1, "r2", "node-c"))
.expect("ok");
mgr.mark_failed(1, "r2");
let status = mgr.replication_status();
assert!(!status.healthy);
assert_eq!(status.failed_replicas, 1);
assert_eq!(status.under_replicated, 1); }
#[test]
fn test_replication_status_multiple_shards() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.register_replica(primary(2, "r0", "node-c"))
.expect("ok");
let status = mgr.replication_status();
assert_eq!(status.total_shards, 2);
assert_eq!(status.under_replicated, 1);
assert!(!status.healthy);
}
#[test]
fn test_unregister_replica() {
let mut mgr = ReplicaManager::new(2);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
let removed = mgr.unregister_replica(1, "r1");
assert!(removed);
assert_eq!(mgr.get_replicas(1).len(), 1);
}
#[test]
fn test_get_healthy_replicas() {
let mut mgr = ReplicaManager::new(3);
mgr.register_replica(primary(1, "r0", "node-a"))
.expect("ok");
mgr.register_replica(replica(1, "r1", "node-b"))
.expect("ok");
mgr.register_replica(replica(1, "r2", "node-c"))
.expect("ok");
mgr.mark_failed(1, "r2");
let healthy = mgr.get_healthy_replicas(1);
assert_eq!(healthy.len(), 2);
assert!(healthy.iter().all(|r| r.state.is_healthy()));
}
}