1use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicUsize, Ordering};
12
13use super::cluster::{ClusterCoordinator, NodeInfo};
14use super::sharding::ShardManager;
15use common::types::{ReadConsistency, StalenessConfig};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RouterConfig {
20 pub strategy: RoutingStrategy,
22 pub max_concurrent_shards: usize,
24 pub shard_timeout_ms: u64,
26 pub retry_failed_shards: bool,
28 pub max_retries: u32,
30}
31
32impl Default for RouterConfig {
33 fn default() -> Self {
34 Self {
35 strategy: RoutingStrategy::RoundRobin,
36 max_concurrent_shards: 10,
37 shard_timeout_ms: 5000,
38 retry_failed_shards: true,
39 max_retries: 2,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum RoutingStrategy {
47 RoundRobin,
49 LeastConnections,
51 Random,
53 PreferLocal,
55 PrimaryOnly,
57}
58
59#[derive(Debug, Clone)]
61pub struct QueryPlan {
62 pub shard_targets: HashMap<u32, NodeTarget>,
64 pub total_shards: usize,
66 pub is_scatter: bool,
68}
69
70#[derive(Debug, Clone)]
72pub struct NodeTarget {
73 pub primary: NodeInfo,
75 pub fallbacks: Vec<NodeInfo>,
77 pub shard_id: u32,
79}
80
81#[derive(Debug, Clone)]
83pub struct ShardResult<T> {
84 pub shard_id: u32,
86 pub served_by: String,
88 pub results: Vec<T>,
90 pub latency_ms: u64,
92 pub was_retry: bool,
94}
95
96#[derive(Debug, Clone)]
98pub struct MergedResults<T> {
99 pub results: Vec<T>,
101 pub shards_queried: usize,
103 pub shards_succeeded: usize,
105 pub total_latency_ms: u64,
107 pub shard_latencies: HashMap<u32, u64>,
109}
110
111pub struct QueryRouter {
113 config: RouterConfig,
115 shard_manager: ShardManager,
117 cluster: ClusterCoordinator,
119 rr_counter: AtomicUsize,
121 local_node_id: String,
123}
124
125impl QueryRouter {
126 pub fn new(
128 config: RouterConfig,
129 shard_manager: ShardManager,
130 cluster: ClusterCoordinator,
131 local_node_id: String,
132 ) -> Self {
133 Self {
134 config,
135 shard_manager,
136 cluster,
137 rr_counter: AtomicUsize::new(0),
138 local_node_id,
139 }
140 }
141
142 pub fn plan_point_query(&self, vector_id: &str) -> QueryPlan {
144 let assignment = self.shard_manager.get_shard(vector_id);
145 let targets = self.get_node_targets(assignment.shard_id);
146
147 let mut shard_targets = HashMap::new();
148 shard_targets.insert(assignment.shard_id, targets);
149
150 QueryPlan {
151 shard_targets,
152 total_shards: 1,
153 is_scatter: false,
154 }
155 }
156
157 pub fn plan_scatter_query(&self) -> QueryPlan {
159 let shards = self.shard_manager.get_all_shards();
160 let mut shard_targets = HashMap::new();
161
162 for shard_id in &shards {
163 let targets = self.get_node_targets(*shard_id);
164 shard_targets.insert(*shard_id, targets);
165 }
166
167 QueryPlan {
168 shard_targets,
169 total_shards: shards.len(),
170 is_scatter: true,
171 }
172 }
173
174 pub fn plan_batch_query(&self, vector_ids: &[String]) -> QueryPlan {
176 let shard_batches = self.shard_manager.get_shards_batch(vector_ids);
177 let mut shard_targets = HashMap::new();
178
179 for shard_id in shard_batches.keys() {
180 let targets = self.get_node_targets(*shard_id);
181 shard_targets.insert(*shard_id, targets);
182 }
183
184 QueryPlan {
185 shard_targets,
186 total_shards: shard_batches.len(),
187 is_scatter: false,
188 }
189 }
190
191 fn get_node_targets(&self, shard_id: u32) -> NodeTarget {
193 let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);
194
195 if healthy_nodes.is_empty() {
196 return NodeTarget {
198 primary: NodeInfo::new(
199 format!("unavailable-{}", shard_id),
200 "unavailable".to_string(),
201 super::cluster::NodeRole::Replica,
202 ),
203 fallbacks: Vec::new(),
204 shard_id,
205 };
206 }
207
208 let (primary, fallbacks) = match self.config.strategy {
209 RoutingStrategy::RoundRobin => {
210 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
211 let primary = healthy_nodes[idx].clone();
212 let fallbacks: Vec<_> = healthy_nodes
213 .into_iter()
214 .enumerate()
215 .filter(|(i, _)| *i != idx)
216 .map(|(_, n)| n)
217 .collect();
218 (primary, fallbacks)
219 }
220 RoutingStrategy::LeastConnections => {
221 let mut sorted = healthy_nodes.clone();
223 sorted.sort_by_key(|n| n.health.active_connections);
224 let primary = sorted.remove(0);
225 (primary, sorted)
226 }
227 RoutingStrategy::Random => {
228 let idx = (std::time::SystemTime::now()
230 .duration_since(std::time::UNIX_EPOCH)
231 .unwrap_or_default()
232 .as_nanos() as usize)
233 % healthy_nodes.len();
234 let primary = healthy_nodes[idx].clone();
235 let fallbacks: Vec<_> = healthy_nodes
236 .into_iter()
237 .enumerate()
238 .filter(|(i, _)| *i != idx)
239 .map(|(_, n)| n)
240 .collect();
241 (primary, fallbacks)
242 }
243 RoutingStrategy::PreferLocal => {
244 let local = healthy_nodes
246 .iter()
247 .find(|n| n.node_id == self.local_node_id);
248 if let Some(local_node) = local {
249 let primary = local_node.clone();
250 let fallbacks: Vec<_> = healthy_nodes
251 .into_iter()
252 .filter(|n| n.node_id != self.local_node_id)
253 .collect();
254 (primary, fallbacks)
255 } else {
256 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
258 let primary = healthy_nodes[idx].clone();
259 let fallbacks: Vec<_> = healthy_nodes
260 .into_iter()
261 .enumerate()
262 .filter(|(i, _)| *i != idx)
263 .map(|(_, n)| n)
264 .collect();
265 (primary, fallbacks)
266 }
267 }
268 RoutingStrategy::PrimaryOnly => {
269 let primary_node = self.cluster.get_primary_for_shard(shard_id);
271 if let Some(primary) = primary_node {
272 let fallbacks: Vec<_> = healthy_nodes
273 .into_iter()
274 .filter(|n| n.node_id != primary.node_id)
275 .collect();
276 (primary, fallbacks)
277 } else {
278 let primary = healthy_nodes[0].clone();
280 let fallbacks = healthy_nodes.into_iter().skip(1).collect();
281 (primary, fallbacks)
282 }
283 }
284 };
285
286 NodeTarget {
287 primary,
288 fallbacks,
289 shard_id,
290 }
291 }
292
293 pub fn merge_similarity_results<T: Clone>(
295 &self,
296 shard_results: Vec<ShardResult<T>>,
297 top_k: usize,
298 score_fn: impl Fn(&T) -> f32,
299 ) -> MergedResults<T> {
300 let shards_queried = shard_results.len();
301 let shards_succeeded = shard_results
302 .iter()
303 .filter(|r| !r.results.is_empty())
304 .count();
305
306 let mut shard_latencies = HashMap::new();
307 let mut total_latency = 0u64;
308
309 let mut all_results: Vec<(T, f32)> = Vec::new();
311
312 for shard_result in shard_results {
313 shard_latencies.insert(shard_result.shard_id, shard_result.latency_ms);
314 total_latency = total_latency.max(shard_result.latency_ms);
315
316 for result in shard_result.results {
317 let score = score_fn(&result);
318 all_results.push((result, score));
319 }
320 }
321
322 all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
324
325 let results: Vec<T> = all_results
327 .into_iter()
328 .take(top_k)
329 .map(|(r, _)| r)
330 .collect();
331
332 MergedResults {
333 results,
334 shards_queried,
335 shards_succeeded,
336 total_latency_ms: total_latency,
337 shard_latencies,
338 }
339 }
340
341 pub fn get_stats(&self) -> RouterStats {
343 let state = self.cluster.get_state();
344 let partitions = self.shard_manager.get_partition_info();
345
346 RouterStats {
347 total_nodes: state.total_node_count,
348 healthy_nodes: state.healthy_node_count,
349 total_shards: partitions.len() as u32,
350 healthy_shards: partitions.iter().filter(|p| p.is_healthy).count() as u32,
351 cluster_healthy: state.is_healthy,
352 has_quorum: state.has_quorum,
353 }
354 }
355
356 pub fn plan_scatter_query_with_consistency(
362 &self,
363 consistency: ReadConsistency,
364 staleness_config: Option<StalenessConfig>,
365 ) -> QueryPlan {
366 let shards = self.shard_manager.get_all_shards();
367 let mut shard_targets = HashMap::new();
368
369 for shard_id in &shards {
370 let targets =
371 self.get_node_targets_with_consistency(*shard_id, consistency, staleness_config);
372 shard_targets.insert(*shard_id, targets);
373 }
374
375 QueryPlan {
376 shard_targets,
377 total_shards: shards.len(),
378 is_scatter: true,
379 }
380 }
381
382 fn get_node_targets_with_consistency(
384 &self,
385 shard_id: u32,
386 consistency: ReadConsistency,
387 staleness_config: Option<StalenessConfig>,
388 ) -> NodeTarget {
389 let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);
390
391 if healthy_nodes.is_empty() {
392 return NodeTarget {
393 primary: NodeInfo::new(
394 format!("unavailable-{}", shard_id),
395 "unavailable".to_string(),
396 super::cluster::NodeRole::Replica,
397 ),
398 fallbacks: Vec::new(),
399 shard_id,
400 };
401 }
402
403 match consistency {
404 ReadConsistency::Strong => {
405 self.get_primary_target(shard_id, healthy_nodes)
407 }
408 ReadConsistency::Eventual => {
409 self.get_node_targets(shard_id)
411 }
412 ReadConsistency::BoundedStaleness => {
413 let max_staleness_ms = staleness_config.map(|c| c.max_staleness_ms).unwrap_or(5000);
415 self.get_bounded_staleness_target(shard_id, healthy_nodes, max_staleness_ms)
416 }
417 }
418 }
419
420 fn get_primary_target(&self, shard_id: u32, healthy_nodes: Vec<NodeInfo>) -> NodeTarget {
422 let primary_node = self.cluster.get_primary_for_shard(shard_id);
423 if let Some(primary) = primary_node {
424 let fallbacks: Vec<_> = healthy_nodes
425 .into_iter()
426 .filter(|n| n.node_id != primary.node_id)
427 .collect();
428 NodeTarget {
429 primary,
430 fallbacks,
431 shard_id,
432 }
433 } else {
434 let primary = healthy_nodes[0].clone();
436 let fallbacks = healthy_nodes.into_iter().skip(1).collect();
437 NodeTarget {
438 primary,
439 fallbacks,
440 shard_id,
441 }
442 }
443 }
444
445 fn get_bounded_staleness_target(
447 &self,
448 shard_id: u32,
449 healthy_nodes: Vec<NodeInfo>,
450 max_staleness_ms: u64,
451 ) -> NodeTarget {
452 let eligible_nodes: Vec<_> = healthy_nodes
455 .iter()
456 .filter(|n| {
457 n.health.replication_lag_ms.unwrap_or(0) <= max_staleness_ms
460 })
461 .cloned()
462 .collect();
463
464 if eligible_nodes.is_empty() {
465 return self.get_primary_target(shard_id, healthy_nodes);
467 }
468
469 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % eligible_nodes.len();
471 let primary = eligible_nodes[idx].clone();
472 let fallbacks: Vec<_> = eligible_nodes
473 .into_iter()
474 .enumerate()
475 .filter(|(i, _)| *i != idx)
476 .map(|(_, n)| n)
477 .collect();
478
479 NodeTarget {
480 primary,
481 fallbacks,
482 shard_id,
483 }
484 }
485
486 pub fn consistency_to_strategy(&self, consistency: ReadConsistency) -> RoutingStrategy {
488 match consistency {
489 ReadConsistency::Strong => RoutingStrategy::PrimaryOnly,
490 ReadConsistency::Eventual => self.config.strategy,
491 ReadConsistency::BoundedStaleness => RoutingStrategy::RoundRobin, }
493 }
494}
495
496#[derive(Debug, Clone, Serialize, Deserialize)]
498pub struct RouterStats {
499 pub total_nodes: u32,
501 pub healthy_nodes: u32,
503 pub total_shards: u32,
505 pub healthy_shards: u32,
507 pub cluster_healthy: bool,
509 pub has_quorum: bool,
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::distributed::cluster::{ClusterConfig, NodeRole};
517 use crate::distributed::sharding::ShardingConfig;
518
519 fn setup_router() -> QueryRouter {
520 let shard_config = ShardingConfig {
521 num_shards: 4,
522 replication_factor: 2,
523 ..Default::default()
524 };
525 let shard_manager = ShardManager::new(shard_config);
526
527 let cluster_config = ClusterConfig::default();
528 let cluster = ClusterCoordinator::new(cluster_config, "local".to_string());
529
530 for i in 0..4 {
532 let mut node = NodeInfo::new(
533 format!("node-{}", i),
534 format!("localhost:{}", 8080 + i),
535 if i == 0 {
536 NodeRole::Primary
537 } else {
538 NodeRole::Replica
539 },
540 );
541 node.shard_ids = vec![i as u32, (i + 1) as u32 % 4];
542 node.health.status = super::super::cluster::NodeStatus::Healthy;
543 cluster.register_node(node).unwrap();
544 }
545
546 let router_config = RouterConfig::default();
547 QueryRouter::new(router_config, shard_manager, cluster, "local".to_string())
548 }
549
550 #[test]
551 fn test_point_query_plan() {
552 let router = setup_router();
553 let plan = router.plan_point_query("test-vector-123");
554
555 assert_eq!(plan.total_shards, 1);
556 assert!(!plan.is_scatter);
557 assert_eq!(plan.shard_targets.len(), 1);
558 }
559
560 #[test]
561 fn test_scatter_query_plan() {
562 let router = setup_router();
563 let plan = router.plan_scatter_query();
564
565 assert_eq!(plan.total_shards, 4);
566 assert!(plan.is_scatter);
567 assert_eq!(plan.shard_targets.len(), 4);
568 }
569
570 #[test]
571 fn test_batch_query_plan() {
572 let router = setup_router();
573 let ids: Vec<String> = (0..10).map(|i| format!("vec-{}", i)).collect();
574 let plan = router.plan_batch_query(&ids);
575
576 assert!(plan.total_shards > 0);
578 assert!(plan.total_shards <= 4);
579 assert!(!plan.is_scatter);
580 }
581
582 #[test]
583 fn test_merge_results() {
584 let router = setup_router();
585
586 let shard_results = vec![
588 ShardResult {
589 shard_id: 0,
590 served_by: "node-0".to_string(),
591 results: vec![("a", 0.9), ("b", 0.7)],
592 latency_ms: 10,
593 was_retry: false,
594 },
595 ShardResult {
596 shard_id: 1,
597 served_by: "node-1".to_string(),
598 results: vec![("c", 0.95), ("d", 0.6)],
599 latency_ms: 15,
600 was_retry: false,
601 },
602 ];
603
604 let merged = router.merge_similarity_results(shard_results, 3, |(_id, score)| *score);
605
606 assert_eq!(merged.results.len(), 3);
607 assert_eq!(merged.shards_queried, 2);
608 assert_eq!(merged.shards_succeeded, 2);
609
610 assert_eq!(merged.results[0].0, "c"); assert_eq!(merged.results[1].0, "a"); assert_eq!(merged.results[2].0, "b"); }
615
616 #[test]
617 fn test_router_stats() {
618 let router = setup_router();
619 let stats = router.get_stats();
620
621 assert_eq!(stats.total_nodes, 4);
622 assert_eq!(stats.total_shards, 4);
623 assert!(stats.cluster_healthy);
624 }
625}