use crate::core::error::Result;
use crate::query::plan::{LogicalOp, LogicalPlan, UnaryOp};
use super::{OptimizationRule, Statistics};
pub struct LimitPushdown;
impl OptimizationRule for LimitPushdown {
fn name(&self) -> &str {
"limit-pushdown"
}
fn apply(&self, plan: &LogicalPlan, _stats: &Statistics) -> Result<Option<LogicalPlan>> {
let (new_root, changed) = self.push_down(&plan.root, None)?;
if changed {
Ok(Some(LogicalPlan {
root: new_root,
temporal_context: plan.temporal_context.clone(),
hints: plan.hints.clone(),
}))
} else {
Ok(None)
}
}
}
impl LimitPushdown {
fn push_down(&self, op: &LogicalOp, limit: Option<usize>) -> Result<(LogicalOp, bool)> {
match op {
LogicalOp::Unary {
op: UnaryOp::Limit(n),
input,
} => {
let effective_limit = limit.map(|l| l.min(*n)).unwrap_or(*n);
let (optimized_input, changed) = self.push_down(input, Some(effective_limit))?;
if let LogicalOp::Unary {
op: UnaryOp::Limit(child_limit),
input: child_input,
} = &optimized_input
{
let combined_limit = effective_limit.min(*child_limit);
return Ok((
LogicalOp::unary(UnaryOp::Limit(combined_limit), (**child_input).clone()),
true,
));
}
Ok((
LogicalOp::unary(UnaryOp::Limit(effective_limit), optimized_input),
changed || effective_limit != *n,
))
}
LogicalOp::Unary {
op:
UnaryOp::VectorRank {
embedding,
top_k,
property_key,
},
input,
} => {
let (optimized_input, input_changed) = self.push_down(input, None)?;
let new_top_k = match (limit, *top_k) {
(Some(l), Some(k)) => Some(l.min(k)),
(Some(l), None) => Some(l),
(None, k) => k,
};
let changed = input_changed || new_top_k != *top_k;
Ok((
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: embedding.clone(),
top_k: new_top_k,
property_key: property_key.clone(),
},
optimized_input,
),
changed,
))
}
LogicalOp::Unary {
op: UnaryOp::Sort { key, descending },
input,
} => {
let (optimized_input, changed) = self.push_down(input, None)?;
Ok((
LogicalOp::unary(
UnaryOp::Sort {
key: key.clone(),
descending: *descending,
},
optimized_input,
),
changed,
))
}
LogicalOp::Unary {
op: UnaryOp::Filter(predicate),
input,
} => {
let (optimized_input, changed) = self.push_down(input, None)?; Ok((
LogicalOp::unary(UnaryOp::Filter(predicate.clone()), optimized_input),
changed,
))
}
LogicalOp::Unary {
op: UnaryOp::Project(props),
input,
} => {
let (optimized_input, changed) = self.push_down(input, limit)?;
Ok((
LogicalOp::unary(UnaryOp::Project(props.clone()), optimized_input),
changed,
))
}
LogicalOp::Unary { op, input } => {
let (optimized_input, changed) = self.push_down(input, None)?;
Ok((LogicalOp::unary(op.clone(), optimized_input), changed))
}
LogicalOp::Binary { op, left, right } => {
let (opt_left, left_changed) = self.push_down(left, None)?;
let (opt_right, right_changed) = self.push_down(right, None)?;
Ok((
LogicalOp::binary(op.clone(), opt_left, opt_right),
left_changed || right_changed,
))
}
LogicalOp::Scan(_) | LogicalOp::Empty => Ok((op.clone(), false)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::NodeId;
use crate::query::plan::ScanOp;
use std::sync::Arc;
fn test_stats() -> Statistics {
Statistics::default()
}
#[test]
fn test_combine_consecutive_limits() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(5),
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_some());
let new_plan = result.unwrap();
match &new_plan.root {
LogicalOp::Unary {
op: UnaryOp::Limit(n),
input,
} => {
assert_eq!(*n, 5);
assert!(matches!(input.as_ref(), LogicalOp::Scan(_)));
}
_ => panic!("Expected Limit"),
}
}
#[test]
fn test_propagate_limit_to_vector_rank() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(5),
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: Some(10),
property_key: None,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_some());
let new_plan = result.unwrap();
match &new_plan.root {
LogicalOp::Unary {
op: UnaryOp::Limit(_),
input,
} => match input.as_ref() {
LogicalOp::Unary {
op: UnaryOp::VectorRank { top_k, .. },
..
} => {
assert_eq!(*top_k, Some(5));
}
_ => panic!("Expected VectorRank"),
},
_ => panic!("Expected Limit"),
}
}
#[test]
fn test_no_change_for_simple_limit() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none()); }
#[test]
fn test_propagate_limit_through_filter() {
use crate::query::ir::{Predicate, PredicateValue};
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(5),
LogicalOp::unary(
UnaryOp::Filter(Predicate::eq(
"name".to_string(),
PredicateValue::String("Alice".to_string()),
)),
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none()); }
#[test]
fn test_propagate_limit_through_project() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(5),
LogicalOp::unary(
UnaryOp::Project(vec!["name".to_string()]),
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_some());
}
#[test]
fn test_binary_op_limit_pushdown() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(5),
LogicalOp::binary(
crate::query::plan::BinaryOp::Union,
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
LogicalOp::unary(
UnaryOp::Limit(15),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()])),
),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_binary_op_limit_pushdown_children() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::binary(
crate::query::plan::BinaryOp::Union,
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::unary(
UnaryOp::Limit(20),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()])),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_some());
}
#[test]
fn test_propagate_limit_to_vector_rank_equal_limit() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::unary(
UnaryOp::VectorRank {
embedding: Arc::from([0.1f32; 4].as_slice()),
top_k: Some(10),
property_key: None,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none());
}
#[test]
fn test_propagate_limit_through_sort() {
let rule = LimitPushdown;
let stats = test_stats();
let plan = LogicalPlan::new(LogicalOp::unary(
UnaryOp::Limit(5),
LogicalOp::unary(
UnaryOp::Sort {
key: crate::query::plan::SortKey::Property("age".into()),
descending: true,
},
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
),
));
let result = rule.apply(&plan, &stats).unwrap();
assert!(result.is_none()); }
}
#[cfg(test)]
mod sentry_tests {
use super::*;
use crate::core::NodeId;
use crate::query::plan::{BinaryOp, ScanOp};
fn test_stats() -> Statistics {
Statistics::default()
}
#[test]
fn test_pushdown_binary_partial_change() {
let rule = LimitPushdown;
let stats = test_stats();
let left = LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::unary(
UnaryOp::Limit(20),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
);
let right = LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()]));
let plan = LogicalPlan::new(LogicalOp::binary(BinaryOp::Union, left, right));
let result = rule.apply(&plan, &stats).unwrap();
let expected_plan = LogicalPlan::new(LogicalOp::binary(
BinaryOp::Union,
LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(2).unwrap()])),
));
assert_eq!(
result,
Some(expected_plan),
"Partial optimization in left branch should propagate with exact limit structure"
);
}
#[test]
fn test_pushdown_limit_neq_child_limit() {
let rule = LimitPushdown;
let op = LogicalOp::unary(
UnaryOp::Limit(10),
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
);
let (new_op, changed) = rule.push_down(&op, Some(5)).unwrap();
assert!(changed, "Expected limit to shrink from 10 to 5");
if let LogicalOp::Unary {
op: UnaryOp::Limit(n),
..
} = new_op
{
assert_eq!(n, 5);
} else {
panic!("Expected limit");
}
}
#[test]
fn test_pushdown_vector_rank_neq() {
let rule = LimitPushdown;
let op = LogicalOp::unary(
UnaryOp::VectorRank {
embedding: vec![0.1f32].into(),
top_k: None,
property_key: None,
},
LogicalOp::Scan(ScanOp::NodeLookup(vec![NodeId::new(1).unwrap()])),
);
let (new_op, changed) = rule.push_down(&op, Some(5)).unwrap();
assert!(changed, "Vector rank should apply new top_k");
if let LogicalOp::Unary {
op: UnaryOp::VectorRank { top_k, .. },
..
} = new_op
{
assert_eq!(top_k, Some(5));
} else {
panic!("Expected VectorRank");
}
}
}