use crate::core::error::{Error, Result, VectorError};
use crate::core::hasher::IdentityHasher;
use crate::core::id::NodeId;
use crate::core::vector::validate_vector;
use crate::index::vector::{DistanceMetric, Quantization, VectorIndex, merge_top_k_results};
use rayon::prelude::*;
use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{BuildHasherDefault, Hash, Hasher};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
const MAX_K: usize = 100_000;
const FILTER_OVERFETCH_FACTOR: usize = 10;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_FAILURE_THRESHOLD: usize = 5;
const DEFAULT_OPEN_DURATION: Duration = Duration::from_secs(30);
pub const RECOMMENDED_IMBALANCE_THRESHOLD: f64 = 2.0;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DistributedError {
NoNodesAvailable,
NodeUnavailable {
node_id: u16,
reason: String,
},
AllNodesFailed {
failed_count: usize,
sample_error: String,
},
Timeout {
operation: String,
duration: Duration,
},
CircuitOpen {
node_id: u16,
remaining: Duration,
},
ConfigError(String),
}
impl fmt::Display for DistributedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DistributedError::NoNodesAvailable => {
write!(f, "No nodes available in the distributed index")
}
DistributedError::NodeUnavailable { node_id, reason } => {
write!(f, "Node {} is unavailable: {}", node_id, reason)
}
DistributedError::AllNodesFailed {
failed_count,
sample_error,
} => {
write!(
f,
"All {} nodes failed during query. Sample error: {}",
failed_count, sample_error
)
}
DistributedError::Timeout {
operation,
duration,
} => {
write!(
f,
"Operation '{}' timed out after {:?}",
operation, duration
)
}
DistributedError::CircuitOpen { node_id, remaining } => {
write!(
f,
"Circuit breaker open for node {}, {} seconds remaining",
node_id,
remaining.as_secs()
)
}
DistributedError::ConfigError(msg) => {
write!(f, "Configuration error: {}", msg)
}
}
}
}
impl std::error::Error for DistributedError {}
pub trait VectorNodeClient: Send + Sync + fmt::Debug {
fn node_id(&self) -> u16;
fn is_healthy(&self) -> bool;
fn add(&self, id: NodeId, vector: &[f32]) -> Result<()>;
fn remove(&self, id: NodeId) -> Result<()>;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>>;
fn len(&self) -> Result<usize>;
fn is_empty(&self) -> Result<bool> {
self.len().map(|len| len == 0)
}
fn health_check(&self) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub open_duration: Duration,
pub success_threshold: usize,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: DEFAULT_FAILURE_THRESHOLD,
open_duration: DEFAULT_OPEN_DURATION,
success_threshold: 3,
}
}
}
#[derive(Debug)]
pub struct NodeCircuitBreaker {
config: CircuitBreakerConfig,
state: RwLock<CircuitState>,
failure_count: AtomicUsize,
success_count: AtomicUsize,
opened_at: RwLock<Option<Instant>>,
}
impl NodeCircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicUsize::new(0),
success_count: AtomicUsize::new(0),
opened_at: RwLock::new(None),
}
}
pub fn state(&self) -> CircuitState {
self.maybe_transition();
self.state
.read()
.map(|s| *s)
.unwrap_or(CircuitState::Closed)
}
pub fn should_allow(&self) -> bool {
self.maybe_transition();
let state = self
.state
.read()
.map(|s| *s)
.unwrap_or(CircuitState::Closed);
matches!(state, CircuitState::Closed | CircuitState::HalfOpen)
}
pub fn record_success(&self) {
let state = match self.state.read() {
Ok(s) => *s,
Err(_) => return,
};
match state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let successes = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
if successes >= self.config.success_threshold {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Closed;
}
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
}
}
CircuitState::Open => {}
}
}
pub fn record_failure(&self) {
let state = match self.state.read() {
Ok(s) => *s,
Err(_) => return,
};
match state {
CircuitState::Closed => {
let failures = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if failures >= self.config.failure_threshold {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Open;
}
if let Ok(mut opened) = self.opened_at.write() {
*opened = Some(Instant::now());
}
}
}
CircuitState::HalfOpen => {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Open;
}
if let Ok(mut opened) = self.opened_at.write() {
*opened = Some(Instant::now());
}
self.success_count.store(0, Ordering::SeqCst);
}
CircuitState::Open => {
}
}
}
fn maybe_transition(&self) {
let should_transition = self
.opened_at
.read()
.ok()
.and_then(|opened| *opened)
.is_some_and(|opened_time| opened_time.elapsed() >= self.config.open_duration);
if !should_transition {
return;
}
let mut state_guard = match self.state.write() {
Ok(s) => s,
Err(_) => return,
};
if *state_guard == CircuitState::Open {
*state_guard = CircuitState::HalfOpen;
self.success_count.store(0, Ordering::SeqCst);
}
}
pub fn remaining_open_time(&self) -> Option<Duration> {
let state = self.state.read().ok()?;
if *state != CircuitState::Open {
return None;
}
if let Ok(opened) = self.opened_at.read()
&& let Some(opened_time) = *opened
{
let elapsed = opened_time.elapsed();
if elapsed < self.config.open_duration {
return Some(self.config.open_duration - elapsed);
}
}
None
}
pub fn reset(&self) {
if let Ok(mut s) = self.state.write() {
*s = CircuitState::Closed;
}
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
}
}
pub struct NodeConnection<C: VectorNodeClient> {
client: Arc<C>,
circuit_breaker: NodeCircuitBreaker,
request_count: AtomicU64,
failure_count: AtomicU64,
}
impl<C: VectorNodeClient> fmt::Debug for NodeConnection<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NodeConnection")
.field("node_id", &self.client.node_id())
.field("circuit_state", &self.circuit_breaker.state())
.field("request_count", &self.request_count.load(Ordering::Relaxed))
.field("failure_count", &self.failure_count.load(Ordering::Relaxed))
.finish()
}
}
impl<C: VectorNodeClient> NodeConnection<C> {
pub fn new(client: Arc<C>, circuit_config: CircuitBreakerConfig) -> Self {
Self {
client,
circuit_breaker: NodeCircuitBreaker::new(circuit_config),
request_count: AtomicU64::new(0),
failure_count: AtomicU64::new(0),
}
}
pub fn node_id(&self) -> u16 {
self.client.node_id()
}
pub fn is_available(&self) -> bool {
self.circuit_breaker.should_allow() && self.client.is_healthy()
}
pub fn circuit_state(&self) -> CircuitState {
self.circuit_breaker.state()
}
pub fn execute<T, F>(&self, f: F) -> Result<T>
where
F: FnOnce(&C) -> Result<T>,
{
self.request_count.fetch_add(1, Ordering::Relaxed);
if !self.circuit_breaker.should_allow() {
return Err(Error::Vector(VectorError::IndexError(format!(
"Circuit breaker open for node {}",
self.client.node_id()
))));
}
match f(&self.client) {
Ok(result) => {
self.circuit_breaker.record_success();
Ok(result)
}
Err(e) => {
self.failure_count.fetch_add(1, Ordering::Relaxed);
self.circuit_breaker.record_failure();
Err(e)
}
}
}
pub fn stats(&self) -> NodeConnectionStats {
NodeConnectionStats {
node_id: self.client.node_id(),
circuit_state: self.circuit_breaker.state(),
request_count: self.request_count.load(Ordering::Relaxed),
failure_count: self.failure_count.load(Ordering::Relaxed),
}
}
pub fn reset_circuit(&self) {
self.circuit_breaker.reset();
}
}
#[derive(Debug, Clone)]
pub struct NodeConnectionStats {
pub node_id: u16,
pub circuit_state: CircuitState,
pub request_count: u64,
pub failure_count: u64,
}
impl NodeConnectionStats {
pub fn success_rate(&self) -> f64 {
if self.request_count == 0 {
1.0
} else {
1.0 - (self.failure_count as f64 / self.request_count as f64)
}
}
}
#[derive(Debug, Clone)]
pub struct VectorNodeConfig {
pub node_id: u16,
pub endpoint: String,
pub timeout: Duration,
pub circuit_breaker: CircuitBreakerConfig,
}
impl VectorNodeConfig {
pub fn new(node_id: u16, endpoint: impl Into<String>) -> Self {
Self {
node_id,
endpoint: endpoint.into(),
timeout: DEFAULT_TIMEOUT,
circuit_breaker: CircuitBreakerConfig::default(),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker = config;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RoutingStrategy {
#[default]
HashBased,
RangeBased,
}
#[derive(Debug, Clone)]
pub struct DistributedVectorConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
pub nodes: Vec<VectorNodeConfig>,
pub routing_strategy: RoutingStrategy,
pub min_nodes_for_search: usize,
pub allow_partial_results: bool,
}
impl DistributedVectorConfig {
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
Self {
dimensions,
metric,
nodes: Vec::new(),
routing_strategy: RoutingStrategy::default(),
min_nodes_for_search: 1,
allow_partial_results: true,
}
}
pub fn with_node(mut self, node: VectorNodeConfig) -> Self {
self.nodes.push(node);
self
}
pub fn with_routing_strategy(mut self, strategy: RoutingStrategy) -> Self {
self.routing_strategy = strategy;
self
}
pub fn with_min_nodes_for_search(mut self, min_nodes: usize) -> Self {
self.min_nodes_for_search = min_nodes;
self
}
pub fn with_allow_partial_results(mut self, allow: bool) -> Self {
self.allow_partial_results = allow;
self
}
pub fn validate(&self) -> Result<()> {
if self.dimensions == 0 {
return Err(Error::Vector(VectorError::InvalidVector {
reason: "dimensions must be > 0".to_string(),
}));
}
if self.nodes.is_empty() {
return Err(Error::Vector(VectorError::IndexError(
"At least one node must be configured".to_string(),
)));
}
let mut seen_ids = std::collections::HashSet::new();
for node in &self.nodes {
if !seen_ids.insert(node.node_id) {
return Err(Error::Vector(VectorError::IndexError(format!(
"Duplicate node ID: {}",
node.node_id
))));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DistributedIndexStats {
pub total_vectors: usize,
pub node_count: usize,
pub available_nodes: usize,
pub node_stats: Vec<NodeConnectionStats>,
}
#[derive(Debug, Clone)]
pub struct RebalanceStats {
pub total_vectors: usize,
pub node_count: usize,
pub min_node_size: usize,
pub max_node_size: usize,
pub imbalance_ratio: f64,
pub vectors_to_move: usize,
pub node_sizes: Vec<(u16, usize)>,
}
pub struct DistributedVectorIndex<C: VectorNodeClient> {
config: DistributedVectorConfig,
nodes: Vec<Arc<NodeConnection<C>>>,
}
impl<C: VectorNodeClient> fmt::Debug for DistributedVectorIndex<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DistributedVectorIndex")
.field("dimensions", &self.config.dimensions)
.field("metric", &self.config.metric)
.field("node_count", &self.nodes.len())
.field("routing_strategy", &self.config.routing_strategy)
.finish()
}
}
impl<C: VectorNodeClient + 'static> DistributedVectorIndex<C> {
pub fn new(config: DistributedVectorConfig, clients: Vec<Arc<C>>) -> Result<Self> {
config.validate()?;
if clients.len() != config.nodes.len() {
return Err(Error::Vector(VectorError::IndexError(format!(
"Number of clients ({}) doesn't match number of configured nodes ({})",
clients.len(),
config.nodes.len()
))));
}
let nodes: Vec<Arc<NodeConnection<C>>> = clients
.into_iter()
.zip(config.nodes.iter())
.map(|(client, node_config)| {
Arc::new(NodeConnection::new(
client,
node_config.circuit_breaker.clone(),
))
})
.collect();
Ok(Self { config, nodes })
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn available_node_count(&self) -> usize {
self.nodes.iter().filter(|n| n.is_available()).count()
}
pub fn routing_strategy(&self) -> RoutingStrategy {
self.config.routing_strategy
}
fn node_for_id(&self, id: NodeId) -> usize {
debug_assert!(!self.nodes.is_empty(), "nodes cannot be empty");
let num_nodes = self.nodes.len();
match self.config.routing_strategy {
RoutingStrategy::HashBased => {
let mut hasher = DefaultHasher::new();
id.as_u64().hash(&mut hasher);
(hasher.finish() as usize) % num_nodes
}
RoutingStrategy::RangeBased => {
let num_nodes_128 = num_nodes as u128;
let id_128 = id.as_u64() as u128;
let node = ((id_128 * num_nodes_128) / (u64::MAX as u128 + 1)) as usize;
node.min(num_nodes - 1)
}
}
}
pub fn get_node(&self, index: usize) -> Option<&Arc<NodeConnection<C>>> {
self.nodes.get(index)
}
pub fn stats(&self) -> DistributedIndexStats {
let node_stats: Vec<NodeConnectionStats> = self.nodes.iter().map(|n| n.stats()).collect();
let available_nodes = self.nodes.iter().filter(|n| n.is_available()).count();
let total_vectors = self.len();
DistributedIndexStats {
total_vectors,
node_count: self.nodes.len(),
available_nodes,
node_stats,
}
}
pub fn reset_all_circuits(&self) {
for node in &self.nodes {
node.reset_circuit();
}
}
pub fn needs_rebalancing(&self, threshold: f64) -> bool {
let sizes: Vec<usize> = self
.nodes
.par_iter()
.filter(|n| n.is_available())
.filter_map(|node| node.execute(|client| client.len()).ok())
.collect();
if sizes.is_empty() || sizes.len() < 2 {
return false;
}
let min_size = sizes.iter().min().copied().unwrap_or(0);
let max_size = sizes.iter().max().copied().unwrap_or(0);
if min_size == 0 {
return max_size > 0;
}
let imbalance_ratio = max_size as f64 / min_size as f64;
imbalance_ratio > threshold
}
pub fn rebalance_stats(&self) -> RebalanceStats {
let sizes: Vec<(u16, usize)> = self
.nodes
.par_iter()
.filter(|n| n.is_available())
.filter_map(|node| {
node.execute(|client| client.len().map(|len| (client.node_id(), len)))
.ok()
})
.collect();
let total_vectors: usize = sizes.iter().map(|(_, s)| s).sum();
let node_count = sizes.len();
let min_size = sizes.iter().map(|(_, s)| *s).min().unwrap_or(0);
let max_size = sizes.iter().map(|(_, s)| *s).max().unwrap_or(0);
let imbalance_ratio = if min_size > 0 {
max_size as f64 / min_size as f64
} else if max_size > 0 {
f64::INFINITY
} else {
1.0
};
let target_per_node = total_vectors.checked_div(node_count).unwrap_or(0);
let vectors_to_move: usize = sizes
.iter()
.filter(|(_, s)| *s > target_per_node)
.map(|(_, s)| s - target_per_node)
.sum();
RebalanceStats {
total_vectors,
node_count,
min_node_size: min_size,
max_node_size: max_size,
imbalance_ratio,
vectors_to_move,
node_sizes: sizes,
}
}
fn merge_results(node_results: Vec<Vec<(NodeId, f32)>>, k: usize) -> Vec<(NodeId, f32)> {
merge_top_k_results(node_results, k)
}
}
impl<C: VectorNodeClient + 'static> VectorIndex for DistributedVectorIndex<C> {
fn add(&self, id: NodeId, vector: &[f32]) -> Result<()> {
validate_vector(vector)?;
if vector.len() != self.config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.dimensions,
actual: vector.len(),
}));
}
let node_idx = self.node_for_id(id);
self.nodes[node_idx].execute(|client| client.add(id, vector))
}
fn remove(&self, id: NodeId) -> Result<()> {
let node_idx = self.node_for_id(id);
self.nodes[node_idx].execute(|client| client.remove(id))
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
validate_vector(query)?;
if query.len() != self.config.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.config.dimensions,
actual: query.len(),
}));
}
if k > MAX_K {
return Err(Error::Vector(VectorError::IndexError(format!(
"k={} exceeds maximum allowed value of {}",
k, MAX_K
))));
}
let available_nodes: Vec<_> = self.nodes.iter().filter(|n| n.is_available()).collect();
if available_nodes.len() < self.config.min_nodes_for_search {
return Err(Error::Vector(VectorError::IndexError(format!(
"Not enough nodes available: {} < {}",
available_nodes.len(),
self.config.min_nodes_for_search
))));
}
let results: Vec<Result<Vec<(NodeId, f32)>>> = available_nodes
.par_iter()
.map(|node| node.execute(|client| client.search(query, k)))
.collect();
let mut successful_results = Vec::with_capacity(results.len()); let mut failed_count = 0;
let mut sample_error = String::new();
for result in results {
match result {
Ok(r) => successful_results.push(r),
Err(e) => {
failed_count += 1;
if sample_error.is_empty() {
sample_error = e.to_string();
}
}
}
}
if successful_results.is_empty() {
return Err(Error::Vector(VectorError::IndexError(format!(
"All nodes failed: {}",
sample_error
))));
}
if !self.config.allow_partial_results && failed_count > 0 {
return Err(Error::Vector(VectorError::IndexError(format!(
"{} nodes failed during search",
failed_count
))));
}
Ok(Self::merge_results(successful_results, k))
}
fn search_with_filter<F>(
&self,
query: &[f32],
k: usize,
predicate: F,
) -> Result<Vec<(NodeId, f32)>>
where
F: Fn(&NodeId) -> bool + Send + Sync,
{
let results = self.search(query, k.saturating_mul(FILTER_OVERFETCH_FACTOR))?;
let filtered: Vec<(NodeId, f32)> = results
.into_iter()
.filter(|(id, _)| predicate(id))
.take(k)
.collect();
Ok(filtered)
}
fn len(&self) -> usize {
self.nodes
.par_iter()
.filter(|n| n.is_available())
.filter_map(|node| node.execute(|client| client.len()).ok())
.sum()
}
fn dimensions(&self) -> usize {
self.config.dimensions
}
fn distance_metric(&self) -> DistanceMetric {
self.config.metric
}
fn memory_usage(&self) -> usize {
std::mem::size_of::<Self>()
+ self.nodes.len() * std::mem::size_of::<Arc<NodeConnection<C>>>()
}
fn quantization(&self) -> Quantization {
Quantization::F32 }
}
#[derive(Debug)]
pub struct MockVectorNodeClient {
node_id: u16,
healthy: AtomicBool,
vectors:
RwLock<std::collections::HashMap<NodeId, Vec<f32>, BuildHasherDefault<IdentityHasher>>>,
dimensions: usize,
metric: DistanceMetric,
fail_next: RwLock<Option<String>>,
}
impl MockVectorNodeClient {
pub fn new(node_id: u16, dimensions: usize, metric: DistanceMetric) -> Self {
Self {
node_id,
healthy: AtomicBool::new(true),
vectors: RwLock::new(std::collections::HashMap::with_hasher(
BuildHasherDefault::default(),
)),
dimensions,
metric,
fail_next: RwLock::new(None),
}
}
pub fn set_healthy(&self, healthy: bool) {
self.healthy.store(healthy, Ordering::SeqCst);
}
pub fn fail_next(&self, error: impl Into<String>) {
*self.fail_next.write().unwrap() = Some(error.into());
}
fn check_fail(&self) -> Result<()> {
if let Some(err) = self.fail_next.write().unwrap().take() {
return Err(Error::Vector(VectorError::IndexError(err)));
}
Ok(())
}
fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
match self.metric {
DistanceMetric::Cosine => {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a > 0.0 && mag_b > 0.0 {
dot / (mag_a * mag_b)
} else {
0.0
}
}
DistanceMetric::Euclidean => {
let dist: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt();
-dist }
DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
other => panic!(
"MockVectorNodeClient does not support {:?} distance metric",
other
),
}
}
}
impl VectorNodeClient for MockVectorNodeClient {
fn node_id(&self) -> u16 {
self.node_id
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::SeqCst)
}
fn add(&self, id: NodeId, vector: &[f32]) -> Result<()> {
self.check_fail()?;
if !self.is_healthy() {
return Err(Error::Vector(VectorError::IndexError(format!(
"Node {} is unavailable",
self.node_id
))));
}
if vector.len() != self.dimensions {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
}));
}
self.vectors.write().unwrap().insert(id, vector.to_vec());
Ok(())
}
fn remove(&self, id: NodeId) -> Result<()> {
self.check_fail()?;
if !self.is_healthy() {
return Err(Error::Vector(VectorError::IndexError(format!(
"Node {} is unavailable",
self.node_id
))));
}
self.vectors.write().unwrap().remove(&id);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
self.check_fail()?;
if !self.is_healthy() {
return Err(Error::Vector(VectorError::IndexError(format!(
"Node {} is unavailable",
self.node_id
))));
}
let vectors = self.vectors.read().unwrap();
let mut results: Vec<(NodeId, f32)> = vectors
.iter()
.map(|(id, vec)| (*id, self.compute_similarity(query, vec)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
Ok(results)
}
fn len(&self) -> Result<usize> {
self.check_fail()?;
if !self.is_healthy() {
return Err(Error::Vector(VectorError::IndexError(format!(
"Node {} is unavailable",
self.node_id
))));
}
Ok(self.vectors.read().unwrap().len())
}
fn health_check(&self) -> Result<()> {
self.check_fail()?;
if self.is_healthy() {
Ok(())
} else {
Err(Error::Vector(VectorError::IndexError(format!(
"Node {} is unavailable",
self.node_id
))))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config(num_nodes: usize) -> DistributedVectorConfig {
let mut config = DistributedVectorConfig::new(4, DistanceMetric::Cosine);
for i in 0..num_nodes {
config = config.with_node(VectorNodeConfig::new(i as u16, format!("node{}:9000", i)));
}
config
}
fn create_test_clients(num_nodes: usize) -> Vec<Arc<MockVectorNodeClient>> {
(0..num_nodes)
.map(|i| {
Arc::new(MockVectorNodeClient::new(
i as u16,
4,
DistanceMetric::Cosine,
))
})
.collect()
}
#[test]
fn test_config_creation() {
let config = DistributedVectorConfig::new(384, DistanceMetric::Cosine);
assert_eq!(config.dimensions, 384);
assert_eq!(config.metric, DistanceMetric::Cosine);
assert!(config.nodes.is_empty());
}
#[test]
fn test_config_with_nodes() {
let config = DistributedVectorConfig::new(384, DistanceMetric::Cosine)
.with_node(VectorNodeConfig::new(0, "node0:9000"))
.with_node(VectorNodeConfig::new(1, "node1:9000"));
assert_eq!(config.nodes.len(), 2);
assert_eq!(config.nodes[0].node_id, 0);
assert_eq!(config.nodes[1].node_id, 1);
}
#[test]
fn test_config_validation_zero_dimensions() {
let config = DistributedVectorConfig::new(0, DistanceMetric::Cosine)
.with_node(VectorNodeConfig::new(0, "node0:9000"));
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_no_nodes() {
let config = DistributedVectorConfig::new(384, DistanceMetric::Cosine);
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_duplicate_nodes() {
let config = DistributedVectorConfig::new(384, DistanceMetric::Cosine)
.with_node(VectorNodeConfig::new(0, "node0:9000"))
.with_node(VectorNodeConfig::new(0, "node1:9000"));
assert!(config.validate().is_err());
}
#[test]
fn test_create_distributed_index() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let clients: Vec<Arc<MockVectorNodeClient>> = clients.into_iter().collect();
let index = DistributedVectorIndex::new(config, clients)?;
assert_eq!(index.node_count(), 3);
assert_eq!(index.dimensions(), 4);
assert_eq!(index.distance_metric(), DistanceMetric::Cosine);
Ok(())
}
#[test]
fn test_create_mismatched_clients() {
let config = create_test_config(3);
let clients = create_test_clients(2);
let result = DistributedVectorIndex::new(config, clients);
assert!(result.is_err());
}
#[test]
fn test_add_vector() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let node = NodeId::new(1).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(node, &vector)?;
assert_eq!(index.len(), 1);
Ok(())
}
#[test]
fn test_add_multiple_vectors() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
for i in 1..=100 {
let node = NodeId::new(i).unwrap();
let vector = vec![i as f32, 0.0, 0.0, 0.0];
index.add(node, &vector)?;
}
assert_eq!(index.len(), 100);
Ok(())
}
#[test]
fn test_add_dimension_mismatch() {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients).unwrap();
let node = NodeId::new(1).unwrap();
let wrong_dim = vec![1.0, 0.0];
assert!(index.add(node, &wrong_dim).is_err());
}
#[test]
fn test_add_with_nan() {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients).unwrap();
let node = NodeId::new(1).unwrap();
let nan_vector = vec![1.0, f32::NAN, 0.0, 0.0];
assert!(index.add(node, &nan_vector).is_err());
}
#[test]
fn test_remove_vector() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let node = NodeId::new(1).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(node, &vector)?;
assert_eq!(index.len(), 1);
index.remove(node)?;
assert_eq!(index.len(), 0);
Ok(())
}
#[test]
fn test_search_empty_index() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10)?;
assert!(results.is_empty());
Ok(())
}
#[test]
fn test_search_basic() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
let node3 = NodeId::new(3).unwrap();
index.add(node1, &[1.0, 0.0, 0.0, 0.0])?;
index.add(node2, &[0.9, 0.1, 0.0, 0.0])?;
index.add(node3, &[0.0, 1.0, 0.0, 0.0])?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 3)?;
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, node1);
Ok(())
}
#[test]
fn test_search_dimension_mismatch() {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients).unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let wrong_query = vec![1.0, 0.0];
assert!(index.search(&wrong_query, 10).is_err());
}
#[test]
fn test_search_with_filter() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
for i in 1..=10 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let query = vec![5.0, 0.0, 0.0, 0.0];
let results = index.search_with_filter(&query, 10, |id| id.as_u64() % 2 == 0)?;
for (id, _) in &results {
assert_eq!(id.as_u64() % 2, 0);
}
Ok(())
}
#[test]
fn test_consistent_routing() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let node = NodeId::new(42).unwrap();
let route1 = index.node_for_id(node);
let route2 = index.node_for_id(node);
let route3 = index.node_for_id(node);
assert_eq!(route1, route2);
assert_eq!(route2, route3);
Ok(())
}
#[test]
fn test_range_based_routing() -> Result<()> {
let mut config = create_test_config(3);
config.routing_strategy = RoutingStrategy::RangeBased;
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let node_min = NodeId::new(0).unwrap();
let node_max = NodeId::new(u64::MAX - 1000).unwrap();
let route_min = index.node_for_id(node_min);
let route_max = index.node_for_id(node_max);
assert!(route_min < 3);
assert!(route_max < 3);
Ok(())
}
#[test]
fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
open_duration: Duration::from_millis(100),
success_threshold: 2,
};
let cb = NodeCircuitBreaker::new(config);
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.should_allow());
}
#[test]
fn test_circuit_breaker_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(10),
success_threshold: 1,
};
let cb = NodeCircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_breaker_closes_from_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(10),
success_threshold: 2,
};
let cb = NodeCircuitBreaker::new(config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_search_with_unhealthy_node() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
clients[1].set_healthy(false);
let index = DistributedVectorIndex::new(config, clients)?;
for i in 1..=10 {
let node = NodeId::new(i).unwrap();
let _ = index.add(node, &[i as f32, 0.0, 0.0, 0.0]);
}
let query = vec![5.0, 0.0, 0.0, 0.0];
let results = index.search(&query, 10)?;
assert!(!results.is_empty() || index.len() == 0);
Ok(())
}
#[test]
fn test_search_all_nodes_unavailable() {
let mut config = create_test_config(3);
config.min_nodes_for_search = 1;
let clients = create_test_clients(3);
for client in &clients {
client.set_healthy(false);
}
let index = DistributedVectorIndex::new(config, clients).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0];
let result = index.search(&query, 10);
assert!(result.is_err());
}
#[test]
fn test_stats() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let stats = index.stats();
assert_eq!(stats.node_count, 3);
assert_eq!(stats.available_nodes, 3);
assert_eq!(stats.node_stats.len(), 3);
Ok(())
}
#[test]
fn test_node_connection_stats() -> Result<()> {
let config = create_test_config(1);
let clients = create_test_clients(1);
let index = DistributedVectorIndex::new(config, clients)?;
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0])?;
index.add(NodeId::new(2).unwrap(), &[0.0, 1.0, 0.0, 0.0])?;
let stats = index.stats();
assert!(stats.node_stats[0].request_count >= 2);
Ok(())
}
#[test]
fn test_merge_results_empty() {
let results = DistributedVectorIndex::<MockVectorNodeClient>::merge_results(vec![], 10);
assert!(results.is_empty());
}
#[test]
fn test_merge_results_single_node() {
let node_results = vec![vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.8),
(NodeId::new(3).unwrap(), 0.7),
]];
let merged = DistributedVectorIndex::<MockVectorNodeClient>::merge_results(node_results, 2);
assert_eq!(merged.len(), 2);
assert_eq!(merged[0].0, NodeId::new(1).unwrap());
assert_eq!(merged[1].0, NodeId::new(2).unwrap());
}
#[test]
fn test_merge_results_multiple_nodes() {
let node_results = vec![
vec![
(NodeId::new(1).unwrap(), 0.9),
(NodeId::new(2).unwrap(), 0.7),
],
vec![
(NodeId::new(3).unwrap(), 0.85),
(NodeId::new(4).unwrap(), 0.6),
],
];
let merged = DistributedVectorIndex::<MockVectorNodeClient>::merge_results(node_results, 3);
assert_eq!(merged.len(), 3);
assert_eq!(merged[0].0, NodeId::new(1).unwrap());
assert_eq!(merged[1].0, NodeId::new(3).unwrap());
assert_eq!(merged[2].0, NodeId::new(2).unwrap());
}
#[test]
fn test_mock_client_add_search() -> Result<()> {
let client = MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine);
let node = NodeId::new(1).unwrap();
client.add(node, &[1.0, 0.0, 0.0, 0.0])?;
let results = client.search(&[1.0, 0.0, 0.0, 0.0], 10)?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, node);
assert!((results[0].1 - 1.0).abs() < 0.001);
Ok(())
}
#[test]
fn test_mock_client_fail_next() {
let client = MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine);
client.fail_next("Test error");
let result = client.add(NodeId::new(1).unwrap(), &[1.0, 0.0, 0.0, 0.0]);
assert!(result.is_err());
let result = client.add(NodeId::new(2).unwrap(), &[0.0, 1.0, 0.0, 0.0]);
assert!(result.is_ok());
}
#[test]
fn test_mock_client_unhealthy() {
let client = MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine);
client.set_healthy(false);
let result = client.add(NodeId::new(1).unwrap(), &[1.0, 0.0, 0.0, 0.0]);
assert!(result.is_err());
}
#[test]
fn test_needs_rebalancing_empty() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
assert!(!index.needs_rebalancing(2.0));
Ok(())
}
#[test]
fn test_needs_rebalancing_balanced() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
for i in 1..=30 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let stats = index.rebalance_stats();
assert!(stats.total_vectors == 30);
Ok(())
}
#[test]
fn test_rebalance_stats() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
for i in 1..=15 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let stats = index.rebalance_stats();
assert_eq!(stats.total_vectors, 15);
assert_eq!(stats.node_count, 3);
assert!(stats.min_node_size <= stats.max_node_size);
assert!(stats.imbalance_ratio >= 1.0);
Ok(())
}
#[test]
fn test_node_config_defaults() {
let config = VectorNodeConfig::new(0, "node0:9000");
assert_eq!(config.node_id, 0);
assert_eq!(config.endpoint, "node0:9000");
assert_eq!(config.timeout, DEFAULT_TIMEOUT);
}
#[test]
fn test_node_config_with_timeout() {
let config = VectorNodeConfig::new(0, "node0:9000").with_timeout(Duration::from_secs(60));
assert_eq!(config.timeout, Duration::from_secs(60));
}
#[test]
fn test_distributed_error_display() {
let err = DistributedError::NoNodesAvailable;
assert!(format!("{}", err).contains("No nodes available"));
let err = DistributedError::NodeUnavailable {
node_id: 0,
reason: "connection refused".to_string(),
};
assert!(format!("{}", err).contains("Node 0"));
assert!(format!("{}", err).contains("connection refused"));
let err = DistributedError::CircuitOpen {
node_id: 1,
remaining: Duration::from_secs(10),
};
assert!(format!("{}", err).contains("Circuit breaker"));
assert!(format!("{}", err).contains("node 1"));
}
#[test]
fn test_distributed_error_display_all_variants() {
let err = DistributedError::AllNodesFailed {
failed_count: 3,
sample_error: "connection timeout".to_string(),
};
assert!(format!("{}", err).contains("All 3 nodes failed"));
assert!(format!("{}", err).contains("connection timeout"));
let err = DistributedError::Timeout {
operation: "search".to_string(),
duration: Duration::from_secs(30),
};
assert!(format!("{}", err).contains("search"));
assert!(format!("{}", err).contains("timed out"));
let err = DistributedError::ConfigError("invalid dimensions".to_string());
assert!(format!("{}", err).contains("Configuration error"));
assert!(format!("{}", err).contains("invalid dimensions"));
}
#[test]
fn test_search_k_exceeds_max() {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0];
let result = index.search(&query, MAX_K + 1);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
}
#[test]
fn test_search_k_at_max() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, MAX_K)?;
assert!(results.is_empty());
Ok(())
}
#[test]
fn test_stats_with_vectors() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
for i in 1..=15 {
let node = NodeId::new(i).unwrap();
index.add(node, &[i as f32, 0.0, 0.0, 0.0])?;
}
let stats = index.stats();
assert_eq!(stats.total_vectors, 15);
assert_eq!(stats.node_count, 3);
assert_eq!(stats.available_nodes, 3);
Ok(())
}
#[test]
fn test_search_partial_results_disabled() {
let mut config = create_test_config(3);
config.allow_partial_results = false;
let clients = create_test_clients(3);
clients[1].set_healthy(false);
let index = DistributedVectorIndex::new(config, clients).unwrap();
let node = NodeId::new(1).unwrap();
let _ = index.add(node, &[1.0, 0.0, 0.0, 0.0]);
let query = vec![1.0, 0.0, 0.0, 0.0];
let _ = index.search(&query, 10);
}
#[test]
fn test_reset_all_circuits() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
for node in index.nodes.iter() {
for _ in 0..10 {
node.circuit_breaker.record_failure();
}
}
for node in index.nodes.iter() {
assert_eq!(node.circuit_state(), CircuitState::Open);
}
index.reset_all_circuits();
for node in index.nodes.iter() {
assert_eq!(node.circuit_state(), CircuitState::Closed);
}
Ok(())
}
#[test]
fn test_mock_client_is_empty() -> Result<()> {
let client = MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine);
assert!(client.is_empty()?);
client.add(NodeId::new(1).unwrap(), &[1.0, 0.0, 0.0, 0.0])?;
assert!(!client.is_empty()?);
Ok(())
}
#[test]
fn test_circuit_breaker_remaining_time() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(60),
success_threshold: 1,
};
let cb = NodeCircuitBreaker::new(config);
assert!(cb.remaining_open_time().is_none());
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
let remaining = cb.remaining_open_time();
assert!(remaining.is_some());
assert!(remaining.unwrap() > Duration::from_secs(0));
assert!(remaining.unwrap() <= Duration::from_secs(60));
}
#[test]
fn test_node_connection_execute_with_circuit_open() {
let client = Arc::new(MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine));
let connection = NodeConnection::new(
client,
CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(60),
success_threshold: 1,
},
);
connection.circuit_breaker.record_failure();
assert_eq!(connection.circuit_state(), CircuitState::Open);
let result = connection.execute(|c| c.health_check());
assert!(result.is_err());
}
#[test]
fn test_node_connection_debug() {
let client = Arc::new(MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine));
let connection = NodeConnection::new(client, CircuitBreakerConfig::default());
let debug_str = format!("{:?}", connection);
assert!(debug_str.contains("NodeConnection"));
assert!(debug_str.contains("node_id"));
}
#[test]
fn test_distributed_index_debug() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
let debug_str = format!("{:?}", index);
assert!(debug_str.contains("DistributedVectorIndex"));
assert!(debug_str.contains("dimensions"));
assert!(debug_str.contains("node_count"));
Ok(())
}
#[test]
fn test_mock_client_dimension_mismatch() {
let client = MockVectorNodeClient::new(0, 4, DistanceMetric::Cosine);
let result = client.add(NodeId::new(1).unwrap(), &[1.0, 0.0]); assert!(result.is_err());
}
#[test]
fn test_config_routing_strategy() {
let config = DistributedVectorConfig::new(384, DistanceMetric::Cosine)
.with_node(VectorNodeConfig::new(0, "node0:9000"))
.with_routing_strategy(RoutingStrategy::RangeBased);
assert_eq!(config.routing_strategy, RoutingStrategy::RangeBased);
}
#[test]
fn test_config_min_nodes_for_search() {
let config = DistributedVectorConfig::new(384, DistanceMetric::Cosine)
.with_node(VectorNodeConfig::new(0, "node0:9000"))
.with_min_nodes_for_search(2);
assert_eq!(config.min_nodes_for_search, 2);
}
#[test]
fn test_node_connection_stats_success_rate() {
let stats = NodeConnectionStats {
node_id: 0,
circuit_state: CircuitState::Closed,
request_count: 10,
failure_count: 3,
};
let rate = stats.success_rate();
assert!((rate - 0.7).abs() < 0.001);
}
#[test]
fn test_node_connection_stats_success_rate_zero_requests() {
let stats = NodeConnectionStats {
node_id: 0,
circuit_state: CircuitState::Closed,
request_count: 0,
failure_count: 0,
};
assert_eq!(stats.success_rate(), 1.0);
}
#[test]
fn test_circuit_breaker_concurrent_failures() {
use std::thread;
let config = CircuitBreakerConfig {
failure_threshold: 10,
open_duration: Duration::from_secs(60),
success_threshold: 3,
};
let cb = Arc::new(NodeCircuitBreaker::new(config));
let mut handles = vec![];
for _ in 0..4 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
for _ in 0..5 {
cb_clone.record_failure();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_concurrent_success_failure_mix() {
use std::thread;
let config = CircuitBreakerConfig {
failure_threshold: 5,
open_duration: Duration::from_millis(10),
success_threshold: 2,
};
let cb = Arc::new(NodeCircuitBreaker::new(config));
let mut handles = vec![];
for _ in 0..2 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
for _ in 0..10 {
cb_clone.record_success();
std::thread::sleep(Duration::from_micros(100));
}
});
handles.push(handle);
}
for _ in 0..2 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
for _ in 0..10 {
cb_clone.record_failure();
std::thread::sleep(Duration::from_micros(100));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let state = cb.state();
assert!(
state == CircuitState::Closed
|| state == CircuitState::Open
|| state == CircuitState::HalfOpen
);
}
#[test]
fn test_circuit_breaker_concurrent_state_checks() {
use std::thread;
let config = CircuitBreakerConfig {
failure_threshold: 3,
open_duration: Duration::from_millis(50),
success_threshold: 2,
};
let cb = Arc::new(NodeCircuitBreaker::new(config));
for _ in 0..5 {
cb.record_failure();
}
assert_eq!(cb.state(), CircuitState::Open);
let mut handles = vec![];
for _ in 0..8 {
let cb_clone = Arc::clone(&cb);
let handle = thread::spawn(move || {
for _ in 0..100 {
let _ = cb_clone.state();
let _ = cb_clone.should_allow();
let _ = cb_clone.remaining_open_time();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_recommended_imbalance_threshold() {
assert!((RECOMMENDED_IMBALANCE_THRESHOLD - 2.0).abs() < 0.001);
}
#[test]
fn test_needs_rebalancing_with_threshold_constant() -> Result<()> {
let config = create_test_config(3);
let clients = create_test_clients(3);
let index = DistributedVectorIndex::new(config, clients)?;
assert!(!index.needs_rebalancing(RECOMMENDED_IMBALANCE_THRESHOLD));
Ok(())
}
}