Skip to main content

amaters_core/compute/planner/
mod.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4pub use super::plan_cache::{CacheKey, CacheStats, CachedPlan, PlanCache, PlanCacheConfig};
5use crate::compute::EncryptedType;
6use crate::compute::circuit::Circuit;
7use crate::compute::predicate::PredicateCompiler;
8use crate::error::{AmateRSError, ErrorContext, Result};
9use crate::types::{CipherBlob, ColumnRef, JoinType, Key, Predicate, Query};
10use dashmap::DashMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13/// Logical query plan node
14///
15/// Represents the *intent* of a query before physical execution details
16/// are decided. The logical plan is the subject of optimization rewrites.
17#[derive(Debug, Clone)]
18pub enum LogicalPlan {
19    /// Full table/collection scan
20    Scan {
21        /// Name of the collection to scan
22        collection: String,
23    },
24    /// Range scan with start/end keys
25    RangeScan {
26        /// Name of the collection
27        collection: String,
28        /// Inclusive start key (None = beginning)
29        start_key: Option<Vec<u8>>,
30        /// Exclusive end key (None = end)
31        end_key: Option<Vec<u8>>,
32    },
33    /// Filter with predicate (operates on encrypted data via FHE)
34    Filter {
35        /// Input plan to filter
36        input: Box<LogicalPlan>,
37        /// Predicate to evaluate
38        predicate: Predicate,
39    },
40    /// Projection (select specific columns)
41    Project {
42        /// Input plan to project
43        input: Box<LogicalPlan>,
44        /// Column names to retain
45        columns: Vec<String>,
46    },
47    /// Limit number of results
48    Limit {
49        /// Input plan to limit
50        input: Box<LogicalPlan>,
51        /// Maximum number of results
52        count: usize,
53    },
54    /// Point lookup by key
55    PointLookup {
56        /// Collection name
57        collection: String,
58        /// Key to look up
59        key: Key,
60    },
61    /// Two-collection join
62    Join {
63        /// Left input plan
64        left: Box<LogicalPlan>,
65        /// Right input plan
66        right: Box<LogicalPlan>,
67        /// Join condition
68        on: Predicate,
69        /// Join type (Inner / Left / Right)
70        join_type: JoinType,
71    },
72}
73/// Physical query plan (executable)
74///
75/// Each variant maps directly to a concrete execution strategy.
76#[derive(Debug, Clone)]
77pub enum PhysicalPlan {
78    /// Sequential full scan
79    SeqScan {
80        /// Collection to scan
81        collection: String,
82    },
83    /// Index/range scan (pushdown to storage layer)
84    IndexScan {
85        /// Collection to scan
86        collection: String,
87        /// Inclusive start key
88        start: Option<Vec<u8>>,
89        /// Exclusive end key
90        end: Option<Vec<u8>>,
91    },
92    /// FHE filter evaluation (evaluated on encrypted data)
93    FheFilter {
94        /// Input physical plan
95        input: Box<PhysicalPlan>,
96        /// Compiled FHE circuit for the filter
97        circuit: Circuit,
98        /// Original predicate (kept for introspection / explain)
99        predicate: Predicate,
100    },
101    /// Client-side projection
102    Projection {
103        /// Input physical plan
104        input: Box<PhysicalPlan>,
105        /// Columns to retain
106        columns: Vec<String>,
107    },
108    /// Limit result count
109    Limit {
110        /// Input physical plan
111        input: Box<PhysicalPlan>,
112        /// Maximum results
113        count: usize,
114    },
115    /// Point lookup by key
116    PointGet {
117        /// Collection name
118        collection: String,
119        /// Key to look up
120        key: Key,
121    },
122    /// Nested-loop join — O(n*m), used for encrypted-key / non-Eq predicates
123    NestedLoopJoin {
124        /// Outer (driving) side
125        outer: Box<PhysicalPlan>,
126        /// Build (inner) side iterated for every outer row
127        build: Box<PhysicalPlan>,
128        /// Join condition
129        on: Predicate,
130        /// Join type
131        join_type: JoinType,
132    },
133    /// Hash join — O(n+m), used when the join condition is a single Eq predicate
134    HashJoin {
135        /// Probe side (larger estimated input)
136        probe: Box<PhysicalPlan>,
137        /// Build side hashed into memory (smaller estimated input)
138        build: Box<PhysicalPlan>,
139        /// Join condition (must be Predicate::Eq)
140        on: Predicate,
141        /// Join type
142        join_type: JoinType,
143    },
144}
145/// Cost estimate for a physical plan
146#[derive(Debug, Clone)]
147pub struct PlanCost {
148    /// Estimated number of rows touched
149    pub estimated_rows: u64,
150    /// Estimated number of FHE gate operations
151    pub estimated_fhe_ops: u64,
152    /// Estimated I/O bytes transferred
153    pub estimated_io_bytes: u64,
154    /// Aggregated scalar cost (lower is better)
155    pub total_cost: f64,
156}
157impl PlanCost {
158    /// Cost weight per byte of I/O
159    const IO_COST_PER_BYTE: f64 = 0.001;
160    /// Cost weight per FHE gate operation (FHE is *very* expensive)
161    const FHE_COST_PER_OP: f64 = 100.0;
162    /// Cost weight per row scanned
163    const SCAN_COST_PER_ROW: f64 = 0.01;
164    /// Fixed cost per point lookup
165    const POINT_LOOKUP_COST: f64 = 1.0;
166    /// Compute the total cost from the individual estimates
167    fn compute(estimated_rows: u64, estimated_fhe_ops: u64, estimated_io_bytes: u64) -> Self {
168        let total_cost = (estimated_rows as f64 * Self::SCAN_COST_PER_ROW)
169            + (estimated_fhe_ops as f64 * Self::FHE_COST_PER_OP)
170            + (estimated_io_bytes as f64 * Self::IO_COST_PER_BYTE);
171        Self {
172            estimated_rows,
173            estimated_fhe_ops,
174            estimated_io_bytes,
175            total_cost,
176        }
177    }
178}
179impl std::fmt::Display for PlanCost {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        write!(
182            f,
183            "PlanCost(rows={}, fhe_ops={}, io_bytes={}, total={:.2})",
184            self.estimated_rows, self.estimated_fhe_ops, self.estimated_io_bytes, self.total_cost
185        )
186    }
187}
188/// Statistics used for cost estimation
189///
190/// Maintains per-collection cardinality estimates and global latency hints
191/// so that the planner can make informed decisions.
192pub struct PlannerStats {
193    /// Estimated row count per collection
194    pub estimated_collection_sizes: DashMap<String, u64>,
195    /// Average value size in bytes across all collections
196    pub average_value_size: u64,
197    /// Estimated microsecond latency of a single FHE gate operation
198    pub fhe_op_latency_us: u64,
199    /// Cost of a single FHE comparison operation (Eq / Lt / Gt / Lte / Gte)
200    pub fhe_comparison_cost: f64,
201    /// Cost of a single FHE boolean operation (And / Or)
202    pub fhe_boolean_cost: f64,
203}
204impl PlannerStats {
205    /// Create default statistics with reasonable starting values
206    fn new() -> Self {
207        Self {
208            estimated_collection_sizes: DashMap::new(),
209            average_value_size: 256,
210            fhe_op_latency_us: 1000,
211            fhe_comparison_cost: 100.0,
212            fhe_boolean_cost: 10.0,
213        }
214    }
215    /// Return the estimated size of a collection, defaulting to 1000
216    fn collection_size(&self, collection: &str) -> u64 {
217        self.estimated_collection_sizes
218            .get(collection)
219            .map(|v| *v)
220            .unwrap_or(1000)
221    }
222    /// Update the estimated size for a collection
223    pub fn set_collection_size(&self, collection: impl Into<String>, size: u64) {
224        self.estimated_collection_sizes
225            .insert(collection.into(), size);
226    }
227    /// Estimate the fraction of rows a predicate will pass (0.0–1.0).
228    ///
229    /// Heuristics (no histograms available):
230    /// - `Eq`            → 0.001  (high selectivity, rare match)
231    /// - `Lt/Gt/Lte/Gte` → 0.3   (moderate selectivity)
232    /// - `And(p1, p2)`   → s1 * s2
233    /// - `Or(p1, p2)`    → 1 - (1-s1)*(1-s2)
234    /// - `Not(p)`        → 1 - s
235    pub fn predicate_selectivity(&self, pred: &Predicate) -> f64 {
236        match pred {
237            Predicate::Eq(_, _) => 0.001,
238            Predicate::Lt(_, _)
239            | Predicate::Gt(_, _)
240            | Predicate::Lte(_, _)
241            | Predicate::Gte(_, _) => 0.3,
242            Predicate::And(p1, p2) => {
243                self.predicate_selectivity(p1) * self.predicate_selectivity(p2)
244            }
245            Predicate::Or(p1, p2) => {
246                let s1 = self.predicate_selectivity(p1);
247                let s2 = self.predicate_selectivity(p2);
248                1.0 - (1.0 - s1) * (1.0 - s2)
249            }
250            Predicate::Not(inner) => 1.0 - self.predicate_selectivity(inner),
251        }
252    }
253    /// Estimate the relative FHE cost of evaluating a predicate (in abstract units).
254    ///
255    /// - Leaf comparisons (`Eq/Lt/Gt/Lte/Gte`) each cost 1.0 comparison unit.
256    /// - `And`/`Or` cost 1.0 boolean-op + sum of children costs.
257    /// - `Not` costs 0.5 boolean-op + child cost.
258    pub fn predicate_fhe_cost(&self, pred: &Predicate) -> f64 {
259        match pred {
260            Predicate::Eq(_, _)
261            | Predicate::Lt(_, _)
262            | Predicate::Gt(_, _)
263            | Predicate::Lte(_, _)
264            | Predicate::Gte(_, _) => self.fhe_comparison_cost,
265            Predicate::And(p1, p2) | Predicate::Or(p1, p2) => {
266                self.fhe_boolean_cost + self.predicate_fhe_cost(p1) + self.predicate_fhe_cost(p2)
267            }
268            Predicate::Not(inner) => self.fhe_boolean_cost * 0.5 + self.predicate_fhe_cost(inner),
269        }
270    }
271}
272impl Default for PlannerStats {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277/// Query planner that converts `Query` into optimized `PhysicalPlan`
278///
279/// The planner applies predicate pushdown, filter merging, and cost-based
280/// selection to produce an execution plan that minimises expensive FHE
281/// operations and I/O.
282///
283/// Optionally maintains a plan cache to avoid re-planning identical queries.
284pub struct QueryPlanner {
285    /// Statistics for cost estimation
286    stats: Arc<PlannerStats>,
287    /// Optional plan cache
288    cache: Option<Arc<PlanCache>>,
289}
290impl QueryPlanner {
291    /// Create a new query planner with default statistics
292    pub fn new() -> Self {
293        Self {
294            stats: Arc::new(PlannerStats::new()),
295            cache: None,
296        }
297    }
298    /// Create a planner with custom statistics
299    pub fn with_stats(stats: Arc<PlannerStats>) -> Self {
300        Self { stats, cache: None }
301    }
302    /// Enable plan caching with the given configuration
303    pub fn with_cache(mut self, config: PlanCacheConfig) -> Self {
304        self.cache = Some(Arc::new(PlanCache::new(config)));
305        self
306    }
307    /// Get a reference to the planner statistics
308    pub fn stats(&self) -> &PlannerStats {
309        &self.stats
310    }
311    /// Get a reference to the plan cache, if enabled
312    pub fn plan_cache(&self) -> Option<&PlanCache> {
313        self.cache.as_deref()
314    }
315    /// Return cache statistics, or default stats if caching is not enabled
316    pub fn cache_stats(&self) -> CacheStats {
317        self.cache
318            .as_ref()
319            .map(|c| c.cache_stats())
320            .unwrap_or_default()
321    }
322    /// Invalidate all cached plans (e.g., after a schema change)
323    pub fn invalidate_all(&self) {
324        if let Some(cache) = &self.cache {
325            cache.invalidate_all();
326        }
327    }
328    /// Invalidate cached plans matching a prefix (e.g., a collection name)
329    pub fn invalidate_prefix(&self, prefix: &str) {
330        if let Some(cache) = &self.cache {
331            cache.invalidate_prefix(prefix);
332        }
333    }
334    /// Plan a query
335    ///
336    /// If caching is enabled, checks the cache first and returns a cached
337    /// plan if one exists and has not expired. Otherwise, plans the query
338    /// from scratch and inserts the result into the cache.
339    pub fn plan(&self, query: &Query) -> Result<PhysicalPlan> {
340        let cache_key = CacheKey::from_query(query);
341        if let Some(cache) = &self.cache {
342            if let Some(cached_plan) = cache.get(&cache_key) {
343                return Ok(cached_plan);
344            }
345        }
346        let logical = self.to_logical(query)?;
347        let optimized = self.optimize_logical(logical);
348        let physical = self.to_physical(&optimized)?;
349        if let Some(cache) = &self.cache {
350            let normalized = CacheKey::normalize(&format!("{:?}", query));
351            cache.insert(cache_key, physical.clone(), normalized);
352        }
353        Ok(physical)
354    }
355    /// Convert a high-level `Query` into a `LogicalPlan`
356    fn to_logical(&self, query: &Query) -> Result<LogicalPlan> {
357        match query {
358            Query::Get { collection, key } => Ok(LogicalPlan::PointLookup {
359                collection: collection.clone(),
360                key: key.clone(),
361            }),
362            Query::Filter {
363                collection,
364                predicate,
365            } => Ok(LogicalPlan::Filter {
366                input: Box::new(LogicalPlan::Scan {
367                    collection: collection.clone(),
368                }),
369                predicate: predicate.clone(),
370            }),
371            Query::Range {
372                collection,
373                start,
374                end,
375            } => Ok(LogicalPlan::RangeScan {
376                collection: collection.clone(),
377                start_key: Some(start.to_vec()),
378                end_key: Some(end.to_vec()),
379            }),
380            Query::Set { collection, .. } => Ok(LogicalPlan::Scan {
381                collection: collection.clone(),
382            }),
383            Query::Delete { collection, key } => Ok(LogicalPlan::PointLookup {
384                collection: collection.clone(),
385                key: key.clone(),
386            }),
387            Query::Update {
388                collection,
389                predicate,
390                ..
391            } => Ok(LogicalPlan::Filter {
392                input: Box::new(LogicalPlan::Scan {
393                    collection: collection.clone(),
394                }),
395                predicate: predicate.clone(),
396            }),
397            Query::Join {
398                left_collection,
399                right_collection,
400                on,
401                join_type,
402                left_limit,
403                right_limit,
404            } => {
405                let mut left: LogicalPlan = LogicalPlan::Scan {
406                    collection: left_collection.clone(),
407                };
408                if let Some(n) = left_limit {
409                    left = LogicalPlan::Limit {
410                        input: Box::new(left),
411                        count: *n,
412                    };
413                }
414                let mut right: LogicalPlan = LogicalPlan::Scan {
415                    collection: right_collection.clone(),
416                };
417                if let Some(n) = right_limit {
418                    right = LogicalPlan::Limit {
419                        input: Box::new(right),
420                        count: *n,
421                    };
422                }
423                Ok(LogicalPlan::Join {
424                    left: Box::new(left),
425                    right: Box::new(right),
426                    on: on.clone(),
427                    join_type: join_type.clone(),
428                })
429            }
430        }
431    }
432    /// Apply all logical optimization passes
433    fn optimize_logical(&self, plan: LogicalPlan) -> LogicalPlan {
434        let plan = self.push_predicates_down(plan);
435        let plan = self.merge_filters(plan);
436        let plan = self.convert_filter_to_range_scan(plan);
437        self.reorder_predicates_by_cost(plan)
438    }
439    /// Predicate pushdown: move filters closer to the data source
440    ///
441    /// Rules applied:
442    /// - `Filter(And(p1, p2), input)` (non-Limit input) → split, recurse each conjunct.
443    /// - `Filter(Project(input, cols), pred)` → push the filter below the projection.
444    /// - `Filter(Join{..}, pred)` → push single-side conjuncts into the appropriate join arm.
445    /// - `Filter(Filter(input, p1), p2)` is handled by `merge_filters`.
446    fn push_predicates_down(&self, plan: LogicalPlan) -> LogicalPlan {
447        match plan {
448            LogicalPlan::Filter {
449                input,
450                predicate: Predicate::And(p1, p2),
451            } if !matches!(*input, LogicalPlan::Limit { .. }) => {
452                let inner = LogicalPlan::Filter {
453                    input,
454                    predicate: *p2,
455                };
456                let outer = LogicalPlan::Filter {
457                    input: Box::new(inner),
458                    predicate: *p1,
459                };
460                self.push_predicates_down(outer)
461            }
462            LogicalPlan::Filter { input, predicate }
463                if matches!(*input, LogicalPlan::Join { .. }) =>
464            {
465                if let LogicalPlan::Join {
466                    left,
467                    right,
468                    on,
469                    join_type,
470                } = *input
471                {
472                    let left_cols = Self::referenced_columns(&predicate);
473                    let right_input_cols = Self::plan_output_columns(&right);
474                    let left_input_cols = Self::plan_output_columns(&left);
475                    let touches_left = left_cols.iter().any(|c| left_input_cols.contains(c));
476                    let touches_right = left_cols.iter().any(|c| right_input_cols.contains(c));
477                    match (touches_left, touches_right) {
478                        (true, false) => {
479                            let new_left = self.push_predicates_down(LogicalPlan::Filter {
480                                input: left,
481                                predicate,
482                            });
483                            self.push_predicates_down(LogicalPlan::Join {
484                                left: Box::new(new_left),
485                                right,
486                                on,
487                                join_type,
488                            })
489                        }
490                        (false, true) => {
491                            let new_right = self.push_predicates_down(LogicalPlan::Filter {
492                                input: right,
493                                predicate,
494                            });
495                            self.push_predicates_down(LogicalPlan::Join {
496                                left,
497                                right: Box::new(new_right),
498                                on,
499                                join_type,
500                            })
501                        }
502                        _ => {
503                            let joined = self.push_predicates_down(LogicalPlan::Join {
504                                left,
505                                right,
506                                on,
507                                join_type,
508                            });
509                            LogicalPlan::Filter {
510                                input: Box::new(joined),
511                                predicate,
512                            }
513                        }
514                    }
515                } else {
516                    unreachable!("guard confirmed Join variant")
517                }
518            }
519            LogicalPlan::Filter { input, predicate } => {
520                let optimized_input = self.push_predicates_down(*input);
521                match optimized_input {
522                    LogicalPlan::Project {
523                        input: proj_input,
524                        columns,
525                    } => {
526                        let pred_cols = Self::referenced_columns(&predicate);
527                        let proj_set: HashSet<&str> = columns.iter().map(|c| c.as_str()).collect();
528                        if pred_cols.iter().all(|c| proj_set.contains(c.as_str())) {
529                            LogicalPlan::Project {
530                                input: Box::new(LogicalPlan::Filter {
531                                    input: proj_input,
532                                    predicate,
533                                }),
534                                columns,
535                            }
536                        } else {
537                            let mut extended_cols = columns.clone();
538                            for col in &pred_cols {
539                                if !proj_set.contains(col.as_str()) {
540                                    extended_cols.push(col.clone());
541                                }
542                            }
543                            LogicalPlan::Project {
544                                input: Box::new(LogicalPlan::Filter {
545                                    input: Box::new(LogicalPlan::Project {
546                                        input: proj_input,
547                                        columns: extended_cols,
548                                    }),
549                                    predicate,
550                                }),
551                                columns,
552                            }
553                        }
554                    }
555                    other => LogicalPlan::Filter {
556                        input: Box::new(other),
557                        predicate,
558                    },
559                }
560            }
561            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
562                input: Box::new(self.push_predicates_down(*input)),
563                columns,
564            },
565            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
566                input: Box::new(self.push_predicates_down(*input)),
567                count,
568            },
569            LogicalPlan::Join {
570                left,
571                right,
572                on,
573                join_type,
574            } => LogicalPlan::Join {
575                left: Box::new(self.push_predicates_down(*left)),
576                right: Box::new(self.push_predicates_down(*right)),
577                on,
578                join_type,
579            },
580            other => other,
581        }
582    }
583    /// Collect the set of column names that a plan might output.
584    ///
585    /// Used to determine whether a predicate touches columns from a specific
586    /// join arm. For scans we have no schema, so we return an empty set (which
587    /// causes cross-side classification and keeps the filter above the join,
588    /// the safe default).
589    fn plan_output_columns(plan: &LogicalPlan) -> HashSet<String> {
590        match plan {
591            LogicalPlan::Project { columns, .. } => columns.iter().cloned().collect(),
592            _ => HashSet::new(),
593        }
594    }
595    /// Merge adjacent filters into a single AND predicate
596    ///
597    /// `Filter(Filter(input, p1), p2)` => `Filter(input, And(p1, p2))`
598    fn merge_filters(&self, plan: LogicalPlan) -> LogicalPlan {
599        match plan {
600            LogicalPlan::Filter { input, predicate } => {
601                let optimized_input = self.merge_filters(*input);
602                match optimized_input {
603                    LogicalPlan::Filter {
604                        input: inner_input,
605                        predicate: inner_pred,
606                    } => LogicalPlan::Filter {
607                        input: inner_input,
608                        predicate: Predicate::And(Box::new(inner_pred), Box::new(predicate)),
609                    },
610                    other => LogicalPlan::Filter {
611                        input: Box::new(other),
612                        predicate,
613                    },
614                }
615            }
616            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
617                input: Box::new(self.merge_filters(*input)),
618                columns,
619            },
620            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
621                input: Box::new(self.merge_filters(*input)),
622                count,
623            },
624            LogicalPlan::Join {
625                left,
626                right,
627                on,
628                join_type,
629            } => LogicalPlan::Join {
630                left: Box::new(self.merge_filters(*left)),
631                right: Box::new(self.merge_filters(*right)),
632                on,
633                join_type,
634            },
635            other => other,
636        }
637    }
638    /// Convert a filter on key range into a `RangeScan` when possible
639    ///
640    /// If a `Filter(Scan(collection), pred)` has a predicate that is purely
641    /// a key-range comparison (Gt/Lt/Gte/Lte on the `_key` column), we can
642    /// replace the scan+filter with a more efficient `RangeScan`.
643    fn convert_filter_to_range_scan(&self, plan: LogicalPlan) -> LogicalPlan {
644        match plan {
645            LogicalPlan::Filter { input, predicate } => {
646                let optimized_input = self.convert_filter_to_range_scan(*input);
647                if let LogicalPlan::Scan { ref collection } = optimized_input {
648                    if let Some((start, end)) = Self::extract_key_range(&predicate) {
649                        return LogicalPlan::RangeScan {
650                            collection: collection.clone(),
651                            start_key: start,
652                            end_key: end,
653                        };
654                    }
655                }
656                LogicalPlan::Filter {
657                    input: Box::new(optimized_input),
658                    predicate,
659                }
660            }
661            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
662                input: Box::new(self.convert_filter_to_range_scan(*input)),
663                columns,
664            },
665            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
666                input: Box::new(self.convert_filter_to_range_scan(*input)),
667                count,
668            },
669            LogicalPlan::Join {
670                left,
671                right,
672                on,
673                join_type,
674            } => LogicalPlan::Join {
675                left: Box::new(self.convert_filter_to_range_scan(*left)),
676                right: Box::new(self.convert_filter_to_range_scan(*right)),
677                on,
678                join_type,
679            },
680            other => other,
681        }
682    }
683    /// Reorder conjuncts in `And` predicates so the cheaper (lower
684    /// selectivity × fhe_cost) predicate is evaluated first (outer And-branch).
685    ///
686    /// This is a pure structural rewrite; semantics are unchanged because AND
687    /// is commutative.
688    fn reorder_predicates_by_cost(&self, plan: LogicalPlan) -> LogicalPlan {
689        match plan {
690            LogicalPlan::Filter { input, predicate } => {
691                let reordered_pred = self.reorder_pred(&predicate);
692                let optimized_input = self.reorder_predicates_by_cost(*input);
693                LogicalPlan::Filter {
694                    input: Box::new(optimized_input),
695                    predicate: reordered_pred,
696                }
697            }
698            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
699                input: Box::new(self.reorder_predicates_by_cost(*input)),
700                columns,
701            },
702            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
703                input: Box::new(self.reorder_predicates_by_cost(*input)),
704                count,
705            },
706            LogicalPlan::Join {
707                left,
708                right,
709                on,
710                join_type,
711            } => {
712                let reordered_on = self.reorder_pred(&on);
713                LogicalPlan::Join {
714                    left: Box::new(self.reorder_predicates_by_cost(*left)),
715                    right: Box::new(self.reorder_predicates_by_cost(*right)),
716                    on: reordered_on,
717                    join_type,
718                }
719            }
720            other => other,
721        }
722    }
723    /// Recursively reorder `And` sub-predicates cheapest-first.
724    fn reorder_pred(&self, pred: &Predicate) -> Predicate {
725        match pred {
726            Predicate::And(p1, p2) => {
727                let r1 = self.reorder_pred(p1);
728                let r2 = self.reorder_pred(p2);
729                let cost1 =
730                    self.stats.predicate_selectivity(&r1) * self.stats.predicate_fhe_cost(&r1);
731                let cost2 =
732                    self.stats.predicate_selectivity(&r2) * self.stats.predicate_fhe_cost(&r2);
733                if cost1 <= cost2 {
734                    Predicate::And(Box::new(r1), Box::new(r2))
735                } else {
736                    Predicate::And(Box::new(r2), Box::new(r1))
737                }
738            }
739            Predicate::Or(p1, p2) => Predicate::Or(
740                Box::new(self.reorder_pred(p1)),
741                Box::new(self.reorder_pred(p2)),
742            ),
743            Predicate::Not(inner) => Predicate::Not(Box::new(self.reorder_pred(inner))),
744            other => other.clone(),
745        }
746    }
747    /// Convert an optimized logical plan into a physical plan
748    fn to_physical(&self, plan: &LogicalPlan) -> Result<PhysicalPlan> {
749        match plan {
750            LogicalPlan::Scan { collection } => Ok(PhysicalPlan::SeqScan {
751                collection: collection.clone(),
752            }),
753            LogicalPlan::RangeScan {
754                collection,
755                start_key,
756                end_key,
757            } => Ok(PhysicalPlan::IndexScan {
758                collection: collection.clone(),
759                start: start_key.clone(),
760                end: end_key.clone(),
761            }),
762            LogicalPlan::Filter { input, predicate } => {
763                let physical_input = self.to_physical(input)?;
764                let circuit = self.compile_predicate_circuit(predicate)?;
765                Ok(PhysicalPlan::FheFilter {
766                    input: Box::new(physical_input),
767                    circuit,
768                    predicate: predicate.clone(),
769                })
770            }
771            LogicalPlan::Project { input, columns } => {
772                let physical_input = self.to_physical(input)?;
773                Ok(PhysicalPlan::Projection {
774                    input: Box::new(physical_input),
775                    columns: columns.clone(),
776                })
777            }
778            LogicalPlan::Limit { input, count } => {
779                let physical_input = self.to_physical(input)?;
780                Ok(PhysicalPlan::Limit {
781                    input: Box::new(physical_input),
782                    count: *count,
783                })
784            }
785            LogicalPlan::PointLookup { collection, key } => Ok(PhysicalPlan::PointGet {
786                collection: collection.clone(),
787                key: key.clone(),
788            }),
789            LogicalPlan::Join {
790                left,
791                right,
792                on,
793                join_type,
794            } => {
795                let left_phys = self.to_physical(left)?;
796                let right_phys = self.to_physical(right)?;
797                let left_rows = self.estimate_cost(&left_phys).estimated_rows;
798                let right_rows = self.estimate_cost(&right_phys).estimated_rows;
799                let use_hash = matches!(on, Predicate::Eq(_, _));
800                if use_hash {
801                    let (probe, build) = if left_rows <= right_rows {
802                        (right_phys, left_phys)
803                    } else {
804                        (left_phys, right_phys)
805                    };
806                    Ok(PhysicalPlan::HashJoin {
807                        probe: Box::new(probe),
808                        build: Box::new(build),
809                        on: on.clone(),
810                        join_type: join_type.clone(),
811                    })
812                } else {
813                    let (outer, build) = if left_rows <= right_rows {
814                        (left_phys, right_phys)
815                    } else {
816                        (right_phys, left_phys)
817                    };
818                    Ok(PhysicalPlan::NestedLoopJoin {
819                        outer: Box::new(outer),
820                        build: Box::new(build),
821                        on: on.clone(),
822                        join_type: join_type.clone(),
823                    })
824                }
825            }
826        }
827    }
828    /// Estimate the cost of a physical plan
829    pub fn estimate_cost(&self, plan: &PhysicalPlan) -> PlanCost {
830        match plan {
831            PhysicalPlan::SeqScan { collection } => {
832                let rows = self.stats.collection_size(collection);
833                let io_bytes = rows * self.stats.average_value_size;
834                PlanCost::compute(rows, 0, io_bytes)
835            }
836            PhysicalPlan::IndexScan {
837                collection,
838                start,
839                end,
840            } => {
841                let total = self.stats.collection_size(collection);
842                let selectivity = match (start, end) {
843                    (Some(_), Some(_)) => 0.10,
844                    (Some(_), None) | (None, Some(_)) => 0.30,
845                    (None, None) => 1.0,
846                };
847                let rows = ((total as f64) * selectivity).max(1.0) as u64;
848                let io_bytes = rows * self.stats.average_value_size;
849                PlanCost::compute(rows, 0, io_bytes)
850            }
851            PhysicalPlan::FheFilter { input, circuit, .. } => {
852                let input_cost = self.estimate_cost(input);
853                let fhe_ops = input_cost.estimated_rows * (circuit.gate_count as u64);
854                let output_rows = (input_cost.estimated_rows / 2).max(1);
855                let io_bytes = output_rows * self.stats.average_value_size;
856                PlanCost::compute(
857                    input_cost.estimated_rows,
858                    input_cost.estimated_fhe_ops + fhe_ops,
859                    input_cost.estimated_io_bytes + io_bytes,
860                )
861            }
862            PhysicalPlan::Projection { input, .. } => {
863                let mut cost = self.estimate_cost(input);
864                cost.estimated_io_bytes = (cost.estimated_io_bytes as f64 * 0.8) as u64;
865                cost.total_cost = (cost.estimated_rows as f64 * PlanCost::SCAN_COST_PER_ROW)
866                    + (cost.estimated_fhe_ops as f64 * PlanCost::FHE_COST_PER_OP)
867                    + (cost.estimated_io_bytes as f64 * PlanCost::IO_COST_PER_BYTE);
868                cost
869            }
870            PhysicalPlan::Limit { input, count } => {
871                let input_cost = self.estimate_cost(input);
872                let rows = (*count as u64).min(input_cost.estimated_rows);
873                let io_bytes = rows * self.stats.average_value_size;
874                PlanCost::compute(rows, input_cost.estimated_fhe_ops, io_bytes)
875            }
876            PhysicalPlan::PointGet { .. } => PlanCost::compute(1, 0, self.stats.average_value_size),
877            PhysicalPlan::NestedLoopJoin { outer, build, .. } => {
878                let outer_cost = self.estimate_cost(outer);
879                let build_cost = self.estimate_cost(build);
880                let outer_rows = outer_cost.estimated_rows;
881                let build_rows = build_cost.estimated_rows;
882                let fhe_ops = outer_rows.saturating_mul(build_rows);
883                let estimated_rows = outer_rows.saturating_mul(build_rows) / 2;
884                let io_bytes = outer_cost.estimated_io_bytes + build_cost.estimated_io_bytes;
885                PlanCost::compute(estimated_rows, fhe_ops, io_bytes)
886            }
887            PhysicalPlan::HashJoin { probe, build, .. } => {
888                let probe_cost = self.estimate_cost(probe);
889                let build_cost = self.estimate_cost(build);
890                let probe_rows = probe_cost.estimated_rows;
891                let build_rows = build_cost.estimated_rows;
892                let fhe_ops = probe_cost.estimated_fhe_ops + build_cost.estimated_fhe_ops;
893                let estimated_rows = probe_rows.saturating_mul(build_rows) / 2;
894                let io_bytes = probe_cost.estimated_io_bytes + build_cost.estimated_io_bytes;
895                PlanCost::compute(estimated_rows, fhe_ops, io_bytes)
896            }
897        }
898    }
899    /// Compare two physical plans by cost and return the cheaper one
900    pub fn choose_cheaper<'a>(&self, a: &'a PhysicalPlan, b: &'a PhysicalPlan) -> &'a PhysicalPlan {
901        let cost_a = self.estimate_cost(a);
902        let cost_b = self.estimate_cost(b);
903        if cost_a.total_cost <= cost_b.total_cost {
904            a
905        } else {
906            b
907        }
908    }
909    /// Extract all column names referenced in a predicate
910    fn referenced_columns(predicate: &Predicate) -> Vec<String> {
911        let mut cols = Vec::new();
912        Self::collect_columns(predicate, &mut cols);
913        cols.sort();
914        cols.dedup();
915        cols
916    }
917    fn collect_columns(predicate: &Predicate, out: &mut Vec<String>) {
918        match predicate {
919            Predicate::Eq(col, _)
920            | Predicate::Gt(col, _)
921            | Predicate::Lt(col, _)
922            | Predicate::Gte(col, _)
923            | Predicate::Lte(col, _) => {
924                out.push(col.name.clone());
925            }
926            Predicate::And(l, r) | Predicate::Or(l, r) => {
927                Self::collect_columns(l, out);
928                Self::collect_columns(r, out);
929            }
930            Predicate::Not(inner) => {
931                Self::collect_columns(inner, out);
932            }
933        }
934    }
935    /// Try to extract a key range from a predicate on the `_key` column
936    ///
937    /// Returns `Some((start, end))` where either bound may be `None`.
938    /// Returns `None` if the predicate is not a simple key-range filter.
939    fn extract_key_range(predicate: &Predicate) -> Option<(Option<Vec<u8>>, Option<Vec<u8>>)> {
940        match predicate {
941            Predicate::Gt(col, blob) if col.name == "_key" => {
942                Some((Some(blob.as_bytes().to_vec()), None))
943            }
944            Predicate::Gte(col, blob) if col.name == "_key" => {
945                Some((Some(blob.as_bytes().to_vec()), None))
946            }
947            Predicate::Lt(col, blob) if col.name == "_key" => {
948                Some((None, Some(blob.as_bytes().to_vec())))
949            }
950            Predicate::Lte(col, blob) if col.name == "_key" => {
951                Some((None, Some(blob.as_bytes().to_vec())))
952            }
953            Predicate::And(left, right) => {
954                let lr = Self::extract_key_range(left);
955                let rr = Self::extract_key_range(right);
956                match (lr, rr) {
957                    (Some((s1, e1)), Some((s2, e2))) => {
958                        let start = s1.or(s2);
959                        let end = e1.or(e2);
960                        Some((start, end))
961                    }
962                    (Some(range), None) | (None, Some(range)) => Some(range),
963                    (None, None) => None,
964                }
965            }
966            _ => None,
967        }
968    }
969    /// Compile a predicate into an FHE circuit
970    fn compile_predicate_circuit(&self, predicate: &Predicate) -> Result<Circuit> {
971        let mut compiler = PredicateCompiler::new();
972        compiler.compile(predicate, EncryptedType::U8)
973    }
974}
975impl Default for QueryPlanner {
976    fn default() -> Self {
977        Self::new()
978    }
979}
980impl std::fmt::Display for LogicalPlan {
981    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
982        self.fmt_indented(f, 0)
983    }
984}
985impl LogicalPlan {
986    fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
987        let pad = "  ".repeat(indent);
988        match self {
989            LogicalPlan::Scan { collection } => {
990                writeln!(f, "{}Scan({})", pad, collection)
991            }
992            LogicalPlan::RangeScan {
993                collection,
994                start_key,
995                end_key,
996            } => {
997                writeln!(
998                    f,
999                    "{}RangeScan({}, start={}, end={})",
1000                    pad,
1001                    collection,
1002                    start_key.is_some(),
1003                    end_key.is_some()
1004                )
1005            }
1006            LogicalPlan::Filter { input, predicate } => {
1007                writeln!(f, "{}Filter(pred={:?})", pad, predicate)?;
1008                input.fmt_indented(f, indent + 1)
1009            }
1010            LogicalPlan::Project { input, columns } => {
1011                writeln!(f, "{}Project({:?})", pad, columns)?;
1012                input.fmt_indented(f, indent + 1)
1013            }
1014            LogicalPlan::Limit { input, count } => {
1015                writeln!(f, "{}Limit({})", pad, count)?;
1016                input.fmt_indented(f, indent + 1)
1017            }
1018            LogicalPlan::PointLookup { collection, key } => {
1019                writeln!(f, "{}PointLookup({}, key={})", pad, collection, key)
1020            }
1021            LogicalPlan::Join {
1022                left,
1023                right,
1024                on,
1025                join_type,
1026            } => {
1027                let jt = match join_type {
1028                    JoinType::Inner => "Inner",
1029                    JoinType::Left => "Left",
1030                    JoinType::Right => "Right",
1031                };
1032                writeln!(f, "{}{}Join(on={:?})", pad, jt, on)?;
1033                left.fmt_indented(f, indent + 1)?;
1034                right.fmt_indented(f, indent + 1)
1035            }
1036        }
1037    }
1038}
1039impl std::fmt::Display for PhysicalPlan {
1040    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041        self.fmt_indented(f, 0)
1042    }
1043}
1044impl PhysicalPlan {
1045    fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
1046        let pad = "  ".repeat(indent);
1047        match self {
1048            PhysicalPlan::SeqScan { collection } => {
1049                writeln!(f, "{}SeqScan({})", pad, collection)
1050            }
1051            PhysicalPlan::IndexScan {
1052                collection,
1053                start,
1054                end,
1055            } => {
1056                writeln!(
1057                    f,
1058                    "{}IndexScan({}, start={}, end={})",
1059                    pad,
1060                    collection,
1061                    start.is_some(),
1062                    end.is_some()
1063                )
1064            }
1065            PhysicalPlan::FheFilter {
1066                input, predicate, ..
1067            } => {
1068                writeln!(f, "{}FheFilter(pred={:?})", pad, predicate)?;
1069                input.fmt_indented(f, indent + 1)
1070            }
1071            PhysicalPlan::Projection { input, columns } => {
1072                writeln!(f, "{}Projection({:?})", pad, columns)?;
1073                input.fmt_indented(f, indent + 1)
1074            }
1075            PhysicalPlan::Limit { input, count } => {
1076                writeln!(f, "{}Limit({})", pad, count)?;
1077                input.fmt_indented(f, indent + 1)
1078            }
1079            PhysicalPlan::PointGet { collection, key } => {
1080                writeln!(f, "{}PointGet({}, key={})", pad, collection, key)
1081            }
1082            PhysicalPlan::NestedLoopJoin {
1083                outer,
1084                build,
1085                on,
1086                join_type,
1087            } => {
1088                let jt = match join_type {
1089                    JoinType::Inner => "Inner",
1090                    JoinType::Left => "Left",
1091                    JoinType::Right => "Right",
1092                };
1093                writeln!(f, "{}NestedLoopJoin[{}](on={:?})", pad, jt, on)?;
1094                outer.fmt_indented(f, indent + 1)?;
1095                build.fmt_indented(f, indent + 1)
1096            }
1097            PhysicalPlan::HashJoin {
1098                probe,
1099                build,
1100                on,
1101                join_type,
1102            } => {
1103                let jt = match join_type {
1104                    JoinType::Inner => "Inner",
1105                    JoinType::Left => "Left",
1106                    JoinType::Right => "Right",
1107                };
1108                writeln!(f, "{}HashJoin[{}](on={:?})", pad, jt, on)?;
1109                probe.fmt_indented(f, indent + 1)?;
1110                build.fmt_indented(f, indent + 1)
1111            }
1112        }
1113    }
1114}
1115
1116#[cfg(test)]
1117mod tests;