Skip to main content

amaters_core/compute/
planner.rs

1//! Query planner with predicate pushdown and cost-based optimization
2//!
3//! This module provides a query planner that transforms high-level `Query` objects
4//! into optimized physical execution plans. It applies several optimization strategies:
5//!
6//! 1. **Predicate Pushdown** - Push filter predicates as close to the data source as possible
7//! 2. **Filter Merging** - Combine adjacent filter operations into compound predicates
8//! 3. **Cost-Based Optimization** - Estimate and compare plan costs to choose the cheapest one
9//! 4. **Range Scan Conversion** - Convert key-range filters into efficient range scans
10//!
11//! # Architecture
12//!
13//! The planner works in three phases:
14//!
15//! 1. **Logical Planning** - Convert a `Query` into a `LogicalPlan` tree
16//! 2. **Logical Optimization** - Apply rewrite rules (predicate pushdown, filter merge, etc.)
17//! 3. **Physical Planning** - Convert the optimized logical plan into a `PhysicalPlan`
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use amaters_core::compute::planner::QueryPlanner;
23//! use amaters_core::types::{Query, QueryBuilder, Predicate, col, CipherBlob};
24//!
25//! let planner = QueryPlanner::new();
26//! let query = QueryBuilder::new("users").filter(
27//!     Predicate::Gt(col("age"), CipherBlob::new(vec![18]))
28//! );
29//!
30//! let plan = planner.plan(&query)?;
31//! let cost = planner.estimate_cost(&plan);
32//! println!("Estimated cost: {}", cost.total_cost);
33//! ```
34
35use crate::compute::EncryptedType;
36use crate::compute::circuit::Circuit;
37use crate::compute::predicate::PredicateCompiler;
38use crate::error::{AmateRSError, ErrorContext, Result};
39use crate::types::{CipherBlob, ColumnRef, Key, Predicate, Query};
40use dashmap::DashMap;
41use std::collections::HashSet;
42use std::sync::Arc;
43
44pub use super::plan_cache::{CacheKey, CacheStats, CachedPlan, PlanCache, PlanCacheConfig};
45
46// ---------------------------------------------------------------------------
47// Logical plan
48// ---------------------------------------------------------------------------
49
50/// Logical query plan node
51///
52/// Represents the *intent* of a query before physical execution details
53/// are decided. The logical plan is the subject of optimization rewrites.
54#[derive(Debug, Clone)]
55pub enum LogicalPlan {
56    /// Full table/collection scan
57    Scan {
58        /// Name of the collection to scan
59        collection: String,
60    },
61
62    /// Range scan with start/end keys
63    RangeScan {
64        /// Name of the collection
65        collection: String,
66        /// Inclusive start key (None = beginning)
67        start_key: Option<Vec<u8>>,
68        /// Exclusive end key (None = end)
69        end_key: Option<Vec<u8>>,
70    },
71
72    /// Filter with predicate (operates on encrypted data via FHE)
73    Filter {
74        /// Input plan to filter
75        input: Box<LogicalPlan>,
76        /// Predicate to evaluate
77        predicate: Predicate,
78    },
79
80    /// Projection (select specific columns)
81    Project {
82        /// Input plan to project
83        input: Box<LogicalPlan>,
84        /// Column names to retain
85        columns: Vec<String>,
86    },
87
88    /// Limit number of results
89    Limit {
90        /// Input plan to limit
91        input: Box<LogicalPlan>,
92        /// Maximum number of results
93        count: usize,
94    },
95
96    /// Point lookup by key
97    PointLookup {
98        /// Collection name
99        collection: String,
100        /// Key to look up
101        key: Key,
102    },
103}
104
105// ---------------------------------------------------------------------------
106// Physical plan
107// ---------------------------------------------------------------------------
108
109/// Physical query plan (executable)
110///
111/// Each variant maps directly to a concrete execution strategy.
112#[derive(Debug, Clone)]
113pub enum PhysicalPlan {
114    /// Sequential full scan
115    SeqScan {
116        /// Collection to scan
117        collection: String,
118    },
119
120    /// Index/range scan (pushdown to storage layer)
121    IndexScan {
122        /// Collection to scan
123        collection: String,
124        /// Inclusive start key
125        start: Option<Vec<u8>>,
126        /// Exclusive end key
127        end: Option<Vec<u8>>,
128    },
129
130    /// FHE filter evaluation (evaluated on encrypted data)
131    FheFilter {
132        /// Input physical plan
133        input: Box<PhysicalPlan>,
134        /// Compiled FHE circuit for the filter
135        circuit: Circuit,
136        /// Original predicate (kept for introspection / explain)
137        predicate: Predicate,
138    },
139
140    /// Client-side projection
141    Projection {
142        /// Input physical plan
143        input: Box<PhysicalPlan>,
144        /// Columns to retain
145        columns: Vec<String>,
146    },
147
148    /// Limit result count
149    Limit {
150        /// Input physical plan
151        input: Box<PhysicalPlan>,
152        /// Maximum results
153        count: usize,
154    },
155
156    /// Point lookup by key
157    PointGet {
158        /// Collection name
159        collection: String,
160        /// Key to look up
161        key: Key,
162    },
163}
164
165// ---------------------------------------------------------------------------
166// Cost model
167// ---------------------------------------------------------------------------
168
169/// Cost estimate for a physical plan
170#[derive(Debug, Clone)]
171pub struct PlanCost {
172    /// Estimated number of rows touched
173    pub estimated_rows: u64,
174    /// Estimated number of FHE gate operations
175    pub estimated_fhe_ops: u64,
176    /// Estimated I/O bytes transferred
177    pub estimated_io_bytes: u64,
178    /// Aggregated scalar cost (lower is better)
179    pub total_cost: f64,
180}
181
182impl PlanCost {
183    /// Cost weight per byte of I/O
184    const IO_COST_PER_BYTE: f64 = 0.001;
185    /// Cost weight per FHE gate operation (FHE is *very* expensive)
186    const FHE_COST_PER_OP: f64 = 100.0;
187    /// Cost weight per row scanned
188    const SCAN_COST_PER_ROW: f64 = 0.01;
189    /// Fixed cost per point lookup
190    const POINT_LOOKUP_COST: f64 = 1.0;
191
192    /// Compute the total cost from the individual estimates
193    fn compute(estimated_rows: u64, estimated_fhe_ops: u64, estimated_io_bytes: u64) -> Self {
194        let total_cost = (estimated_rows as f64 * Self::SCAN_COST_PER_ROW)
195            + (estimated_fhe_ops as f64 * Self::FHE_COST_PER_OP)
196            + (estimated_io_bytes as f64 * Self::IO_COST_PER_BYTE);
197        Self {
198            estimated_rows,
199            estimated_fhe_ops,
200            estimated_io_bytes,
201            total_cost,
202        }
203    }
204}
205
206impl std::fmt::Display for PlanCost {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        write!(
209            f,
210            "PlanCost(rows={}, fhe_ops={}, io_bytes={}, total={:.2})",
211            self.estimated_rows, self.estimated_fhe_ops, self.estimated_io_bytes, self.total_cost
212        )
213    }
214}
215
216// ---------------------------------------------------------------------------
217// Planner statistics
218// ---------------------------------------------------------------------------
219
220/// Statistics used for cost estimation
221///
222/// Maintains per-collection cardinality estimates and global latency hints
223/// so that the planner can make informed decisions.
224pub struct PlannerStats {
225    /// Estimated row count per collection
226    pub estimated_collection_sizes: DashMap<String, u64>,
227    /// Average value size in bytes across all collections
228    pub average_value_size: u64,
229    /// Estimated microsecond latency of a single FHE gate operation
230    pub fhe_op_latency_us: u64,
231}
232
233impl PlannerStats {
234    /// Create default statistics with reasonable starting values
235    fn new() -> Self {
236        Self {
237            estimated_collection_sizes: DashMap::new(),
238            average_value_size: 256,
239            fhe_op_latency_us: 1000,
240        }
241    }
242
243    /// Return the estimated size of a collection, defaulting to 1000
244    fn collection_size(&self, collection: &str) -> u64 {
245        self.estimated_collection_sizes
246            .get(collection)
247            .map(|v| *v)
248            .unwrap_or(1000)
249    }
250
251    /// Update the estimated size for a collection
252    pub fn set_collection_size(&self, collection: impl Into<String>, size: u64) {
253        self.estimated_collection_sizes
254            .insert(collection.into(), size);
255    }
256}
257
258impl Default for PlannerStats {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264// ---------------------------------------------------------------------------
265// Query Planner
266// ---------------------------------------------------------------------------
267
268/// Query planner that converts `Query` into optimized `PhysicalPlan`
269///
270/// The planner applies predicate pushdown, filter merging, and cost-based
271/// selection to produce an execution plan that minimises expensive FHE
272/// operations and I/O.
273///
274/// Optionally maintains a plan cache to avoid re-planning identical queries.
275pub struct QueryPlanner {
276    /// Statistics for cost estimation
277    stats: Arc<PlannerStats>,
278    /// Optional plan cache
279    cache: Option<Arc<PlanCache>>,
280}
281
282impl QueryPlanner {
283    /// Create a new query planner with default statistics
284    pub fn new() -> Self {
285        Self {
286            stats: Arc::new(PlannerStats::new()),
287            cache: None,
288        }
289    }
290
291    /// Create a planner with custom statistics
292    pub fn with_stats(stats: Arc<PlannerStats>) -> Self {
293        Self { stats, cache: None }
294    }
295
296    /// Enable plan caching with the given configuration
297    pub fn with_cache(mut self, config: PlanCacheConfig) -> Self {
298        self.cache = Some(Arc::new(PlanCache::new(config)));
299        self
300    }
301
302    /// Get a reference to the planner statistics
303    pub fn stats(&self) -> &PlannerStats {
304        &self.stats
305    }
306
307    /// Get a reference to the plan cache, if enabled
308    pub fn plan_cache(&self) -> Option<&PlanCache> {
309        self.cache.as_deref()
310    }
311
312    /// Return cache statistics, or default stats if caching is not enabled
313    pub fn cache_stats(&self) -> CacheStats {
314        self.cache
315            .as_ref()
316            .map(|c| c.cache_stats())
317            .unwrap_or_default()
318    }
319
320    /// Invalidate all cached plans (e.g., after a schema change)
321    pub fn invalidate_all(&self) {
322        if let Some(cache) = &self.cache {
323            cache.invalidate_all();
324        }
325    }
326
327    /// Invalidate cached plans matching a prefix (e.g., a collection name)
328    pub fn invalidate_prefix(&self, prefix: &str) {
329        if let Some(cache) = &self.cache {
330            cache.invalidate_prefix(prefix);
331        }
332    }
333
334    // -----------------------------------------------------------------------
335    // Public entry point
336    // -----------------------------------------------------------------------
337
338    /// Plan a query
339    ///
340    /// If caching is enabled, checks the cache first and returns a cached
341    /// plan if one exists and has not expired. Otherwise, plans the query
342    /// from scratch and inserts the result into the cache.
343    pub fn plan(&self, query: &Query) -> Result<PhysicalPlan> {
344        let cache_key = CacheKey::from_query(query);
345
346        // Check cache first
347        if let Some(cache) = &self.cache {
348            if let Some(cached_plan) = cache.get(&cache_key) {
349                return Ok(cached_plan);
350            }
351        }
352
353        // Plan from scratch
354        let logical = self.to_logical(query)?;
355        let optimized = self.optimize_logical(logical);
356        let physical = self.to_physical(&optimized)?;
357
358        // Insert into cache
359        if let Some(cache) = &self.cache {
360            let normalized = CacheKey::normalize(&format!("{:?}", query));
361            cache.insert(cache_key, physical.clone(), normalized);
362        }
363
364        Ok(physical)
365    }
366
367    // -----------------------------------------------------------------------
368    // Logical plan construction
369    // -----------------------------------------------------------------------
370
371    /// Convert a high-level `Query` into a `LogicalPlan`
372    fn to_logical(&self, query: &Query) -> Result<LogicalPlan> {
373        match query {
374            Query::Get { collection, key } => Ok(LogicalPlan::PointLookup {
375                collection: collection.clone(),
376                key: key.clone(),
377            }),
378
379            Query::Filter {
380                collection,
381                predicate,
382            } => Ok(LogicalPlan::Filter {
383                input: Box::new(LogicalPlan::Scan {
384                    collection: collection.clone(),
385                }),
386                predicate: predicate.clone(),
387            }),
388
389            Query::Range {
390                collection,
391                start,
392                end,
393            } => Ok(LogicalPlan::RangeScan {
394                collection: collection.clone(),
395                start_key: Some(start.to_vec()),
396                end_key: Some(end.to_vec()),
397            }),
398
399            Query::Set { collection, .. } => {
400                // Write operations do not really need a read plan, but we model
401                // them as a point lookup for the target key so that upstream can
402                // check for existence first.
403                Ok(LogicalPlan::Scan {
404                    collection: collection.clone(),
405                })
406            }
407
408            Query::Delete { collection, key } => Ok(LogicalPlan::PointLookup {
409                collection: collection.clone(),
410                key: key.clone(),
411            }),
412
413            Query::Update {
414                collection,
415                predicate,
416                ..
417            } => Ok(LogicalPlan::Filter {
418                input: Box::new(LogicalPlan::Scan {
419                    collection: collection.clone(),
420                }),
421                predicate: predicate.clone(),
422            }),
423        }
424    }
425
426    // -----------------------------------------------------------------------
427    // Logical optimizations
428    // -----------------------------------------------------------------------
429
430    /// Apply all logical optimization passes
431    fn optimize_logical(&self, plan: LogicalPlan) -> LogicalPlan {
432        let plan = self.push_predicates_down(plan);
433        let plan = self.merge_filters(plan);
434        self.convert_filter_to_range_scan(plan)
435    }
436
437    /// Predicate pushdown: move filters closer to the data source
438    ///
439    /// Rules applied:
440    /// - `Filter(Project(input, cols), pred)` -> if pred only references
441    ///   columns in `cols`, push the filter below the projection.
442    /// - `Filter(Filter(input, p1), p2)` is handled by `merge_filters`.
443    fn push_predicates_down(&self, plan: LogicalPlan) -> LogicalPlan {
444        match plan {
445            // Rule: push filter below projection when possible
446            LogicalPlan::Filter { input, predicate } => {
447                let optimized_input = self.push_predicates_down(*input);
448
449                match optimized_input {
450                    // Filter over Project -> check if we can push through
451                    LogicalPlan::Project {
452                        input: proj_input,
453                        columns,
454                    } => {
455                        let pred_cols = Self::referenced_columns(&predicate);
456                        let proj_set: HashSet<&str> = columns.iter().map(|c| c.as_str()).collect();
457
458                        if pred_cols.iter().all(|c| proj_set.contains(c.as_str())) {
459                            // All predicate columns exist in the projection,
460                            // so we can push the filter below.
461                            LogicalPlan::Project {
462                                input: Box::new(LogicalPlan::Filter {
463                                    input: proj_input,
464                                    predicate,
465                                }),
466                                columns,
467                            }
468                        } else {
469                            // Some columns not in projection; need to widen
470                            // the projection to include predicate columns,
471                            // then re-project afterwards.
472                            let mut extended_cols = columns.clone();
473                            for col in &pred_cols {
474                                if !proj_set.contains(col.as_str()) {
475                                    extended_cols.push(col.clone());
476                                }
477                            }
478
479                            LogicalPlan::Project {
480                                input: Box::new(LogicalPlan::Filter {
481                                    input: Box::new(LogicalPlan::Project {
482                                        input: proj_input,
483                                        columns: extended_cols,
484                                    }),
485                                    predicate,
486                                }),
487                                columns,
488                            }
489                        }
490                    }
491
492                    // Filter over Limit: cannot push filter below Limit
493                    // because Limit is a cardinality-changing operation on
494                    // encrypted data where we cannot peek.
495                    other => LogicalPlan::Filter {
496                        input: Box::new(other),
497                        predicate,
498                    },
499                }
500            }
501
502            // Recurse into other plan nodes
503            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
504                input: Box::new(self.push_predicates_down(*input)),
505                columns,
506            },
507
508            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
509                input: Box::new(self.push_predicates_down(*input)),
510                count,
511            },
512
513            // Leaf nodes are returned unchanged
514            other => other,
515        }
516    }
517
518    /// Merge adjacent filters into a single AND predicate
519    ///
520    /// `Filter(Filter(input, p1), p2)` => `Filter(input, And(p1, p2))`
521    fn merge_filters(&self, plan: LogicalPlan) -> LogicalPlan {
522        match plan {
523            LogicalPlan::Filter { input, predicate } => {
524                let optimized_input = self.merge_filters(*input);
525
526                match optimized_input {
527                    LogicalPlan::Filter {
528                        input: inner_input,
529                        predicate: inner_pred,
530                    } => {
531                        // Merge the two predicates with AND
532                        LogicalPlan::Filter {
533                            input: inner_input,
534                            predicate: Predicate::And(Box::new(inner_pred), Box::new(predicate)),
535                        }
536                    }
537                    other => LogicalPlan::Filter {
538                        input: Box::new(other),
539                        predicate,
540                    },
541                }
542            }
543
544            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
545                input: Box::new(self.merge_filters(*input)),
546                columns,
547            },
548
549            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
550                input: Box::new(self.merge_filters(*input)),
551                count,
552            },
553
554            other => other,
555        }
556    }
557
558    /// Convert a filter on key range into a `RangeScan` when possible
559    ///
560    /// If a `Filter(Scan(collection), pred)` has a predicate that is purely
561    /// a key-range comparison (Gt/Lt/Gte/Lte on the `_key` column), we can
562    /// replace the scan+filter with a more efficient `RangeScan`.
563    fn convert_filter_to_range_scan(&self, plan: LogicalPlan) -> LogicalPlan {
564        match plan {
565            LogicalPlan::Filter { input, predicate } => {
566                let optimized_input = self.convert_filter_to_range_scan(*input);
567
568                if let LogicalPlan::Scan { ref collection } = optimized_input {
569                    if let Some((start, end)) = Self::extract_key_range(&predicate) {
570                        return LogicalPlan::RangeScan {
571                            collection: collection.clone(),
572                            start_key: start,
573                            end_key: end,
574                        };
575                    }
576                }
577
578                LogicalPlan::Filter {
579                    input: Box::new(optimized_input),
580                    predicate,
581                }
582            }
583
584            LogicalPlan::Project { input, columns } => LogicalPlan::Project {
585                input: Box::new(self.convert_filter_to_range_scan(*input)),
586                columns,
587            },
588
589            LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
590                input: Box::new(self.convert_filter_to_range_scan(*input)),
591                count,
592            },
593
594            other => other,
595        }
596    }
597
598    // -----------------------------------------------------------------------
599    // Physical plan construction
600    // -----------------------------------------------------------------------
601
602    /// Convert an optimized logical plan into a physical plan
603    fn to_physical(&self, plan: &LogicalPlan) -> Result<PhysicalPlan> {
604        match plan {
605            LogicalPlan::Scan { collection } => Ok(PhysicalPlan::SeqScan {
606                collection: collection.clone(),
607            }),
608
609            LogicalPlan::RangeScan {
610                collection,
611                start_key,
612                end_key,
613            } => Ok(PhysicalPlan::IndexScan {
614                collection: collection.clone(),
615                start: start_key.clone(),
616                end: end_key.clone(),
617            }),
618
619            LogicalPlan::Filter { input, predicate } => {
620                let physical_input = self.to_physical(input)?;
621                let circuit = self.compile_predicate_circuit(predicate)?;
622
623                Ok(PhysicalPlan::FheFilter {
624                    input: Box::new(physical_input),
625                    circuit,
626                    predicate: predicate.clone(),
627                })
628            }
629
630            LogicalPlan::Project { input, columns } => {
631                let physical_input = self.to_physical(input)?;
632                Ok(PhysicalPlan::Projection {
633                    input: Box::new(physical_input),
634                    columns: columns.clone(),
635                })
636            }
637
638            LogicalPlan::Limit { input, count } => {
639                let physical_input = self.to_physical(input)?;
640                Ok(PhysicalPlan::Limit {
641                    input: Box::new(physical_input),
642                    count: *count,
643                })
644            }
645
646            LogicalPlan::PointLookup { collection, key } => Ok(PhysicalPlan::PointGet {
647                collection: collection.clone(),
648                key: key.clone(),
649            }),
650        }
651    }
652
653    // -----------------------------------------------------------------------
654    // Cost estimation
655    // -----------------------------------------------------------------------
656
657    /// Estimate the cost of a physical plan
658    pub fn estimate_cost(&self, plan: &PhysicalPlan) -> PlanCost {
659        match plan {
660            PhysicalPlan::SeqScan { collection } => {
661                let rows = self.stats.collection_size(collection);
662                let io_bytes = rows * self.stats.average_value_size;
663                PlanCost::compute(rows, 0, io_bytes)
664            }
665
666            PhysicalPlan::IndexScan {
667                collection,
668                start,
669                end,
670            } => {
671                let total = self.stats.collection_size(collection);
672                // Estimate selectivity: a range scan typically touches a fraction.
673                // Without histograms we use a heuristic: if both bounds present
674                // assume 10%, one bound 30%, no bounds = full scan.
675                let selectivity = match (start, end) {
676                    (Some(_), Some(_)) => 0.10,
677                    (Some(_), None) | (None, Some(_)) => 0.30,
678                    (None, None) => 1.0,
679                };
680                let rows = ((total as f64) * selectivity).max(1.0) as u64;
681                let io_bytes = rows * self.stats.average_value_size;
682                PlanCost::compute(rows, 0, io_bytes)
683            }
684
685            PhysicalPlan::FheFilter { input, circuit, .. } => {
686                let input_cost = self.estimate_cost(input);
687                // FHE filter applies the circuit to every row from the input
688                let fhe_ops = input_cost.estimated_rows * (circuit.gate_count as u64);
689                // After filter, assume 50% selectivity without better stats
690                let output_rows = (input_cost.estimated_rows / 2).max(1);
691                let io_bytes = output_rows * self.stats.average_value_size;
692                PlanCost::compute(
693                    input_cost.estimated_rows,
694                    input_cost.estimated_fhe_ops + fhe_ops,
695                    input_cost.estimated_io_bytes + io_bytes,
696                )
697            }
698
699            PhysicalPlan::Projection { input, .. } => {
700                // Projection is cheap; just trim columns
701                let mut cost = self.estimate_cost(input);
702                // Slightly reduce IO since we return fewer bytes
703                cost.estimated_io_bytes = (cost.estimated_io_bytes as f64 * 0.8) as u64;
704                cost.total_cost = (cost.estimated_rows as f64 * PlanCost::SCAN_COST_PER_ROW)
705                    + (cost.estimated_fhe_ops as f64 * PlanCost::FHE_COST_PER_OP)
706                    + (cost.estimated_io_bytes as f64 * PlanCost::IO_COST_PER_BYTE);
707                cost
708            }
709
710            PhysicalPlan::Limit { input, count } => {
711                let input_cost = self.estimate_cost(input);
712                let rows = (*count as u64).min(input_cost.estimated_rows);
713                let io_bytes = rows * self.stats.average_value_size;
714                // Note: FHE ops from input still happen because we do not know
715                // which rows will survive until after FHE evaluation.
716                PlanCost::compute(rows, input_cost.estimated_fhe_ops, io_bytes)
717            }
718
719            PhysicalPlan::PointGet { .. } => PlanCost::compute(1, 0, self.stats.average_value_size),
720        }
721    }
722
723    /// Compare two physical plans by cost and return the cheaper one
724    pub fn choose_cheaper<'a>(&self, a: &'a PhysicalPlan, b: &'a PhysicalPlan) -> &'a PhysicalPlan {
725        let cost_a = self.estimate_cost(a);
726        let cost_b = self.estimate_cost(b);
727        if cost_a.total_cost <= cost_b.total_cost {
728            a
729        } else {
730            b
731        }
732    }
733
734    // -----------------------------------------------------------------------
735    // Helpers
736    // -----------------------------------------------------------------------
737
738    /// Extract all column names referenced in a predicate
739    fn referenced_columns(predicate: &Predicate) -> Vec<String> {
740        let mut cols = Vec::new();
741        Self::collect_columns(predicate, &mut cols);
742        cols.sort();
743        cols.dedup();
744        cols
745    }
746
747    fn collect_columns(predicate: &Predicate, out: &mut Vec<String>) {
748        match predicate {
749            Predicate::Eq(col, _)
750            | Predicate::Gt(col, _)
751            | Predicate::Lt(col, _)
752            | Predicate::Gte(col, _)
753            | Predicate::Lte(col, _) => {
754                out.push(col.name.clone());
755            }
756            Predicate::And(l, r) | Predicate::Or(l, r) => {
757                Self::collect_columns(l, out);
758                Self::collect_columns(r, out);
759            }
760            Predicate::Not(inner) => {
761                Self::collect_columns(inner, out);
762            }
763        }
764    }
765
766    /// Try to extract a key range from a predicate on the `_key` column
767    ///
768    /// Returns `Some((start, end))` where either bound may be `None`.
769    /// Returns `None` if the predicate is not a simple key-range filter.
770    fn extract_key_range(predicate: &Predicate) -> Option<(Option<Vec<u8>>, Option<Vec<u8>>)> {
771        match predicate {
772            Predicate::Gt(col, blob) if col.name == "_key" => {
773                Some((Some(blob.as_bytes().to_vec()), None))
774            }
775            Predicate::Gte(col, blob) if col.name == "_key" => {
776                Some((Some(blob.as_bytes().to_vec()), None))
777            }
778            Predicate::Lt(col, blob) if col.name == "_key" => {
779                Some((None, Some(blob.as_bytes().to_vec())))
780            }
781            Predicate::Lte(col, blob) if col.name == "_key" => {
782                Some((None, Some(blob.as_bytes().to_vec())))
783            }
784            Predicate::And(left, right) => {
785                // Combine two half-ranges
786                let lr = Self::extract_key_range(left);
787                let rr = Self::extract_key_range(right);
788
789                match (lr, rr) {
790                    (Some((s1, e1)), Some((s2, e2))) => {
791                        let start = s1.or(s2);
792                        let end = e1.or(e2);
793                        Some((start, end))
794                    }
795                    (Some(range), None) | (None, Some(range)) => Some(range),
796                    (None, None) => None,
797                }
798            }
799            _ => None,
800        }
801    }
802
803    /// Compile a predicate into an FHE circuit
804    fn compile_predicate_circuit(&self, predicate: &Predicate) -> Result<Circuit> {
805        let mut compiler = PredicateCompiler::new();
806        // Default to U8 type for now; in a full implementation the type
807        // would be inferred from schema metadata.
808        compiler.compile(predicate, EncryptedType::U8)
809    }
810}
811
812impl Default for QueryPlanner {
813    fn default() -> Self {
814        Self::new()
815    }
816}
817
818// ---------------------------------------------------------------------------
819// Display implementations for explain/debugging
820// ---------------------------------------------------------------------------
821
822impl std::fmt::Display for LogicalPlan {
823    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
824        self.fmt_indented(f, 0)
825    }
826}
827
828impl LogicalPlan {
829    fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
830        let pad = "  ".repeat(indent);
831        match self {
832            LogicalPlan::Scan { collection } => {
833                writeln!(f, "{}Scan({})", pad, collection)
834            }
835            LogicalPlan::RangeScan {
836                collection,
837                start_key,
838                end_key,
839            } => {
840                writeln!(
841                    f,
842                    "{}RangeScan({}, start={}, end={})",
843                    pad,
844                    collection,
845                    start_key.is_some(),
846                    end_key.is_some()
847                )
848            }
849            LogicalPlan::Filter { input, predicate } => {
850                writeln!(f, "{}Filter(pred={:?})", pad, predicate)?;
851                input.fmt_indented(f, indent + 1)
852            }
853            LogicalPlan::Project { input, columns } => {
854                writeln!(f, "{}Project({:?})", pad, columns)?;
855                input.fmt_indented(f, indent + 1)
856            }
857            LogicalPlan::Limit { input, count } => {
858                writeln!(f, "{}Limit({})", pad, count)?;
859                input.fmt_indented(f, indent + 1)
860            }
861            LogicalPlan::PointLookup { collection, key } => {
862                writeln!(f, "{}PointLookup({}, key={})", pad, collection, key)
863            }
864        }
865    }
866}
867
868impl std::fmt::Display for PhysicalPlan {
869    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
870        self.fmt_indented(f, 0)
871    }
872}
873
874impl PhysicalPlan {
875    fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
876        let pad = "  ".repeat(indent);
877        match self {
878            PhysicalPlan::SeqScan { collection } => {
879                writeln!(f, "{}SeqScan({})", pad, collection)
880            }
881            PhysicalPlan::IndexScan {
882                collection,
883                start,
884                end,
885            } => {
886                writeln!(
887                    f,
888                    "{}IndexScan({}, start={}, end={})",
889                    pad,
890                    collection,
891                    start.is_some(),
892                    end.is_some()
893                )
894            }
895            PhysicalPlan::FheFilter {
896                input, predicate, ..
897            } => {
898                writeln!(f, "{}FheFilter(pred={:?})", pad, predicate)?;
899                input.fmt_indented(f, indent + 1)
900            }
901            PhysicalPlan::Projection { input, columns } => {
902                writeln!(f, "{}Projection({:?})", pad, columns)?;
903                input.fmt_indented(f, indent + 1)
904            }
905            PhysicalPlan::Limit { input, count } => {
906                writeln!(f, "{}Limit({})", pad, count)?;
907                input.fmt_indented(f, indent + 1)
908            }
909            PhysicalPlan::PointGet { collection, key } => {
910                writeln!(f, "{}PointGet({}, key={})", pad, collection, key)
911            }
912        }
913    }
914}
915
916// ---------------------------------------------------------------------------
917// Tests
918// ---------------------------------------------------------------------------
919
920#[cfg(test)]
921mod tests {
922    use super::*;
923    use crate::types::col;
924
925    fn make_blob(v: u8) -> CipherBlob {
926        CipherBlob::new(vec![v])
927    }
928
929    // -- Basic planning tests -----------------------------------------------
930
931    #[test]
932    fn test_scan_plan() -> Result<()> {
933        let planner = QueryPlanner::new();
934        let query = Query::Filter {
935            collection: "users".to_string(),
936            predicate: Predicate::Gt(col("age"), make_blob(18)),
937        };
938
939        let plan = planner.plan(&query)?;
940
941        // Should produce FheFilter over SeqScan because "age" is not "_key"
942        match &plan {
943            PhysicalPlan::FheFilter { input, .. } => {
944                assert!(matches!(input.as_ref(), PhysicalPlan::SeqScan { .. }));
945            }
946            other => {
947                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
948                    "Expected FheFilter, got: {:?}",
949                    other
950                ))));
951            }
952        }
953        Ok(())
954    }
955
956    #[test]
957    fn test_range_scan_pushdown() -> Result<()> {
958        let planner = QueryPlanner::new();
959
960        // Filter on _key column should convert to IndexScan
961        let query = Query::Filter {
962            collection: "data".to_string(),
963            predicate: Predicate::And(
964                Box::new(Predicate::Gte(col("_key"), make_blob(10))),
965                Box::new(Predicate::Lt(col("_key"), make_blob(50))),
966            ),
967        };
968
969        let plan = planner.plan(&query)?;
970
971        match &plan {
972            PhysicalPlan::IndexScan {
973                collection,
974                start,
975                end,
976            } => {
977                assert_eq!(collection, "data");
978                assert!(start.is_some());
979                assert!(end.is_some());
980            }
981            other => {
982                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
983                    "Expected IndexScan, got: {:?}",
984                    other
985                ))));
986            }
987        }
988        Ok(())
989    }
990
991    #[test]
992    fn test_predicate_pushdown() -> Result<()> {
993        let planner = QueryPlanner::new();
994
995        // Construct: Filter(Project(Scan, [age]), pred_on_age)
996        // The filter should be pushed below the projection.
997        let scan = LogicalPlan::Scan {
998            collection: "users".to_string(),
999        };
1000        let project = LogicalPlan::Project {
1001            input: Box::new(scan),
1002            columns: vec!["age".to_string(), "name".to_string()],
1003        };
1004        let filter = LogicalPlan::Filter {
1005            input: Box::new(project),
1006            predicate: Predicate::Gt(col("age"), make_blob(18)),
1007        };
1008
1009        let optimized = planner.push_predicates_down(filter);
1010
1011        // After pushdown: Project([age, name], Filter(Scan, pred))
1012        match &optimized {
1013            LogicalPlan::Project { input, columns } => {
1014                assert!(columns.contains(&"age".to_string()));
1015                assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
1016            }
1017            other => {
1018                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1019                    "Expected Project, got: {:?}",
1020                    other
1021                ))));
1022            }
1023        }
1024        Ok(())
1025    }
1026
1027    #[test]
1028    fn test_filter_merge() -> Result<()> {
1029        let planner = QueryPlanner::new();
1030
1031        // Construct: Filter(Filter(Scan, p1), p2) -> Filter(Scan, And(p1, p2))
1032        let scan = LogicalPlan::Scan {
1033            collection: "users".to_string(),
1034        };
1035        let filter1 = LogicalPlan::Filter {
1036            input: Box::new(scan),
1037            predicate: Predicate::Gt(col("age"), make_blob(18)),
1038        };
1039        let filter2 = LogicalPlan::Filter {
1040            input: Box::new(filter1),
1041            predicate: Predicate::Lt(col("age"), make_blob(65)),
1042        };
1043
1044        let optimized = planner.merge_filters(filter2);
1045
1046        match &optimized {
1047            LogicalPlan::Filter { input, predicate } => {
1048                // Should be a single filter with AND predicate
1049                assert!(matches!(predicate, Predicate::And(_, _)));
1050                // Input should be Scan, not another Filter
1051                assert!(matches!(input.as_ref(), LogicalPlan::Scan { .. }));
1052            }
1053            other => {
1054                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1055                    "Expected Filter, got: {:?}",
1056                    other
1057                ))));
1058            }
1059        }
1060        Ok(())
1061    }
1062
1063    #[test]
1064    fn test_cost_estimation() -> Result<()> {
1065        let planner = QueryPlanner::new();
1066        planner.stats().set_collection_size("data", 10_000);
1067
1068        // Full scan cost
1069        let seq_scan = PhysicalPlan::SeqScan {
1070            collection: "data".to_string(),
1071        };
1072        let seq_cost = planner.estimate_cost(&seq_scan);
1073
1074        // Index scan cost (should be cheaper)
1075        let idx_scan = PhysicalPlan::IndexScan {
1076            collection: "data".to_string(),
1077            start: Some(vec![10]),
1078            end: Some(vec![50]),
1079        };
1080        let idx_cost = planner.estimate_cost(&idx_scan);
1081
1082        // Index scan should be cheaper than full scan
1083        assert!(
1084            idx_cost.total_cost < seq_cost.total_cost,
1085            "IndexScan cost ({}) should be less than SeqScan cost ({})",
1086            idx_cost.total_cost,
1087            seq_cost.total_cost,
1088        );
1089
1090        // Point get should be the cheapest
1091        let point = PhysicalPlan::PointGet {
1092            collection: "data".to_string(),
1093            key: Key::from_str("k"),
1094        };
1095        let point_cost = planner.estimate_cost(&point);
1096        assert!(
1097            point_cost.total_cost < idx_cost.total_cost,
1098            "PointGet cost ({}) should be less than IndexScan cost ({})",
1099            point_cost.total_cost,
1100            idx_cost.total_cost,
1101        );
1102
1103        Ok(())
1104    }
1105
1106    #[test]
1107    fn test_limit_planning() -> Result<()> {
1108        let planner = QueryPlanner::new();
1109
1110        // Build a filter query and wrap with Limit via logical plan
1111        let scan = LogicalPlan::Scan {
1112            collection: "logs".to_string(),
1113        };
1114        let filter = LogicalPlan::Filter {
1115            input: Box::new(scan),
1116            predicate: Predicate::Eq(col("level"), make_blob(1)),
1117        };
1118        let limited = LogicalPlan::Limit {
1119            input: Box::new(filter),
1120            count: 10,
1121        };
1122
1123        let physical = planner.to_physical(&limited)?;
1124
1125        // Limit should be on top
1126        match &physical {
1127            PhysicalPlan::Limit { input, count } => {
1128                assert_eq!(*count, 10);
1129                assert!(matches!(input.as_ref(), PhysicalPlan::FheFilter { .. }));
1130            }
1131            other => {
1132                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1133                    "Expected Limit, got: {:?}",
1134                    other
1135                ))));
1136            }
1137        }
1138
1139        Ok(())
1140    }
1141
1142    #[test]
1143    fn test_plan_with_fhe_filter() -> Result<()> {
1144        let planner = QueryPlanner::new();
1145        let query = Query::Filter {
1146            collection: "accounts".to_string(),
1147            predicate: Predicate::And(
1148                Box::new(Predicate::Gt(col("balance"), make_blob(100))),
1149                Box::new(Predicate::Lt(col("balance"), make_blob(200))),
1150            ),
1151        };
1152
1153        let plan = planner.plan(&query)?;
1154
1155        // Should have an FheFilter with a compiled circuit
1156        match &plan {
1157            PhysicalPlan::FheFilter { circuit, .. } => {
1158                // The circuit should have gate_count > 0 for AND of two comparisons
1159                assert!(circuit.gate_count > 0);
1160                assert_eq!(circuit.result_type, EncryptedType::Bool);
1161            }
1162            other => {
1163                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1164                    "Expected FheFilter, got: {:?}",
1165                    other
1166                ))));
1167            }
1168        }
1169        Ok(())
1170    }
1171
1172    #[test]
1173    fn test_complex_plan() -> Result<()> {
1174        let planner = QueryPlanner::new();
1175        planner.stats().set_collection_size("orders", 50_000);
1176
1177        // Complex query: Filter with non-key predicate -> should remain FheFilter
1178        let query = Query::Filter {
1179            collection: "orders".to_string(),
1180            predicate: Predicate::Or(
1181                Box::new(Predicate::Eq(col("status"), make_blob(1))),
1182                Box::new(Predicate::And(
1183                    Box::new(Predicate::Gt(col("amount"), make_blob(100))),
1184                    Box::new(Predicate::Lt(col("amount"), make_blob(255))),
1185                )),
1186            ),
1187        };
1188
1189        let plan = planner.plan(&query)?;
1190        let cost = planner.estimate_cost(&plan);
1191
1192        // Should have a non-trivial cost due to FHE ops
1193        assert!(cost.estimated_fhe_ops > 0);
1194        assert!(cost.total_cost > 0.0);
1195
1196        // Verify display works
1197        let plan_str = format!("{}", plan);
1198        assert!(!plan_str.is_empty());
1199
1200        Ok(())
1201    }
1202
1203    #[test]
1204    fn test_get_query_planning() -> Result<()> {
1205        let planner = QueryPlanner::new();
1206        let query = Query::Get {
1207            collection: "users".to_string(),
1208            key: Key::from_str("user:42"),
1209        };
1210
1211        let plan = planner.plan(&query)?;
1212
1213        match &plan {
1214            PhysicalPlan::PointGet { collection, key } => {
1215                assert_eq!(collection, "users");
1216                assert_eq!(key.to_string_lossy(), "user:42");
1217            }
1218            other => {
1219                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1220                    "Expected PointGet, got: {:?}",
1221                    other
1222                ))));
1223            }
1224        }
1225
1226        let cost = planner.estimate_cost(&plan);
1227        assert_eq!(cost.estimated_rows, 1);
1228        assert_eq!(cost.estimated_fhe_ops, 0);
1229
1230        Ok(())
1231    }
1232
1233    #[test]
1234    fn test_range_query_planning() -> Result<()> {
1235        let planner = QueryPlanner::new();
1236        let query = Query::Range {
1237            collection: "events".to_string(),
1238            start: Key::from_str("2024-01"),
1239            end: Key::from_str("2024-12"),
1240        };
1241
1242        let plan = planner.plan(&query)?;
1243
1244        match &plan {
1245            PhysicalPlan::IndexScan {
1246                collection,
1247                start,
1248                end,
1249            } => {
1250                assert_eq!(collection, "events");
1251                assert!(start.is_some());
1252                assert!(end.is_some());
1253            }
1254            other => {
1255                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1256                    "Expected IndexScan, got: {:?}",
1257                    other
1258                ))));
1259            }
1260        }
1261        Ok(())
1262    }
1263
1264    #[test]
1265    fn test_cost_comparison() -> Result<()> {
1266        let planner = QueryPlanner::new();
1267        planner.stats().set_collection_size("items", 100_000);
1268
1269        let scan = PhysicalPlan::SeqScan {
1270            collection: "items".to_string(),
1271        };
1272
1273        let idx = PhysicalPlan::IndexScan {
1274            collection: "items".to_string(),
1275            start: Some(vec![1]),
1276            end: Some(vec![10]),
1277        };
1278
1279        let cheaper = planner.choose_cheaper(&scan, &idx);
1280
1281        // IndexScan should win
1282        assert!(matches!(cheaper, PhysicalPlan::IndexScan { .. }));
1283
1284        Ok(())
1285    }
1286
1287    #[test]
1288    fn test_filter_not_pushed_below_limit() -> Result<()> {
1289        let planner = QueryPlanner::new();
1290
1291        // Filter(Limit(Scan, 10), pred) -> Filter should stay on top
1292        let scan = LogicalPlan::Scan {
1293            collection: "data".to_string(),
1294        };
1295        let limited = LogicalPlan::Limit {
1296            input: Box::new(scan),
1297            count: 10,
1298        };
1299        let filter = LogicalPlan::Filter {
1300            input: Box::new(limited),
1301            predicate: Predicate::Gt(col("x"), make_blob(5)),
1302        };
1303
1304        let optimized = planner.push_predicates_down(filter);
1305
1306        // Filter should remain on top of Limit
1307        match &optimized {
1308            LogicalPlan::Filter { input, .. } => {
1309                assert!(matches!(input.as_ref(), LogicalPlan::Limit { .. }));
1310            }
1311            other => {
1312                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
1313                    "Expected Filter on top, got: {:?}",
1314                    other
1315                ))));
1316            }
1317        }
1318
1319        Ok(())
1320    }
1321
1322    #[test]
1323    fn test_stats_update() {
1324        let planner = QueryPlanner::new();
1325        planner.stats().set_collection_size("big_table", 1_000_000);
1326
1327        let size = planner.stats().collection_size("big_table");
1328        assert_eq!(size, 1_000_000);
1329
1330        // Unknown collection should default to 1000
1331        let default_size = planner.stats().collection_size("unknown");
1332        assert_eq!(default_size, 1000);
1333    }
1334
1335    #[test]
1336    fn test_referenced_columns() {
1337        let pred = Predicate::And(
1338            Box::new(Predicate::Gt(col("age"), make_blob(18))),
1339            Box::new(Predicate::Or(
1340                Box::new(Predicate::Lt(col("salary"), make_blob(100))),
1341                Box::new(Predicate::Eq(col("age"), make_blob(30))),
1342            )),
1343        );
1344
1345        let cols = QueryPlanner::referenced_columns(&pred);
1346        assert_eq!(cols, vec!["age".to_string(), "salary".to_string()]);
1347    }
1348
1349    #[test]
1350    fn test_display_plan_cost() {
1351        let cost = PlanCost::compute(1000, 50, 256_000);
1352        let display = format!("{}", cost);
1353        assert!(display.contains("1000"));
1354        assert!(display.contains("50"));
1355    }
1356
1357    #[test]
1358    fn test_logical_plan_display() {
1359        let plan = LogicalPlan::Filter {
1360            input: Box::new(LogicalPlan::Scan {
1361                collection: "t".to_string(),
1362            }),
1363            predicate: Predicate::Eq(col("x"), make_blob(1)),
1364        };
1365
1366        let s = format!("{}", plan);
1367        assert!(s.contains("Filter"));
1368        assert!(s.contains("Scan"));
1369    }
1370}