use crate::core::error::Result;
use crate::query::plan::{LogicalOp, LogicalPlan, UnaryOp};
use super::{OptimizationRule, Statistics};
pub struct PredicatePushdown;
impl OptimizationRule for PredicatePushdown {
fn name(&self) -> &str {
"predicate-pushdown"
}
fn apply(&self, plan: &LogicalPlan, _stats: &Statistics) -> Result<Option<LogicalPlan>> {
let (new_root, changed) = self.push_down(&plan.root)?;
if changed {
Ok(Some(LogicalPlan {
root: new_root,
temporal_context: plan.temporal_context.clone(),
hints: plan.hints.clone(),
}))
} else {
Ok(None)
}
}
}
impl PredicatePushdown {
fn push_down(&self, op: &LogicalOp) -> Result<(LogicalOp, bool)> {
match op {
LogicalOp::Unary {
op: UnaryOp::Filter(predicate),
input,
} => {
let (optimized_input, input_changed) = self.push_down(input)?;
match &optimized_input {
LogicalOp::Scan(_) => Ok((
LogicalOp::unary(UnaryOp::Filter(predicate.clone()), optimized_input),
input_changed,
)),
LogicalOp::Unary {
op: UnaryOp::Traverse { .. },
..
} => Ok((
LogicalOp::unary(UnaryOp::Filter(predicate.clone()), optimized_input),
input_changed,
)),
LogicalOp::Unary {
op:
UnaryOp::VectorRank {
embedding,
top_k,
property_key,
},
input: vector_input,
} => {
if top_k.is_some() {
Ok((
LogicalOp::unary(
UnaryOp::Filter(predicate.clone()),
optimized_input,
),
input_changed,
))
} else {
let filter_then_rank = LogicalOp::unary(
UnaryOp::VectorRank {
embedding: embedding.clone(),
top_k: *top_k,
property_key: property_key.clone(),
},
LogicalOp::unary(
UnaryOp::Filter(predicate.clone()),
(**vector_input).clone(),
),
);
Ok((filter_then_rank, true))
}
}
LogicalOp::Unary {
op: UnaryOp::Sort { key, descending },
input: sort_input,
} => {
let filter_then_sort = LogicalOp::unary(
UnaryOp::Sort {
key: key.clone(),
descending: *descending,
},
LogicalOp::unary(
UnaryOp::Filter(predicate.clone()),
(**sort_input).clone(),
),
);
Ok((filter_then_sort, true))
}
_ => Ok((
LogicalOp::unary(UnaryOp::Filter(predicate.clone()), optimized_input),
input_changed,
)),
}
}
LogicalOp::Unary { op, input } => {
let (optimized_input, changed) = self.push_down(input)?;
Ok((LogicalOp::unary(op.clone(), optimized_input), changed))
}
LogicalOp::Binary { op, left, right } => {
let (opt_left, left_changed) = self.push_down(left)?;
let (opt_right, right_changed) = self.push_down(right)?;
Ok((
LogicalOp::binary(op.clone(), opt_left, opt_right),
left_changed || right_changed,
))
}
LogicalOp::Scan(_) | LogicalOp::Empty => Ok((op.clone(), false)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::NodeId;
use crate::query::ir::Predicate;
use crate::query::plan::{ScanOp, SortKey};
use std::sync::Arc;
fn test_stats() -> Statistics {
Statistics::default()
}
#[test]
fn test_no_change_on_simple_filter() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("name", "Alice")),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none()); }
#[test]
fn test_push_filter_below_vector_rank_no_limit() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("name", "Alice")),
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: None,
property_key: None,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
let expected_plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: None,
property_key: None,
},
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("name", "Alice")),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
assert_eq!(result, Some(expected_plan));
}
#[test]
fn test_stop_filter_at_traverse() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("name", "Alice")),
LogicalOp::unary(
UnaryOp::Traverse {
label: None,
direction: crate::query::ir::Direction::Outgoing,
depth: crate::query::ir::TraversalDepth::Exact(1),
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_stop_filter_at_scan() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("name", "Alice")),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_stop_filter_at_vector_rank_with_limit() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("name", "Alice")),
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: Some(10),
property_key: None,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_push_filter_below_sort() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("active", true)),
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("created".to_string()),
descending: true,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
let expected_plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("created".to_string()),
descending: true,
},
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("active", true)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
assert_eq!(result, Some(expected_plan));
}
#[test]
fn test_multi_level_pushdown() {
let rule = PredicatePushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("active", true)),
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("created".to_string()),
descending: true,
},
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: None,
property_key: None,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
));
let mut current_plan = plan;
let mut changed = true;
let mut iterations = 0;
while changed && iterations < 10 {
let result = rule.apply(¤t_plan, &stats).unwrap();
if let Some(new_plan) = result {
current_plan = new_plan;
changed = true;
} else {
changed = false;
}
iterations += 1;
}
let expected_plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("created".to_string()),
descending: true,
},
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: None,
property_key: None,
},
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("active", true)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
));
assert_eq!(current_plan, expected_plan);
}
#[test]
fn test_binary_op_recursion_logic() {
use crate::query::plan::BinaryOp;
let rule = PredicatePushdown;
let stats = test_stats();
let left_op = LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("active", true)),
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("created".to_string()),
descending: true,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
);
let right_op = LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()]));
let plan = LogicalPlan::new(LogicalOp::binary(BinaryOp::Union, left_op, right_op));
let result = rule.apply(&plan, &stats).unwrap();
let expected_plan = LogicalPlan::new(LogicalOp::binary(
BinaryOp::Union,
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("created".to_string()),
descending: true,
},
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("active", true)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()])),
));
assert_eq!(
result,
Some(expected_plan),
"Binary op with one changed branch should return Some"
);
}
}
#[cfg(test)]
mod sentry_tests {
use super::*;
#[test]
fn test_apply_unchanged() {
let rule = PredicatePushdown;
let stats = Statistics::default();
let plan = LogicalPlan::new(LogicalOp::Scan(ScanOp::NodeLookup(vec![
NodeId::new(1).unwrap(),
])));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_apply_changed() {
let rule = PredicatePushdown;
let stats = Statistics::default();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("a", 1)),
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("a".to_string()),
descending: true,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_some());
}
#[test]
fn test_pushdown_traverse() {
let rule = PredicatePushdown;
let stats = Statistics::default();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("a", 1)),
LogicalOp::unary(
UnaryOp::Traverse {
label: None,
direction: crate::query::ir::Direction::Outgoing,
depth: crate::query::ir::TraversalDepth::Exact(1),
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_pushdown_unsupported_unary_op() {
let rule = PredicatePushdown;
let stats = Statistics::default();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("a", 1)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
let plan2 = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("a", 1)),
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result2 = rule.apply(&plan2, &stats).unwrap();
assert!(result2.is_none());
}
#[test]
fn test_pushdown_name() {
let rule = PredicatePushdown;
assert_eq!(rule.name(), "predicate-pushdown");
}
use crate::core::NodeId;
use crate::query::ir::Predicate;
use crate::query::plan::{BinaryOp, ScanOp, SortKey};
#[test]
fn test_binary_op_partial_optimization() {
let rule = PredicatePushdown;
let stats = Statistics::default();
let left = LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("a", 1)),
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("a".to_string()),
descending: true,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
);
let right = LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("b", 2)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()])),
);
let root = LogicalOp::binary(BinaryOp::Union, left, right);
let plan = LogicalPlan::new(root);
let result = rule.apply(&plan, &stats).unwrap();
let expected_plan = LogicalPlan::new(LogicalOp::binary(
BinaryOp::Union,
LogicalOp::unary(
UnaryOp::Sort {
key: SortKey::Property("a".to_string()),
descending: true,
},
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("a", 1)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq("b", 2)),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()])),
),
));
assert_eq!(
result,
Some(expected_plan),
"Partial optimization (left branch) should trigger change"
);
}
}