Skip to main content

engine/distributed/
routing.rs

1//! Query Routing for Distributed Dakera
2//!
3//! Handles routing of queries across shards:
4//! - Scatter-gather for similarity search
5//! - Load balancing across replicas
6//! - Result merging and ranking
7//! - Consistency-aware routing (Turbopuffer-inspired)
8
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicUsize, Ordering};
12
13use super::cluster::{ClusterCoordinator, NodeInfo};
14use super::sharding::ShardManager;
15use common::types::{ReadConsistency, StalenessConfig};
16
17/// Configuration for query routing
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RouterConfig {
20    /// Strategy for selecting nodes
21    pub strategy: RoutingStrategy,
22    /// Maximum number of concurrent shard queries
23    pub max_concurrent_shards: usize,
24    /// Timeout for shard queries in milliseconds
25    pub shard_timeout_ms: u64,
26    /// Whether to retry failed shards
27    pub retry_failed_shards: bool,
28    /// Maximum retries per shard
29    pub max_retries: u32,
30}
31
32impl Default for RouterConfig {
33    fn default() -> Self {
34        Self {
35            strategy: RoutingStrategy::RoundRobin,
36            max_concurrent_shards: 10,
37            shard_timeout_ms: 5000,
38            retry_failed_shards: true,
39            max_retries: 2,
40        }
41    }
42}
43
44/// Strategy for selecting among replica nodes
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum RoutingStrategy {
47    /// Round-robin across healthy replicas
48    RoundRobin,
49    /// Least connections (load-based)
50    LeastConnections,
51    /// Random selection
52    Random,
53    /// Prefer local node if available
54    PreferLocal,
55    /// Route to primary only
56    PrimaryOnly,
57}
58
59/// Plan for executing a distributed query
60#[derive(Debug, Clone)]
61pub struct QueryPlan {
62    /// Shard to target node mapping
63    pub shard_targets: HashMap<u32, NodeTarget>,
64    /// Total number of shards to query
65    pub total_shards: usize,
66    /// Whether this is a scatter query (all shards)
67    pub is_scatter: bool,
68}
69
70/// Target node for a shard query
71#[derive(Debug, Clone)]
72pub struct NodeTarget {
73    /// Primary target node
74    pub primary: NodeInfo,
75    /// Fallback nodes for retry
76    pub fallbacks: Vec<NodeInfo>,
77    /// Shard ID
78    pub shard_id: u32,
79}
80
81/// Result from a single shard
82#[derive(Debug, Clone)]
83pub struct ShardResult<T> {
84    /// Shard ID
85    pub shard_id: u32,
86    /// Node that served the request
87    pub served_by: String,
88    /// Query results
89    pub results: Vec<T>,
90    /// Query latency in milliseconds
91    pub latency_ms: u64,
92    /// Whether this was a retry
93    pub was_retry: bool,
94}
95
96/// Merged results from scatter-gather
97#[derive(Debug, Clone)]
98pub struct MergedResults<T> {
99    /// Combined and ranked results
100    pub results: Vec<T>,
101    /// Number of shards queried
102    pub shards_queried: usize,
103    /// Number of shards that succeeded
104    pub shards_succeeded: usize,
105    /// Total latency in milliseconds
106    pub total_latency_ms: u64,
107    /// Per-shard latencies
108    pub shard_latencies: HashMap<u32, u64>,
109}
110
111/// Query router for distributed operations
112pub struct QueryRouter {
113    /// Configuration
114    config: RouterConfig,
115    /// Shard manager for routing decisions
116    shard_manager: ShardManager,
117    /// Cluster coordinator for node health
118    cluster: ClusterCoordinator,
119    /// Round-robin counter
120    rr_counter: AtomicUsize,
121    /// Local node ID
122    local_node_id: String,
123}
124
125impl QueryRouter {
126    /// Create a new query router
127    pub fn new(
128        config: RouterConfig,
129        shard_manager: ShardManager,
130        cluster: ClusterCoordinator,
131        local_node_id: String,
132    ) -> Self {
133        Self {
134            config,
135            shard_manager,
136            cluster,
137            rr_counter: AtomicUsize::new(0),
138            local_node_id,
139        }
140    }
141
142    /// Plan a point query (single vector lookup)
143    pub fn plan_point_query(&self, vector_id: &str) -> QueryPlan {
144        let assignment = self.shard_manager.get_shard(vector_id);
145        let targets = self.get_node_targets(assignment.shard_id);
146
147        let mut shard_targets = HashMap::new();
148        shard_targets.insert(assignment.shard_id, targets);
149
150        QueryPlan {
151            shard_targets,
152            total_shards: 1,
153            is_scatter: false,
154        }
155    }
156
157    /// Plan a scatter query (similarity search across all shards)
158    pub fn plan_scatter_query(&self) -> QueryPlan {
159        let shards = self.shard_manager.get_all_shards();
160        let mut shard_targets = HashMap::new();
161
162        for shard_id in &shards {
163            let targets = self.get_node_targets(*shard_id);
164            shard_targets.insert(*shard_id, targets);
165        }
166
167        QueryPlan {
168            shard_targets,
169            total_shards: shards.len(),
170            is_scatter: true,
171        }
172    }
173
174    /// Plan a batch query (multiple vectors)
175    pub fn plan_batch_query(&self, vector_ids: &[String]) -> QueryPlan {
176        let shard_batches = self.shard_manager.get_shards_batch(vector_ids);
177        let mut shard_targets = HashMap::new();
178
179        for shard_id in shard_batches.keys() {
180            let targets = self.get_node_targets(*shard_id);
181            shard_targets.insert(*shard_id, targets);
182        }
183
184        QueryPlan {
185            shard_targets,
186            total_shards: shard_batches.len(),
187            is_scatter: false,
188        }
189    }
190
191    /// Get target nodes for a shard based on routing strategy
192    fn get_node_targets(&self, shard_id: u32) -> NodeTarget {
193        let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);
194
195        if healthy_nodes.is_empty() {
196            // Return a placeholder for error handling
197            return NodeTarget {
198                primary: NodeInfo::new(
199                    format!("unavailable-{}", shard_id),
200                    "unavailable".to_string(),
201                    super::cluster::NodeRole::Replica,
202                ),
203                fallbacks: Vec::new(),
204                shard_id,
205            };
206        }
207
208        let (primary, fallbacks) = match self.config.strategy {
209            RoutingStrategy::RoundRobin => {
210                let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
211                let primary = healthy_nodes[idx].clone();
212                let fallbacks: Vec<_> = healthy_nodes
213                    .into_iter()
214                    .enumerate()
215                    .filter(|(i, _)| *i != idx)
216                    .map(|(_, n)| n)
217                    .collect();
218                (primary, fallbacks)
219            }
220            RoutingStrategy::LeastConnections => {
221                // Sort by active connections
222                let mut sorted = healthy_nodes.clone();
223                sorted.sort_by_key(|n| n.health.active_connections);
224                let primary = sorted.remove(0);
225                (primary, sorted)
226            }
227            RoutingStrategy::Random => {
228                // Simple pseudo-random selection
229                let idx = (std::time::SystemTime::now()
230                    .duration_since(std::time::UNIX_EPOCH)
231                    .unwrap_or_default()
232                    .as_nanos() as usize)
233                    % healthy_nodes.len();
234                let primary = healthy_nodes[idx].clone();
235                let fallbacks: Vec<_> = healthy_nodes
236                    .into_iter()
237                    .enumerate()
238                    .filter(|(i, _)| *i != idx)
239                    .map(|(_, n)| n)
240                    .collect();
241                (primary, fallbacks)
242            }
243            RoutingStrategy::PreferLocal => {
244                // Check if local node serves this shard
245                let local = healthy_nodes
246                    .iter()
247                    .find(|n| n.node_id == self.local_node_id);
248                if let Some(local_node) = local {
249                    let primary = local_node.clone();
250                    let fallbacks: Vec<_> = healthy_nodes
251                        .into_iter()
252                        .filter(|n| n.node_id != self.local_node_id)
253                        .collect();
254                    (primary, fallbacks)
255                } else {
256                    // Fall back to round-robin
257                    let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
258                    let primary = healthy_nodes[idx].clone();
259                    let fallbacks: Vec<_> = healthy_nodes
260                        .into_iter()
261                        .enumerate()
262                        .filter(|(i, _)| *i != idx)
263                        .map(|(_, n)| n)
264                        .collect();
265                    (primary, fallbacks)
266                }
267            }
268            RoutingStrategy::PrimaryOnly => {
269                // Find primary node for this shard
270                let primary_node = self.cluster.get_primary_for_shard(shard_id);
271                if let Some(primary) = primary_node {
272                    let fallbacks: Vec<_> = healthy_nodes
273                        .into_iter()
274                        .filter(|n| n.node_id != primary.node_id)
275                        .collect();
276                    (primary, fallbacks)
277                } else {
278                    // No primary available, use first healthy node
279                    let primary = healthy_nodes[0].clone();
280                    let fallbacks = healthy_nodes.into_iter().skip(1).collect();
281                    (primary, fallbacks)
282                }
283            }
284        };
285
286        NodeTarget {
287            primary,
288            fallbacks,
289            shard_id,
290        }
291    }
292
293    /// Merge results from multiple shards for similarity search
294    pub fn merge_similarity_results<T: Clone>(
295        &self,
296        shard_results: Vec<ShardResult<T>>,
297        top_k: usize,
298        score_fn: impl Fn(&T) -> f32,
299    ) -> MergedResults<T> {
300        let shards_queried = shard_results.len();
301        let shards_succeeded = shard_results
302            .iter()
303            .filter(|r| !r.results.is_empty())
304            .count();
305
306        let mut shard_latencies = HashMap::new();
307        let mut total_latency = 0u64;
308
309        // Collect all results with scores
310        let mut all_results: Vec<(T, f32)> = Vec::new();
311
312        for shard_result in shard_results {
313            shard_latencies.insert(shard_result.shard_id, shard_result.latency_ms);
314            total_latency = total_latency.max(shard_result.latency_ms);
315
316            for result in shard_result.results {
317                let score = score_fn(&result);
318                all_results.push((result, score));
319            }
320        }
321
322        // Sort by score (descending for similarity)
323        all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
324
325        // Take top-k
326        let results: Vec<T> = all_results
327            .into_iter()
328            .take(top_k)
329            .map(|(r, _)| r)
330            .collect();
331
332        MergedResults {
333            results,
334            shards_queried,
335            shards_succeeded,
336            total_latency_ms: total_latency,
337            shard_latencies,
338        }
339    }
340
341    /// Get routing statistics
342    pub fn get_stats(&self) -> RouterStats {
343        let state = self.cluster.get_state();
344        let partitions = self.shard_manager.get_partition_info();
345
346        RouterStats {
347            total_nodes: state.total_node_count,
348            healthy_nodes: state.healthy_node_count,
349            total_shards: partitions.len() as u32,
350            healthy_shards: partitions.iter().filter(|p| p.is_healthy).count() as u32,
351            cluster_healthy: state.is_healthy,
352            has_quorum: state.has_quorum,
353        }
354    }
355
356    // =========================================================================
357    // Consistency-aware routing methods (Turbopuffer-inspired)
358    // =========================================================================
359
360    /// Plan a scatter query with consistency requirements
361    pub fn plan_scatter_query_with_consistency(
362        &self,
363        consistency: ReadConsistency,
364        staleness_config: Option<StalenessConfig>,
365    ) -> QueryPlan {
366        let shards = self.shard_manager.get_all_shards();
367        let mut shard_targets = HashMap::new();
368
369        for shard_id in &shards {
370            let targets =
371                self.get_node_targets_with_consistency(*shard_id, consistency, staleness_config);
372            shard_targets.insert(*shard_id, targets);
373        }
374
375        QueryPlan {
376            shard_targets,
377            total_shards: shards.len(),
378            is_scatter: true,
379        }
380    }
381
382    /// Get target nodes for a shard based on consistency requirements
383    fn get_node_targets_with_consistency(
384        &self,
385        shard_id: u32,
386        consistency: ReadConsistency,
387        staleness_config: Option<StalenessConfig>,
388    ) -> NodeTarget {
389        let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);
390
391        if healthy_nodes.is_empty() {
392            return NodeTarget {
393                primary: NodeInfo::new(
394                    format!("unavailable-{}", shard_id),
395                    "unavailable".to_string(),
396                    super::cluster::NodeRole::Replica,
397                ),
398                fallbacks: Vec::new(),
399                shard_id,
400            };
401        }
402
403        match consistency {
404            ReadConsistency::Strong => {
405                // Strong consistency: always route to primary
406                self.get_primary_target(shard_id, healthy_nodes)
407            }
408            ReadConsistency::Eventual => {
409                // Eventual: use default routing strategy for best latency
410                self.get_node_targets(shard_id)
411            }
412            ReadConsistency::BoundedStaleness => {
413                // Bounded staleness: select replicas within staleness window
414                let max_staleness_ms = staleness_config.map(|c| c.max_staleness_ms).unwrap_or(5000);
415                self.get_bounded_staleness_target(shard_id, healthy_nodes, max_staleness_ms)
416            }
417        }
418    }
419
420    /// Get primary node target for strong consistency
421    fn get_primary_target(&self, shard_id: u32, healthy_nodes: Vec<NodeInfo>) -> NodeTarget {
422        let primary_node = self.cluster.get_primary_for_shard(shard_id);
423        if let Some(primary) = primary_node {
424            let fallbacks: Vec<_> = healthy_nodes
425                .into_iter()
426                .filter(|n| n.node_id != primary.node_id)
427                .collect();
428            NodeTarget {
429                primary,
430                fallbacks,
431                shard_id,
432            }
433        } else {
434            // No primary available, use first healthy node
435            let primary = healthy_nodes[0].clone();
436            let fallbacks = healthy_nodes.into_iter().skip(1).collect();
437            NodeTarget {
438                primary,
439                fallbacks,
440                shard_id,
441            }
442        }
443    }
444
445    /// Get node target within staleness bounds
446    fn get_bounded_staleness_target(
447        &self,
448        shard_id: u32,
449        healthy_nodes: Vec<NodeInfo>,
450        max_staleness_ms: u64,
451    ) -> NodeTarget {
452        // Filter nodes to those within staleness bounds
453        // Nodes track their replication lag via health metrics
454        let eligible_nodes: Vec<_> = healthy_nodes
455            .iter()
456            .filter(|n| {
457                // Check replication lag is within bounds
458                // If no lag info, assume it's acceptable for bounded reads
459                n.health.replication_lag_ms.unwrap_or(0) <= max_staleness_ms
460            })
461            .cloned()
462            .collect();
463
464        if eligible_nodes.is_empty() {
465            // Fall back to primary if no nodes meet staleness requirement
466            return self.get_primary_target(shard_id, healthy_nodes);
467        }
468
469        // Among eligible nodes, use round-robin for load balancing
470        let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % eligible_nodes.len();
471        let primary = eligible_nodes[idx].clone();
472        let fallbacks: Vec<_> = eligible_nodes
473            .into_iter()
474            .enumerate()
475            .filter(|(i, _)| *i != idx)
476            .map(|(_, n)| n)
477            .collect();
478
479        NodeTarget {
480            primary,
481            fallbacks,
482            shard_id,
483        }
484    }
485
486    /// Convert ReadConsistency to effective RoutingStrategy
487    pub fn consistency_to_strategy(&self, consistency: ReadConsistency) -> RoutingStrategy {
488        match consistency {
489            ReadConsistency::Strong => RoutingStrategy::PrimaryOnly,
490            ReadConsistency::Eventual => self.config.strategy,
491            ReadConsistency::BoundedStaleness => RoutingStrategy::RoundRobin, // Among eligible replicas
492        }
493    }
494}
495
496/// Statistics about routing state
497#[derive(Debug, Clone, Serialize, Deserialize)]
498pub struct RouterStats {
499    /// Total number of nodes
500    pub total_nodes: u32,
501    /// Number of healthy nodes
502    pub healthy_nodes: u32,
503    /// Total number of shards
504    pub total_shards: u32,
505    /// Number of healthy shards
506    pub healthy_shards: u32,
507    /// Whether cluster is healthy
508    pub cluster_healthy: bool,
509    /// Whether cluster has quorum
510    pub has_quorum: bool,
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::distributed::cluster::{ClusterConfig, NodeRole};
517    use crate::distributed::sharding::ShardingConfig;
518
519    fn setup_router() -> QueryRouter {
520        let shard_config = ShardingConfig {
521            num_shards: 4,
522            replication_factor: 2,
523            ..Default::default()
524        };
525        let shard_manager = ShardManager::new(shard_config);
526
527        let cluster_config = ClusterConfig::default();
528        let cluster = ClusterCoordinator::new(cluster_config, "local".to_string());
529
530        // Register some nodes
531        for i in 0..4 {
532            let mut node = NodeInfo::new(
533                format!("node-{}", i),
534                format!("localhost:{}", 8080 + i),
535                if i == 0 {
536                    NodeRole::Primary
537                } else {
538                    NodeRole::Replica
539                },
540            );
541            node.shard_ids = vec![i as u32, (i + 1) as u32 % 4];
542            node.health.status = super::super::cluster::NodeStatus::Healthy;
543            cluster.register_node(node).unwrap();
544        }
545
546        let router_config = RouterConfig::default();
547        QueryRouter::new(router_config, shard_manager, cluster, "local".to_string())
548    }
549
550    #[test]
551    fn test_point_query_plan() {
552        let router = setup_router();
553        let plan = router.plan_point_query("test-vector-123");
554
555        assert_eq!(plan.total_shards, 1);
556        assert!(!plan.is_scatter);
557        assert_eq!(plan.shard_targets.len(), 1);
558    }
559
560    #[test]
561    fn test_scatter_query_plan() {
562        let router = setup_router();
563        let plan = router.plan_scatter_query();
564
565        assert_eq!(plan.total_shards, 4);
566        assert!(plan.is_scatter);
567        assert_eq!(plan.shard_targets.len(), 4);
568    }
569
570    #[test]
571    fn test_batch_query_plan() {
572        let router = setup_router();
573        let ids: Vec<String> = (0..10).map(|i| format!("vec-{}", i)).collect();
574        let plan = router.plan_batch_query(&ids);
575
576        // Should have some shards (depends on hashing)
577        assert!(plan.total_shards > 0);
578        assert!(plan.total_shards <= 4);
579        assert!(!plan.is_scatter);
580    }
581
582    #[test]
583    fn test_merge_results() {
584        let router = setup_router();
585
586        // Create mock shard results
587        let shard_results = vec![
588            ShardResult {
589                shard_id: 0,
590                served_by: "node-0".to_string(),
591                results: vec![("a", 0.9), ("b", 0.7)],
592                latency_ms: 10,
593                was_retry: false,
594            },
595            ShardResult {
596                shard_id: 1,
597                served_by: "node-1".to_string(),
598                results: vec![("c", 0.95), ("d", 0.6)],
599                latency_ms: 15,
600                was_retry: false,
601            },
602        ];
603
604        let merged = router.merge_similarity_results(shard_results, 3, |(_id, score)| *score);
605
606        assert_eq!(merged.results.len(), 3);
607        assert_eq!(merged.shards_queried, 2);
608        assert_eq!(merged.shards_succeeded, 2);
609
610        // Results should be sorted by score
611        assert_eq!(merged.results[0].0, "c"); // 0.95
612        assert_eq!(merged.results[1].0, "a"); // 0.9
613        assert_eq!(merged.results[2].0, "b"); // 0.7
614    }
615
616    #[test]
617    fn test_router_stats() {
618        let router = setup_router();
619        let stats = router.get_stats();
620
621        assert_eq!(stats.total_nodes, 4);
622        assert_eq!(stats.total_shards, 4);
623        assert!(stats.cluster_healthy);
624    }
625}