use crate::compute::EncryptedType;
use crate::compute::circuit::Circuit;
use crate::compute::predicate::PredicateCompiler;
use crate::error::{AmateRSError, ErrorContext, Result};
use crate::types::{CipherBlob, ColumnRef, Key, Predicate, Query};
use dashmap::DashMap;
use std::collections::HashSet;
use std::sync::Arc;
pub use super::plan_cache::{CacheKey, CacheStats, CachedPlan, PlanCache, PlanCacheConfig};
#[derive(Debug, Clone)]
pub enum LogicalPlan {
Scan {
collection: String,
},
RangeScan {
collection: String,
start_key: Option<Vec<u8>>,
end_key: Option<Vec<u8>>,
},
Filter {
input: Box<LogicalPlan>,
predicate: Predicate,
},
Project {
input: Box<LogicalPlan>,
columns: Vec<String>,
},
Limit {
input: Box<LogicalPlan>,
count: usize,
},
PointLookup {
collection: String,
key: Key,
},
}
#[derive(Debug, Clone)]
pub enum PhysicalPlan {
SeqScan {
collection: String,
},
IndexScan {
collection: String,
start: Option<Vec<u8>>,
end: Option<Vec<u8>>,
},
FheFilter {
input: Box<PhysicalPlan>,
circuit: Circuit,
predicate: Predicate,
},
Projection {
input: Box<PhysicalPlan>,
columns: Vec<String>,
},
Limit {
input: Box<PhysicalPlan>,
count: usize,
},
PointGet {
collection: String,
key: Key,
},
}
#[derive(Debug, Clone)]
pub struct PlanCost {
pub estimated_rows: u64,
pub estimated_fhe_ops: u64,
pub estimated_io_bytes: u64,
pub total_cost: f64,
}
impl PlanCost {
const IO_COST_PER_BYTE: f64 = 0.001;
const FHE_COST_PER_OP: f64 = 100.0;
const SCAN_COST_PER_ROW: f64 = 0.01;
const POINT_LOOKUP_COST: f64 = 1.0;
fn compute(estimated_rows: u64, estimated_fhe_ops: u64, estimated_io_bytes: u64) -> Self {
let total_cost = (estimated_rows as f64 * Self::SCAN_COST_PER_ROW)
+ (estimated_fhe_ops as f64 * Self::FHE_COST_PER_OP)
+ (estimated_io_bytes as f64 * Self::IO_COST_PER_BYTE);
Self {
estimated_rows,
estimated_fhe_ops,
estimated_io_bytes,
total_cost,
}
}
}
impl std::fmt::Display for PlanCost {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PlanCost(rows={}, fhe_ops={}, io_bytes={}, total={:.2})",
self.estimated_rows, self.estimated_fhe_ops, self.estimated_io_bytes, self.total_cost
)
}
}
pub struct PlannerStats {
pub estimated_collection_sizes: DashMap<String, u64>,
pub average_value_size: u64,
pub fhe_op_latency_us: u64,
}
impl PlannerStats {
fn new() -> Self {
Self {
estimated_collection_sizes: DashMap::new(),
average_value_size: 256,
fhe_op_latency_us: 1000,
}
}
fn collection_size(&self, collection: &str) -> u64 {
self.estimated_collection_sizes
.get(collection)
.map(|v| *v)
.unwrap_or(1000)
}
pub fn set_collection_size(&self, collection: impl Into<String>, size: u64) {
self.estimated_collection_sizes
.insert(collection.into(), size);
}
}
impl Default for PlannerStats {
fn default() -> Self {
Self::new()
}
}
pub struct QueryPlanner {
stats: Arc<PlannerStats>,
cache: Option<Arc<PlanCache>>,
}
impl QueryPlanner {
pub fn new() -> Self {
Self {
stats: Arc::new(PlannerStats::new()),
cache: None,
}
}
pub fn with_stats(stats: Arc<PlannerStats>) -> Self {
Self { stats, cache: None }
}
pub fn with_cache(mut self, config: PlanCacheConfig) -> Self {
self.cache = Some(Arc::new(PlanCache::new(config)));
self
}
pub fn stats(&self) -> &PlannerStats {
&self.stats
}
pub fn plan_cache(&self) -> Option<&PlanCache> {
self.cache.as_deref()
}
pub fn cache_stats(&self) -> CacheStats {
self.cache
.as_ref()
.map(|c| c.cache_stats())
.unwrap_or_default()
}
pub fn invalidate_all(&self) {
if let Some(cache) = &self.cache {
cache.invalidate_all();
}
}
pub fn invalidate_prefix(&self, prefix: &str) {
if let Some(cache) = &self.cache {
cache.invalidate_prefix(prefix);
}
}
pub fn plan(&self, query: &Query) -> Result<PhysicalPlan> {
let cache_key = CacheKey::from_query(query);
if let Some(cache) = &self.cache {
if let Some(cached_plan) = cache.get(&cache_key) {
return Ok(cached_plan);
}
}
let logical = self.to_logical(query)?;
let optimized = self.optimize_logical(logical);
let physical = self.to_physical(&optimized)?;
if let Some(cache) = &self.cache {
let normalized = CacheKey::normalize(&format!("{:?}", query));
cache.insert(cache_key, physical.clone(), normalized);
}
Ok(physical)
}
fn to_logical(&self, query: &Query) -> Result<LogicalPlan> {
match query {
Query::Get { collection, key } => Ok(LogicalPlan::PointLookup {
collection: collection.clone(),
key: key.clone(),
}),
Query::Filter {
collection,
predicate,
} => Ok(LogicalPlan::Filter {
input: Box::new(LogicalPlan::Scan {
collection: collection.clone(),
}),
predicate: predicate.clone(),
}),
Query::Range {
collection,
start,
end,
} => Ok(LogicalPlan::RangeScan {
collection: collection.clone(),
start_key: Some(start.to_vec()),
end_key: Some(end.to_vec()),
}),
Query::Set { collection, .. } => {
Ok(LogicalPlan::Scan {
collection: collection.clone(),
})
}
Query::Delete { collection, key } => Ok(LogicalPlan::PointLookup {
collection: collection.clone(),
key: key.clone(),
}),
Query::Update {
collection,
predicate,
..
} => Ok(LogicalPlan::Filter {
input: Box::new(LogicalPlan::Scan {
collection: collection.clone(),
}),
predicate: predicate.clone(),
}),
}
}
fn optimize_logical(&self, plan: LogicalPlan) -> LogicalPlan {
let plan = self.push_predicates_down(plan);
let plan = self.merge_filters(plan);
self.convert_filter_to_range_scan(plan)
}
fn push_predicates_down(&self, plan: LogicalPlan) -> LogicalPlan {
match plan {
LogicalPlan::Filter { input, predicate } => {
let optimized_input = self.push_predicates_down(*input);
match optimized_input {
LogicalPlan::Project {
input: proj_input,
columns,
} => {
let pred_cols = Self::referenced_columns(&predicate);
let proj_set: HashSet<&str> = columns.iter().map(|c| c.as_str()).collect();
if pred_cols.iter().all(|c| proj_set.contains(c.as_str())) {
LogicalPlan::Project {
input: Box::new(LogicalPlan::Filter {
input: proj_input,
predicate,
}),
columns,
}
} else {
let mut extended_cols = columns.clone();
for col in &pred_cols {
if !proj_set.contains(col.as_str()) {
extended_cols.push(col.clone());
}
}
LogicalPlan::Project {
input: Box::new(LogicalPlan::Filter {
input: Box::new(LogicalPlan::Project {
input: proj_input,
columns: extended_cols,
}),
predicate,
}),
columns,
}
}
}
other => LogicalPlan::Filter {
input: Box::new(other),
predicate,
},
}
}
LogicalPlan::Project { input, columns } => LogicalPlan::Project {
input: Box::new(self.push_predicates_down(*input)),
columns,
},
LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
input: Box::new(self.push_predicates_down(*input)),
count,
},
other => other,
}
}
fn merge_filters(&self, plan: LogicalPlan) -> LogicalPlan {
match plan {
LogicalPlan::Filter { input, predicate } => {
let optimized_input = self.merge_filters(*input);
match optimized_input {
LogicalPlan::Filter {
input: inner_input,
predicate: inner_pred,
} => {
LogicalPlan::Filter {
input: inner_input,
predicate: Predicate::And(Box::new(inner_pred), Box::new(predicate)),
}
}
other => LogicalPlan::Filter {
input: Box::new(other),
predicate,
},
}
}
LogicalPlan::Project { input, columns } => LogicalPlan::Project {
input: Box::new(self.merge_filters(*input)),
columns,
},
LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
input: Box::new(self.merge_filters(*input)),
count,
},
other => other,
}
}
fn convert_filter_to_range_scan(&self, plan: LogicalPlan) -> LogicalPlan {
match plan {
LogicalPlan::Filter { input, predicate } => {
let optimized_input = self.convert_filter_to_range_scan(*input);
if let LogicalPlan::Scan { ref collection } = optimized_input {
if let Some((start, end)) = Self::extract_key_range(&predicate) {
return LogicalPlan::RangeScan {
collection: collection.clone(),
start_key: start,
end_key: end,
};
}
}
LogicalPlan::Filter {
input: Box::new(optimized_input),
predicate,
}
}
LogicalPlan::Project { input, columns } => LogicalPlan::Project {
input: Box::new(self.convert_filter_to_range_scan(*input)),
columns,
},
LogicalPlan::Limit { input, count } => LogicalPlan::Limit {
input: Box::new(self.convert_filter_to_range_scan(*input)),
count,
},
other => other,
}
}
fn to_physical(&self, plan: &LogicalPlan) -> Result<PhysicalPlan> {
match plan {
LogicalPlan::Scan { collection } => Ok(PhysicalPlan::SeqScan {
collection: collection.clone(),
}),
LogicalPlan::RangeScan {
collection,
start_key,
end_key,
} => Ok(PhysicalPlan::IndexScan {
collection: collection.clone(),
start: start_key.clone(),
end: end_key.clone(),
}),
LogicalPlan::Filter { input, predicate } => {
let physical_input = self.to_physical(input)?;
let circuit = self.compile_predicate_circuit(predicate)?;
Ok(PhysicalPlan::FheFilter {
input: Box::new(physical_input),
circuit,
predicate: predicate.clone(),
})
}
LogicalPlan::Project { input, columns } => {
let physical_input = self.to_physical(input)?;
Ok(PhysicalPlan::Projection {
input: Box::new(physical_input),
columns: columns.clone(),
})
}
LogicalPlan::Limit { input, count } => {
let physical_input = self.to_physical(input)?;
Ok(PhysicalPlan::Limit {
input: Box::new(physical_input),
count: *count,
})
}
LogicalPlan::PointLookup { collection, key } => Ok(PhysicalPlan::PointGet {
collection: collection.clone(),
key: key.clone(),
}),
}
}
pub fn estimate_cost(&self, plan: &PhysicalPlan) -> PlanCost {
match plan {
PhysicalPlan::SeqScan { collection } => {
let rows = self.stats.collection_size(collection);
let io_bytes = rows * self.stats.average_value_size;
PlanCost::compute(rows, 0, io_bytes)
}
PhysicalPlan::IndexScan {
collection,
start,
end,
} => {
let total = self.stats.collection_size(collection);
let selectivity = match (start, end) {
(Some(_), Some(_)) => 0.10,
(Some(_), None) | (None, Some(_)) => 0.30,
(None, None) => 1.0,
};
let rows = ((total as f64) * selectivity).max(1.0) as u64;
let io_bytes = rows * self.stats.average_value_size;
PlanCost::compute(rows, 0, io_bytes)
}
PhysicalPlan::FheFilter { input, circuit, .. } => {
let input_cost = self.estimate_cost(input);
let fhe_ops = input_cost.estimated_rows * (circuit.gate_count as u64);
let output_rows = (input_cost.estimated_rows / 2).max(1);
let io_bytes = output_rows * self.stats.average_value_size;
PlanCost::compute(
input_cost.estimated_rows,
input_cost.estimated_fhe_ops + fhe_ops,
input_cost.estimated_io_bytes + io_bytes,
)
}
PhysicalPlan::Projection { input, .. } => {
let mut cost = self.estimate_cost(input);
cost.estimated_io_bytes = (cost.estimated_io_bytes as f64 * 0.8) as u64;
cost.total_cost = (cost.estimated_rows as f64 * PlanCost::SCAN_COST_PER_ROW)
+ (cost.estimated_fhe_ops as f64 * PlanCost::FHE_COST_PER_OP)
+ (cost.estimated_io_bytes as f64 * PlanCost::IO_COST_PER_BYTE);
cost
}
PhysicalPlan::Limit { input, count } => {
let input_cost = self.estimate_cost(input);
let rows = (*count as u64).min(input_cost.estimated_rows);
let io_bytes = rows * self.stats.average_value_size;
PlanCost::compute(rows, input_cost.estimated_fhe_ops, io_bytes)
}
PhysicalPlan::PointGet { .. } => PlanCost::compute(1, 0, self.stats.average_value_size),
}
}
pub fn choose_cheaper<'a>(&self, a: &'a PhysicalPlan, b: &'a PhysicalPlan) -> &'a PhysicalPlan {
let cost_a = self.estimate_cost(a);
let cost_b = self.estimate_cost(b);
if cost_a.total_cost <= cost_b.total_cost {
a
} else {
b
}
}
fn referenced_columns(predicate: &Predicate) -> Vec<String> {
let mut cols = Vec::new();
Self::collect_columns(predicate, &mut cols);
cols.sort();
cols.dedup();
cols
}
fn collect_columns(predicate: &Predicate, out: &mut Vec<String>) {
match predicate {
Predicate::Eq(col, _)
| Predicate::Gt(col, _)
| Predicate::Lt(col, _)
| Predicate::Gte(col, _)
| Predicate::Lte(col, _) => {
out.push(col.name.clone());
}
Predicate::And(l, r) | Predicate::Or(l, r) => {
Self::collect_columns(l, out);
Self::collect_columns(r, out);
}
Predicate::Not(inner) => {
Self::collect_columns(inner, out);
}
}
}
fn extract_key_range(predicate: &Predicate) -> Option<(Option<Vec<u8>>, Option<Vec<u8>>)> {
match predicate {
Predicate::Gt(col, blob) if col.name == "_key" => {
Some((Some(blob.as_bytes().to_vec()), None))
}
Predicate::Gte(col, blob) if col.name == "_key" => {
Some((Some(blob.as_bytes().to_vec()), None))
}
Predicate::Lt(col, blob) if col.name == "_key" => {
Some((None, Some(blob.as_bytes().to_vec())))
}
Predicate::Lte(col, blob) if col.name == "_key" => {
Some((None, Some(blob.as_bytes().to_vec())))
}
Predicate::And(left, right) => {
let lr = Self::extract_key_range(left);
let rr = Self::extract_key_range(right);
match (lr, rr) {
(Some((s1, e1)), Some((s2, e2))) => {
let start = s1.or(s2);
let end = e1.or(e2);
Some((start, end))
}
(Some(range), None) | (None, Some(range)) => Some(range),
(None, None) => None,
}
}
_ => None,
}
}
fn compile_predicate_circuit(&self, predicate: &Predicate) -> Result<Circuit> {
let mut compiler = PredicateCompiler::new();
compiler.compile(predicate, EncryptedType::U8)
}
}
impl Default for QueryPlanner {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for LogicalPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.fmt_indented(f, 0)
}
}
impl LogicalPlan {
fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
let pad = " ".repeat(indent);
match self {
LogicalPlan::Scan { collection } => {
writeln!(f, "{}Scan({})", pad, collection)
}
LogicalPlan::RangeScan {
collection,
start_key,
end_key,
} => {
writeln!(
f,
"{}RangeScan({}, start={}, end={})",
pad,
collection,
start_key.is_some(),
end_key.is_some()
)
}
LogicalPlan::Filter { input, predicate } => {
writeln!(f, "{}Filter(pred={:?})", pad, predicate)?;
input.fmt_indented(f, indent + 1)
}
LogicalPlan::Project { input, columns } => {
writeln!(f, "{}Project({:?})", pad, columns)?;
input.fmt_indented(f, indent + 1)
}
LogicalPlan::Limit { input, count } => {
writeln!(f, "{}Limit({})", pad, count)?;
input.fmt_indented(f, indent + 1)
}
LogicalPlan::PointLookup { collection, key } => {
writeln!(f, "{}PointLookup({}, key={})", pad, collection, key)
}
}
}
}
impl std::fmt::Display for PhysicalPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.fmt_indented(f, 0)
}
}
impl PhysicalPlan {
fn fmt_indented(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
let pad = " ".repeat(indent);
match self {
PhysicalPlan::SeqScan { collection } => {
writeln!(f, "{}SeqScan({})", pad, collection)
}
PhysicalPlan::IndexScan {
collection,
start,
end,
} => {
writeln!(
f,
"{}IndexScan({}, start={}, end={})",
pad,
collection,
start.is_some(),
end.is_some()
)
}
PhysicalPlan::FheFilter {
input, predicate, ..
} => {
writeln!(f, "{}FheFilter(pred={:?})", pad, predicate)?;
input.fmt_indented(f, indent + 1)
}
PhysicalPlan::Projection { input, columns } => {
writeln!(f, "{}Projection({:?})", pad, columns)?;
input.fmt_indented(f, indent + 1)
}
PhysicalPlan::Limit { input, count } => {
writeln!(f, "{}Limit({})", pad, count)?;
input.fmt_indented(f, indent + 1)
}
PhysicalPlan::PointGet { collection, key } => {
writeln!(f, "{}PointGet({}, key={})", pad, collection, key)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::col;
fn make_blob(v: u8) -> CipherBlob {
CipherBlob::new(vec![v])
}
#[test]
fn test_scan_plan() -> Result<()> {
let planner = QueryPlanner::new();
let query = Query::Filter {
collection: "users".to_string(),
predicate: Predicate::Gt(col("age"), make_blob(18)),
};
let plan = planner.plan(&query)?;
match &plan {
PhysicalPlan::FheFilter { input, .. } => {
assert!(matches!(input.as_ref(), PhysicalPlan::SeqScan { .. }));
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected FheFilter, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_range_scan_pushdown() -> Result<()> {
let planner = QueryPlanner::new();
let query = Query::Filter {
collection: "data".to_string(),
predicate: Predicate::And(
Box::new(Predicate::Gte(col("_key"), make_blob(10))),
Box::new(Predicate::Lt(col("_key"), make_blob(50))),
),
};
let plan = planner.plan(&query)?;
match &plan {
PhysicalPlan::IndexScan {
collection,
start,
end,
} => {
assert_eq!(collection, "data");
assert!(start.is_some());
assert!(end.is_some());
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected IndexScan, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_predicate_pushdown() -> Result<()> {
let planner = QueryPlanner::new();
let scan = LogicalPlan::Scan {
collection: "users".to_string(),
};
let project = LogicalPlan::Project {
input: Box::new(scan),
columns: vec!["age".to_string(), "name".to_string()],
};
let filter = LogicalPlan::Filter {
input: Box::new(project),
predicate: Predicate::Gt(col("age"), make_blob(18)),
};
let optimized = planner.push_predicates_down(filter);
match &optimized {
LogicalPlan::Project { input, columns } => {
assert!(columns.contains(&"age".to_string()));
assert!(matches!(input.as_ref(), LogicalPlan::Filter { .. }));
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected Project, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_filter_merge() -> Result<()> {
let planner = QueryPlanner::new();
let scan = LogicalPlan::Scan {
collection: "users".to_string(),
};
let filter1 = LogicalPlan::Filter {
input: Box::new(scan),
predicate: Predicate::Gt(col("age"), make_blob(18)),
};
let filter2 = LogicalPlan::Filter {
input: Box::new(filter1),
predicate: Predicate::Lt(col("age"), make_blob(65)),
};
let optimized = planner.merge_filters(filter2);
match &optimized {
LogicalPlan::Filter { input, predicate } => {
assert!(matches!(predicate, Predicate::And(_, _)));
assert!(matches!(input.as_ref(), LogicalPlan::Scan { .. }));
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected Filter, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_cost_estimation() -> Result<()> {
let planner = QueryPlanner::new();
planner.stats().set_collection_size("data", 10_000);
let seq_scan = PhysicalPlan::SeqScan {
collection: "data".to_string(),
};
let seq_cost = planner.estimate_cost(&seq_scan);
let idx_scan = PhysicalPlan::IndexScan {
collection: "data".to_string(),
start: Some(vec![10]),
end: Some(vec![50]),
};
let idx_cost = planner.estimate_cost(&idx_scan);
assert!(
idx_cost.total_cost < seq_cost.total_cost,
"IndexScan cost ({}) should be less than SeqScan cost ({})",
idx_cost.total_cost,
seq_cost.total_cost,
);
let point = PhysicalPlan::PointGet {
collection: "data".to_string(),
key: Key::from_str("k"),
};
let point_cost = planner.estimate_cost(&point);
assert!(
point_cost.total_cost < idx_cost.total_cost,
"PointGet cost ({}) should be less than IndexScan cost ({})",
point_cost.total_cost,
idx_cost.total_cost,
);
Ok(())
}
#[test]
fn test_limit_planning() -> Result<()> {
let planner = QueryPlanner::new();
let scan = LogicalPlan::Scan {
collection: "logs".to_string(),
};
let filter = LogicalPlan::Filter {
input: Box::new(scan),
predicate: Predicate::Eq(col("level"), make_blob(1)),
};
let limited = LogicalPlan::Limit {
input: Box::new(filter),
count: 10,
};
let physical = planner.to_physical(&limited)?;
match &physical {
PhysicalPlan::Limit { input, count } => {
assert_eq!(*count, 10);
assert!(matches!(input.as_ref(), PhysicalPlan::FheFilter { .. }));
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected Limit, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_plan_with_fhe_filter() -> Result<()> {
let planner = QueryPlanner::new();
let query = Query::Filter {
collection: "accounts".to_string(),
predicate: Predicate::And(
Box::new(Predicate::Gt(col("balance"), make_blob(100))),
Box::new(Predicate::Lt(col("balance"), make_blob(200))),
),
};
let plan = planner.plan(&query)?;
match &plan {
PhysicalPlan::FheFilter { circuit, .. } => {
assert!(circuit.gate_count > 0);
assert_eq!(circuit.result_type, EncryptedType::Bool);
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected FheFilter, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_complex_plan() -> Result<()> {
let planner = QueryPlanner::new();
planner.stats().set_collection_size("orders", 50_000);
let query = Query::Filter {
collection: "orders".to_string(),
predicate: Predicate::Or(
Box::new(Predicate::Eq(col("status"), make_blob(1))),
Box::new(Predicate::And(
Box::new(Predicate::Gt(col("amount"), make_blob(100))),
Box::new(Predicate::Lt(col("amount"), make_blob(255))),
)),
),
};
let plan = planner.plan(&query)?;
let cost = planner.estimate_cost(&plan);
assert!(cost.estimated_fhe_ops > 0);
assert!(cost.total_cost > 0.0);
let plan_str = format!("{}", plan);
assert!(!plan_str.is_empty());
Ok(())
}
#[test]
fn test_get_query_planning() -> Result<()> {
let planner = QueryPlanner::new();
let query = Query::Get {
collection: "users".to_string(),
key: Key::from_str("user:42"),
};
let plan = planner.plan(&query)?;
match &plan {
PhysicalPlan::PointGet { collection, key } => {
assert_eq!(collection, "users");
assert_eq!(key.to_string_lossy(), "user:42");
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected PointGet, got: {:?}",
other
))));
}
}
let cost = planner.estimate_cost(&plan);
assert_eq!(cost.estimated_rows, 1);
assert_eq!(cost.estimated_fhe_ops, 0);
Ok(())
}
#[test]
fn test_range_query_planning() -> Result<()> {
let planner = QueryPlanner::new();
let query = Query::Range {
collection: "events".to_string(),
start: Key::from_str("2024-01"),
end: Key::from_str("2024-12"),
};
let plan = planner.plan(&query)?;
match &plan {
PhysicalPlan::IndexScan {
collection,
start,
end,
} => {
assert_eq!(collection, "events");
assert!(start.is_some());
assert!(end.is_some());
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected IndexScan, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_cost_comparison() -> Result<()> {
let planner = QueryPlanner::new();
planner.stats().set_collection_size("items", 100_000);
let scan = PhysicalPlan::SeqScan {
collection: "items".to_string(),
};
let idx = PhysicalPlan::IndexScan {
collection: "items".to_string(),
start: Some(vec![1]),
end: Some(vec![10]),
};
let cheaper = planner.choose_cheaper(&scan, &idx);
assert!(matches!(cheaper, PhysicalPlan::IndexScan { .. }));
Ok(())
}
#[test]
fn test_filter_not_pushed_below_limit() -> Result<()> {
let planner = QueryPlanner::new();
let scan = LogicalPlan::Scan {
collection: "data".to_string(),
};
let limited = LogicalPlan::Limit {
input: Box::new(scan),
count: 10,
};
let filter = LogicalPlan::Filter {
input: Box::new(limited),
predicate: Predicate::Gt(col("x"), make_blob(5)),
};
let optimized = planner.push_predicates_down(filter);
match &optimized {
LogicalPlan::Filter { input, .. } => {
assert!(matches!(input.as_ref(), LogicalPlan::Limit { .. }));
}
other => {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Expected Filter on top, got: {:?}",
other
))));
}
}
Ok(())
}
#[test]
fn test_stats_update() {
let planner = QueryPlanner::new();
planner.stats().set_collection_size("big_table", 1_000_000);
let size = planner.stats().collection_size("big_table");
assert_eq!(size, 1_000_000);
let default_size = planner.stats().collection_size("unknown");
assert_eq!(default_size, 1000);
}
#[test]
fn test_referenced_columns() {
let pred = Predicate::And(
Box::new(Predicate::Gt(col("age"), make_blob(18))),
Box::new(Predicate::Or(
Box::new(Predicate::Lt(col("salary"), make_blob(100))),
Box::new(Predicate::Eq(col("age"), make_blob(30))),
)),
);
let cols = QueryPlanner::referenced_columns(&pred);
assert_eq!(cols, vec!["age".to_string(), "salary".to_string()]);
}
#[test]
fn test_display_plan_cost() {
let cost = PlanCost::compute(1000, 50, 256_000);
let display = format!("{}", cost);
assert!(display.contains("1000"));
assert!(display.contains("50"));
}
#[test]
fn test_logical_plan_display() {
let plan = LogicalPlan::Filter {
input: Box::new(LogicalPlan::Scan {
collection: "t".to_string(),
}),
predicate: Predicate::Eq(col("x"), make_blob(1)),
};
let s = format!("{}", plan);
assert!(s.contains("Filter"));
assert!(s.contains("Scan"));
}
}