Skip to main content

query_router/
distributed.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Distributed query execution with semantic routing and scatter-gather.
3//!
4//! Routes queries to appropriate shards based on key or embedding similarity.
5//! Supports multi-shard scatter-gather for table scans and similarity searches.
6
7use std::sync::Arc;
8
9use relational_engine::{Row, Value};
10use serde::{Deserialize, Serialize};
11use tensor_store::{PartitionResult, Partitioner, SemanticPartitioner};
12
13use crate::{QueryResult, Result, SimilarResult};
14
15/// Shard identifier.
16pub type ShardId = usize;
17
18/// Query execution plan.
19#[derive(Debug, Clone)]
20pub enum QueryPlan {
21    /// Execute locally on this node.
22    Local { query: String },
23    /// Forward to a remote shard.
24    Remote { shard: ShardId, query: String },
25    /// Scatter to multiple shards and gather results.
26    ScatterGather {
27        shards: Vec<ShardId>,
28        query: String,
29        merge: MergeStrategy,
30    },
31}
32
33/// Strategy for merging results from multiple shards.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum MergeStrategy {
36    /// Union all results (for SELECT, NODE queries).
37    Union,
38    /// Keep top K by similarity score (for SIMILAR queries).
39    TopK(usize),
40    /// Aggregate results (SUM, COUNT, AVG).
41    Aggregate(AggregateFunction),
42    /// First non-empty result (for point lookups).
43    FirstNonEmpty,
44    /// Concatenate all results in order.
45    Concat,
46}
47
48/// Aggregate function for distributed aggregation.
49#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
50pub enum AggregateFunction {
51    /// Sum all values.
52    Sum,
53    /// Count all results.
54    Count,
55    /// Average (sum/count).
56    Avg,
57    /// Maximum value.
58    Max,
59    /// Minimum value.
60    Min,
61}
62
63/// Result from a single shard.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ShardResult {
66    /// Shard that produced this result.
67    pub shard: ShardId,
68    /// Query result from the shard.
69    pub result: QueryResult,
70    /// Execution time in microseconds.
71    pub execution_time_us: u64,
72    /// Whether this shard had any errors.
73    pub error: Option<String>,
74}
75
76impl ShardResult {
77    /// Create a successful shard result.
78    #[must_use]
79    pub const fn success(shard: ShardId, result: QueryResult, execution_time_us: u64) -> Self {
80        Self {
81            shard,
82            result,
83            execution_time_us,
84            error: None,
85        }
86    }
87
88    /// Create an error shard result.
89    #[must_use]
90    pub const fn error(shard: ShardId, error: String) -> Self {
91        Self {
92            shard,
93            result: QueryResult::Empty,
94            execution_time_us: 0,
95            error: Some(error),
96        }
97    }
98}
99
100/// Configuration for distributed query execution.
101#[derive(Debug, Clone)]
102pub struct DistributedQueryConfig {
103    /// Maximum concurrent shard queries.
104    pub max_concurrent: usize,
105    /// Query timeout per shard in milliseconds.
106    pub shard_timeout_ms: u64,
107    /// Retry count for failed shards.
108    pub retry_count: usize,
109    /// Whether to fail fast on first shard error.
110    pub fail_fast: bool,
111}
112
113impl Default for DistributedQueryConfig {
114    fn default() -> Self {
115        Self {
116            max_concurrent: 10,
117            shard_timeout_ms: 5000,
118            retry_count: 2,
119            fail_fast: false,
120        }
121    }
122}
123
124/// Query planner for distributed execution.
125#[derive(Debug)]
126pub struct QueryPlanner {
127    /// Partitioner for routing decisions.
128    partitioner: Arc<dyn Partitioner + Send + Sync>,
129    /// Semantic partitioner for embedding-based shard routing.
130    semantic_partitioner: Option<Arc<SemanticPartitioner>>,
131    /// Local shard ID (reserved for future optimizations like local-first execution).
132    #[allow(dead_code)]
133    local_shard: ShardId,
134}
135
136impl QueryPlanner {
137    /// Create a new query planner.
138    pub fn new(partitioner: Arc<dyn Partitioner + Send + Sync>, local_shard: ShardId) -> Self {
139        Self {
140            partitioner,
141            semantic_partitioner: None,
142            local_shard,
143        }
144    }
145
146    /// Set the semantic partitioner for embedding-based routing.
147    #[must_use]
148    pub fn with_semantic_partitioner(mut self, partitioner: Arc<SemanticPartitioner>) -> Self {
149        self.semantic_partitioner = Some(partitioner);
150        self
151    }
152
153    /// Plan query execution.
154    #[must_use]
155    pub fn plan(&self, query: &str) -> QueryPlan {
156        // Parse query to determine routing
157        let query_type = Self::classify_query(query);
158
159        match query_type {
160            QueryType::PointLookup { key } => {
161                let result = self.partitioner.partition(&key);
162                if result.is_local {
163                    QueryPlan::Local {
164                        query: query.to_string(),
165                    }
166                } else {
167                    QueryPlan::Remote {
168                        shard: self.shard_from_result(&result),
169                        query: query.to_string(),
170                    }
171                }
172            },
173            QueryType::SimilaritySearch { k } => {
174                // Scatter to all shards, merge top K
175                QueryPlan::ScatterGather {
176                    shards: self.all_shards(),
177                    query: query.to_string(),
178                    merge: MergeStrategy::TopK(k),
179                }
180            },
181            QueryType::TableScan => {
182                // Scatter to all shards, union results
183                QueryPlan::ScatterGather {
184                    shards: self.all_shards(),
185                    query: query.to_string(),
186                    merge: MergeStrategy::Union,
187                }
188            },
189            QueryType::Aggregate { func } => {
190                // Scatter to all shards, aggregate
191                QueryPlan::ScatterGather {
192                    shards: self.all_shards(),
193                    query: query.to_string(),
194                    merge: MergeStrategy::Aggregate(func),
195                }
196            },
197            QueryType::Unknown => {
198                // Default to local execution
199                QueryPlan::Local {
200                    query: query.to_string(),
201                }
202            },
203        }
204    }
205
206    /// Plan query with explicit embedding for semantic routing.
207    #[must_use]
208    pub fn plan_with_embedding(&self, query: &str, embedding: &[f32]) -> QueryPlan {
209        // Get semantically relevant shards
210        let relevant_shards = self.shards_for_embedding(embedding);
211
212        if relevant_shards.is_empty() {
213            // Fall back to all shards
214            return self.plan(query);
215        }
216
217        let query_type = Self::classify_query(query);
218
219        match query_type {
220            QueryType::SimilaritySearch { k } => QueryPlan::ScatterGather {
221                shards: relevant_shards,
222                query: query.to_string(),
223                merge: MergeStrategy::TopK(k),
224            },
225            _ => self.plan(query),
226        }
227    }
228
229    /// Get all shards in the cluster.
230    fn all_shards(&self) -> Vec<ShardId> {
231        let nodes = self.partitioner.nodes();
232        (0..nodes.len()).collect()
233    }
234
235    /// Convert partition result to shard ID.
236    fn shard_from_result(&self, result: &PartitionResult) -> ShardId {
237        let nodes = self.partitioner.nodes();
238        nodes.iter().position(|n| *n == result.primary).unwrap_or(0)
239    }
240
241    /// Get shards relevant to an embedding using semantic routing.
242    fn shards_for_embedding(&self, embedding: &[f32]) -> Vec<ShardId> {
243        if let Some(sp) = &self.semantic_partitioner {
244            let results = sp.shards_for_embedding(embedding);
245            if !results.is_empty() {
246                return results.into_iter().map(|(shard, _score)| shard).collect();
247            }
248        }
249        // Fall back to all shards if no semantic partitioner or no results
250        self.all_shards()
251    }
252
253    /// Classify query type for routing.
254    fn classify_query(query: &str) -> QueryType {
255        let query_upper = query.to_uppercase();
256        let query_trimmed = query_upper.trim();
257
258        // Point lookups
259        if query_trimmed.starts_with("GET ")
260            || query_trimmed.starts_with("NODE GET ")
261            || query_trimmed.starts_with("ENTITY GET ")
262        {
263            // Extract key from query
264            if let Some(key) = Self::extract_key(query) {
265                return QueryType::PointLookup { key };
266            }
267        }
268
269        // Similarity search
270        if query_trimmed.starts_with("SIMILAR ") {
271            let k = Self::extract_top_k(query).unwrap_or(10);
272            return QueryType::SimilaritySearch { k };
273        }
274
275        // Table scans
276        if query_trimmed.starts_with("SELECT ") || query_trimmed.starts_with("NODE LIST") {
277            // Check for aggregates
278            if query_trimmed.contains("COUNT(") {
279                return QueryType::Aggregate {
280                    func: AggregateFunction::Count,
281                };
282            }
283            if query_trimmed.contains("SUM(") {
284                return QueryType::Aggregate {
285                    func: AggregateFunction::Sum,
286                };
287            }
288            if query_trimmed.contains("AVG(") {
289                return QueryType::Aggregate {
290                    func: AggregateFunction::Avg,
291                };
292            }
293            return QueryType::TableScan;
294        }
295
296        QueryType::Unknown
297    }
298
299    /// Extract key from a point lookup query.
300    fn extract_key(query: &str) -> Option<String> {
301        let parts: Vec<&str> = query.split_whitespace().collect();
302        if parts.len() >= 2 {
303            // Handle "GET key", "NODE GET key", etc.
304            for (i, part) in parts.iter().enumerate() {
305                if part.eq_ignore_ascii_case("GET") && i + 1 < parts.len() {
306                    return Some(parts[i + 1].to_string());
307                }
308            }
309        }
310        None
311    }
312
313    /// Extract TOP K value from query.
314    fn extract_top_k(query: &str) -> Option<usize> {
315        let query_upper = query.to_uppercase();
316        if let Some(pos) = query_upper.find("TOP ") {
317            let rest = &query_upper[pos + 4..];
318            let num_str: String = rest.chars().take_while(char::is_ascii_digit).collect();
319            return num_str.parse().ok();
320        }
321        None
322    }
323}
324
325/// Query type classification.
326#[derive(Debug)]
327enum QueryType {
328    /// Single key lookup.
329    PointLookup { key: String },
330    /// Similarity search with top K.
331    SimilaritySearch { k: usize },
332    /// Full table/entity scan.
333    TableScan,
334    /// Aggregate query.
335    Aggregate { func: AggregateFunction },
336    /// Unknown query type.
337    Unknown,
338}
339
340/// Merger for combining results from multiple shards.
341#[derive(Debug)]
342pub struct ResultMerger;
343
344impl ResultMerger {
345    /// Merge shard results using the specified strategy.
346    ///
347    /// # Errors
348    ///
349    /// This function currently never returns an error, but returns `Result` for
350    /// forward compatibility with future merge strategies that may fail.
351    pub fn merge(results: Vec<ShardResult>, strategy: &MergeStrategy) -> Result<QueryResult> {
352        // Filter out errors if not fail-fast
353        let successful: Vec<_> = results.into_iter().filter(|r| r.error.is_none()).collect();
354
355        if successful.is_empty() {
356            return Ok(QueryResult::Empty);
357        }
358
359        Ok(match strategy {
360            MergeStrategy::Union => Self::merge_union(successful),
361            MergeStrategy::TopK(k) => Self::merge_top_k(successful, *k),
362            MergeStrategy::Aggregate(func) => Self::merge_aggregate(successful, *func),
363            MergeStrategy::FirstNonEmpty => Self::merge_first_non_empty(successful),
364            MergeStrategy::Concat => Self::merge_concat(successful),
365        })
366    }
367
368    /// Merge results using union (combine all).
369    fn merge_union(results: Vec<ShardResult>) -> QueryResult {
370        let mut all_rows = Vec::new();
371        let mut all_nodes = Vec::new();
372        let mut all_edges = Vec::new();
373        let mut all_similar = Vec::new();
374
375        for shard_result in results {
376            match shard_result.result {
377                QueryResult::Rows(rows) => all_rows.extend(rows),
378                QueryResult::Nodes(nodes) => all_nodes.extend(nodes),
379                QueryResult::Edges(edges) => all_edges.extend(edges),
380                QueryResult::Similar(similar) => all_similar.extend(similar),
381                QueryResult::Count(n) => {
382                    // Safety: usize to i64 wraps on 64-bit if n > i64::MAX, but count
383                    // values are expected to be within reasonable bounds
384                    #[allow(clippy::cast_possible_wrap)]
385                    let count_val = n as i64;
386                    all_rows.push(Row {
387                        id: 0,
388                        values: vec![("count".to_string(), Value::Int(count_val))],
389                    });
390                },
391                _ => {},
392            }
393        }
394
395        // Return appropriate type based on what we collected
396        if !all_similar.is_empty() {
397            return QueryResult::Similar(all_similar);
398        }
399        if !all_nodes.is_empty() {
400            return QueryResult::Nodes(all_nodes);
401        }
402        if !all_edges.is_empty() {
403            return QueryResult::Edges(all_edges);
404        }
405        if !all_rows.is_empty() {
406            return QueryResult::Rows(all_rows);
407        }
408
409        QueryResult::Empty
410    }
411
412    /// Merge similarity results keeping top K.
413    fn merge_top_k(results: Vec<ShardResult>, k: usize) -> QueryResult {
414        let mut all_similar: Vec<SimilarResult> = Vec::new();
415
416        for shard_result in results {
417            if let QueryResult::Similar(similar) = shard_result.result {
418                all_similar.extend(similar);
419            }
420        }
421
422        // Sort by score descending
423        all_similar.sort_by(|a, b| {
424            b.score
425                .partial_cmp(&a.score)
426                .unwrap_or(std::cmp::Ordering::Equal)
427        });
428
429        // Take top K
430        all_similar.truncate(k);
431
432        QueryResult::Similar(all_similar)
433    }
434
435    /// Merge using aggregate function.
436    fn merge_aggregate(results: Vec<ShardResult>, func: AggregateFunction) -> QueryResult {
437        let mut values: Vec<i64> = Vec::new();
438
439        for shard_result in results {
440            match shard_result.result {
441                QueryResult::Count(n) => {
442                    // Safety: usize to i64 wraps on 64-bit if n > i64::MAX, but count
443                    // values are expected to be within reasonable bounds
444                    #[allow(clippy::cast_possible_wrap)]
445                    let count_val = n as i64;
446                    values.push(count_val);
447                },
448                QueryResult::Value(s) => {
449                    if let Ok(n) = s.parse::<i64>() {
450                        values.push(n);
451                    }
452                },
453                _ => {},
454            }
455        }
456
457        if values.is_empty() {
458            return QueryResult::Count(0);
459        }
460
461        // Safety: i64 to usize casts below may truncate on 32-bit systems or lose sign,
462        // but aggregate results are expected to be non-negative and within usize range.
463        // The len() to i64 cast may wrap if len > i64::MAX, but this is unrealistic.
464        #[allow(
465            clippy::cast_possible_truncation,
466            clippy::cast_sign_loss,
467            clippy::cast_possible_wrap
468        )]
469        let result = match func {
470            AggregateFunction::Sum | AggregateFunction::Count => {
471                values.iter().sum::<i64>() as usize
472            },
473            AggregateFunction::Max => *values.iter().max().unwrap_or(&0) as usize,
474            AggregateFunction::Min => *values.iter().min().unwrap_or(&0) as usize,
475            AggregateFunction::Avg => (values.iter().sum::<i64>() / (values.len() as i64)) as usize,
476        };
477
478        QueryResult::Count(result)
479    }
480
481    /// Return first non-empty result.
482    fn merge_first_non_empty(results: Vec<ShardResult>) -> QueryResult {
483        for shard_result in results {
484            if !matches!(&shard_result.result, QueryResult::Empty) {
485                return shard_result.result;
486            }
487        }
488        QueryResult::Empty
489    }
490
491    /// Concatenate all results in order.
492    fn merge_concat(results: Vec<ShardResult>) -> QueryResult {
493        // Same as union for most types
494        Self::merge_union(results)
495    }
496}
497
498/// Statistics for distributed query execution.
499#[derive(Debug, Clone, Default)]
500pub struct DistributedQueryStats {
501    /// Total queries executed.
502    pub queries_executed: u64,
503    /// Local queries (no distribution needed).
504    pub local_queries: u64,
505    /// Remote single-shard queries.
506    pub remote_queries: u64,
507    /// Scatter-gather queries.
508    pub scatter_gather_queries: u64,
509    /// Total shards contacted.
510    pub shards_contacted: u64,
511    /// Average latency in microseconds.
512    pub avg_latency_us: u64,
513    /// Shard errors encountered.
514    pub shard_errors: u64,
515}
516
517impl DistributedQueryStats {
518    /// Record a query execution.
519    pub const fn record_query(&mut self, plan: &QueryPlan, latency_us: u64, errors: usize) {
520        self.queries_executed += 1;
521
522        match plan {
523            QueryPlan::Local { .. } => {
524                self.local_queries += 1;
525                self.shards_contacted += 1;
526            },
527            QueryPlan::Remote { .. } => {
528                self.remote_queries += 1;
529                self.shards_contacted += 1;
530            },
531            QueryPlan::ScatterGather { shards, .. } => {
532                self.scatter_gather_queries += 1;
533                self.shards_contacted += shards.len() as u64;
534            },
535        }
536
537        self.shard_errors += errors as u64;
538
539        // Update average latency
540        if self.queries_executed == 1 {
541            self.avg_latency_us = latency_us;
542        } else {
543            self.avg_latency_us = (self.avg_latency_us * (self.queries_executed - 1) + latency_us)
544                / self.queries_executed;
545        }
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use tensor_store::{ConsistentHashConfig, ConsistentHashPartitioner};
552
553    use super::*;
554
555    fn create_test_partitioner() -> Arc<dyn Partitioner + Send + Sync> {
556        let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
557        let mut partitioner = ConsistentHashPartitioner::new(config);
558        partitioner.add_node("node1".to_string());
559        partitioner.add_node("node2".to_string());
560        partitioner.add_node("node3".to_string());
561        Arc::new(partitioner)
562    }
563
564    #[test]
565    fn test_query_plan_local() {
566        let partitioner = create_test_partitioner();
567        let planner = QueryPlanner::new(partitioner, 0);
568
569        let plan = planner.plan("GET some_key");
570        assert!(
571            matches!(plan, QueryPlan::Local { .. } | QueryPlan::Remote { .. }),
572            "Expected Local or Remote plan"
573        );
574    }
575
576    #[test]
577    fn test_query_plan_scatter_gather() {
578        let partitioner = create_test_partitioner();
579        let planner = QueryPlanner::new(partitioner, 0);
580
581        let plan = planner.plan("SELECT users");
582        assert!(
583            matches!(
584                plan,
585                QueryPlan::ScatterGather {
586                    merge: MergeStrategy::Union,
587                    ..
588                }
589            ),
590            "Expected ScatterGather with Union merge"
591        );
592    }
593
594    #[test]
595    fn test_query_plan_similar() {
596        let partitioner = create_test_partitioner();
597        let planner = QueryPlanner::new(partitioner, 0);
598
599        let plan = planner.plan("SIMILAR key TOP 5");
600        assert!(
601            matches!(
602                plan,
603                QueryPlan::ScatterGather {
604                    merge: MergeStrategy::TopK(5),
605                    ..
606                }
607            ),
608            "Expected ScatterGather with TopK(5) merge"
609        );
610    }
611
612    #[test]
613    fn test_query_plan_aggregate() {
614        let partitioner = create_test_partitioner();
615        let planner = QueryPlanner::new(partitioner, 0);
616
617        let plan = planner.plan("SELECT COUNT(*) FROM users");
618        assert!(
619            matches!(
620                plan,
621                QueryPlan::ScatterGather {
622                    merge: MergeStrategy::Aggregate(AggregateFunction::Count),
623                    ..
624                }
625            ),
626            "Expected ScatterGather with Count aggregate"
627        );
628    }
629
630    #[test]
631    fn test_merge_union() {
632        let results = vec![
633            ShardResult::success(0, QueryResult::Count(10), 100),
634            ShardResult::success(1, QueryResult::Count(20), 150),
635        ];
636
637        let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
638        let QueryResult::Rows(rows) = merged else {
639            panic!("Expected Rows result");
640        };
641        assert_eq!(rows.len(), 2);
642    }
643
644    #[test]
645    fn test_merge_top_k() {
646        let results = vec![
647            ShardResult::success(
648                0,
649                QueryResult::Similar(vec![
650                    SimilarResult {
651                        key: "a".to_string(),
652                        score: 0.9,
653                    },
654                    SimilarResult {
655                        key: "b".to_string(),
656                        score: 0.8,
657                    },
658                ]),
659                100,
660            ),
661            ShardResult::success(
662                1,
663                QueryResult::Similar(vec![SimilarResult {
664                    key: "c".to_string(),
665                    score: 0.95,
666                }]),
667                150,
668            ),
669        ];
670
671        let merged = ResultMerger::merge(results, &MergeStrategy::TopK(2)).unwrap();
672        match merged {
673            QueryResult::Similar(similar) => {
674                assert_eq!(similar.len(), 2);
675                assert_eq!(similar[0].key, "c"); // Highest score
676                assert_eq!(similar[1].key, "a");
677            },
678            _ => panic!("Expected Similar result"),
679        }
680    }
681
682    #[test]
683    fn test_merge_aggregate_sum() {
684        let results = vec![
685            ShardResult::success(0, QueryResult::Count(10), 100),
686            ShardResult::success(1, QueryResult::Count(20), 150),
687            ShardResult::success(2, QueryResult::Count(30), 200),
688        ];
689
690        let merged =
691            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
692                .unwrap();
693        match merged {
694            QueryResult::Count(n) => assert_eq!(n, 60),
695            _ => panic!("Expected Count result"),
696        }
697    }
698
699    #[test]
700    fn test_merge_aggregate_avg() {
701        let results = vec![
702            ShardResult::success(0, QueryResult::Count(10), 100),
703            ShardResult::success(1, QueryResult::Count(20), 150),
704            ShardResult::success(2, QueryResult::Count(30), 200),
705        ];
706
707        let merged =
708            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Avg))
709                .unwrap();
710        match merged {
711            QueryResult::Count(n) => assert_eq!(n, 20),
712            _ => panic!("Expected Count result"),
713        }
714    }
715
716    #[test]
717    fn test_merge_first_non_empty() {
718        let results = vec![
719            ShardResult::success(0, QueryResult::Empty, 100),
720            ShardResult::success(1, QueryResult::Value("found".to_string()), 150),
721            ShardResult::success(2, QueryResult::Value("also_found".to_string()), 200),
722        ];
723
724        let merged = ResultMerger::merge(results, &MergeStrategy::FirstNonEmpty).unwrap();
725        match merged {
726            QueryResult::Value(s) => assert_eq!(s, "found"),
727            _ => panic!("Expected Value result"),
728        }
729    }
730
731    #[test]
732    fn test_shard_result_success() {
733        let result = ShardResult::success(0, QueryResult::Count(10), 100);
734        assert_eq!(result.shard, 0);
735        assert!(result.error.is_none());
736        assert_eq!(result.execution_time_us, 100);
737    }
738
739    #[test]
740    fn test_shard_result_error() {
741        let result = ShardResult::error(1, "timeout".to_string());
742        assert_eq!(result.shard, 1);
743        assert!(result.error.is_some());
744        assert_eq!(result.error.unwrap(), "timeout");
745    }
746
747    #[test]
748    fn test_config_default() {
749        let config = DistributedQueryConfig::default();
750        assert_eq!(config.max_concurrent, 10);
751        assert_eq!(config.shard_timeout_ms, 5000);
752        assert_eq!(config.retry_count, 2);
753        assert!(!config.fail_fast);
754    }
755
756    #[test]
757    fn test_stats_record_local() {
758        let mut stats = DistributedQueryStats::default();
759        let plan = QueryPlan::Local {
760            query: "GET key".to_string(),
761        };
762
763        stats.record_query(&plan, 100, 0);
764
765        assert_eq!(stats.queries_executed, 1);
766        assert_eq!(stats.local_queries, 1);
767        assert_eq!(stats.shards_contacted, 1);
768        assert_eq!(stats.avg_latency_us, 100);
769    }
770
771    #[test]
772    fn test_stats_record_scatter_gather() {
773        let mut stats = DistributedQueryStats::default();
774        let plan = QueryPlan::ScatterGather {
775            shards: vec![0, 1, 2],
776            query: "SELECT users".to_string(),
777            merge: MergeStrategy::Union,
778        };
779
780        stats.record_query(&plan, 500, 1);
781
782        assert_eq!(stats.queries_executed, 1);
783        assert_eq!(stats.scatter_gather_queries, 1);
784        assert_eq!(stats.shards_contacted, 3);
785        assert_eq!(stats.shard_errors, 1);
786    }
787
788    #[test]
789    fn test_merge_empty_results() {
790        let results: Vec<ShardResult> = vec![];
791        let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
792        assert!(matches!(merged, QueryResult::Empty));
793    }
794
795    #[test]
796    fn test_merge_filters_errors() {
797        let results = vec![
798            ShardResult::success(0, QueryResult::Count(10), 100),
799            ShardResult::error(1, "timeout".to_string()),
800        ];
801
802        let merged =
803            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
804                .unwrap();
805        match merged {
806            QueryResult::Count(n) => assert_eq!(n, 10), // Only successful shard
807            _ => panic!("Expected Count result"),
808        }
809    }
810
811    #[test]
812    fn test_planner_extract_key() {
813        // Test various GET formats
814        assert_eq!(
815            QueryPlanner::extract_key("GET mykey"),
816            Some("mykey".to_string())
817        );
818        assert_eq!(
819            QueryPlanner::extract_key("NODE GET user:123"),
820            Some("user:123".to_string())
821        );
822    }
823
824    #[test]
825    fn test_planner_extract_top_k() {
826        assert_eq!(QueryPlanner::extract_top_k("SIMILAR key TOP 5"), Some(5));
827        assert_eq!(
828            QueryPlanner::extract_top_k("SIMILAR key TOP 100"),
829            Some(100)
830        );
831        assert_eq!(QueryPlanner::extract_top_k("SIMILAR key"), None);
832    }
833
834    #[test]
835    fn test_aggregate_function_equality() {
836        assert_eq!(AggregateFunction::Sum, AggregateFunction::Sum);
837        assert_ne!(AggregateFunction::Sum, AggregateFunction::Count);
838    }
839
840    #[test]
841    fn test_all_shards() {
842        let partitioner = create_test_partitioner();
843        let planner = QueryPlanner::new(partitioner, 0);
844
845        let shards = planner.all_shards();
846        assert_eq!(shards.len(), 3);
847        assert_eq!(shards, vec![0, 1, 2]);
848    }
849
850    #[test]
851    fn test_plan_with_embedding() {
852        let partitioner = create_test_partitioner();
853        let planner = QueryPlanner::new(partitioner, 0);
854
855        let embedding = vec![1.0, 0.0, 0.0, 0.0];
856        let plan = planner.plan_with_embedding("SIMILAR key TOP 10", &embedding);
857
858        match plan {
859            QueryPlan::ScatterGather { .. } => {},
860            _ => panic!("Expected ScatterGather plan"),
861        }
862    }
863
864    #[test]
865    fn test_merge_max() {
866        let results = vec![
867            ShardResult::success(0, QueryResult::Count(10), 100),
868            ShardResult::success(1, QueryResult::Count(50), 150),
869            ShardResult::success(2, QueryResult::Count(30), 200),
870        ];
871
872        let merged =
873            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Max))
874                .unwrap();
875        match merged {
876            QueryResult::Count(n) => assert_eq!(n, 50),
877            _ => panic!("Expected Count result"),
878        }
879    }
880
881    #[test]
882    fn test_merge_min() {
883        let results = vec![
884            ShardResult::success(0, QueryResult::Count(10), 100),
885            ShardResult::success(1, QueryResult::Count(50), 150),
886            ShardResult::success(2, QueryResult::Count(30), 200),
887        ];
888
889        let merged =
890            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Min))
891                .unwrap();
892        match merged {
893            QueryResult::Count(n) => assert_eq!(n, 10),
894            _ => panic!("Expected Count result"),
895        }
896    }
897
898    #[test]
899    fn test_stats_avg_latency_updates() {
900        let mut stats = DistributedQueryStats::default();
901        let plan = QueryPlan::Local {
902            query: "GET key".to_string(),
903        };
904
905        stats.record_query(&plan, 100, 0);
906        assert_eq!(stats.avg_latency_us, 100);
907
908        stats.record_query(&plan, 200, 0);
909        assert_eq!(stats.avg_latency_us, 150);
910    }
911
912    #[test]
913    fn test_merge_concat() {
914        let results = vec![
915            ShardResult::success(0, QueryResult::Count(10), 100),
916            ShardResult::success(1, QueryResult::Count(20), 150),
917        ];
918
919        let merged = ResultMerger::merge(results, &MergeStrategy::Concat).unwrap();
920        match merged {
921            QueryResult::Rows(rows) => assert_eq!(rows.len(), 2),
922            _ => panic!("Expected Rows result"),
923        }
924    }
925
926    #[test]
927    fn test_merge_union_nodes() {
928        use crate::NodeResult;
929
930        let results = vec![
931            ShardResult::success(
932                0,
933                QueryResult::Nodes(vec![
934                    NodeResult {
935                        id: 1,
936                        label: "Person".to_string(),
937                        properties: std::collections::HashMap::new(),
938                    },
939                    NodeResult {
940                        id: 2,
941                        label: "Person".to_string(),
942                        properties: std::collections::HashMap::new(),
943                    },
944                ]),
945                100,
946            ),
947            ShardResult::success(
948                1,
949                QueryResult::Nodes(vec![NodeResult {
950                    id: 3,
951                    label: "Person".to_string(),
952                    properties: std::collections::HashMap::new(),
953                }]),
954                150,
955            ),
956        ];
957
958        let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
959        match merged {
960            QueryResult::Nodes(nodes) => assert_eq!(nodes.len(), 3),
961            _ => panic!("Expected Nodes result"),
962        }
963    }
964
965    #[test]
966    fn test_merge_union_edges() {
967        use crate::EdgeResult;
968
969        let results = vec![
970            ShardResult::success(
971                0,
972                QueryResult::Edges(vec![EdgeResult {
973                    id: 1,
974                    from: 1,
975                    to: 2,
976                    label: "KNOWS".to_string(),
977                }]),
978                100,
979            ),
980            ShardResult::success(
981                1,
982                QueryResult::Edges(vec![EdgeResult {
983                    id: 2,
984                    from: 2,
985                    to: 3,
986                    label: "KNOWS".to_string(),
987                }]),
988                150,
989            ),
990        ];
991
992        let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
993        match merged {
994            QueryResult::Edges(edges) => assert_eq!(edges.len(), 2),
995            _ => panic!("Expected Edges result"),
996        }
997    }
998
999    #[test]
1000    fn test_merge_union_empty_all() {
1001        let results = vec![
1002            ShardResult::success(0, QueryResult::Empty, 100),
1003            ShardResult::success(1, QueryResult::Empty, 150),
1004        ];
1005
1006        let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
1007        assert!(matches!(merged, QueryResult::Empty));
1008    }
1009
1010    #[test]
1011    fn test_merge_first_non_empty_all_empty() {
1012        let results = vec![
1013            ShardResult::success(0, QueryResult::Empty, 100),
1014            ShardResult::success(1, QueryResult::Empty, 150),
1015        ];
1016
1017        let merged = ResultMerger::merge(results, &MergeStrategy::FirstNonEmpty).unwrap();
1018        assert!(matches!(merged, QueryResult::Empty));
1019    }
1020
1021    #[test]
1022    fn test_merge_aggregate_value_strings() {
1023        let results = vec![
1024            ShardResult::success(0, QueryResult::Value("100".to_string()), 100),
1025            ShardResult::success(1, QueryResult::Value("200".to_string()), 150),
1026        ];
1027
1028        let merged =
1029            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
1030                .unwrap();
1031        match merged {
1032            QueryResult::Count(n) => assert_eq!(n, 300),
1033            _ => panic!("Expected Count result"),
1034        }
1035    }
1036
1037    #[test]
1038    fn test_merge_aggregate_empty_values() {
1039        let results = vec![
1040            ShardResult::success(0, QueryResult::Empty, 100),
1041            ShardResult::success(1, QueryResult::Empty, 150),
1042        ];
1043
1044        let merged =
1045            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
1046                .unwrap();
1047        match merged {
1048            QueryResult::Count(n) => assert_eq!(n, 0),
1049            _ => panic!("Expected Count result"),
1050        }
1051    }
1052
1053    #[test]
1054    fn test_query_plan_node_list() {
1055        let partitioner = create_test_partitioner();
1056        let planner = QueryPlanner::new(partitioner, 0);
1057
1058        let plan = planner.plan("NODE LIST users");
1059        match plan {
1060            QueryPlan::ScatterGather {
1061                merge: MergeStrategy::Union,
1062                ..
1063            } => {},
1064            _ => panic!("Expected ScatterGather with Union merge"),
1065        }
1066    }
1067
1068    #[test]
1069    fn test_query_plan_select_sum() {
1070        let partitioner = create_test_partitioner();
1071        let planner = QueryPlanner::new(partitioner, 0);
1072
1073        let plan = planner.plan("SELECT SUM(amount) FROM orders");
1074        match plan {
1075            QueryPlan::ScatterGather {
1076                merge: MergeStrategy::Aggregate(AggregateFunction::Sum),
1077                ..
1078            } => {},
1079            _ => panic!("Expected ScatterGather with Sum aggregate"),
1080        }
1081    }
1082
1083    #[test]
1084    fn test_query_plan_select_avg() {
1085        let partitioner = create_test_partitioner();
1086        let planner = QueryPlanner::new(partitioner, 0);
1087
1088        let plan = planner.plan("SELECT AVG(price) FROM products");
1089        match plan {
1090            QueryPlan::ScatterGather {
1091                merge: MergeStrategy::Aggregate(AggregateFunction::Avg),
1092                ..
1093            } => {},
1094            _ => panic!("Expected ScatterGather with Avg aggregate"),
1095        }
1096    }
1097
1098    #[test]
1099    fn test_query_plan_unknown() {
1100        let partitioner = create_test_partitioner();
1101        let planner = QueryPlanner::new(partitioner, 0);
1102
1103        // Unknown query type should default to local
1104        let plan = planner.plan("FOOBAR something");
1105        match plan {
1106            QueryPlan::Local { .. } => {},
1107            _ => panic!("Expected Local plan for unknown query"),
1108        }
1109    }
1110
1111    #[test]
1112    fn test_plan_with_embedding_non_similar() {
1113        let partitioner = create_test_partitioner();
1114        let planner = QueryPlanner::new(partitioner, 0);
1115
1116        let embedding = vec![1.0, 0.0, 0.0, 0.0];
1117        // Non-similarity query with embedding should fall back to plan()
1118        let plan = planner.plan_with_embedding("SELECT * FROM users", &embedding);
1119
1120        match plan {
1121            QueryPlan::ScatterGather { .. } => {},
1122            _ => panic!("Expected ScatterGather plan"),
1123        }
1124    }
1125
1126    #[test]
1127    fn test_extract_key_no_get() {
1128        // Query without GET keyword
1129        assert!(QueryPlanner::extract_key("something else").is_none());
1130    }
1131
1132    #[test]
1133    fn test_extract_key_empty() {
1134        // Empty query
1135        assert!(QueryPlanner::extract_key("").is_none());
1136    }
1137
1138    #[test]
1139    fn test_query_plan_node_get() {
1140        let partitioner = create_test_partitioner();
1141        let planner = QueryPlanner::new(partitioner, 0);
1142
1143        let plan = planner.plan("NODE GET user:123");
1144        match plan {
1145            QueryPlan::Local { .. } | QueryPlan::Remote { .. } => {},
1146            _ => panic!("Expected Local or Remote plan"),
1147        }
1148    }
1149
1150    #[test]
1151    fn test_query_plan_entity_get() {
1152        let partitioner = create_test_partitioner();
1153        let planner = QueryPlanner::new(partitioner, 0);
1154
1155        let plan = planner.plan("ENTITY GET entity:456");
1156        match plan {
1157            QueryPlan::Local { .. } | QueryPlan::Remote { .. } => {},
1158            _ => panic!("Expected Local or Remote plan"),
1159        }
1160    }
1161
1162    #[test]
1163    fn test_merge_top_k_non_similar_results() {
1164        // TopK merge with non-similar results should handle gracefully
1165        let results = vec![
1166            ShardResult::success(0, QueryResult::Empty, 100),
1167            ShardResult::success(1, QueryResult::Count(10), 150),
1168        ];
1169
1170        let merged = ResultMerger::merge(results, &MergeStrategy::TopK(5)).unwrap();
1171        match merged {
1172            QueryResult::Similar(similar) => assert!(similar.is_empty()),
1173            _ => panic!("Expected Similar result"),
1174        }
1175    }
1176
1177    #[test]
1178    fn test_merge_aggregate_avg_empty() {
1179        // Edge case: empty values in avg should return 0
1180        let results = vec![ShardResult::success(0, QueryResult::Rows(vec![]), 100)];
1181
1182        let merged =
1183            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Avg))
1184                .unwrap();
1185        match merged {
1186            QueryResult::Count(n) => assert_eq!(n, 0),
1187            _ => panic!("Expected Count result"),
1188        }
1189    }
1190
1191    #[test]
1192    fn test_query_plan_get_only_no_key() {
1193        let partitioner = create_test_partitioner();
1194        let planner = QueryPlanner::new(partitioner, 0);
1195
1196        // "GET" without a key should fall through to Unknown -> Local
1197        let plan = planner.plan("GET");
1198        match plan {
1199            QueryPlan::Local { .. } => {},
1200            _ => panic!("Expected Local plan for GET without key"),
1201        }
1202    }
1203
1204    #[test]
1205    fn test_query_plan_node_get_only() {
1206        let partitioner = create_test_partitioner();
1207        let planner = QueryPlanner::new(partitioner, 0);
1208
1209        // "NODE GET" without a key should fall through to Unknown -> Local
1210        let plan = planner.plan("NODE GET");
1211        match plan {
1212            QueryPlan::Local { .. } => {},
1213            _ => panic!("Expected Local plan for NODE GET without key"),
1214        }
1215    }
1216
1217    #[test]
1218    fn test_merge_union_other_result_types() {
1219        // Test with other QueryResult types that fall through to empty handling
1220        let results = vec![
1221            ShardResult::success(0, QueryResult::Path(vec![1, 2, 3]), 100),
1222            ShardResult::success(1, QueryResult::Value("test".to_string()), 150),
1223        ];
1224
1225        let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
1226        // Non-row, non-node, non-edge, non-similar types fall through
1227        assert!(matches!(merged, QueryResult::Empty));
1228    }
1229
1230    #[test]
1231    fn test_stats_record_remote() {
1232        let mut stats = DistributedQueryStats::default();
1233        let plan = QueryPlan::Remote {
1234            shard: 1,
1235            query: "GET key".to_string(),
1236        };
1237
1238        stats.record_query(&plan, 100, 0);
1239
1240        assert_eq!(stats.queries_executed, 1);
1241        assert_eq!(stats.remote_queries, 1);
1242        assert_eq!(stats.shards_contacted, 1);
1243    }
1244
1245    #[test]
1246    fn test_extract_key_get_at_end() {
1247        // "GET" at end without following key
1248        assert!(QueryPlanner::extract_key("something GET").is_none());
1249    }
1250
1251    #[test]
1252    fn test_plan_with_embedding_empty_partitioner() {
1253        // Create partitioner with no nodes to trigger empty shards fallback
1254        let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
1255        let partitioner = ConsistentHashPartitioner::new(config);
1256        let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
1257        let planner = QueryPlanner::new(partitioner, 0);
1258
1259        let embedding = vec![1.0, 0.0, 0.0, 0.0];
1260        // With empty shards, should fall back to plan()
1261        let plan = planner.plan_with_embedding("SIMILAR key TOP 10", &embedding);
1262
1263        // Falls back to plan() which returns Local for unknown query when no shards
1264        match plan {
1265            QueryPlan::Local { .. } | QueryPlan::ScatterGather { .. } => {},
1266            _ => panic!("Expected Local or ScatterGather plan"),
1267        }
1268    }
1269
1270    #[test]
1271    fn test_all_shards_empty() {
1272        let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
1273        let partitioner = ConsistentHashPartitioner::new(config);
1274        let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
1275        let planner = QueryPlanner::new(partitioner, 0);
1276
1277        let shards = planner.all_shards();
1278        assert!(shards.is_empty());
1279    }
1280
1281    #[test]
1282    fn test_plan_select_with_empty_partitioner() {
1283        let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
1284        let partitioner = ConsistentHashPartitioner::new(config);
1285        let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
1286        let planner = QueryPlanner::new(partitioner, 0);
1287
1288        // SELECT with no shards
1289        let plan = planner.plan("SELECT * FROM users");
1290        match plan {
1291            QueryPlan::ScatterGather { shards, .. } => {
1292                assert!(shards.is_empty());
1293            },
1294            _ => panic!("Expected ScatterGather plan"),
1295        }
1296    }
1297
1298    #[test]
1299    fn test_get_with_trailing_space_no_key() {
1300        let partitioner = create_test_partitioner();
1301        let planner = QueryPlanner::new(partitioner, 0);
1302
1303        // "GET " with trailing space but no key - triggers the GET block
1304        // but extract_key returns None since split_whitespace gives ["GET"]
1305        let plan = planner.plan("GET ");
1306        match plan {
1307            QueryPlan::Local { .. } => {},
1308            _ => panic!("Expected Local plan for GET without key"),
1309        }
1310    }
1311
1312    #[test]
1313    fn test_merge_aggregate_unparseable_value() {
1314        // Test with Value string that cannot be parsed as i64
1315        let results = vec![
1316            ShardResult::success(0, QueryResult::Value("not_a_number".to_string()), 100),
1317            ShardResult::success(1, QueryResult::Count(100), 150),
1318        ];
1319
1320        let merged =
1321            ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
1322                .unwrap();
1323        match merged {
1324            QueryResult::Count(n) => assert_eq!(n, 100), // Only the Count value is used
1325            _ => panic!("Expected Count result"),
1326        }
1327    }
1328
1329    #[test]
1330    fn test_node_get_trailing_space() {
1331        let partitioner = create_test_partitioner();
1332        let planner = QueryPlanner::new(partitioner, 0);
1333
1334        // "NODE GET " triggers the block but extract_key fails
1335        let plan = planner.plan("NODE GET ");
1336        match plan {
1337            QueryPlan::Local { .. } => {},
1338            _ => panic!("Expected Local plan"),
1339        }
1340    }
1341
1342    #[test]
1343    fn test_debug_impls() {
1344        // Test Debug implementations for coverage
1345        let config = DistributedQueryConfig::default();
1346        let _ = format!("{:?}", config);
1347
1348        let plan_local = QueryPlan::Local {
1349            query: "test".to_string(),
1350        };
1351        let plan_remote = QueryPlan::Remote {
1352            shard: 0,
1353            query: "test".to_string(),
1354        };
1355        let plan_scatter = QueryPlan::ScatterGather {
1356            shards: vec![0, 1],
1357            query: "test".to_string(),
1358            merge: MergeStrategy::Union,
1359        };
1360        let _ = format!("{:?}", plan_local);
1361        let _ = format!("{:?}", plan_remote);
1362        let _ = format!("{:?}", plan_scatter);
1363
1364        let _ = format!("{:?}", MergeStrategy::TopK(10));
1365        let _ = format!("{:?}", MergeStrategy::Aggregate(AggregateFunction::Count));
1366        let _ = format!("{:?}", MergeStrategy::FirstNonEmpty);
1367        let _ = format!("{:?}", MergeStrategy::Concat);
1368
1369        let _ = format!("{:?}", AggregateFunction::Max);
1370        let _ = format!("{:?}", AggregateFunction::Min);
1371
1372        let result = ShardResult::success(0, QueryResult::Empty, 100);
1373        let _ = format!("{:?}", result);
1374
1375        let stats = DistributedQueryStats::default();
1376        let _ = format!("{:?}", stats);
1377    }
1378
1379    #[test]
1380    fn test_shard_result_clone() {
1381        let result = ShardResult::success(0, QueryResult::Count(10), 100);
1382        let cloned = result.clone();
1383        assert_eq!(cloned.shard, result.shard);
1384    }
1385
1386    #[test]
1387    fn test_config_clone() {
1388        let config = DistributedQueryConfig::default();
1389        let cloned = config.clone();
1390        assert_eq!(cloned.max_concurrent, config.max_concurrent);
1391    }
1392
1393    #[test]
1394    fn test_stats_clone() {
1395        let mut stats = DistributedQueryStats::default();
1396        stats.queries_executed = 10;
1397        let cloned = stats.clone();
1398        assert_eq!(cloned.queries_executed, 10);
1399    }
1400
1401    #[test]
1402    fn test_merge_strategy_clone() {
1403        let strategy = MergeStrategy::TopK(5);
1404        let cloned = strategy.clone();
1405        assert!(matches!(cloned, MergeStrategy::TopK(5)));
1406    }
1407
1408    #[test]
1409    fn test_aggregate_function_copy() {
1410        let func = AggregateFunction::Sum;
1411        let copied: AggregateFunction = func;
1412        assert_eq!(copied, AggregateFunction::Sum);
1413    }
1414
1415    #[test]
1416    fn test_query_plan_clone() {
1417        let plan = QueryPlan::Local {
1418            query: "test".to_string(),
1419        };
1420        let cloned = plan.clone();
1421        assert!(matches!(cloned, QueryPlan::Local { .. }));
1422    }
1423}