1pub use super::plan_cache::{CacheKey, CacheStats, CachedPlan, PlanCache, PlanCacheConfig};
5use crate::compute::EncryptedType;
6use crate::compute::circuit::Circuit;
7use crate::compute::predicate::PredicateCompiler;
8use crate::error::{AmateRSError, ErrorContext, Result};
9use crate::types::{CipherBlob, ColumnRef, JoinType, Key, Predicate, Query};
10use dashmap::DashMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13#[derive(Debug, Clone)]
18pub enum LogicalPlan {
19 Scan {
21 collection: String,
23 },
24 RangeScan {
26 collection: String,
28 start_key: Option<Vec<u8>>,
30 end_key: Option<Vec<u8>>,
32 },
33 Filter {
35 input: Box<LogicalPlan>,
37 predicate: Predicate,
39 },
40 Project {
42 input: Box<LogicalPlan>,
44 columns: Vec<String>,
46 },
47 Limit {
49 input: Box<LogicalPlan>,
51 count: usize,
53 },
54 PointLookup {
56 collection: String,
58 key: Key,
60 },
61 Join {
63 left: Box<LogicalPlan>,
65 right: Box<LogicalPlan>,
67 on: Predicate,
69 join_type: JoinType,
71 },
72}
73#[derive(Debug, Clone)]
77pub enum PhysicalPlan {
78 SeqScan {
80 collection: String,
82 },
83 IndexScan {
85 collection: String,
87 start: Option<Vec<u8>>,
89 end: Option<Vec<u8>>,
91 },
92 FheFilter {
94 input: Box<PhysicalPlan>,
96 circuit: Circuit,
98 predicate: Predicate,
100 },
101 Projection {
103 input: Box<PhysicalPlan>,
105 columns: Vec<String>,
107 },
108 Limit {
110 input: Box<PhysicalPlan>,
112 count: usize,
114 },
115 PointGet {
117 collection: String,
119 key: Key,
121 },
122 NestedLoopJoin {
124 outer: Box<PhysicalPlan>,
126 build: Box<PhysicalPlan>,
128 on: Predicate,
130 join_type: JoinType,
132 },
133 HashJoin {
135 probe: Box<PhysicalPlan>,
137 build: Box<PhysicalPlan>,
139 on: Predicate,
141 join_type: JoinType,
143 },
144}
145#[derive(Debug, Clone)]
147pub struct PlanCost {
148 pub estimated_rows: u64,
150 pub estimated_fhe_ops: u64,
152 pub estimated_io_bytes: u64,
154 pub total_cost: f64,
156}
157impl PlanCost {
158 const IO_COST_PER_BYTE: f64 = 0.001;
160 const FHE_COST_PER_OP: f64 = 100.0;
162 const SCAN_COST_PER_ROW: f64 = 0.01;
164 const POINT_LOOKUP_COST: f64 = 1.0;
166 fn compute(estimated_rows: u64, estimated_fhe_ops: u64, estimated_io_bytes: u64) -> Self {
168 let total_cost = (estimated_rows as f64 * Self::SCAN_COST_PER_ROW)
169 + (estimated_fhe_ops as f64 * Self::FHE_COST_PER_OP)
170 + (estimated_io_bytes as f64 * Self::IO_COST_PER_BYTE);
171 Self {
172 estimated_rows,
173 estimated_fhe_ops,
174 estimated_io_bytes,
175 total_cost,
176 }
177 }
178}
179impl std::fmt::Display for PlanCost {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 write!(
182 f,
183 "PlanCost(rows={}, fhe_ops={}, io_bytes={}, total={:.2})",
184 self.estimated_rows, self.estimated_fhe_ops, self.estimated_io_bytes, self.total_cost
185 )
186 }
187}
188pub struct PlannerStats {
193 pub estimated_collection_sizes: DashMap<String, u64>,
195 pub average_value_size: u64,
197 pub fhe_op_latency_us: u64,
199 pub fhe_comparison_cost: f64,
201 pub fhe_boolean_cost: f64,
203}
204impl PlannerStats {
205 fn new() -> Self {
207 Self {
208 estimated_collection_sizes: DashMap::new(),
209 average_value_size: 256,
210 fhe_op_latency_us: 1000,
211 fhe_comparison_cost: 100.0,
212 fhe_boolean_cost: 10.0,
213 }
214 }
215 fn collection_size(&self, collection: &str) -> u64 {
217 self.estimated_collection_sizes
218 .get(collection)
219 .map(|v| *v)
220 .unwrap_or(1000)
221 }
222 pub fn set_collection_size(&self, collection: impl Into<String>, size: u64) {
224 self.estimated_collection_sizes
225 .insert(collection.into(), size);
226 }
227 pub fn predicate_selectivity(&self, pred: &Predicate) -> f64 {
236 match pred {
237 Predicate::Eq(_, _) => 0.001,
238 Predicate::Lt(_, _)
239 | Predicate::Gt(_, _)
240 | Predicate::Lte(_, _)
241 | Predicate::Gte(_, _) => 0.3,
242 Predicate::And(p1, p2) => {
243 self.predicate_selectivity(p1) * self.predicate_selectivity(p2)
244 }
245 Predicate::Or(p1, p2) => {
246 let s1 = self.predicate_selectivity(p1);
247 let s2 = self.predicate_selectivity(p2);
248 1.0 - (1.0 - s1) * (1.0 - s2)
249 }
250 Predicate::Not(inner) => 1.0 - self.predicate_selectivity(inner),
251 }
252 }
253 pub fn predicate_fhe_cost(&self, pred: &Predicate) -> f64 {
259 match pred {
260 Predicate::Eq(_, _)
261 | Predicate::Lt(_, _)
262 | Predicate::Gt(_, _)
263 | Predicate::Lte(_, _)
264 | Predicate::Gte(_, _) => self.fhe_comparison_cost,
265 Predicate::And(p1, p2) | Predicate::Or(p1, p2) => {
266 self.fhe_boolean_cost + self.predicate_fhe_cost(p1) + self.predicate_fhe_cost(p2)
267 }
268 Predicate::Not(inner) => self.fhe_boolean_cost * 0.5 + self.predicate_fhe_cost(inner),
269 }
270 }
271}
272impl Default for PlannerStats {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277pub struct QueryPlanner {
285 stats: Arc<PlannerStats>,
287 cache: Option<Arc<PlanCache>>,
289}
290impl QueryPlanner {
291 pub fn new() -> Self {
293 Self {
294 stats: Arc::new(PlannerStats::new()),
295 cache: None,
296 }
297 }
298 pub fn with_stats(stats: Arc<PlannerStats>) -> Self {
300 Self { stats, cache: None }
301 }
302 pub fn with_cache(mut self, config: PlanCacheConfig) -> Self {
304 self.cache = Some(Arc::new(PlanCache::new(config)));
305 self
306 }
307 pub fn stats(&self) -> &PlannerStats {
309 &self.stats
310 }
311 pub fn plan_cache(&self) -> Option<&PlanCache> {
313 self.cache.as_deref()
314 }
315 pub fn cache_stats(&self) -> CacheStats {
317 self.cache
318 .as_ref()
319 .map(|c| c.cache_stats())
320 .unwrap_or_default()
321 }
322 pub fn invalidate_all(&self) {
324 if let Some(cache) = &self.cache {
325 cache.invalidate_all();
326 }
327 }
328 pub fn invalidate_prefix(&self, prefix: &str) {
330 if let Some(cache) = &self.cache {
331 cache.invalidate_prefix(prefix);
332 }
333 }
334 pub fn plan(&self, query: &Query) -> Result<PhysicalPlan> {
340 let cache_key = CacheKey::from_query(query);
341 if let Some(cache) = &self.cache {
342 if let Some(cached_plan) = cache.get(&cache_key) {
343 return Ok(cached_plan);
344 }
345 }
346 let logical = self.to_logical(query)?;
347 let optimized = self.optimize_logical(logical);
348 let physical = self.to_physical(&optimized)?;
349 if let Some(cache) = &self.cache {
350 let normalized = CacheKey::normalize(&format!("{:?}", query));
351 cache.insert(cache_key, physical.clone(), normalized);
352 }
353 Ok(physical)
354 }
355 fn to_logical(&self, query: &Query) -> Result<LogicalPlan> {
357 match query {
358 Query::Get { collection, key } => Ok(LogicalPlan::PointLookup {
359 collection: collection.clone(),
360 key: key.clone(),
361 }),
362 Query::Filter {
363 collection,
364 predicate,
365 } => Ok(LogicalPlan::Filter {
366 input: Box::new(LogicalPlan::Scan {
367 collection: collection.clone(),
368 }),
369 predicate: predicate.clone(),
370 }),
371 Query::Range {
372 collection,
373 start,
374 end,
375 } => Ok(LogicalPlan::RangeScan {
376 collection: collection.clone(),
377 start_key: Some(start.to_vec()),
378 end_key: Some(end.to_vec()),
379 }),
380 Query::Set { collection, .. } => Ok(LogicalPlan::Scan {
381 collection: collection.clone(),
382 }),
383 Query::Delete { collection, key } => Ok(LogicalPlan::PointLookup {
384 collection: collection.clone(),
385 key: key.clone(),
386 }),
387 Query::Update {
388 collection,
389 predicate,
390 ..
391 } => Ok(LogicalPlan::Filter {
392 input: Box::new(LogicalPlan::Scan {
393 collection: collection.clone(),
394 }),
395 predicate: predicate.clone(),
396 }),
397 Query::Join {
398 left_collection,
399 right_collection,
400 on,
401 join_type,
402 left_limit,
403 right_limit,
404 } => {
405 let mut left: LogicalPlan = LogicalPlan::Scan {
406 collection: left_collection.clone(),
407 };
408 if let Some(n) = left_limit {
409 left = LogicalPlan::Limit {
410 input: Box::new(left),
411 count: *n,
412 };
413 }
414 let mut right: LogicalPlan = LogicalPlan::Scan {
415 collection: right_collection.clone(),
416 };
417 if let Some(n) = right_limit {
418 right = LogicalPlan::Limit {
419 input: Box::new(right),
420 count: *n,
421 };
422 }
423 Ok(LogicalPlan::Join {
424 left: Box::new(left),
425 right: Box::new(right),
426 on: on.clone(),
427 join_type: join_type.clone(),
428 })
429 }
430 }
431 }
432 fn optimize_logical(&self, plan: LogicalPlan) -> LogicalPlan {
434 let plan = self.push_predicates_down(plan);
435 let plan = self.merge_filters(plan);
436 let plan = self.convert_filter_to_range_scan(plan);
437 self.reorder_predicates_by_cost(plan)
438 }
439 fn push_predicates_down(&self, plan: LogicalPlan) -> LogicalPlan {
447 match plan {
448 LogicalPlan::Filter {
449 input,
450 predicate: Predicate::And(p1, p2),
451 } if !matches!(*input, LogicalPlan::Limit { .. }) => {
452 let inner = LogicalPlan::Filter {
453 input,
454 predicate: *p2,
455 };
456 let outer = LogicalPlan::Filter {
457 input: Box::new(inner),
458 predicate: *p1,
459 };
460 self.push_predicates_down(outer)
461 }
462 LogicalPlan::Filter { input, predicate }
463 if matches!(*input, LogicalPlan::Join { .. }) =>
464 {
465 if let LogicalPlan::Join {
466 left,
467 right,
468 on,
469 join_type,
470 } = *input
471 {
472 let left_cols = Self::referenced_columns(&predicate);
473 let right_input_cols = Self::plan_output_columns(&right);
474 let left_input_cols = Self::plan_output_columns(&left);
475 let touches_left = left_cols.iter().any(|c| left_input_cols.contains(c));
476 let touches_right = left_cols.iter().any(|c| right_input_cols.contains(c));
477 match (touches_left, touches_right) {
478 (true, false) => {
479 let new_left = self.push_predicates_down(LogicalPlan::Filter {
480 input: left,
481 predicate,
482 });
483 self.push_predicates_down(LogicalPlan::Join {
484 left: Box::new(new_left),
485 right,
486 on,
487 join_type,
488 })
489 }
490 (false, true) => {
491 let new_right = self.push_predicates_down(LogicalPlan::Filter {
492 input: right,
493 predicate,
494 });
495 self.push_predicates_down(LogicalPlan::Join {
496 left,
497 right: Box::new(new_right),
498 on,
499 join_type,
500 })
501 }
502 _ => {
503 let joined = self.push_predicates_down(LogicalPlan::Join {
504 left,
505 right,
506 on,
507 join_type,
508 });
509 LogicalPlan::Filter {
510 input: Box::new(joined),
511 predicate,
512 }
513 }
514 }
515 } else {
516 unreachable!("guard confirmed Join variant")
517 }
518 }
519 LogicalPlan::Filter { input, predicate } => {
520 let optimized_input = self.push_predicates_down(*input);
521 match optimized_input {
522 LogicalPlan::Project {
523 input: proj_input,
524 columns,
525 } => {
526 let pred_cols = Self::referenced_columns(&predicate);
527 let proj_set: HashSet<&str> = columns.iter().map(|c| c.as_str()).collect();
528 if pred_cols.iter().all(|c| proj_set.contains(c.as_str())) {
529 LogicalPlan::Project {
530 input: Box::new(LogicalPlan::Filter {
531 input: proj_input,
532 predicate,
533 }),
534 columns,
535 }
536 } else {
537 let mut extended_cols = columns.clone();
538 for col in &pred_cols {
539 if !proj_set.contains(col.as_str()) {
540 extended_cols.push(col.clone());
541 }
542 }
543 LogicalPlan::Project {
544 input: Box::new(LogicalPlan::Filter {
545 input: Box::new(LogicalPlan::Project {
546 input: proj_input,
547 columns: extended_cols,
548 }),
549 predicate,
550 }),
551 columns,
552 }
553 }
554 }
555 other => LogicalPlan::Filter {
556 input: Box::new(other),
557 predicate,
558 },
559 }
560 }
561 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
562 input: Box::new(self.push_predicates_down(*input)),
563 columns,
564 },
565 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
566 input: Box::new(self.push_predicates_down(*input)),
567 count,
568 },
569 LogicalPlan::Join {
570 left,
571 right,
572 on,
573 join_type,
574 } => LogicalPlan::Join {
575 left: Box::new(self.push_predicates_down(*left)),
576 right: Box::new(self.push_predicates_down(*right)),
577 on,
578 join_type,
579 },
580 other => other,
581 }
582 }
583 fn plan_output_columns(plan: &LogicalPlan) -> HashSet<String> {
590 match plan {
591 LogicalPlan::Project { columns, .. } => columns.iter().cloned().collect(),
592 _ => HashSet::new(),
593 }
594 }
595 fn merge_filters(&self, plan: LogicalPlan) -> LogicalPlan {
599 match plan {
600 LogicalPlan::Filter { input, predicate } => {
601 let optimized_input = self.merge_filters(*input);
602 match optimized_input {
603 LogicalPlan::Filter {
604 input: inner_input,
605 predicate: inner_pred,
606 } => LogicalPlan::Filter {
607 input: inner_input,
608 predicate: Predicate::And(Box::new(inner_pred), Box::new(predicate)),
609 },
610 other => LogicalPlan::Filter {
611 input: Box::new(other),
612 predicate,
613 },
614 }
615 }
616 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
617 input: Box::new(self.merge_filters(*input)),
618 columns,
619 },
620 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
621 input: Box::new(self.merge_filters(*input)),
622 count,
623 },
624 LogicalPlan::Join {
625 left,
626 right,
627 on,
628 join_type,
629 } => LogicalPlan::Join {
630 left: Box::new(self.merge_filters(*left)),
631 right: Box::new(self.merge_filters(*right)),
632 on,
633 join_type,
634 },
635 other => other,
636 }
637 }
638 fn convert_filter_to_range_scan(&self, plan: LogicalPlan) -> LogicalPlan {
644 match plan {
645 LogicalPlan::Filter { input, predicate } => {
646 let optimized_input = self.convert_filter_to_range_scan(*input);
647 if let LogicalPlan::Scan { ref collection } = optimized_input {
648 if let Some((start, end)) = Self::extract_key_range(&predicate) {
649 return LogicalPlan::RangeScan {
650 collection: collection.clone(),
651 start_key: start,
652 end_key: end,
653 };
654 }
655 }
656 LogicalPlan::Filter {
657 input: Box::new(optimized_input),
658 predicate,
659 }
660 }
661 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
662 input: Box::new(self.convert_filter_to_range_scan(*input)),
663 columns,
664 },
665 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
666 input: Box::new(self.convert_filter_to_range_scan(*input)),
667 count,
668 },
669 LogicalPlan::Join {
670 left,
671 right,
672 on,
673 join_type,
674 } => LogicalPlan::Join {
675 left: Box::new(self.convert_filter_to_range_scan(*left)),
676 right: Box::new(self.convert_filter_to_range_scan(*right)),
677 on,
678 join_type,
679 },
680 other => other,
681 }
682 }
683 fn reorder_predicates_by_cost(&self, plan: LogicalPlan) -> LogicalPlan {
689 match plan {
690 LogicalPlan::Filter { input, predicate } => {
691 let reordered_pred = self.reorder_pred(&predicate);
692 let optimized_input = self.reorder_predicates_by_cost(*input);
693 LogicalPlan::Filter {
694 input: Box::new(optimized_input),
695 predicate: reordered_pred,
696 }
697 }
698 LogicalPlan::Project { input, columns } => LogicalPlan::Project {
699 input: Box::new(self.reorder_predicates_by_cost(*input)),
700 columns,
701 },
702 LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
703 input: Box::new(self.reorder_predicates_by_cost(*input)),
704 count,
705 },
706 LogicalPlan::Join {
707 left,
708 right,
709 on,
710 join_type,
711 } => {
712 let reordered_on = self.reorder_pred(&on);
713 LogicalPlan::Join {
714 left: Box::new(self.reorder_predicates_by_cost(*left)),
715 right: Box::new(self.reorder_predicates_by_cost(*right)),
716 on: reordered_on,
717 join_type,
718 }
719 }
720 other => other,
721 }
722 }
723 fn reorder_pred(&self, pred: &Predicate) -> Predicate {
725 match pred {
726 Predicate::And(p1, p2) => {
727 let r1 = self.reorder_pred(p1);
728 let r2 = self.reorder_pred(p2);
729 let cost1 =
730 self.stats.predicate_selectivity(&r1) * self.stats.predicate_fhe_cost(&r1);
731 let cost2 =
732 self.stats.predicate_selectivity(&r2) * self.stats.predicate_fhe_cost(&r2);
733 if cost1 <= cost2 {
734 Predicate::And(Box::new(r1), Box::new(r2))
735 } else {
736 Predicate::And(Box::new(r2), Box::new(r1))
737 }
738 }
739 Predicate::Or(p1, p2) => Predicate::Or(
740 Box::new(self.reorder_pred(p1)),
741 Box::new(self.reorder_pred(p2)),
742 ),
743 Predicate::Not(inner) => Predicate::Not(Box::new(self.reorder_pred(inner))),
744 other => other.clone(),
745 }
746 }
747 fn to_physical(&self, plan: &LogicalPlan) -> Result<PhysicalPlan> {
749 match plan {
750 LogicalPlan::Scan { collection } => Ok(PhysicalPlan::SeqScan {
751 collection: collection.clone(),
752 }),
753 LogicalPlan::RangeScan {
754 collection,
755 start_key,
756 end_key,
757 } => Ok(PhysicalPlan::IndexScan {
758 collection: collection.clone(),
759 start: start_key.clone(),
760 end: end_key.clone(),
761 }),
762 LogicalPlan::Filter { input, predicate } => {
763 let physical_input = self.to_physical(input)?;
764 let circuit = self.compile_predicate_circuit(predicate)?;
765 Ok(PhysicalPlan::FheFilter {
766 input: Box::new(physical_input),
767 circuit,
768 predicate: predicate.clone(),
769 })
770 }
771 LogicalPlan::Project { input, columns } => {
772 let physical_input = self.to_physical(input)?;
773 Ok(PhysicalPlan::Projection {
774 input: Box::new(physical_input),
775 columns: columns.clone(),
776 })
777 }
778 LogicalPlan::Limit { input, count } => {
779 let physical_input = self.to_physical(input)?;
780 Ok(PhysicalPlan::Limit {
781 input: Box::new(physical_input),
782 count: *count,
783 })
784 }
785 LogicalPlan::PointLookup { collection, key } => Ok(PhysicalPlan::PointGet {
786 collection: collection.clone(),
787 key: key.clone(),
788 }),
789 LogicalPlan::Join {
790 left,
791 right,
792 on,
793 join_type,
794 } => {
795 let left_phys = self.to_physical(left)?;
796 let right_phys = self.to_physical(right)?;
797 let left_rows = self.estimate_cost(&left_phys).estimated_rows;
798 let right_rows = self.estimate_cost(&right_phys).estimated_rows;
799 let use_hash = matches!(on, Predicate::Eq(_, _));
800 if use_hash {
801 let (probe, build) = if left_rows <= right_rows {
802 (right_phys, left_phys)
803 } else {
804 (left_phys, right_phys)
805 };
806 Ok(PhysicalPlan::HashJoin {
807 probe: Box::new(probe),
808 build: Box::new(build),
809 on: on.clone(),
810 join_type: join_type.clone(),
811 })
812 } else {
813 let (outer, build) = if left_rows <= right_rows {
814 (left_phys, right_phys)
815 } else {
816 (right_phys, left_phys)
817 };
818 Ok(PhysicalPlan::NestedLoopJoin {
819 outer: Box::new(outer),
820 build: Box::new(build),
821 on: on.clone(),
822 join_type: join_type.clone(),
823 })
824 }
825 }
826 }
827 }
828 pub fn estimate_cost(&self, plan: &PhysicalPlan) -> PlanCost {
830 match plan {
831 PhysicalPlan::SeqScan { collection } => {
832 let rows = self.stats.collection_size(collection);
833 let io_bytes = rows * self.stats.average_value_size;
834 PlanCost::compute(rows, 0, io_bytes)
835 }
836 PhysicalPlan::IndexScan {
837 collection,
838 start,
839 end,
840 } => {
841 let total = self.stats.collection_size(collection);
842 let selectivity = match (start, end) {
843 (Some(_), Some(_)) => 0.10,
844 (Some(_), None) | (None, Some(_)) => 0.30,
845 (None, None) => 1.0,
846 };
847 let rows = ((total as f64) * selectivity).max(1.0) as u64;
848 let io_bytes = rows * self.stats.average_value_size;
849 PlanCost::compute(rows, 0, io_bytes)
850 }
851 PhysicalPlan::FheFilter { input, circuit, .. } => {
852 let input_cost = self.estimate_cost(input);
853 let fhe_ops = input_cost.estimated_rows * (circuit.gate_count as u64);
854 let output_rows = (input_cost.estimated_rows / 2).max(1);
855 let io_bytes = output_rows * self.stats.average_value_size;
856 PlanCost::compute(
857 input_cost.estimated_rows,
858 input_cost.estimated_fhe_ops + fhe_ops,
859 input_cost.estimated_io_bytes + io_bytes,
860 )
861 }
862 PhysicalPlan::Projection { input, .. } => {
863 let mut cost = self.estimate_cost(input);
864 cost.estimated_io_bytes = (cost.estimated_io_bytes as f64 * 0.8) as u64;
865 cost.total_cost = (cost.estimated_rows as f64 * PlanCost::SCAN_COST_PER_ROW)
866 + (cost.estimated_fhe_ops as f64 * PlanCost::FHE_COST_PER_OP)
867 + (cost.estimated_io_bytes as f64 * PlanCost::IO_COST_PER_BYTE);
868 cost
869 }
870 PhysicalPlan::Limit { input, count } => {
871 let input_cost = self.estimate_cost(input);
872 let rows = (*count as u64).min(input_cost.estimated_rows);
873 let io_bytes = rows * self.stats.average_value_size;
874 PlanCost::compute(rows, input_cost.estimated_fhe_ops, io_bytes)
875 }
876 PhysicalPlan::PointGet { .. } => PlanCost::compute(1, 0, self.stats.average_value_size),
877 PhysicalPlan::NestedLoopJoin { outer, build, .. } => {
878 let outer_cost = self.estimate_cost(outer);
879 let build_cost = self.estimate_cost(build);
880 let outer_rows = outer_cost.estimated_rows;
881 let build_rows = build_cost.estimated_rows;
882 let fhe_ops = outer_rows.saturating_mul(build_rows);
883 let estimated_rows = outer_rows.saturating_mul(build_rows) / 2;
884 let io_bytes = outer_cost.estimated_io_bytes + build_cost.estimated_io_bytes;
885 PlanCost::compute(estimated_rows, fhe_ops, io_bytes)
886 }
887 PhysicalPlan::HashJoin { probe, build, .. } => {
888 let probe_cost = self.estimate_cost(probe);
889 let build_cost = self.estimate_cost(build);
890 let probe_rows = probe_cost.estimated_rows;
891 let build_rows = build_cost.estimated_rows;
892 let fhe_ops = probe_cost.estimated_fhe_ops + build_cost.estimated_fhe_ops;
893 let estimated_rows = probe_rows.saturating_mul(build_rows) / 2;
894 let io_bytes = probe_cost.estimated_io_bytes + build_cost.estimated_io_bytes;
895 PlanCost::compute(estimated_rows, fhe_ops, io_bytes)
896 }
897 }
898 }
899 pub fn choose_cheaper<'a>(&self, a: &'a PhysicalPlan, b: &'a PhysicalPlan) -> &'a PhysicalPlan {
901 let cost_a = self.estimate_cost(a);
902 let cost_b = self.estimate_cost(b);
903 if cost_a.total_cost <= cost_b.total_cost {
904 a
905 } else {
906 b
907 }
908 }
909 fn referenced_columns(predicate: &Predicate) -> Vec<String> {
911 let mut cols = Vec::new();
912 Self::collect_columns(predicate, &mut cols);
913 cols.sort();
914 cols.dedup();
915 cols
916 }
917 fn collect_columns(predicate: &Predicate, out: &mut Vec<String>) {
918 match predicate {
919 Predicate::Eq(col, _)
920 | Predicate::Gt(col, _)
921 | Predicate::Lt(col, _)
922 | Predicate::Gte(col, _)
923 | Predicate::Lte(col, _) => {
924 out.push(col.name.clone());
925 }
926 Predicate::And(l, r) | Predicate::Or(l, r) => {
927 Self::collect_columns(l, out);
928 Self::collect_columns(r, out);
929 }
930 Predicate::Not(inner) => {
931 Self::collect_columns(inner, out);
932 }
933 }
934 }
935 fn extract_key_range(predicate: &Predicate) -> Option<(Option<Vec<u8>>, Option<Vec<u8>>)> {
940 match predicate {
941 Predicate::Gt(col, blob) if col.name == "_key" => {
942 Some((Some(blob.as_bytes().to_vec()), None))
943 }
944 Predicate::Gte(col, blob) if col.name == "_key" => {
945 Some((Some(blob.as_bytes().to_vec()), None))
946 }
947 Predicate::Lt(col, blob) if col.name == "_key" => {
948 Some((None, Some(blob.as_bytes().to_vec())))
949 }
950 Predicate::Lte(col, blob) if col.name == "_key" => {
951 Some((None, Some(blob.as_bytes().to_vec())))
952 }
953 Predicate::And(left, right) => {
954 let lr = Self::extract_key_range(left);
955 let rr = Self::extract_key_range(right);
956 match (lr, rr) {
957 (Some((s1, e1)), Some((s2, e2))) => {
958 let start = s1.or(s2);
959 let end = e1.or(e2);
960 Some((start, end))
961 }
962 (Some(range), None) | (None, Some(range)) => Some(range),
963 (None, None) => None,
964 }
965 }
966 _ => None,
967 }
968 }
969 fn compile_predicate_circuit(&self, predicate: &Predicate) -> Result<Circuit> {
971 let mut compiler = PredicateCompiler::new();
972 compiler.compile(predicate, EncryptedType::U8)
973 }
974}
975impl Default for QueryPlanner {
976 fn default() -> Self {
977 Self::new()
978 }
979}
980impl std::fmt::Display for LogicalPlan {
981 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
982 self.fmt_indented(f, 0)
983 }
984}
985impl LogicalPlan {
986 fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
987 let pad = " ".repeat(indent);
988 match self {
989 LogicalPlan::Scan { collection } => {
990 writeln!(f, "{}Scan({})", pad, collection)
991 }
992 LogicalPlan::RangeScan {
993 collection,
994 start_key,
995 end_key,
996 } => {
997 writeln!(
998 f,
999 "{}RangeScan({}, start={}, end={})",
1000 pad,
1001 collection,
1002 start_key.is_some(),
1003 end_key.is_some()
1004 )
1005 }
1006 LogicalPlan::Filter { input, predicate } => {
1007 writeln!(f, "{}Filter(pred={:?})", pad, predicate)?;
1008 input.fmt_indented(f, indent + 1)
1009 }
1010 LogicalPlan::Project { input, columns } => {
1011 writeln!(f, "{}Project({:?})", pad, columns)?;
1012 input.fmt_indented(f, indent + 1)
1013 }
1014 LogicalPlan::Limit { input, count } => {
1015 writeln!(f, "{}Limit({})", pad, count)?;
1016 input.fmt_indented(f, indent + 1)
1017 }
1018 LogicalPlan::PointLookup { collection, key } => {
1019 writeln!(f, "{}PointLookup({}, key={})", pad, collection, key)
1020 }
1021 LogicalPlan::Join {
1022 left,
1023 right,
1024 on,
1025 join_type,
1026 } => {
1027 let jt = match join_type {
1028 JoinType::Inner => "Inner",
1029 JoinType::Left => "Left",
1030 JoinType::Right => "Right",
1031 };
1032 writeln!(f, "{}{}Join(on={:?})", pad, jt, on)?;
1033 left.fmt_indented(f, indent + 1)?;
1034 right.fmt_indented(f, indent + 1)
1035 }
1036 }
1037 }
1038}
1039impl std::fmt::Display for PhysicalPlan {
1040 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041 self.fmt_indented(f, 0)
1042 }
1043}
1044impl PhysicalPlan {
1045 fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
1046 let pad = " ".repeat(indent);
1047 match self {
1048 PhysicalPlan::SeqScan { collection } => {
1049 writeln!(f, "{}SeqScan({})", pad, collection)
1050 }
1051 PhysicalPlan::IndexScan {
1052 collection,
1053 start,
1054 end,
1055 } => {
1056 writeln!(
1057 f,
1058 "{}IndexScan({}, start={}, end={})",
1059 pad,
1060 collection,
1061 start.is_some(),
1062 end.is_some()
1063 )
1064 }
1065 PhysicalPlan::FheFilter {
1066 input, predicate, ..
1067 } => {
1068 writeln!(f, "{}FheFilter(pred={:?})", pad, predicate)?;
1069 input.fmt_indented(f, indent + 1)
1070 }
1071 PhysicalPlan::Projection { input, columns } => {
1072 writeln!(f, "{}Projection({:?})", pad, columns)?;
1073 input.fmt_indented(f, indent + 1)
1074 }
1075 PhysicalPlan::Limit { input, count } => {
1076 writeln!(f, "{}Limit({})", pad, count)?;
1077 input.fmt_indented(f, indent + 1)
1078 }
1079 PhysicalPlan::PointGet { collection, key } => {
1080 writeln!(f, "{}PointGet({}, key={})", pad, collection, key)
1081 }
1082 PhysicalPlan::NestedLoopJoin {
1083 outer,
1084 build,
1085 on,
1086 join_type,
1087 } => {
1088 let jt = match join_type {
1089 JoinType::Inner => "Inner",
1090 JoinType::Left => "Left",
1091 JoinType::Right => "Right",
1092 };
1093 writeln!(f, "{}NestedLoopJoin[{}](on={:?})", pad, jt, on)?;
1094 outer.fmt_indented(f, indent + 1)?;
1095 build.fmt_indented(f, indent + 1)
1096 }
1097 PhysicalPlan::HashJoin {
1098 probe,
1099 build,
1100 on,
1101 join_type,
1102 } => {
1103 let jt = match join_type {
1104 JoinType::Inner => "Inner",
1105 JoinType::Left => "Left",
1106 JoinType::Right => "Right",
1107 };
1108 writeln!(f, "{}HashJoin[{}](on={:?})", pad, jt, on)?;
1109 probe.fmt_indented(f, indent + 1)?;
1110 build.fmt_indented(f, indent + 1)
1111 }
1112 }
1113 }
1114}
1115
1116#[cfg(test)]
1117mod tests;