1use crate::compute::EncryptedType;
36use crate::compute::circuit::Circuit;
37use crate::compute::predicate::PredicateCompiler;
38use crate::error::{AmateRSError, ErrorContext, Result};
39use crate::types::{CipherBlob, ColumnRef, Key, Predicate, Query};
40use dashmap::DashMap;
41use std::collections::HashSet;
42use std::sync::Arc;
43
44pub use super::plan_cache::{CacheKey, CacheStats, CachedPlan, PlanCache, PlanCacheConfig};
45
46#[derive(Debug, Clone)]
55pub enum LogicalPlan {
56 Scan {
58 collection: String,
60 },
61
62 RangeScan {
64 collection: String,
66 start_key: Option<Vec<u8>>,
68 end_key: Option<Vec<u8>>,
70 },
71
72 Filter {
74 input: Box<LogicalPlan>,
76 predicate: Predicate,
78 },
79
80 Project {
82 input: Box<LogicalPlan>,
84 columns: Vec<String>,
86 },
87
88 Limit {
90 input: Box<LogicalPlan>,
92 count: usize,
94 },
95
96 PointLookup {
98 collection: String,
100 key: Key,
102 },
103}
104
105#[derive(Debug, Clone)]
113pub enum PhysicalPlan {
114 SeqScan {
116 collection: String,
118 },
119
120 IndexScan {
122 collection: String,
124 start: Option<Vec<u8>>,
126 end: Option<Vec<u8>>,
128 },
129
130 FheFilter {
132 input: Box<PhysicalPlan>,
134 circuit: Circuit,
136 predicate: Predicate,
138 },
139
140 Projection {
142 input: Box<PhysicalPlan>,
144 columns: Vec<String>,
146 },
147
148 Limit {
150 input: Box<PhysicalPlan>,
152 count: usize,
154 },
155
156 PointGet {
158 collection: String,
160 key: Key,
162 },
163}
164
165#[derive(Debug, Clone)]
171pub struct PlanCost {
172 pub estimated_rows: u64,
174 pub estimated_fhe_ops: u64,
176 pub estimated_io_bytes: u64,
178 pub total_cost: f64,
180}
181
182impl PlanCost {
183 const IO_COST_PER_BYTE: f64 = 0.001;
185 const FHE_COST_PER_OP: f64 = 100.0;
187 const SCAN_COST_PER_ROW: f64 = 0.01;
189 const POINT_LOOKUP_COST: f64 = 1.0;
191
192 fn compute(estimated_rows: u64, estimated_fhe_ops: u64, estimated_io_bytes: u64) -> Self {
194 let total_cost = (estimated_rows as f64 * Self::SCAN_COST_PER_ROW)
195 + (estimated_fhe_ops as f64 * Self::FHE_COST_PER_OP)
196 + (estimated_io_bytes as f64 * Self::IO_COST_PER_BYTE);
197 Self {
198 estimated_rows,
199 estimated_fhe_ops,
200 estimated_io_bytes,
201 total_cost,
202 }
203 }
204}
205
206impl std::fmt::Display for PlanCost {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(
209 f,
210 "PlanCost(rows={}, fhe_ops={}, io_bytes={}, total={:.2})",
211 self.estimated_rows, self.estimated_fhe_ops, self.estimated_io_bytes, self.total_cost
212 )
213 }
214}
215
216pub struct PlannerStats {
225 pub estimated_collection_sizes: DashMap<String, u64>,
227 pub average_value_size: u64,
229 pub fhe_op_latency_us: u64,
231}
232
233impl PlannerStats {
234 fn new() -> Self {
236 Self {
237 estimated_collection_sizes: DashMap::new(),
238 average_value_size: 256,
239 fhe_op_latency_us: 1000,
240 }
241 }
242
243 fn collection_size(&self, collection: &str) -> u64 {
245 self.estimated_collection_sizes
246 .get(collection)
247 .map(|v| *v)
248 .unwrap_or(1000)
249 }
250
251 pub fn set_collection_size(&self, collection: impl Into<String>, size: u64) {
253 self.estimated_collection_sizes
254 .insert(collection.into(), size);
255 }
256}
257
258impl Default for PlannerStats {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264pub struct QueryPlanner {
276 stats: Arc<PlannerStats>,
278 cache: Option<Arc<PlanCache>>,
280}
281
282impl QueryPlanner {
283 pub fn new() -> Self {
285 Self {
286 stats: Arc::new(PlannerStats::new()),
287 cache: None,
288 }
289 }
290
291 pub fn with_stats(stats: Arc<PlannerStats>) -> Self {
293 Self { stats, cache: None }
294 }
295
296 pub fn with_cache(mut self, config: PlanCacheConfig) -> Self {
298 self.cache = Some(Arc::new(PlanCache::new(config)));
299 self
300 }
301
302 pub fn stats(&self) -> &PlannerStats {
304 &self.stats
305 }
306
307 pub fn plan_cache(&self) -> Option<&PlanCache> {
309 self.cache.as_deref()
310 }
311
312 pub fn cache_stats(&self) -> CacheStats {
314 self.cache
315 .as_ref()
316 .map(|c| c.cache_stats())
317 .unwrap_or_default()
318 }
319
320 pub fn invalidate_all(&self) {
322 if let Some(cache) = &self.cache {
323 cache.invalidate_all();
324 }
325 }
326
327 pub fn invalidate_prefix(&self, prefix: &str) {
329 if let Some(cache) = &self.cache {
330 cache.invalidate_prefix(prefix);
331 }
332 }
333
334 pub fn plan(&self, query: &Query) -> Result<PhysicalPlan> {
344 let cache_key = CacheKey::from_query(query);
345
346 if let Some(cache) = &self.cache {
348 if let Some(cached_plan) = cache.get(&cache_key) {
349 return Ok(cached_plan);
350 }
351 }
352
353 let logical = self.to_logical(query)?;
355 let optimized = self.optimize_logical(logical);
356 let physical = self.to_physical(&optimized)?;
357
358 if let Some(cache) = &self.cache {
360 let normalized = CacheKey::normalize(&format!("{:?}", query));
361 cache.insert(cache_key, physical.clone(), normalized);
362 }
363
364 Ok(physical)
365 }
366
367 fn to_logical(&self, query: &Query) -> Result<LogicalPlan> {
373 match query {
374 Query::Get { collection, key } => Ok(LogicalPlan::PointLookup {
375 collection: collection.clone(),
376 key: key.clone(),
377 }),
378
379 Query::Filter {
380 collection,
381 predicate,
382 } => Ok(LogicalPlan::Filter {
383 input: Box::new(LogicalPlan::Scan {
384 collection: collection.clone(),
385 }),
386 predicate: predicate.clone(),
387 }),
388
389 Query::Range {
390 collection,
391 start,
392 end,
393 } => Ok(LogicalPlan::RangeScan {
394 collection: collection.clone(),
395 start_key: Some(start.to_vec()),
396 end_key: Some(end.to_vec()),
397 }),
398
399 Query::Set { collection, .. } => {
400 Ok(LogicalPlan::Scan {
404 collection: collection.clone(),
405 })
406 }
407
408 Query::Delete { collection, key } => Ok(LogicalPlan::PointLookup {
409 collection: collection.clone(),
410 key: key.clone(),
411 }),
412
413 Query::Update {
414 collection,
415 predicate,
416 ..
417 } => Ok(LogicalPlan::Filter {
418 input: Box::new(LogicalPlan::Scan {
419 collection: collection.clone(),
420 }),
421 predicate: predicate.clone(),
422 }),
423 }
424 }
425
426 fn optimize_logical(&self, plan: LogicalPlan) -> LogicalPlan {
432 let plan = self.push_predicates_down(plan);
433 let plan = self.merge_filters(plan);
434 self.convert_filter_to_range_scan(plan)
435 }
436
437 fn push_predicates_down(&self, plan: LogicalPlan) -> LogicalPlan {
444 match plan {
445 LogicalPlan::Filter { input, predicate } => {
447 let optimized_input = self.push_predicates_down(*input);
448
449 match optimized_input {
450 LogicalPlan::Project {
452 input: proj_input,
453 columns,
454 } => {
455 let pred_cols = Self::referenced_columns(&predicate);
456 let proj_set: HashSet<&str> = columns.iter().map(|c| c.as_str()).collect();
457
458 if pred_cols.iter().all(|c| proj_set.contains(c.as_str())) {
459 LogicalPlan::Project {
462 input: Box::new(LogicalPlan::Filter {
463 input: proj_input,
464 predicate,
465 }),
466 columns,
467 }
468 } else {
469 let mut extended_cols = columns.clone();
473 for col in &pred_cols {
474 if !proj_set.contains(col.as_str()) {
475 extended_cols.push(col.clone());
476 }
477 }
478
479 LogicalPlan::Project {
480 input: Box::new(LogicalPlan::Filter {
481 input: Box::new(LogicalPlan::Project {
482 input: proj_input,
483 columns: extended_cols,
484 }),
485 predicate,
486 }),
487 columns,
488 }
489 }
490 }
491
492 other => LogicalPlan::Filter {
496 input: Box::new(other),
497 predicate,
498 },
499 }
500 }
501
502 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
504 input: Box::new(self.push_predicates_down(*input)),
505 columns,
506 },
507
508 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
509 input: Box::new(self.push_predicates_down(*input)),
510 count,
511 },
512
513 other => other,
515 }
516 }
517
518 fn merge_filters(&self, plan: LogicalPlan) -> LogicalPlan {
522 match plan {
523 LogicalPlan::Filter { input, predicate } => {
524 let optimized_input = self.merge_filters(*input);
525
526 match optimized_input {
527 LogicalPlan::Filter {
528 input: inner_input,
529 predicate: inner_pred,
530 } => {
531 LogicalPlan::Filter {
533 input: inner_input,
534 predicate: Predicate::And(Box::new(inner_pred), Box::new(predicate)),
535 }
536 }
537 other => LogicalPlan::Filter {
538 input: Box::new(other),
539 predicate,
540 },
541 }
542 }
543
544 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
545 input: Box::new(self.merge_filters(*input)),
546 columns,
547 },
548
549 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
550 input: Box::new(self.merge_filters(*input)),
551 count,
552 },
553
554 other => other,
555 }
556 }
557
558 fn convert_filter_to_range_scan(&self, plan: LogicalPlan) -> LogicalPlan {
564 match plan {
565 LogicalPlan::Filter { input, predicate } => {
566 let optimized_input = self.convert_filter_to_range_scan(*input);
567
568 if let LogicalPlan::Scan { ref collection } = optimized_input {
569 if let Some((start, end)) = Self::extract_key_range(&predicate) {
570 return LogicalPlan::RangeScan {
571 collection: collection.clone(),
572 start_key: start,
573 end_key: end,
574 };
575 }
576 }
577
578 LogicalPlan::Filter {
579 input: Box::new(optimized_input),
580 predicate,
581 }
582 }
583
584 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
585 input: Box::new(self.convert_filter_to_range_scan(*input)),
586 columns,
587 },
588
589 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
590 input: Box::new(self.convert_filter_to_range_scan(*input)),
591 count,
592 },
593
594 other => other,
595 }
596 }
597
598 fn to_physical(&self, plan: &LogicalPlan) -> Result<PhysicalPlan> {
604 match plan {
605 LogicalPlan::Scan { collection } => Ok(PhysicalPlan::SeqScan {
606 collection: collection.clone(),
607 }),
608
609 LogicalPlan::RangeScan {
610 collection,
611 start_key,
612 end_key,
613 } => Ok(PhysicalPlan::IndexScan {
614 collection: collection.clone(),
615 start: start_key.clone(),
616 end: end_key.clone(),
617 }),
618
619 LogicalPlan::Filter { input, predicate } => {
620 let physical_input = self.to_physical(input)?;
621 let circuit = self.compile_predicate_circuit(predicate)?;
622
623 Ok(PhysicalPlan::FheFilter {
624 input: Box::new(physical_input),
625 circuit,
626 predicate: predicate.clone(),
627 })
628 }
629
630 LogicalPlan::Project { input, columns } => {
631 let physical_input = self.to_physical(input)?;
632 Ok(PhysicalPlan::Projection {
633 input: Box::new(physical_input),
634 columns: columns.clone(),
635 })
636 }
637
638 LogicalPlan::Limit { input, count } => {
639 let physical_input = self.to_physical(input)?;
640 Ok(PhysicalPlan::Limit {
641 input: Box::new(physical_input),
642 count: *count,
643 })
644 }
645
646 LogicalPlan::PointLookup { collection, key } => Ok(PhysicalPlan::PointGet {
647 collection: collection.clone(),
648 key: key.clone(),
649 }),
650 }
651 }
652
653 pub fn estimate_cost(&self, plan: &PhysicalPlan) -> PlanCost {
659 match plan {
660 PhysicalPlan::SeqScan { collection } => {
661 let rows = self.stats.collection_size(collection);
662 let io_bytes = rows * self.stats.average_value_size;
663 PlanCost::compute(rows, 0, io_bytes)
664 }
665
666 PhysicalPlan::IndexScan {
667 collection,
668 start,
669 end,
670 } => {
671 let total = self.stats.collection_size(collection);
672 let selectivity = match (start, end) {
676 (Some(_), Some(_)) => 0.10,
677 (Some(_), None) | (None, Some(_)) => 0.30,
678 (None, None) => 1.0,
679 };
680 let rows = ((total as f64) * selectivity).max(1.0) as u64;
681 let io_bytes = rows * self.stats.average_value_size;
682 PlanCost::compute(rows, 0, io_bytes)
683 }
684
685 PhysicalPlan::FheFilter { input, circuit, .. } => {
686 let input_cost = self.estimate_cost(input);
687 let fhe_ops = input_cost.estimated_rows * (circuit.gate_count as u64);
689 let output_rows = (input_cost.estimated_rows / 2).max(1);
691 let io_bytes = output_rows * self.stats.average_value_size;
692 PlanCost::compute(
693 input_cost.estimated_rows,
694 input_cost.estimated_fhe_ops + fhe_ops,
695 input_cost.estimated_io_bytes + io_bytes,
696 )
697 }
698
699 PhysicalPlan::Projection { input, .. } => {
700 let mut cost = self.estimate_cost(input);
702 cost.estimated_io_bytes = (cost.estimated_io_bytes as f64 * 0.8) as u64;
704 cost.total_cost = (cost.estimated_rows as f64 * PlanCost::SCAN_COST_PER_ROW)
705 + (cost.estimated_fhe_ops as f64 * PlanCost::FHE_COST_PER_OP)
706 + (cost.estimated_io_bytes as f64 * PlanCost::IO_COST_PER_BYTE);
707 cost
708 }
709
710 PhysicalPlan::Limit { input, count } => {
711 let input_cost = self.estimate_cost(input);
712 let rows = (*count as u64).min(input_cost.estimated_rows);
713 let io_bytes = rows * self.stats.average_value_size;
714 PlanCost::compute(rows, input_cost.estimated_fhe_ops, io_bytes)
717 }
718
719 PhysicalPlan::PointGet { .. } => PlanCost::compute(1, 0, self.stats.average_value_size),
720 }
721 }
722
723 pub fn choose_cheaper<'a>(&self, a: &'a PhysicalPlan, b: &'a PhysicalPlan) -> &'a PhysicalPlan {
725 let cost_a = self.estimate_cost(a);
726 let cost_b = self.estimate_cost(b);
727 if cost_a.total_cost <= cost_b.total_cost {
728 a
729 } else {
730 b
731 }
732 }
733
734 fn referenced_columns(predicate: &Predicate) -> Vec<String> {
740 let mut cols = Vec::new();
741 Self::collect_columns(predicate, &mut cols);
742 cols.sort();
743 cols.dedup();
744 cols
745 }
746
747 fn collect_columns(predicate: &Predicate, out: &mut Vec<String>) {
748 match predicate {
749 Predicate::Eq(col, _)
750 | Predicate::Gt(col, _)
751 | Predicate::Lt(col, _)
752 | Predicate::Gte(col, _)
753 | Predicate::Lte(col, _) => {
754 out.push(col.name.clone());
755 }
756 Predicate::And(l, r) | Predicate::Or(l, r) => {
757 Self::collect_columns(l, out);
758 Self::collect_columns(r, out);
759 }
760 Predicate::Not(inner) => {
761 Self::collect_columns(inner, out);
762 }
763 }
764 }
765
766 fn extract_key_range(predicate: &Predicate) -> Option<(Option<Vec<u8>>, Option<Vec<u8>>)> {
771 match predicate {
772 Predicate::Gt(col, blob) if col.name == "_key" => {
773 Some((Some(blob.as_bytes().to_vec()), None))
774 }
775 Predicate::Gte(col, blob) if col.name == "_key" => {
776 Some((Some(blob.as_bytes().to_vec()), None))
777 }
778 Predicate::Lt(col, blob) if col.name == "_key" => {
779 Some((None, Some(blob.as_bytes().to_vec())))
780 }
781 Predicate::Lte(col, blob) if col.name == "_key" => {
782 Some((None, Some(blob.as_bytes().to_vec())))
783 }
784 Predicate::And(left, right) => {
785 let lr = Self::extract_key_range(left);
787 let rr = Self::extract_key_range(right);
788
789 match (lr, rr) {
790 (Some((s1, e1)), Some((s2, e2))) => {
791 let start = s1.or(s2);
792 let end = e1.or(e2);
793 Some((start, end))
794 }
795 (Some(range), None) | (None, Some(range)) => Some(range),
796 (None, None) => None,
797 }
798 }
799 _ => None,
800 }
801 }
802
803 fn compile_predicate_circuit(&self, predicate: &Predicate) -> Result<Circuit> {
805 let mut compiler = PredicateCompiler::new();
806 compiler.compile(predicate, EncryptedType::U8)
809 }
810}
811
812impl Default for QueryPlanner {
813 fn default() -> Self {
814 Self::new()
815 }
816}
817
818impl std::fmt::Display for LogicalPlan {
823 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
824 self.fmt_indented(f, 0)
825 }
826}
827
828impl LogicalPlan {
829 fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
830 let pad = " ".repeat(indent);
831 match self {
832 LogicalPlan::Scan { collection } => {
833 writeln!(f, "{}Scan({})", pad, collection)
834 }
835 LogicalPlan::RangeScan {
836 collection,
837 start_key,
838 end_key,
839 } => {
840 writeln!(
841 f,
842 "{}RangeScan({}, start={}, end={})",
843 pad,
844 collection,
845 start_key.is_some(),
846 end_key.is_some()
847 )
848 }
849 LogicalPlan::Filter { input, predicate } => {
850 writeln!(f, "{}Filter(pred={:?})", pad, predicate)?;
851 input.fmt_indented(f, indent + 1)
852 }
853 LogicalPlan::Project { input, columns } => {
854 writeln!(f, "{}Project({:?})", pad, columns)?;
855 input.fmt_indented(f, indent + 1)
856 }
857 LogicalPlan::Limit { input, count } => {
858 writeln!(f, "{}Limit({})", pad, count)?;
859 input.fmt_indented(f, indent + 1)
860 }
861 LogicalPlan::PointLookup { collection, key } => {
862 writeln!(f, "{}PointLookup({}, key={})", pad, collection, key)
863 }
864 }
865 }
866}
867
868impl std::fmt::Display for PhysicalPlan {
869 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
870 self.fmt_indented(f, 0)
871 }
872}
873
874impl PhysicalPlan {
875 fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
876 let pad = " ".repeat(indent);
877 match self {
878 PhysicalPlan::SeqScan { collection } => {
879 writeln!(f, "{}SeqScan({})", pad, collection)
880 }
881 PhysicalPlan::IndexScan {
882 collection,
883 start,
884 end,
885 } => {
886 writeln!(
887 f,
888 "{}IndexScan({}, start={}, end={})",
889 pad,
890 collection,
891 start.is_some(),
892 end.is_some()
893 )
894 }
895 PhysicalPlan::FheFilter {
896 input, predicate, ..
897 } => {
898 writeln!(f, "{}FheFilter(pred={:?})", pad, predicate)?;
899 input.fmt_indented(f, indent + 1)
900 }
901 PhysicalPlan::Projection { input, columns } => {
902 writeln!(f, "{}Projection({:?})", pad, columns)?;
903 input.fmt_indented(f, indent + 1)
904 }
905 PhysicalPlan::Limit { input, count } => {
906 writeln!(f, "{}Limit({})", pad, count)?;
907 input.fmt_indented(f, indent + 1)
908 }
909 PhysicalPlan::PointGet { collection, key } => {
910 writeln!(f, "{}PointGet({}, key={})", pad, collection, key)
911 }
912 }
913 }
914}
915
916#[cfg(test)]
921mod tests {
922 use super::*;
923 use crate::types::col;
924
925 fn make_blob(v: u8) -> CipherBlob {
926 CipherBlob::new(vec![v])
927 }
928
929 #[test]
932 fn test_scan_plan() -> Result<()> {
933 let planner = QueryPlanner::new();
934 let query = Query::Filter {
935 collection: "users".to_string(),
936 predicate: Predicate::Gt(col("age"), make_blob(18)),
937 };
938
939 let plan = planner.plan(&query)?;
940
941 match &plan {
943 PhysicalPlan::FheFilter { input, .. } => {
944 assert!(matches!(input.as_ref(), PhysicalPlan::SeqScan { .. }));
945 }
946 other => {
947 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
948 "Expected FheFilter, got: {:?}",
949 other
950 ))));
951 }
952 }
953 Ok(())
954 }
955
956 #[test]
957 fn test_range_scan_pushdown() -> Result<()> {
958 let planner = QueryPlanner::new();
959
960 let query = Query::Filter {
962 collection: "data".to_string(),
963 predicate: Predicate::And(
964 Box::new(Predicate::Gte(col("_key"), make_blob(10))),
965 Box::new(Predicate::Lt(col("_key"), make_blob(50))),
966 ),
967 };
968
969 let plan = planner.plan(&query)?;
970
971 match &plan {
972 PhysicalPlan::IndexScan {
973 collection,
974 start,
975 end,
976 } => {
977 assert_eq!(collection, "data");
978 assert!(start.is_some());
979 assert!(end.is_some());
980 }
981 other => {
982 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
983 "Expected IndexScan, got: {:?}",
984 other
985 ))));
986 }
987 }
988 Ok(())
989 }
990
991 #[test]
992 fn test_predicate_pushdown() -> Result<()> {
993 let planner = QueryPlanner::new();
994
995 let scan = LogicalPlan::Scan {
998 collection: "users".to_string(),
999 };
1000 let project = LogicalPlan::Project {
1001 input: Box::new(scan),
1002 columns: vec!["age".to_string(), "name".to_string()],
1003 };
1004 let filter = LogicalPlan::Filter {
1005 input: Box::new(project),
1006 predicate: Predicate::Gt(col("age"), make_blob(18)),
1007 };
1008
1009 let optimized = planner.push_predicates_down(filter);
1010
1011 match &optimized {
1013 LogicalPlan::Project { input, columns } => {
1014 assert!(columns.contains(&"age".to_string()));
1015 assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
1016 }
1017 other => {
1018 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1019 "Expected Project, got: {:?}",
1020 other
1021 ))));
1022 }
1023 }
1024 Ok(())
1025 }
1026
1027 #[test]
1028 fn test_filter_merge() -> Result<()> {
1029 let planner = QueryPlanner::new();
1030
1031 let scan = LogicalPlan::Scan {
1033 collection: "users".to_string(),
1034 };
1035 let filter1 = LogicalPlan::Filter {
1036 input: Box::new(scan),
1037 predicate: Predicate::Gt(col("age"), make_blob(18)),
1038 };
1039 let filter2 = LogicalPlan::Filter {
1040 input: Box::new(filter1),
1041 predicate: Predicate::Lt(col("age"), make_blob(65)),
1042 };
1043
1044 let optimized = planner.merge_filters(filter2);
1045
1046 match &optimized {
1047 LogicalPlan::Filter { input, predicate } => {
1048 assert!(matches!(predicate, Predicate::And(_, _)));
1050 assert!(matches!(input.as_ref(), LogicalPlan::Scan { .. }));
1052 }
1053 other => {
1054 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1055 "Expected Filter, got: {:?}",
1056 other
1057 ))));
1058 }
1059 }
1060 Ok(())
1061 }
1062
1063 #[test]
1064 fn test_cost_estimation() -> Result<()> {
1065 let planner = QueryPlanner::new();
1066 planner.stats().set_collection_size("data", 10_000);
1067
1068 let seq_scan = PhysicalPlan::SeqScan {
1070 collection: "data".to_string(),
1071 };
1072 let seq_cost = planner.estimate_cost(&seq_scan);
1073
1074 let idx_scan = PhysicalPlan::IndexScan {
1076 collection: "data".to_string(),
1077 start: Some(vec![10]),
1078 end: Some(vec![50]),
1079 };
1080 let idx_cost = planner.estimate_cost(&idx_scan);
1081
1082 assert!(
1084 idx_cost.total_cost < seq_cost.total_cost,
1085 "IndexScan cost ({}) should be less than SeqScan cost ({})",
1086 idx_cost.total_cost,
1087 seq_cost.total_cost,
1088 );
1089
1090 let point = PhysicalPlan::PointGet {
1092 collection: "data".to_string(),
1093 key: Key::from_str("k"),
1094 };
1095 let point_cost = planner.estimate_cost(&point);
1096 assert!(
1097 point_cost.total_cost < idx_cost.total_cost,
1098 "PointGet cost ({}) should be less than IndexScan cost ({})",
1099 point_cost.total_cost,
1100 idx_cost.total_cost,
1101 );
1102
1103 Ok(())
1104 }
1105
1106 #[test]
1107 fn test_limit_planning() -> Result<()> {
1108 let planner = QueryPlanner::new();
1109
1110 let scan = LogicalPlan::Scan {
1112 collection: "logs".to_string(),
1113 };
1114 let filter = LogicalPlan::Filter {
1115 input: Box::new(scan),
1116 predicate: Predicate::Eq(col("level"), make_blob(1)),
1117 };
1118 let limited = LogicalPlan::Limit {
1119 input: Box::new(filter),
1120 count: 10,
1121 };
1122
1123 let physical = planner.to_physical(&limited)?;
1124
1125 match &physical {
1127 PhysicalPlan::Limit { input, count } => {
1128 assert_eq!(*count, 10);
1129 assert!(matches!(input.as_ref(), PhysicalPlan::FheFilter { .. }));
1130 }
1131 other => {
1132 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1133 "Expected Limit, got: {:?}",
1134 other
1135 ))));
1136 }
1137 }
1138
1139 Ok(())
1140 }
1141
1142 #[test]
1143 fn test_plan_with_fhe_filter() -> Result<()> {
1144 let planner = QueryPlanner::new();
1145 let query = Query::Filter {
1146 collection: "accounts".to_string(),
1147 predicate: Predicate::And(
1148 Box::new(Predicate::Gt(col("balance"), make_blob(100))),
1149 Box::new(Predicate::Lt(col("balance"), make_blob(200))),
1150 ),
1151 };
1152
1153 let plan = planner.plan(&query)?;
1154
1155 match &plan {
1157 PhysicalPlan::FheFilter { circuit, .. } => {
1158 assert!(circuit.gate_count > 0);
1160 assert_eq!(circuit.result_type, EncryptedType::Bool);
1161 }
1162 other => {
1163 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1164 "Expected FheFilter, got: {:?}",
1165 other
1166 ))));
1167 }
1168 }
1169 Ok(())
1170 }
1171
1172 #[test]
1173 fn test_complex_plan() -> Result<()> {
1174 let planner = QueryPlanner::new();
1175 planner.stats().set_collection_size("orders", 50_000);
1176
1177 let query = Query::Filter {
1179 collection: "orders".to_string(),
1180 predicate: Predicate::Or(
1181 Box::new(Predicate::Eq(col("status"), make_blob(1))),
1182 Box::new(Predicate::And(
1183 Box::new(Predicate::Gt(col("amount"), make_blob(100))),
1184 Box::new(Predicate::Lt(col("amount"), make_blob(255))),
1185 )),
1186 ),
1187 };
1188
1189 let plan = planner.plan(&query)?;
1190 let cost = planner.estimate_cost(&plan);
1191
1192 assert!(cost.estimated_fhe_ops > 0);
1194 assert!(cost.total_cost > 0.0);
1195
1196 let plan_str = format!("{}", plan);
1198 assert!(!plan_str.is_empty());
1199
1200 Ok(())
1201 }
1202
1203 #[test]
1204 fn test_get_query_planning() -> Result<()> {
1205 let planner = QueryPlanner::new();
1206 let query = Query::Get {
1207 collection: "users".to_string(),
1208 key: Key::from_str("user:42"),
1209 };
1210
1211 let plan = planner.plan(&query)?;
1212
1213 match &plan {
1214 PhysicalPlan::PointGet { collection, key } => {
1215 assert_eq!(collection, "users");
1216 assert_eq!(key.to_string_lossy(), "user:42");
1217 }
1218 other => {
1219 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1220 "Expected PointGet, got: {:?}",
1221 other
1222 ))));
1223 }
1224 }
1225
1226 let cost = planner.estimate_cost(&plan);
1227 assert_eq!(cost.estimated_rows, 1);
1228 assert_eq!(cost.estimated_fhe_ops, 0);
1229
1230 Ok(())
1231 }
1232
1233 #[test]
1234 fn test_range_query_planning() -> Result<()> {
1235 let planner = QueryPlanner::new();
1236 let query = Query::Range {
1237 collection: "events".to_string(),
1238 start: Key::from_str("2024-01"),
1239 end: Key::from_str("2024-12"),
1240 };
1241
1242 let plan = planner.plan(&query)?;
1243
1244 match &plan {
1245 PhysicalPlan::IndexScan {
1246 collection,
1247 start,
1248 end,
1249 } => {
1250 assert_eq!(collection, "events");
1251 assert!(start.is_some());
1252 assert!(end.is_some());
1253 }
1254 other => {
1255 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1256 "Expected IndexScan, got: {:?}",
1257 other
1258 ))));
1259 }
1260 }
1261 Ok(())
1262 }
1263
1264 #[test]
1265 fn test_cost_comparison() -> Result<()> {
1266 let planner = QueryPlanner::new();
1267 planner.stats().set_collection_size("items", 100_000);
1268
1269 let scan = PhysicalPlan::SeqScan {
1270 collection: "items".to_string(),
1271 };
1272
1273 let idx = PhysicalPlan::IndexScan {
1274 collection: "items".to_string(),
1275 start: Some(vec![1]),
1276 end: Some(vec![10]),
1277 };
1278
1279 let cheaper = planner.choose_cheaper(&scan, &idx);
1280
1281 assert!(matches!(cheaper, PhysicalPlan::IndexScan { .. }));
1283
1284 Ok(())
1285 }
1286
1287 #[test]
1288 fn test_filter_not_pushed_below_limit() -> Result<()> {
1289 let planner = QueryPlanner::new();
1290
1291 let scan = LogicalPlan::Scan {
1293 collection: "data".to_string(),
1294 };
1295 let limited = LogicalPlan::Limit {
1296 input: Box::new(scan),
1297 count: 10,
1298 };
1299 let filter = LogicalPlan::Filter {
1300 input: Box::new(limited),
1301 predicate: Predicate::Gt(col("x"), make_blob(5)),
1302 };
1303
1304 let optimized = planner.push_predicates_down(filter);
1305
1306 match &optimized {
1308 LogicalPlan::Filter { input, .. } => {
1309 assert!(matches!(input.as_ref(), LogicalPlan::Limit { .. }));
1310 }
1311 other => {
1312 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1313 "Expected Filter on top, got: {:?}",
1314 other
1315 ))));
1316 }
1317 }
1318
1319 Ok(())
1320 }
1321
1322 #[test]
1323 fn test_stats_update() {
1324 let planner = QueryPlanner::new();
1325 planner.stats().set_collection_size("big_table", 1_000_000);
1326
1327 let size = planner.stats().collection_size("big_table");
1328 assert_eq!(size, 1_000_000);
1329
1330 let default_size = planner.stats().collection_size("unknown");
1332 assert_eq!(default_size, 1000);
1333 }
1334
1335 #[test]
1336 fn test_referenced_columns() {
1337 let pred = Predicate::And(
1338 Box::new(Predicate::Gt(col("age"), make_blob(18))),
1339 Box::new(Predicate::Or(
1340 Box::new(Predicate::Lt(col("salary"), make_blob(100))),
1341 Box::new(Predicate::Eq(col("age"), make_blob(30))),
1342 )),
1343 );
1344
1345 let cols = QueryPlanner::referenced_columns(&pred);
1346 assert_eq!(cols, vec!["age".to_string(), "salary".to_string()]);
1347 }
1348
1349 #[test]
1350 fn test_display_plan_cost() {
1351 let cost = PlanCost::compute(1000, 50, 256_000);
1352 let display = format!("{}", cost);
1353 assert!(display.contains("1000"));
1354 assert!(display.contains("50"));
1355 }
1356
1357 #[test]
1358 fn test_logical_plan_display() {
1359 let plan = LogicalPlan::Filter {
1360 input: Box::new(LogicalPlan::Scan {
1361 collection: "t".to_string(),
1362 }),
1363 predicate: Predicate::Eq(col("x"), make_blob(1)),
1364 };
1365
1366 let s = format!("{}", plan);
1367 assert!(s.contains("Filter"));
1368 assert!(s.contains("Scan"));
1369 }
1370}