dakera-engine 0.10.2

Vector search engine for the Dakera AI memory platform
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
//! Query Routing for Distributed Dakera
//!
//! Handles routing of queries across shards:
//! - Scatter-gather for similarity search
//! - Load balancing across replicas
//! - Result merging and ranking
//! - Consistency-aware routing (Turbopuffer-inspired)

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};

use super::cluster::{ClusterCoordinator, NodeInfo};
use super::sharding::ShardManager;
use common::types::{ReadConsistency, StalenessConfig};

/// Configuration for query routing
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig {
    /// Strategy for selecting nodes
    pub strategy: RoutingStrategy,
    /// Maximum number of concurrent shard queries
    pub max_concurrent_shards: usize,
    /// Timeout for shard queries in milliseconds
    pub shard_timeout_ms: u64,
    /// Whether to retry failed shards
    pub retry_failed_shards: bool,
    /// Maximum retries per shard
    pub max_retries: u32,
}

impl Default for RouterConfig {
    fn default() -> Self {
        Self {
            strategy: RoutingStrategy::RoundRobin,
            max_concurrent_shards: 10,
            shard_timeout_ms: 5000,
            retry_failed_shards: true,
            max_retries: 2,
        }
    }
}

/// Strategy for selecting among replica nodes
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RoutingStrategy {
    /// Round-robin across healthy replicas
    RoundRobin,
    /// Least connections (load-based)
    LeastConnections,
    /// Random selection
    Random,
    /// Prefer local node if available
    PreferLocal,
    /// Route to primary only
    PrimaryOnly,
}

/// Plan for executing a distributed query
#[derive(Debug, Clone)]
pub struct QueryPlan {
    /// Shard to target node mapping
    pub shard_targets: HashMap<u32, NodeTarget>,
    /// Total number of shards to query
    pub total_shards: usize,
    /// Whether this is a scatter query (all shards)
    pub is_scatter: bool,
}

/// Target node for a shard query
#[derive(Debug, Clone)]
pub struct NodeTarget {
    /// Primary target node
    pub primary: NodeInfo,
    /// Fallback nodes for retry
    pub fallbacks: Vec<NodeInfo>,
    /// Shard ID
    pub shard_id: u32,
}

/// Result from a single shard
#[derive(Debug, Clone)]
pub struct ShardResult<T> {
    /// Shard ID
    pub shard_id: u32,
    /// Node that served the request
    pub served_by: String,
    /// Query results
    pub results: Vec<T>,
    /// Query latency in milliseconds
    pub latency_ms: u64,
    /// Whether this was a retry
    pub was_retry: bool,
}

/// Merged results from scatter-gather
#[derive(Debug, Clone)]
pub struct MergedResults<T> {
    /// Combined and ranked results
    pub results: Vec<T>,
    /// Number of shards queried
    pub shards_queried: usize,
    /// Number of shards that succeeded
    pub shards_succeeded: usize,
    /// Total latency in milliseconds
    pub total_latency_ms: u64,
    /// Per-shard latencies
    pub shard_latencies: HashMap<u32, u64>,
}

/// Query router for distributed operations
pub struct QueryRouter {
    /// Configuration
    config: RouterConfig,
    /// Shard manager for routing decisions
    shard_manager: ShardManager,
    /// Cluster coordinator for node health
    cluster: ClusterCoordinator,
    /// Round-robin counter
    rr_counter: AtomicUsize,
    /// Local node ID
    local_node_id: String,
}

impl QueryRouter {
    /// Create a new query router
    pub fn new(
        config: RouterConfig,
        shard_manager: ShardManager,
        cluster: ClusterCoordinator,
        local_node_id: String,
    ) -> Self {
        Self {
            config,
            shard_manager,
            cluster,
            rr_counter: AtomicUsize::new(0),
            local_node_id,
        }
    }

    /// Plan a point query (single vector lookup)
    pub fn plan_point_query(&self, vector_id: &str) -> QueryPlan {
        let assignment = self.shard_manager.get_shard(vector_id);
        let targets = self.get_node_targets(assignment.shard_id);

        let mut shard_targets = HashMap::new();
        shard_targets.insert(assignment.shard_id, targets);

        QueryPlan {
            shard_targets,
            total_shards: 1,
            is_scatter: false,
        }
    }

    /// Plan a scatter query (similarity search across all shards)
    pub fn plan_scatter_query(&self) -> QueryPlan {
        let shards = self.shard_manager.get_all_shards();
        let mut shard_targets = HashMap::new();

        for shard_id in &shards {
            let targets = self.get_node_targets(*shard_id);
            shard_targets.insert(*shard_id, targets);
        }

        QueryPlan {
            shard_targets,
            total_shards: shards.len(),
            is_scatter: true,
        }
    }

    /// Plan a batch query (multiple vectors)
    pub fn plan_batch_query(&self, vector_ids: &[String]) -> QueryPlan {
        let shard_batches = self.shard_manager.get_shards_batch(vector_ids);
        let mut shard_targets = HashMap::new();

        for shard_id in shard_batches.keys() {
            let targets = self.get_node_targets(*shard_id);
            shard_targets.insert(*shard_id, targets);
        }

        QueryPlan {
            shard_targets,
            total_shards: shard_batches.len(),
            is_scatter: false,
        }
    }

    /// Get target nodes for a shard based on routing strategy
    fn get_node_targets(&self, shard_id: u32) -> NodeTarget {
        let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);

        if healthy_nodes.is_empty() {
            // Return a placeholder for error handling
            return NodeTarget {
                primary: NodeInfo::new(
                    format!("unavailable-{}", shard_id),
                    "unavailable".to_string(),
                    super::cluster::NodeRole::Replica,
                ),
                fallbacks: Vec::new(),
                shard_id,
            };
        }

        let (primary, fallbacks) = match self.config.strategy {
            RoutingStrategy::RoundRobin => {
                let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
                let primary = healthy_nodes[idx].clone();
                let fallbacks: Vec<_> = healthy_nodes
                    .into_iter()
                    .enumerate()
                    .filter(|(i, _)| *i != idx)
                    .map(|(_, n)| n)
                    .collect();
                (primary, fallbacks)
            }
            RoutingStrategy::LeastConnections => {
                // Sort by active connections
                let mut sorted = healthy_nodes.clone();
                sorted.sort_by_key(|n| n.health.active_connections);
                let primary = sorted.remove(0);
                (primary, sorted)
            }
            RoutingStrategy::Random => {
                // Simple pseudo-random selection
                let idx = (std::time::SystemTime::now()
                    .duration_since(std::time::UNIX_EPOCH)
                    .unwrap_or_default()
                    .as_nanos() as usize)
                    % healthy_nodes.len();
                let primary = healthy_nodes[idx].clone();
                let fallbacks: Vec<_> = healthy_nodes
                    .into_iter()
                    .enumerate()
                    .filter(|(i, _)| *i != idx)
                    .map(|(_, n)| n)
                    .collect();
                (primary, fallbacks)
            }
            RoutingStrategy::PreferLocal => {
                // Check if local node serves this shard
                let local = healthy_nodes
                    .iter()
                    .find(|n| n.node_id == self.local_node_id);
                if let Some(local_node) = local {
                    let primary = local_node.clone();
                    let fallbacks: Vec<_> = healthy_nodes
                        .into_iter()
                        .filter(|n| n.node_id != self.local_node_id)
                        .collect();
                    (primary, fallbacks)
                } else {
                    // Fall back to round-robin
                    let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_nodes.len();
                    let primary = healthy_nodes[idx].clone();
                    let fallbacks: Vec<_> = healthy_nodes
                        .into_iter()
                        .enumerate()
                        .filter(|(i, _)| *i != idx)
                        .map(|(_, n)| n)
                        .collect();
                    (primary, fallbacks)
                }
            }
            RoutingStrategy::PrimaryOnly => {
                // Find primary node for this shard
                let primary_node = self.cluster.get_primary_for_shard(shard_id);
                if let Some(primary) = primary_node {
                    let fallbacks: Vec<_> = healthy_nodes
                        .into_iter()
                        .filter(|n| n.node_id != primary.node_id)
                        .collect();
                    (primary, fallbacks)
                } else {
                    // No primary available, use first healthy node
                    let primary = healthy_nodes[0].clone();
                    let fallbacks = healthy_nodes.into_iter().skip(1).collect();
                    (primary, fallbacks)
                }
            }
        };

        NodeTarget {
            primary,
            fallbacks,
            shard_id,
        }
    }

    /// Merge results from multiple shards for similarity search
    pub fn merge_similarity_results<T: Clone>(
        &self,
        shard_results: Vec<ShardResult<T>>,
        top_k: usize,
        score_fn: impl Fn(&T) -> f32,
    ) -> MergedResults<T> {
        let shards_queried = shard_results.len();
        let shards_succeeded = shard_results
            .iter()
            .filter(|r| !r.results.is_empty())
            .count();

        let mut shard_latencies = HashMap::new();
        let mut total_latency = 0u64;

        // Collect all results with scores
        let mut all_results: Vec<(T, f32)> = Vec::new();

        for shard_result in shard_results {
            shard_latencies.insert(shard_result.shard_id, shard_result.latency_ms);
            total_latency = total_latency.max(shard_result.latency_ms);

            for result in shard_result.results {
                let score = score_fn(&result);
                all_results.push((result, score));
            }
        }

        // Sort by score (descending for similarity)
        all_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        // Take top-k
        let results: Vec<T> = all_results
            .into_iter()
            .take(top_k)
            .map(|(r, _)| r)
            .collect();

        MergedResults {
            results,
            shards_queried,
            shards_succeeded,
            total_latency_ms: total_latency,
            shard_latencies,
        }
    }

    /// Get routing statistics
    pub fn get_stats(&self) -> RouterStats {
        let state = self.cluster.get_state();
        let partitions = self.shard_manager.get_partition_info();

        RouterStats {
            total_nodes: state.total_node_count,
            healthy_nodes: state.healthy_node_count,
            total_shards: partitions.len() as u32,
            healthy_shards: partitions.iter().filter(|p| p.is_healthy).count() as u32,
            cluster_healthy: state.is_healthy,
            has_quorum: state.has_quorum,
        }
    }

    // =========================================================================
    // Consistency-aware routing methods (Turbopuffer-inspired)
    // =========================================================================

    /// Plan a scatter query with consistency requirements
    pub fn plan_scatter_query_with_consistency(
        &self,
        consistency: ReadConsistency,
        staleness_config: Option<StalenessConfig>,
    ) -> QueryPlan {
        let shards = self.shard_manager.get_all_shards();
        let mut shard_targets = HashMap::new();

        for shard_id in &shards {
            let targets =
                self.get_node_targets_with_consistency(*shard_id, consistency, staleness_config);
            shard_targets.insert(*shard_id, targets);
        }

        QueryPlan {
            shard_targets,
            total_shards: shards.len(),
            is_scatter: true,
        }
    }

    /// Get target nodes for a shard based on consistency requirements
    fn get_node_targets_with_consistency(
        &self,
        shard_id: u32,
        consistency: ReadConsistency,
        staleness_config: Option<StalenessConfig>,
    ) -> NodeTarget {
        let healthy_nodes = self.cluster.get_healthy_nodes_for_shard(shard_id);

        if healthy_nodes.is_empty() {
            return NodeTarget {
                primary: NodeInfo::new(
                    format!("unavailable-{}", shard_id),
                    "unavailable".to_string(),
                    super::cluster::NodeRole::Replica,
                ),
                fallbacks: Vec::new(),
                shard_id,
            };
        }

        match consistency {
            ReadConsistency::Strong => {
                // Strong consistency: always route to primary
                self.get_primary_target(shard_id, healthy_nodes)
            }
            ReadConsistency::Eventual => {
                // Eventual: use default routing strategy for best latency
                self.get_node_targets(shard_id)
            }
            ReadConsistency::BoundedStaleness => {
                // Bounded staleness: select replicas within staleness window
                let max_staleness_ms = staleness_config.map(|c| c.max_staleness_ms).unwrap_or(5000);
                self.get_bounded_staleness_target(shard_id, healthy_nodes, max_staleness_ms)
            }
        }
    }

    /// Get primary node target for strong consistency
    fn get_primary_target(&self, shard_id: u32, healthy_nodes: Vec<NodeInfo>) -> NodeTarget {
        let primary_node = self.cluster.get_primary_for_shard(shard_id);
        if let Some(primary) = primary_node {
            let fallbacks: Vec<_> = healthy_nodes
                .into_iter()
                .filter(|n| n.node_id != primary.node_id)
                .collect();
            NodeTarget {
                primary,
                fallbacks,
                shard_id,
            }
        } else {
            // No primary available, use first healthy node
            let primary = healthy_nodes[0].clone();
            let fallbacks = healthy_nodes.into_iter().skip(1).collect();
            NodeTarget {
                primary,
                fallbacks,
                shard_id,
            }
        }
    }

    /// Get node target within staleness bounds
    fn get_bounded_staleness_target(
        &self,
        shard_id: u32,
        healthy_nodes: Vec<NodeInfo>,
        max_staleness_ms: u64,
    ) -> NodeTarget {
        // Filter nodes to those within staleness bounds
        // Nodes track their replication lag via health metrics
        let eligible_nodes: Vec<_> = healthy_nodes
            .iter()
            .filter(|n| {
                // Check replication lag is within bounds
                // If no lag info, assume it's acceptable for bounded reads
                n.health.replication_lag_ms.unwrap_or(0) <= max_staleness_ms
            })
            .cloned()
            .collect();

        if eligible_nodes.is_empty() {
            // Fall back to primary if no nodes meet staleness requirement
            return self.get_primary_target(shard_id, healthy_nodes);
        }

        // Among eligible nodes, use round-robin for load balancing
        let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % eligible_nodes.len();
        let primary = eligible_nodes[idx].clone();
        let fallbacks: Vec<_> = eligible_nodes
            .into_iter()
            .enumerate()
            .filter(|(i, _)| *i != idx)
            .map(|(_, n)| n)
            .collect();

        NodeTarget {
            primary,
            fallbacks,
            shard_id,
        }
    }

    /// Convert ReadConsistency to effective RoutingStrategy
    pub fn consistency_to_strategy(&self, consistency: ReadConsistency) -> RoutingStrategy {
        match consistency {
            ReadConsistency::Strong => RoutingStrategy::PrimaryOnly,
            ReadConsistency::Eventual => self.config.strategy,
            ReadConsistency::BoundedStaleness => RoutingStrategy::RoundRobin, // Among eligible replicas
        }
    }
}

/// Statistics about routing state
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterStats {
    /// Total number of nodes
    pub total_nodes: u32,
    /// Number of healthy nodes
    pub healthy_nodes: u32,
    /// Total number of shards
    pub total_shards: u32,
    /// Number of healthy shards
    pub healthy_shards: u32,
    /// Whether cluster is healthy
    pub cluster_healthy: bool,
    /// Whether cluster has quorum
    pub has_quorum: bool,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::distributed::cluster::{ClusterConfig, NodeRole};
    use crate::distributed::sharding::ShardingConfig;

    fn setup_router() -> QueryRouter {
        let shard_config = ShardingConfig {
            num_shards: 4,
            replication_factor: 2,
            ..Default::default()
        };
        let shard_manager = ShardManager::new(shard_config);

        let cluster_config = ClusterConfig::default();
        let cluster = ClusterCoordinator::new(cluster_config, "local".to_string());

        // Register some nodes
        for i in 0..4 {
            let mut node = NodeInfo::new(
                format!("node-{}", i),
                format!("localhost:{}", 8080 + i),
                if i == 0 {
                    NodeRole::Primary
                } else {
                    NodeRole::Replica
                },
            );
            node.shard_ids = vec![i as u32, (i + 1) as u32 % 4];
            node.health.status = super::super::cluster::NodeStatus::Healthy;
            cluster.register_node(node).unwrap();
        }

        let router_config = RouterConfig::default();
        QueryRouter::new(router_config, shard_manager, cluster, "local".to_string())
    }

    #[test]
    fn test_point_query_plan() {
        let router = setup_router();
        let plan = router.plan_point_query("test-vector-123");

        assert_eq!(plan.total_shards, 1);
        assert!(!plan.is_scatter);
        assert_eq!(plan.shard_targets.len(), 1);
    }

    #[test]
    fn test_scatter_query_plan() {
        let router = setup_router();
        let plan = router.plan_scatter_query();

        assert_eq!(plan.total_shards, 4);
        assert!(plan.is_scatter);
        assert_eq!(plan.shard_targets.len(), 4);
    }

    #[test]
    fn test_batch_query_plan() {
        let router = setup_router();
        let ids: Vec<String> = (0..10).map(|i| format!("vec-{}", i)).collect();
        let plan = router.plan_batch_query(&ids);

        // Should have some shards (depends on hashing)
        assert!(plan.total_shards > 0);
        assert!(plan.total_shards <= 4);
        assert!(!plan.is_scatter);
    }

    #[test]
    fn test_merge_results() {
        let router = setup_router();

        // Create mock shard results
        let shard_results = vec![
            ShardResult {
                shard_id: 0,
                served_by: "node-0".to_string(),
                results: vec![("a", 0.9), ("b", 0.7)],
                latency_ms: 10,
                was_retry: false,
            },
            ShardResult {
                shard_id: 1,
                served_by: "node-1".to_string(),
                results: vec![("c", 0.95), ("d", 0.6)],
                latency_ms: 15,
                was_retry: false,
            },
        ];

        let merged = router.merge_similarity_results(shard_results, 3, |(_id, score)| *score);

        assert_eq!(merged.results.len(), 3);
        assert_eq!(merged.shards_queried, 2);
        assert_eq!(merged.shards_succeeded, 2);

        // Results should be sorted by score
        assert_eq!(merged.results[0].0, "c"); // 0.95
        assert_eq!(merged.results[1].0, "a"); // 0.9
        assert_eq!(merged.results[2].0, "b"); // 0.7
    }

    #[test]
    fn test_router_stats() {
        let router = setup_router();
        let stats = router.get_stats();

        assert_eq!(stats.total_nodes, 4);
        assert_eq!(stats.total_shards, 4);
        assert!(stats.cluster_healthy);
    }
}