1use std::sync::Arc;
8
9use relational_engine::{Row, Value};
10use serde::{Deserialize, Serialize};
11use tensor_store::{PartitionResult, Partitioner, SemanticPartitioner};
12
13use crate::{QueryResult, Result, SimilarResult};
14
15pub type ShardId = usize;
17
18#[derive(Debug, Clone)]
20pub enum QueryPlan {
21 Local { query: String },
23 Remote { shard: ShardId, query: String },
25 ScatterGather {
27 shards: Vec<ShardId>,
28 query: String,
29 merge: MergeStrategy,
30 },
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum MergeStrategy {
36 Union,
38 TopK(usize),
40 Aggregate(AggregateFunction),
42 FirstNonEmpty,
44 Concat,
46}
47
48#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
50pub enum AggregateFunction {
51 Sum,
53 Count,
55 Avg,
57 Max,
59 Min,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ShardResult {
66 pub shard: ShardId,
68 pub result: QueryResult,
70 pub execution_time_us: u64,
72 pub error: Option<String>,
74}
75
76impl ShardResult {
77 #[must_use]
79 pub const fn success(shard: ShardId, result: QueryResult, execution_time_us: u64) -> Self {
80 Self {
81 shard,
82 result,
83 execution_time_us,
84 error: None,
85 }
86 }
87
88 #[must_use]
90 pub const fn error(shard: ShardId, error: String) -> Self {
91 Self {
92 shard,
93 result: QueryResult::Empty,
94 execution_time_us: 0,
95 error: Some(error),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct DistributedQueryConfig {
103 pub max_concurrent: usize,
105 pub shard_timeout_ms: u64,
107 pub retry_count: usize,
109 pub fail_fast: bool,
111}
112
113impl Default for DistributedQueryConfig {
114 fn default() -> Self {
115 Self {
116 max_concurrent: 10,
117 shard_timeout_ms: 5000,
118 retry_count: 2,
119 fail_fast: false,
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct QueryPlanner {
127 partitioner: Arc<dyn Partitioner + Send + Sync>,
129 semantic_partitioner: Option<Arc<SemanticPartitioner>>,
131 #[allow(dead_code)]
133 local_shard: ShardId,
134}
135
136impl QueryPlanner {
137 pub fn new(partitioner: Arc<dyn Partitioner + Send + Sync>, local_shard: ShardId) -> Self {
139 Self {
140 partitioner,
141 semantic_partitioner: None,
142 local_shard,
143 }
144 }
145
146 #[must_use]
148 pub fn with_semantic_partitioner(mut self, partitioner: Arc<SemanticPartitioner>) -> Self {
149 self.semantic_partitioner = Some(partitioner);
150 self
151 }
152
153 #[must_use]
155 pub fn plan(&self, query: &str) -> QueryPlan {
156 let query_type = Self::classify_query(query);
158
159 match query_type {
160 QueryType::PointLookup { key } => {
161 let result = self.partitioner.partition(&key);
162 if result.is_local {
163 QueryPlan::Local {
164 query: query.to_string(),
165 }
166 } else {
167 QueryPlan::Remote {
168 shard: self.shard_from_result(&result),
169 query: query.to_string(),
170 }
171 }
172 },
173 QueryType::SimilaritySearch { k } => {
174 QueryPlan::ScatterGather {
176 shards: self.all_shards(),
177 query: query.to_string(),
178 merge: MergeStrategy::TopK(k),
179 }
180 },
181 QueryType::TableScan => {
182 QueryPlan::ScatterGather {
184 shards: self.all_shards(),
185 query: query.to_string(),
186 merge: MergeStrategy::Union,
187 }
188 },
189 QueryType::Aggregate { func } => {
190 QueryPlan::ScatterGather {
192 shards: self.all_shards(),
193 query: query.to_string(),
194 merge: MergeStrategy::Aggregate(func),
195 }
196 },
197 QueryType::Unknown => {
198 QueryPlan::Local {
200 query: query.to_string(),
201 }
202 },
203 }
204 }
205
206 #[must_use]
208 pub fn plan_with_embedding(&self, query: &str, embedding: &[f32]) -> QueryPlan {
209 let relevant_shards = self.shards_for_embedding(embedding);
211
212 if relevant_shards.is_empty() {
213 return self.plan(query);
215 }
216
217 let query_type = Self::classify_query(query);
218
219 match query_type {
220 QueryType::SimilaritySearch { k } => QueryPlan::ScatterGather {
221 shards: relevant_shards,
222 query: query.to_string(),
223 merge: MergeStrategy::TopK(k),
224 },
225 _ => self.plan(query),
226 }
227 }
228
229 fn all_shards(&self) -> Vec<ShardId> {
231 let nodes = self.partitioner.nodes();
232 (0..nodes.len()).collect()
233 }
234
235 fn shard_from_result(&self, result: &PartitionResult) -> ShardId {
237 let nodes = self.partitioner.nodes();
238 nodes.iter().position(|n| *n == result.primary).unwrap_or(0)
239 }
240
241 fn shards_for_embedding(&self, embedding: &[f32]) -> Vec<ShardId> {
243 if let Some(sp) = &self.semantic_partitioner {
244 let results = sp.shards_for_embedding(embedding);
245 if !results.is_empty() {
246 return results.into_iter().map(|(shard, _score)| shard).collect();
247 }
248 }
249 self.all_shards()
251 }
252
253 fn classify_query(query: &str) -> QueryType {
255 let query_upper = query.to_uppercase();
256 let query_trimmed = query_upper.trim();
257
258 if query_trimmed.starts_with("GET ")
260 || query_trimmed.starts_with("NODE GET ")
261 || query_trimmed.starts_with("ENTITY GET ")
262 {
263 if let Some(key) = Self::extract_key(query) {
265 return QueryType::PointLookup { key };
266 }
267 }
268
269 if query_trimmed.starts_with("SIMILAR ") {
271 let k = Self::extract_top_k(query).unwrap_or(10);
272 return QueryType::SimilaritySearch { k };
273 }
274
275 if query_trimmed.starts_with("SELECT ") || query_trimmed.starts_with("NODE LIST") {
277 if query_trimmed.contains("COUNT(") {
279 return QueryType::Aggregate {
280 func: AggregateFunction::Count,
281 };
282 }
283 if query_trimmed.contains("SUM(") {
284 return QueryType::Aggregate {
285 func: AggregateFunction::Sum,
286 };
287 }
288 if query_trimmed.contains("AVG(") {
289 return QueryType::Aggregate {
290 func: AggregateFunction::Avg,
291 };
292 }
293 return QueryType::TableScan;
294 }
295
296 QueryType::Unknown
297 }
298
299 fn extract_key(query: &str) -> Option<String> {
301 let parts: Vec<&str> = query.split_whitespace().collect();
302 if parts.len() >= 2 {
303 for (i, part) in parts.iter().enumerate() {
305 if part.eq_ignore_ascii_case("GET") && i + 1 < parts.len() {
306 return Some(parts[i + 1].to_string());
307 }
308 }
309 }
310 None
311 }
312
313 fn extract_top_k(query: &str) -> Option<usize> {
315 let query_upper = query.to_uppercase();
316 if let Some(pos) = query_upper.find("TOP ") {
317 let rest = &query_upper[pos + 4..];
318 let num_str: String = rest.chars().take_while(char::is_ascii_digit).collect();
319 return num_str.parse().ok();
320 }
321 None
322 }
323}
324
325#[derive(Debug)]
327enum QueryType {
328 PointLookup { key: String },
330 SimilaritySearch { k: usize },
332 TableScan,
334 Aggregate { func: AggregateFunction },
336 Unknown,
338}
339
340#[derive(Debug)]
342pub struct ResultMerger;
343
344impl ResultMerger {
345 pub fn merge(results: Vec<ShardResult>, strategy: &MergeStrategy) -> Result<QueryResult> {
352 let successful: Vec<_> = results.into_iter().filter(|r| r.error.is_none()).collect();
354
355 if successful.is_empty() {
356 return Ok(QueryResult::Empty);
357 }
358
359 Ok(match strategy {
360 MergeStrategy::Union => Self::merge_union(successful),
361 MergeStrategy::TopK(k) => Self::merge_top_k(successful, *k),
362 MergeStrategy::Aggregate(func) => Self::merge_aggregate(successful, *func),
363 MergeStrategy::FirstNonEmpty => Self::merge_first_non_empty(successful),
364 MergeStrategy::Concat => Self::merge_concat(successful),
365 })
366 }
367
368 fn merge_union(results: Vec<ShardResult>) -> QueryResult {
370 let mut all_rows = Vec::new();
371 let mut all_nodes = Vec::new();
372 let mut all_edges = Vec::new();
373 let mut all_similar = Vec::new();
374
375 for shard_result in results {
376 match shard_result.result {
377 QueryResult::Rows(rows) => all_rows.extend(rows),
378 QueryResult::Nodes(nodes) => all_nodes.extend(nodes),
379 QueryResult::Edges(edges) => all_edges.extend(edges),
380 QueryResult::Similar(similar) => all_similar.extend(similar),
381 QueryResult::Count(n) => {
382 #[allow(clippy::cast_possible_wrap)]
385 let count_val = n as i64;
386 all_rows.push(Row {
387 id: 0,
388 values: vec![("count".to_string(), Value::Int(count_val))],
389 });
390 },
391 _ => {},
392 }
393 }
394
395 if !all_similar.is_empty() {
397 return QueryResult::Similar(all_similar);
398 }
399 if !all_nodes.is_empty() {
400 return QueryResult::Nodes(all_nodes);
401 }
402 if !all_edges.is_empty() {
403 return QueryResult::Edges(all_edges);
404 }
405 if !all_rows.is_empty() {
406 return QueryResult::Rows(all_rows);
407 }
408
409 QueryResult::Empty
410 }
411
412 fn merge_top_k(results: Vec<ShardResult>, k: usize) -> QueryResult {
414 let mut all_similar: Vec<SimilarResult> = Vec::new();
415
416 for shard_result in results {
417 if let QueryResult::Similar(similar) = shard_result.result {
418 all_similar.extend(similar);
419 }
420 }
421
422 all_similar.sort_by(|a, b| {
424 b.score
425 .partial_cmp(&a.score)
426 .unwrap_or(std::cmp::Ordering::Equal)
427 });
428
429 all_similar.truncate(k);
431
432 QueryResult::Similar(all_similar)
433 }
434
435 fn merge_aggregate(results: Vec<ShardResult>, func: AggregateFunction) -> QueryResult {
437 let mut values: Vec<i64> = Vec::new();
438
439 for shard_result in results {
440 match shard_result.result {
441 QueryResult::Count(n) => {
442 #[allow(clippy::cast_possible_wrap)]
445 let count_val = n as i64;
446 values.push(count_val);
447 },
448 QueryResult::Value(s) => {
449 if let Ok(n) = s.parse::<i64>() {
450 values.push(n);
451 }
452 },
453 _ => {},
454 }
455 }
456
457 if values.is_empty() {
458 return QueryResult::Count(0);
459 }
460
461 #[allow(
465 clippy::cast_possible_truncation,
466 clippy::cast_sign_loss,
467 clippy::cast_possible_wrap
468 )]
469 let result = match func {
470 AggregateFunction::Sum | AggregateFunction::Count => {
471 values.iter().sum::<i64>() as usize
472 },
473 AggregateFunction::Max => *values.iter().max().unwrap_or(&0) as usize,
474 AggregateFunction::Min => *values.iter().min().unwrap_or(&0) as usize,
475 AggregateFunction::Avg => (values.iter().sum::<i64>() / (values.len() as i64)) as usize,
476 };
477
478 QueryResult::Count(result)
479 }
480
481 fn merge_first_non_empty(results: Vec<ShardResult>) -> QueryResult {
483 for shard_result in results {
484 if !matches!(&shard_result.result, QueryResult::Empty) {
485 return shard_result.result;
486 }
487 }
488 QueryResult::Empty
489 }
490
491 fn merge_concat(results: Vec<ShardResult>) -> QueryResult {
493 Self::merge_union(results)
495 }
496}
497
498#[derive(Debug, Clone, Default)]
500pub struct DistributedQueryStats {
501 pub queries_executed: u64,
503 pub local_queries: u64,
505 pub remote_queries: u64,
507 pub scatter_gather_queries: u64,
509 pub shards_contacted: u64,
511 pub avg_latency_us: u64,
513 pub shard_errors: u64,
515}
516
517impl DistributedQueryStats {
518 pub const fn record_query(&mut self, plan: &QueryPlan, latency_us: u64, errors: usize) {
520 self.queries_executed += 1;
521
522 match plan {
523 QueryPlan::Local { .. } => {
524 self.local_queries += 1;
525 self.shards_contacted += 1;
526 },
527 QueryPlan::Remote { .. } => {
528 self.remote_queries += 1;
529 self.shards_contacted += 1;
530 },
531 QueryPlan::ScatterGather { shards, .. } => {
532 self.scatter_gather_queries += 1;
533 self.shards_contacted += shards.len() as u64;
534 },
535 }
536
537 self.shard_errors += errors as u64;
538
539 if self.queries_executed == 1 {
541 self.avg_latency_us = latency_us;
542 } else {
543 self.avg_latency_us = (self.avg_latency_us * (self.queries_executed - 1) + latency_us)
544 / self.queries_executed;
545 }
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use tensor_store::{ConsistentHashConfig, ConsistentHashPartitioner};
552
553 use super::*;
554
555 fn create_test_partitioner() -> Arc<dyn Partitioner + Send + Sync> {
556 let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
557 let mut partitioner = ConsistentHashPartitioner::new(config);
558 partitioner.add_node("node1".to_string());
559 partitioner.add_node("node2".to_string());
560 partitioner.add_node("node3".to_string());
561 Arc::new(partitioner)
562 }
563
564 #[test]
565 fn test_query_plan_local() {
566 let partitioner = create_test_partitioner();
567 let planner = QueryPlanner::new(partitioner, 0);
568
569 let plan = planner.plan("GET some_key");
570 assert!(
571 matches!(plan, QueryPlan::Local { .. } | QueryPlan::Remote { .. }),
572 "Expected Local or Remote plan"
573 );
574 }
575
576 #[test]
577 fn test_query_plan_scatter_gather() {
578 let partitioner = create_test_partitioner();
579 let planner = QueryPlanner::new(partitioner, 0);
580
581 let plan = planner.plan("SELECT users");
582 assert!(
583 matches!(
584 plan,
585 QueryPlan::ScatterGather {
586 merge: MergeStrategy::Union,
587 ..
588 }
589 ),
590 "Expected ScatterGather with Union merge"
591 );
592 }
593
594 #[test]
595 fn test_query_plan_similar() {
596 let partitioner = create_test_partitioner();
597 let planner = QueryPlanner::new(partitioner, 0);
598
599 let plan = planner.plan("SIMILAR key TOP 5");
600 assert!(
601 matches!(
602 plan,
603 QueryPlan::ScatterGather {
604 merge: MergeStrategy::TopK(5),
605 ..
606 }
607 ),
608 "Expected ScatterGather with TopK(5) merge"
609 );
610 }
611
612 #[test]
613 fn test_query_plan_aggregate() {
614 let partitioner = create_test_partitioner();
615 let planner = QueryPlanner::new(partitioner, 0);
616
617 let plan = planner.plan("SELECT COUNT(*) FROM users");
618 assert!(
619 matches!(
620 plan,
621 QueryPlan::ScatterGather {
622 merge: MergeStrategy::Aggregate(AggregateFunction::Count),
623 ..
624 }
625 ),
626 "Expected ScatterGather with Count aggregate"
627 );
628 }
629
630 #[test]
631 fn test_merge_union() {
632 let results = vec![
633 ShardResult::success(0, QueryResult::Count(10), 100),
634 ShardResult::success(1, QueryResult::Count(20), 150),
635 ];
636
637 let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
638 let QueryResult::Rows(rows) = merged else {
639 panic!("Expected Rows result");
640 };
641 assert_eq!(rows.len(), 2);
642 }
643
644 #[test]
645 fn test_merge_top_k() {
646 let results = vec![
647 ShardResult::success(
648 0,
649 QueryResult::Similar(vec![
650 SimilarResult {
651 key: "a".to_string(),
652 score: 0.9,
653 },
654 SimilarResult {
655 key: "b".to_string(),
656 score: 0.8,
657 },
658 ]),
659 100,
660 ),
661 ShardResult::success(
662 1,
663 QueryResult::Similar(vec![SimilarResult {
664 key: "c".to_string(),
665 score: 0.95,
666 }]),
667 150,
668 ),
669 ];
670
671 let merged = ResultMerger::merge(results, &MergeStrategy::TopK(2)).unwrap();
672 match merged {
673 QueryResult::Similar(similar) => {
674 assert_eq!(similar.len(), 2);
675 assert_eq!(similar[0].key, "c"); assert_eq!(similar[1].key, "a");
677 },
678 _ => panic!("Expected Similar result"),
679 }
680 }
681
682 #[test]
683 fn test_merge_aggregate_sum() {
684 let results = vec![
685 ShardResult::success(0, QueryResult::Count(10), 100),
686 ShardResult::success(1, QueryResult::Count(20), 150),
687 ShardResult::success(2, QueryResult::Count(30), 200),
688 ];
689
690 let merged =
691 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
692 .unwrap();
693 match merged {
694 QueryResult::Count(n) => assert_eq!(n, 60),
695 _ => panic!("Expected Count result"),
696 }
697 }
698
699 #[test]
700 fn test_merge_aggregate_avg() {
701 let results = vec![
702 ShardResult::success(0, QueryResult::Count(10), 100),
703 ShardResult::success(1, QueryResult::Count(20), 150),
704 ShardResult::success(2, QueryResult::Count(30), 200),
705 ];
706
707 let merged =
708 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Avg))
709 .unwrap();
710 match merged {
711 QueryResult::Count(n) => assert_eq!(n, 20),
712 _ => panic!("Expected Count result"),
713 }
714 }
715
716 #[test]
717 fn test_merge_first_non_empty() {
718 let results = vec![
719 ShardResult::success(0, QueryResult::Empty, 100),
720 ShardResult::success(1, QueryResult::Value("found".to_string()), 150),
721 ShardResult::success(2, QueryResult::Value("also_found".to_string()), 200),
722 ];
723
724 let merged = ResultMerger::merge(results, &MergeStrategy::FirstNonEmpty).unwrap();
725 match merged {
726 QueryResult::Value(s) => assert_eq!(s, "found"),
727 _ => panic!("Expected Value result"),
728 }
729 }
730
731 #[test]
732 fn test_shard_result_success() {
733 let result = ShardResult::success(0, QueryResult::Count(10), 100);
734 assert_eq!(result.shard, 0);
735 assert!(result.error.is_none());
736 assert_eq!(result.execution_time_us, 100);
737 }
738
739 #[test]
740 fn test_shard_result_error() {
741 let result = ShardResult::error(1, "timeout".to_string());
742 assert_eq!(result.shard, 1);
743 assert!(result.error.is_some());
744 assert_eq!(result.error.unwrap(), "timeout");
745 }
746
747 #[test]
748 fn test_config_default() {
749 let config = DistributedQueryConfig::default();
750 assert_eq!(config.max_concurrent, 10);
751 assert_eq!(config.shard_timeout_ms, 5000);
752 assert_eq!(config.retry_count, 2);
753 assert!(!config.fail_fast);
754 }
755
756 #[test]
757 fn test_stats_record_local() {
758 let mut stats = DistributedQueryStats::default();
759 let plan = QueryPlan::Local {
760 query: "GET key".to_string(),
761 };
762
763 stats.record_query(&plan, 100, 0);
764
765 assert_eq!(stats.queries_executed, 1);
766 assert_eq!(stats.local_queries, 1);
767 assert_eq!(stats.shards_contacted, 1);
768 assert_eq!(stats.avg_latency_us, 100);
769 }
770
771 #[test]
772 fn test_stats_record_scatter_gather() {
773 let mut stats = DistributedQueryStats::default();
774 let plan = QueryPlan::ScatterGather {
775 shards: vec![0, 1, 2],
776 query: "SELECT users".to_string(),
777 merge: MergeStrategy::Union,
778 };
779
780 stats.record_query(&plan, 500, 1);
781
782 assert_eq!(stats.queries_executed, 1);
783 assert_eq!(stats.scatter_gather_queries, 1);
784 assert_eq!(stats.shards_contacted, 3);
785 assert_eq!(stats.shard_errors, 1);
786 }
787
788 #[test]
789 fn test_merge_empty_results() {
790 let results: Vec<ShardResult> = vec![];
791 let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
792 assert!(matches!(merged, QueryResult::Empty));
793 }
794
795 #[test]
796 fn test_merge_filters_errors() {
797 let results = vec![
798 ShardResult::success(0, QueryResult::Count(10), 100),
799 ShardResult::error(1, "timeout".to_string()),
800 ];
801
802 let merged =
803 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
804 .unwrap();
805 match merged {
806 QueryResult::Count(n) => assert_eq!(n, 10), _ => panic!("Expected Count result"),
808 }
809 }
810
811 #[test]
812 fn test_planner_extract_key() {
813 assert_eq!(
815 QueryPlanner::extract_key("GET mykey"),
816 Some("mykey".to_string())
817 );
818 assert_eq!(
819 QueryPlanner::extract_key("NODE GET user:123"),
820 Some("user:123".to_string())
821 );
822 }
823
824 #[test]
825 fn test_planner_extract_top_k() {
826 assert_eq!(QueryPlanner::extract_top_k("SIMILAR key TOP 5"), Some(5));
827 assert_eq!(
828 QueryPlanner::extract_top_k("SIMILAR key TOP 100"),
829 Some(100)
830 );
831 assert_eq!(QueryPlanner::extract_top_k("SIMILAR key"), None);
832 }
833
834 #[test]
835 fn test_aggregate_function_equality() {
836 assert_eq!(AggregateFunction::Sum, AggregateFunction::Sum);
837 assert_ne!(AggregateFunction::Sum, AggregateFunction::Count);
838 }
839
840 #[test]
841 fn test_all_shards() {
842 let partitioner = create_test_partitioner();
843 let planner = QueryPlanner::new(partitioner, 0);
844
845 let shards = planner.all_shards();
846 assert_eq!(shards.len(), 3);
847 assert_eq!(shards, vec![0, 1, 2]);
848 }
849
850 #[test]
851 fn test_plan_with_embedding() {
852 let partitioner = create_test_partitioner();
853 let planner = QueryPlanner::new(partitioner, 0);
854
855 let embedding = vec![1.0, 0.0, 0.0, 0.0];
856 let plan = planner.plan_with_embedding("SIMILAR key TOP 10", &embedding);
857
858 match plan {
859 QueryPlan::ScatterGather { .. } => {},
860 _ => panic!("Expected ScatterGather plan"),
861 }
862 }
863
864 #[test]
865 fn test_merge_max() {
866 let results = vec![
867 ShardResult::success(0, QueryResult::Count(10), 100),
868 ShardResult::success(1, QueryResult::Count(50), 150),
869 ShardResult::success(2, QueryResult::Count(30), 200),
870 ];
871
872 let merged =
873 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Max))
874 .unwrap();
875 match merged {
876 QueryResult::Count(n) => assert_eq!(n, 50),
877 _ => panic!("Expected Count result"),
878 }
879 }
880
881 #[test]
882 fn test_merge_min() {
883 let results = vec![
884 ShardResult::success(0, QueryResult::Count(10), 100),
885 ShardResult::success(1, QueryResult::Count(50), 150),
886 ShardResult::success(2, QueryResult::Count(30), 200),
887 ];
888
889 let merged =
890 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Min))
891 .unwrap();
892 match merged {
893 QueryResult::Count(n) => assert_eq!(n, 10),
894 _ => panic!("Expected Count result"),
895 }
896 }
897
898 #[test]
899 fn test_stats_avg_latency_updates() {
900 let mut stats = DistributedQueryStats::default();
901 let plan = QueryPlan::Local {
902 query: "GET key".to_string(),
903 };
904
905 stats.record_query(&plan, 100, 0);
906 assert_eq!(stats.avg_latency_us, 100);
907
908 stats.record_query(&plan, 200, 0);
909 assert_eq!(stats.avg_latency_us, 150);
910 }
911
912 #[test]
913 fn test_merge_concat() {
914 let results = vec![
915 ShardResult::success(0, QueryResult::Count(10), 100),
916 ShardResult::success(1, QueryResult::Count(20), 150),
917 ];
918
919 let merged = ResultMerger::merge(results, &MergeStrategy::Concat).unwrap();
920 match merged {
921 QueryResult::Rows(rows) => assert_eq!(rows.len(), 2),
922 _ => panic!("Expected Rows result"),
923 }
924 }
925
926 #[test]
927 fn test_merge_union_nodes() {
928 use crate::NodeResult;
929
930 let results = vec![
931 ShardResult::success(
932 0,
933 QueryResult::Nodes(vec![
934 NodeResult {
935 id: 1,
936 label: "Person".to_string(),
937 properties: std::collections::HashMap::new(),
938 },
939 NodeResult {
940 id: 2,
941 label: "Person".to_string(),
942 properties: std::collections::HashMap::new(),
943 },
944 ]),
945 100,
946 ),
947 ShardResult::success(
948 1,
949 QueryResult::Nodes(vec![NodeResult {
950 id: 3,
951 label: "Person".to_string(),
952 properties: std::collections::HashMap::new(),
953 }]),
954 150,
955 ),
956 ];
957
958 let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
959 match merged {
960 QueryResult::Nodes(nodes) => assert_eq!(nodes.len(), 3),
961 _ => panic!("Expected Nodes result"),
962 }
963 }
964
965 #[test]
966 fn test_merge_union_edges() {
967 use crate::EdgeResult;
968
969 let results = vec![
970 ShardResult::success(
971 0,
972 QueryResult::Edges(vec![EdgeResult {
973 id: 1,
974 from: 1,
975 to: 2,
976 label: "KNOWS".to_string(),
977 }]),
978 100,
979 ),
980 ShardResult::success(
981 1,
982 QueryResult::Edges(vec![EdgeResult {
983 id: 2,
984 from: 2,
985 to: 3,
986 label: "KNOWS".to_string(),
987 }]),
988 150,
989 ),
990 ];
991
992 let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
993 match merged {
994 QueryResult::Edges(edges) => assert_eq!(edges.len(), 2),
995 _ => panic!("Expected Edges result"),
996 }
997 }
998
999 #[test]
1000 fn test_merge_union_empty_all() {
1001 let results = vec![
1002 ShardResult::success(0, QueryResult::Empty, 100),
1003 ShardResult::success(1, QueryResult::Empty, 150),
1004 ];
1005
1006 let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
1007 assert!(matches!(merged, QueryResult::Empty));
1008 }
1009
1010 #[test]
1011 fn test_merge_first_non_empty_all_empty() {
1012 let results = vec![
1013 ShardResult::success(0, QueryResult::Empty, 100),
1014 ShardResult::success(1, QueryResult::Empty, 150),
1015 ];
1016
1017 let merged = ResultMerger::merge(results, &MergeStrategy::FirstNonEmpty).unwrap();
1018 assert!(matches!(merged, QueryResult::Empty));
1019 }
1020
1021 #[test]
1022 fn test_merge_aggregate_value_strings() {
1023 let results = vec![
1024 ShardResult::success(0, QueryResult::Value("100".to_string()), 100),
1025 ShardResult::success(1, QueryResult::Value("200".to_string()), 150),
1026 ];
1027
1028 let merged =
1029 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
1030 .unwrap();
1031 match merged {
1032 QueryResult::Count(n) => assert_eq!(n, 300),
1033 _ => panic!("Expected Count result"),
1034 }
1035 }
1036
1037 #[test]
1038 fn test_merge_aggregate_empty_values() {
1039 let results = vec![
1040 ShardResult::success(0, QueryResult::Empty, 100),
1041 ShardResult::success(1, QueryResult::Empty, 150),
1042 ];
1043
1044 let merged =
1045 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
1046 .unwrap();
1047 match merged {
1048 QueryResult::Count(n) => assert_eq!(n, 0),
1049 _ => panic!("Expected Count result"),
1050 }
1051 }
1052
1053 #[test]
1054 fn test_query_plan_node_list() {
1055 let partitioner = create_test_partitioner();
1056 let planner = QueryPlanner::new(partitioner, 0);
1057
1058 let plan = planner.plan("NODE LIST users");
1059 match plan {
1060 QueryPlan::ScatterGather {
1061 merge: MergeStrategy::Union,
1062 ..
1063 } => {},
1064 _ => panic!("Expected ScatterGather with Union merge"),
1065 }
1066 }
1067
1068 #[test]
1069 fn test_query_plan_select_sum() {
1070 let partitioner = create_test_partitioner();
1071 let planner = QueryPlanner::new(partitioner, 0);
1072
1073 let plan = planner.plan("SELECT SUM(amount) FROM orders");
1074 match plan {
1075 QueryPlan::ScatterGather {
1076 merge: MergeStrategy::Aggregate(AggregateFunction::Sum),
1077 ..
1078 } => {},
1079 _ => panic!("Expected ScatterGather with Sum aggregate"),
1080 }
1081 }
1082
1083 #[test]
1084 fn test_query_plan_select_avg() {
1085 let partitioner = create_test_partitioner();
1086 let planner = QueryPlanner::new(partitioner, 0);
1087
1088 let plan = planner.plan("SELECT AVG(price) FROM products");
1089 match plan {
1090 QueryPlan::ScatterGather {
1091 merge: MergeStrategy::Aggregate(AggregateFunction::Avg),
1092 ..
1093 } => {},
1094 _ => panic!("Expected ScatterGather with Avg aggregate"),
1095 }
1096 }
1097
1098 #[test]
1099 fn test_query_plan_unknown() {
1100 let partitioner = create_test_partitioner();
1101 let planner = QueryPlanner::new(partitioner, 0);
1102
1103 let plan = planner.plan("FOOBAR something");
1105 match plan {
1106 QueryPlan::Local { .. } => {},
1107 _ => panic!("Expected Local plan for unknown query"),
1108 }
1109 }
1110
1111 #[test]
1112 fn test_plan_with_embedding_non_similar() {
1113 let partitioner = create_test_partitioner();
1114 let planner = QueryPlanner::new(partitioner, 0);
1115
1116 let embedding = vec![1.0, 0.0, 0.0, 0.0];
1117 let plan = planner.plan_with_embedding("SELECT * FROM users", &embedding);
1119
1120 match plan {
1121 QueryPlan::ScatterGather { .. } => {},
1122 _ => panic!("Expected ScatterGather plan"),
1123 }
1124 }
1125
1126 #[test]
1127 fn test_extract_key_no_get() {
1128 assert!(QueryPlanner::extract_key("something else").is_none());
1130 }
1131
1132 #[test]
1133 fn test_extract_key_empty() {
1134 assert!(QueryPlanner::extract_key("").is_none());
1136 }
1137
1138 #[test]
1139 fn test_query_plan_node_get() {
1140 let partitioner = create_test_partitioner();
1141 let planner = QueryPlanner::new(partitioner, 0);
1142
1143 let plan = planner.plan("NODE GET user:123");
1144 match plan {
1145 QueryPlan::Local { .. } | QueryPlan::Remote { .. } => {},
1146 _ => panic!("Expected Local or Remote plan"),
1147 }
1148 }
1149
1150 #[test]
1151 fn test_query_plan_entity_get() {
1152 let partitioner = create_test_partitioner();
1153 let planner = QueryPlanner::new(partitioner, 0);
1154
1155 let plan = planner.plan("ENTITY GET entity:456");
1156 match plan {
1157 QueryPlan::Local { .. } | QueryPlan::Remote { .. } => {},
1158 _ => panic!("Expected Local or Remote plan"),
1159 }
1160 }
1161
1162 #[test]
1163 fn test_merge_top_k_non_similar_results() {
1164 let results = vec![
1166 ShardResult::success(0, QueryResult::Empty, 100),
1167 ShardResult::success(1, QueryResult::Count(10), 150),
1168 ];
1169
1170 let merged = ResultMerger::merge(results, &MergeStrategy::TopK(5)).unwrap();
1171 match merged {
1172 QueryResult::Similar(similar) => assert!(similar.is_empty()),
1173 _ => panic!("Expected Similar result"),
1174 }
1175 }
1176
1177 #[test]
1178 fn test_merge_aggregate_avg_empty() {
1179 let results = vec![ShardResult::success(0, QueryResult::Rows(vec![]), 100)];
1181
1182 let merged =
1183 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Avg))
1184 .unwrap();
1185 match merged {
1186 QueryResult::Count(n) => assert_eq!(n, 0),
1187 _ => panic!("Expected Count result"),
1188 }
1189 }
1190
1191 #[test]
1192 fn test_query_plan_get_only_no_key() {
1193 let partitioner = create_test_partitioner();
1194 let planner = QueryPlanner::new(partitioner, 0);
1195
1196 let plan = planner.plan("GET");
1198 match plan {
1199 QueryPlan::Local { .. } => {},
1200 _ => panic!("Expected Local plan for GET without key"),
1201 }
1202 }
1203
1204 #[test]
1205 fn test_query_plan_node_get_only() {
1206 let partitioner = create_test_partitioner();
1207 let planner = QueryPlanner::new(partitioner, 0);
1208
1209 let plan = planner.plan("NODE GET");
1211 match plan {
1212 QueryPlan::Local { .. } => {},
1213 _ => panic!("Expected Local plan for NODE GET without key"),
1214 }
1215 }
1216
1217 #[test]
1218 fn test_merge_union_other_result_types() {
1219 let results = vec![
1221 ShardResult::success(0, QueryResult::Path(vec![1, 2, 3]), 100),
1222 ShardResult::success(1, QueryResult::Value("test".to_string()), 150),
1223 ];
1224
1225 let merged = ResultMerger::merge(results, &MergeStrategy::Union).unwrap();
1226 assert!(matches!(merged, QueryResult::Empty));
1228 }
1229
1230 #[test]
1231 fn test_stats_record_remote() {
1232 let mut stats = DistributedQueryStats::default();
1233 let plan = QueryPlan::Remote {
1234 shard: 1,
1235 query: "GET key".to_string(),
1236 };
1237
1238 stats.record_query(&plan, 100, 0);
1239
1240 assert_eq!(stats.queries_executed, 1);
1241 assert_eq!(stats.remote_queries, 1);
1242 assert_eq!(stats.shards_contacted, 1);
1243 }
1244
1245 #[test]
1246 fn test_extract_key_get_at_end() {
1247 assert!(QueryPlanner::extract_key("something GET").is_none());
1249 }
1250
1251 #[test]
1252 fn test_plan_with_embedding_empty_partitioner() {
1253 let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
1255 let partitioner = ConsistentHashPartitioner::new(config);
1256 let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
1257 let planner = QueryPlanner::new(partitioner, 0);
1258
1259 let embedding = vec![1.0, 0.0, 0.0, 0.0];
1260 let plan = planner.plan_with_embedding("SIMILAR key TOP 10", &embedding);
1262
1263 match plan {
1265 QueryPlan::Local { .. } | QueryPlan::ScatterGather { .. } => {},
1266 _ => panic!("Expected Local or ScatterGather plan"),
1267 }
1268 }
1269
1270 #[test]
1271 fn test_all_shards_empty() {
1272 let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
1273 let partitioner = ConsistentHashPartitioner::new(config);
1274 let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
1275 let planner = QueryPlanner::new(partitioner, 0);
1276
1277 let shards = planner.all_shards();
1278 assert!(shards.is_empty());
1279 }
1280
1281 #[test]
1282 fn test_plan_select_with_empty_partitioner() {
1283 let config = ConsistentHashConfig::new("node1").with_virtual_nodes(10);
1284 let partitioner = ConsistentHashPartitioner::new(config);
1285 let partitioner: Arc<dyn Partitioner + Send + Sync> = Arc::new(partitioner);
1286 let planner = QueryPlanner::new(partitioner, 0);
1287
1288 let plan = planner.plan("SELECT * FROM users");
1290 match plan {
1291 QueryPlan::ScatterGather { shards, .. } => {
1292 assert!(shards.is_empty());
1293 },
1294 _ => panic!("Expected ScatterGather plan"),
1295 }
1296 }
1297
1298 #[test]
1299 fn test_get_with_trailing_space_no_key() {
1300 let partitioner = create_test_partitioner();
1301 let planner = QueryPlanner::new(partitioner, 0);
1302
1303 let plan = planner.plan("GET ");
1306 match plan {
1307 QueryPlan::Local { .. } => {},
1308 _ => panic!("Expected Local plan for GET without key"),
1309 }
1310 }
1311
1312 #[test]
1313 fn test_merge_aggregate_unparseable_value() {
1314 let results = vec![
1316 ShardResult::success(0, QueryResult::Value("not_a_number".to_string()), 100),
1317 ShardResult::success(1, QueryResult::Count(100), 150),
1318 ];
1319
1320 let merged =
1321 ResultMerger::merge(results, &MergeStrategy::Aggregate(AggregateFunction::Sum))
1322 .unwrap();
1323 match merged {
1324 QueryResult::Count(n) => assert_eq!(n, 100), _ => panic!("Expected Count result"),
1326 }
1327 }
1328
1329 #[test]
1330 fn test_node_get_trailing_space() {
1331 let partitioner = create_test_partitioner();
1332 let planner = QueryPlanner::new(partitioner, 0);
1333
1334 let plan = planner.plan("NODE GET ");
1336 match plan {
1337 QueryPlan::Local { .. } => {},
1338 _ => panic!("Expected Local plan"),
1339 }
1340 }
1341
1342 #[test]
1343 fn test_debug_impls() {
1344 let config = DistributedQueryConfig::default();
1346 let _ = format!("{:?}", config);
1347
1348 let plan_local = QueryPlan::Local {
1349 query: "test".to_string(),
1350 };
1351 let plan_remote = QueryPlan::Remote {
1352 shard: 0,
1353 query: "test".to_string(),
1354 };
1355 let plan_scatter = QueryPlan::ScatterGather {
1356 shards: vec![0, 1],
1357 query: "test".to_string(),
1358 merge: MergeStrategy::Union,
1359 };
1360 let _ = format!("{:?}", plan_local);
1361 let _ = format!("{:?}", plan_remote);
1362 let _ = format!("{:?}", plan_scatter);
1363
1364 let _ = format!("{:?}", MergeStrategy::TopK(10));
1365 let _ = format!("{:?}", MergeStrategy::Aggregate(AggregateFunction::Count));
1366 let _ = format!("{:?}", MergeStrategy::FirstNonEmpty);
1367 let _ = format!("{:?}", MergeStrategy::Concat);
1368
1369 let _ = format!("{:?}", AggregateFunction::Max);
1370 let _ = format!("{:?}", AggregateFunction::Min);
1371
1372 let result = ShardResult::success(0, QueryResult::Empty, 100);
1373 let _ = format!("{:?}", result);
1374
1375 let stats = DistributedQueryStats::default();
1376 let _ = format!("{:?}", stats);
1377 }
1378
1379 #[test]
1380 fn test_shard_result_clone() {
1381 let result = ShardResult::success(0, QueryResult::Count(10), 100);
1382 let cloned = result.clone();
1383 assert_eq!(cloned.shard, result.shard);
1384 }
1385
1386 #[test]
1387 fn test_config_clone() {
1388 let config = DistributedQueryConfig::default();
1389 let cloned = config.clone();
1390 assert_eq!(cloned.max_concurrent, config.max_concurrent);
1391 }
1392
1393 #[test]
1394 fn test_stats_clone() {
1395 let mut stats = DistributedQueryStats::default();
1396 stats.queries_executed = 10;
1397 let cloned = stats.clone();
1398 assert_eq!(cloned.queries_executed, 10);
1399 }
1400
1401 #[test]
1402 fn test_merge_strategy_clone() {
1403 let strategy = MergeStrategy::TopK(5);
1404 let cloned = strategy.clone();
1405 assert!(matches!(cloned, MergeStrategy::TopK(5)));
1406 }
1407
1408 #[test]
1409 fn test_aggregate_function_copy() {
1410 let func = AggregateFunction::Sum;
1411 let copied: AggregateFunction = func;
1412 assert_eq!(copied, AggregateFunction::Sum);
1413 }
1414
1415 #[test]
1416 fn test_query_plan_clone() {
1417 let plan = QueryPlan::Local {
1418 query: "test".to_string(),
1419 };
1420 let cloned = plan.clone();
1421 assert!(matches!(cloned, QueryPlan::Local { .. }));
1422 }
1423}