use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::cluster::{ClusterCoordinator, NodeInfo};
use super::sharding::ShardManager;
use common::types::{ReadConsistency, StalenessConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
pub strategy: RoutingStrategy,
pub max_concurrent_shards: usize,
pub shard_timeout_ms: u64,
pub retry_failed_shards: bool,
pub max_retries: u32,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
strategy: RoutingStrategy::RoundRobin,
max_concurrent_shards: 10,
shard_timeout_ms: 5000,
retry_failed_shards: true,
max_retries: 2,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RoutingStrategy {
RoundRobin,
LeastConnections,
Random,
PreferLocal,
PrimaryOnly,
}
#[derive(Debug, Clone)]
pub struct QueryPlan {
pub shard_targets: HashMap<u32, NodeTarget>,
pub total_shards: usize,
pub is_scatter: bool,
}
#[derive(Debug, Clone)]
pub struct NodeTarget {
pub primary: NodeInfo,
pub fallbacks: Vec<NodeInfo>,
pub shard_id: u32,
}
#[derive(Debug, Clone)]
pub struct ShardResult<T> {
pub shard_id: u32,
pub served_by: String,
pub results: Vec<T>,
pub latency_ms: u64,
pub was_retry: bool,
}
#[derive(Debug, Clone)]
pub struct MergedResults<T> {
pub results: Vec<T>,
pub shards_queried: usize,
pub shards_succeeded: usize,
pub total_latency_ms: u64,
pub shard_latencies: HashMap<u32, u64>,
}
pub struct QueryRouter {
config: RouterConfig,
shard_manager: ShardManager,
cluster: ClusterCoordinator,
rr_counter: AtomicUsize,
local_node_id: String,
}
impl QueryRouter {
pub fn new(
config: RouterConfig,
shard_manager: ShardManager,
cluster: ClusterCoordinator,
local_node_id: String,
) -> Self {
Self {
config,
shard_manager,
cluster,
rr_counter: AtomicUsize::new(0),
local_node_id,
}
}
pub fn plan_point_query(&self, vector_id: &str) -> QueryPlan {
let assignment = self.shard_manager.get_shard(vector_id);
let targets = self.get_node_targets(assignment.shard_id);
let mut shard_targets = HashMap::new();
shard_targets.insert(assignment.shard_id, targets);
QueryPlan {
shard_targets,
total_shards: 1,
is_scatter: false,
}
}
pub fn plan_scatter_query(&self) -> QueryPlan {
let shards = self.shard_manager.get_all_shards();
let mut shard_targets = HashMap::new();
for shard_id in &shards {
let targets = self.get_node_targets(*shard_id);
shard_targets.insert(*shard_id, targets);
}
QueryPlan {
shard_targets,
total_shards: shards.len(),
is_scatter: true,
}
}
pub fn plan_batch_query(&self, vector_ids: &[String]) -> QueryPlan {
let shard_batches = self.shard_manager.get_shards_batch(vector_ids);
let mut shard_targets = HashMap::new();
for shard_id in shard_batches.keys() {
let targets = self.get_node_targets(*shard_id);
shard_targets.insert(*shard_id, targets);
}
QueryPlan {
shard_targets,
total_shards: shard_batches.len(),
is_scatter: false,
}
}
fn get_node_targets(&self, shard_id: u32) -> NodeTarget {
let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);
if healthy_nodes.is_empty() {
return NodeTarget {
primary: NodeInfo::new(
format!("unavailable-{}", shard_id),
"unavailable".to_string(),
super::cluster::NodeRole::Replica,
),
fallbacks: Vec::new(),
shard_id,
};
}
let (primary, fallbacks) = match self.config.strategy {
RoutingStrategy::RoundRobin => {
let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
let primary = healthy_nodes[idx].clone();
let fallbacks: Vec<_> = healthy_nodes
.into_iter()
.enumerate()
.filter(|(i, _)| *i != idx)
.map(|(_, n)| n)
.collect();
(primary, fallbacks)
}
RoutingStrategy::LeastConnections => {
let mut sorted = healthy_nodes.clone();
sorted.sort_by_key(|n| n.health.active_connections);
let primary = sorted.remove(0);
(primary, sorted)
}
RoutingStrategy::Random => {
let idx = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as usize)
% healthy_nodes.len();
let primary = healthy_nodes[idx].clone();
let fallbacks: Vec<_> = healthy_nodes
.into_iter()
.enumerate()
.filter(|(i, _)| *i != idx)
.map(|(_, n)| n)
.collect();
(primary, fallbacks)
}
RoutingStrategy::PreferLocal => {
let local = healthy_nodes
.iter()
.find(|n| n.node_id == self.local_node_id);
if let Some(local_node) = local {
let primary = local_node.clone();
let fallbacks: Vec<_> = healthy_nodes
.into_iter()
.filter(|n| n.node_id != self.local_node_id)
.collect();
(primary, fallbacks)
} else {
let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
let primary = healthy_nodes[idx].clone();
let fallbacks: Vec<_> = healthy_nodes
.into_iter()
.enumerate()
.filter(|(i, _)| *i != idx)
.map(|(_, n)| n)
.collect();
(primary, fallbacks)
}
}
RoutingStrategy::PrimaryOnly => {
let primary_node = self.cluster.get_primary_for_shard(shard_id);
if let Some(primary) = primary_node {
let fallbacks: Vec<_> = healthy_nodes
.into_iter()
.filter(|n| n.node_id != primary.node_id)
.collect();
(primary, fallbacks)
} else {
let primary = healthy_nodes[0].clone();
let fallbacks = healthy_nodes.into_iter().skip(1).collect();
(primary, fallbacks)
}
}
};
NodeTarget {
primary,
fallbacks,
shard_id,
}
}
pub fn merge_similarity_results<T: Clone>(
&self,
shard_results: Vec<ShardResult<T>>,
top_k: usize,
score_fn: impl Fn(&T) -> f32,
) -> MergedResults<T> {
let shards_queried = shard_results.len();
let shards_succeeded = shard_results
.iter()
.filter(|r| !r.results.is_empty())
.count();
let mut shard_latencies = HashMap::new();
let mut total_latency = 0u64;
let mut all_results: Vec<(T, f32)> = Vec::new();
for shard_result in shard_results {
shard_latencies.insert(shard_result.shard_id, shard_result.latency_ms);
total_latency = total_latency.max(shard_result.latency_ms);
for result in shard_result.results {
let score = score_fn(&result);
all_results.push((result, score));
}
}
all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<T> = all_results
.into_iter()
.take(top_k)
.map(|(r, _)| r)
.collect();
MergedResults {
results,
shards_queried,
shards_succeeded,
total_latency_ms: total_latency,
shard_latencies,
}
}
pub fn get_stats(&self) -> RouterStats {
let state = self.cluster.get_state();
let partitions = self.shard_manager.get_partition_info();
RouterStats {
total_nodes: state.total_node_count,
healthy_nodes: state.healthy_node_count,
total_shards: partitions.len() as u32,
healthy_shards: partitions.iter().filter(|p| p.is_healthy).count() as u32,
cluster_healthy: state.is_healthy,
has_quorum: state.has_quorum,
}
}
pub fn plan_scatter_query_with_consistency(
&self,
consistency: ReadConsistency,
staleness_config: Option<StalenessConfig>,
) -> QueryPlan {
let shards = self.shard_manager.get_all_shards();
let mut shard_targets = HashMap::new();
for shard_id in &shards {
let targets =
self.get_node_targets_with_consistency(*shard_id, consistency, staleness_config);
shard_targets.insert(*shard_id, targets);
}
QueryPlan {
shard_targets,
total_shards: shards.len(),
is_scatter: true,
}
}
fn get_node_targets_with_consistency(
&self,
shard_id: u32,
consistency: ReadConsistency,
staleness_config: Option<StalenessConfig>,
) -> NodeTarget {
let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);
if healthy_nodes.is_empty() {
return NodeTarget {
primary: NodeInfo::new(
format!("unavailable-{}", shard_id),
"unavailable".to_string(),
super::cluster::NodeRole::Replica,
),
fallbacks: Vec::new(),
shard_id,
};
}
match consistency {
ReadConsistency::Strong => {
self.get_primary_target(shard_id, healthy_nodes)
}
ReadConsistency::Eventual => {
self.get_node_targets(shard_id)
}
ReadConsistency::BoundedStaleness => {
let max_staleness_ms = staleness_config.map(|c| c.max_staleness_ms).unwrap_or(5000);
self.get_bounded_staleness_target(shard_id, healthy_nodes, max_staleness_ms)
}
}
}
fn get_primary_target(&self, shard_id: u32, healthy_nodes: Vec<NodeInfo>) -> NodeTarget {
let primary_node = self.cluster.get_primary_for_shard(shard_id);
if let Some(primary) = primary_node {
let fallbacks: Vec<_> = healthy_nodes
.into_iter()
.filter(|n| n.node_id != primary.node_id)
.collect();
NodeTarget {
primary,
fallbacks,
shard_id,
}
} else {
let primary = healthy_nodes[0].clone();
let fallbacks = healthy_nodes.into_iter().skip(1).collect();
NodeTarget {
primary,
fallbacks,
shard_id,
}
}
}
fn get_bounded_staleness_target(
&self,
shard_id: u32,
healthy_nodes: Vec<NodeInfo>,
max_staleness_ms: u64,
) -> NodeTarget {
let eligible_nodes: Vec<_> = healthy_nodes
.iter()
.filter(|n| {
n.health.replication_lag_ms.unwrap_or(0) <= max_staleness_ms
})
.cloned()
.collect();
if eligible_nodes.is_empty() {
return self.get_primary_target(shard_id, healthy_nodes);
}
let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % eligible_nodes.len();
let primary = eligible_nodes[idx].clone();
let fallbacks: Vec<_> = eligible_nodes
.into_iter()
.enumerate()
.filter(|(i, _)| *i != idx)
.map(|(_, n)| n)
.collect();
NodeTarget {
primary,
fallbacks,
shard_id,
}
}
pub fn consistency_to_strategy(&self, consistency: ReadConsistency) -> RoutingStrategy {
match consistency {
ReadConsistency::Strong => RoutingStrategy::PrimaryOnly,
ReadConsistency::Eventual => self.config.strategy,
ReadConsistency::BoundedStaleness => RoutingStrategy::RoundRobin, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterStats {
pub total_nodes: u32,
pub healthy_nodes: u32,
pub total_shards: u32,
pub healthy_shards: u32,
pub cluster_healthy: bool,
pub has_quorum: bool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::cluster::{ClusterConfig, NodeRole};
use crate::distributed::sharding::ShardingConfig;
fn setup_router() -> QueryRouter {
let shard_config = ShardingConfig {
num_shards: 4,
replication_factor: 2,
..Default::default()
};
let shard_manager = ShardManager::new(shard_config);
let cluster_config = ClusterConfig::default();
let cluster = ClusterCoordinator::new(cluster_config, "local".to_string());
for i in 0..4 {
let mut node = NodeInfo::new(
format!("node-{}", i),
format!("localhost:{}", 8080 + i),
if i == 0 {
NodeRole::Primary
} else {
NodeRole::Replica
},
);
node.shard_ids = vec![i as u32, (i + 1) as u32 % 4];
node.health.status = super::super::cluster::NodeStatus::Healthy;
cluster.register_node(node).unwrap();
}
let router_config = RouterConfig::default();
QueryRouter::new(router_config, shard_manager, cluster, "local".to_string())
}
#[test]
fn test_point_query_plan() {
let router = setup_router();
let plan = router.plan_point_query("test-vector-123");
assert_eq!(plan.total_shards, 1);
assert!(!plan.is_scatter);
assert_eq!(plan.shard_targets.len(), 1);
}
#[test]
fn test_scatter_query_plan() {
let router = setup_router();
let plan = router.plan_scatter_query();
assert_eq!(plan.total_shards, 4);
assert!(plan.is_scatter);
assert_eq!(plan.shard_targets.len(), 4);
}
#[test]
fn test_batch_query_plan() {
let router = setup_router();
let ids: Vec<String> = (0..10).map(|i| format!("vec-{}", i)).collect();
let plan = router.plan_batch_query(&ids);
assert!(plan.total_shards > 0);
assert!(plan.total_shards <= 4);
assert!(!plan.is_scatter);
}
#[test]
fn test_merge_results() {
let router = setup_router();
let shard_results = vec![
ShardResult {
shard_id: 0,
served_by: "node-0".to_string(),
results: vec![("a", 0.9), ("b", 0.7)],
latency_ms: 10,
was_retry: false,
},
ShardResult {
shard_id: 1,
served_by: "node-1".to_string(),
results: vec![("c", 0.95), ("d", 0.6)],
latency_ms: 15,
was_retry: false,
},
];
let merged = router.merge_similarity_results(shard_results, 3, |(_id, score)| *score);
assert_eq!(merged.results.len(), 3);
assert_eq!(merged.shards_queried, 2);
assert_eq!(merged.shards_succeeded, 2);
assert_eq!(merged.results[0].0, "c"); assert_eq!(merged.results[1].0, "a"); assert_eq!(merged.results[2].0, "b"); }
#[test]
fn test_router_stats() {
let router = setup_router();
let stats = router.get_stats();
assert_eq!(stats.total_nodes, 4);
assert_eq!(stats.total_shards, 4);
assert!(stats.cluster_healthy);
}
}