use super::*;
use crate::datatypes::values::Value;
use crate::graph::languages::cypher::ast::*;
pub(crate) fn fuse_node_scan_top_k(query: &mut CypherQuery) {
use crate::graph::languages::cypher::ast::is_aggregate_expression;
if query.clauses.len() < 4 {
return;
}
let mut i = 0;
while i + 3 < query.clauses.len() {
if i > 0 {
i += 1;
continue;
}
let (match_idx, where_idx, return_idx, orderby_idx, limit_idx) =
if matches!(&query.clauses[i], Clause::Match(_))
&& matches!(&query.clauses[i + 1], Clause::Where(_))
&& i + 4 < query.clauses.len()
&& matches!(&query.clauses[i + 2], Clause::Return(_))
&& matches!(&query.clauses[i + 3], Clause::OrderBy(_))
&& matches!(&query.clauses[i + 4], Clause::Limit(_))
{
(i, Some(i + 1), i + 2, i + 3, i + 4)
} else if matches!(&query.clauses[i], Clause::Match(_))
&& matches!(&query.clauses[i + 1], Clause::Return(_))
&& matches!(&query.clauses[i + 2], Clause::OrderBy(_))
&& matches!(&query.clauses[i + 3], Clause::Limit(_))
{
(i, None, i + 1, i + 2, i + 3)
} else {
i += 1;
continue;
};
let is_single_node = if let Clause::Match(mc) = &query.clauses[match_idx] {
mc.patterns.len() == 1
&& mc.patterns[0].elements.len() == 1
&& matches!(
mc.patterns[0].elements[0],
crate::graph::core::pattern_matching::PatternElement::Node(_)
)
&& mc.path_assignments.is_empty()
} else {
false
};
if !is_single_node {
i += 1;
continue;
}
let return_ok = if let Clause::Return(r) = &query.clauses[return_idx] {
!r.distinct
&& !r
.items
.iter()
.any(|item| is_aggregate_expression(&item.expression))
&& !r
.items
.iter()
.any(|item| matches!(item.expression, Expression::FunctionCall { .. }))
} else {
false
};
if !return_ok {
i += 1;
continue;
}
let sort_info = if let Clause::OrderBy(o) = &query.clauses[orderby_idx] {
if o.items.len() == 1 {
Some((o.items[0].expression.clone(), !o.items[0].ascending))
} else {
None
}
} else {
None
};
let Some((sort_expr, descending)) = sort_info else {
i += 1;
continue;
};
if let Clause::Return(r) = &query.clauses[return_idx] {
let return_aliases: std::collections::HashSet<String> = r
.items
.iter()
.filter_map(|item| item.alias.clone())
.collect();
if expression_touches_vars(&sort_expr, &return_aliases) {
i += 1;
continue;
}
}
let limit_val = if let Clause::Limit(l) = &query.clauses[limit_idx] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => Some(*n as usize),
_ => None,
}
} else {
None
};
let Some(limit) = limit_val else {
i += 1;
continue;
};
query.clauses.remove(limit_idx);
query.clauses.remove(orderby_idx);
let return_clause = if let Clause::Return(r) = query.clauses.remove(return_idx) {
r
} else {
unreachable!()
};
let where_predicate = if let Some(wi) = where_idx {
if let Clause::Where(w) = query.clauses.remove(wi) {
Some(w.predicate)
} else {
None
}
} else {
None
};
let match_clause = if let Clause::Match(mc) = query.clauses.remove(match_idx) {
mc
} else {
unreachable!()
};
query.clauses.insert(
match_idx,
Clause::FusedNodeScanTopK {
match_clause,
where_predicate,
return_clause,
sort_expression: sort_expr,
descending,
limit,
},
);
i += 1;
}
}
pub(crate) fn fuse_vector_score_order_limit(query: &mut CypherQuery) {
use crate::graph::languages::cypher::ast::is_aggregate_expression;
if query.clauses.len() < 3 {
return;
}
let mut i = 0;
while i + 2 < query.clauses.len() {
let is_pattern = matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2]
),
(Clause::Return(_), Clause::OrderBy(_), Clause::Limit(_))
);
if !is_pattern {
i += 1;
continue;
}
let (score_idx, alias) = if let Clause::Return(r) = &query.clauses[i] {
if r.distinct
|| r.items
.iter()
.any(|item| is_aggregate_expression(&item.expression))
{
i += 1;
continue;
}
let found = r.items.iter().enumerate().find(|(_, item)| {
matches!(
&item.expression,
Expression::FunctionCall { name, .. }
if name == "vector_score"
)
});
match found {
Some((idx, item)) => {
let col = return_item_column_name(item);
(idx, col)
}
None => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
let descending = if let Clause::OrderBy(o) = &query.clauses[i + 1] {
if o.items.len() != 1 {
i += 1;
continue;
}
let sort_name = match &o.items[0].expression {
Expression::Variable(v) => v.clone(),
other => expression_to_column_name(other),
};
if sort_name != alias {
i += 1;
continue;
}
!o.items[0].ascending
} else {
i += 1;
continue;
};
let limit = if let Clause::Limit(l) = &query.clauses[i + 2] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => *n as usize,
_ => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
query.clauses.remove(i + 2); query.clauses.remove(i + 1); let return_clause = if let Clause::Return(r) = query.clauses.remove(i) {
r
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedVectorScoreTopK {
return_clause,
score_item_index: score_idx,
descending,
limit,
},
);
i += 1;
}
}
pub(crate) fn return_item_column_name(item: &ReturnItem) -> String {
if let Some(ref alias) = item.alias {
alias.clone()
} else {
expression_to_column_name(&item.expression)
}
}
pub(crate) fn expression_to_column_name(expr: &Expression) -> String {
match expr {
Expression::Variable(name) => name.clone(),
Expression::PropertyAccess { variable, property } => format!("{}.{}", variable, property),
Expression::FunctionCall { name, args, .. } => {
let args_str: Vec<String> = args.iter().map(expression_to_column_name).collect();
format!("{}({})", name, args_str.join(", "))
}
_ => format!("{:?}", expr),
}
}
pub(crate) fn fuse_order_by_top_k(query: &mut CypherQuery) {
if query.clauses.len() < 3 {
return;
}
let mut i = 0;
while i + 2 < query.clauses.len() {
let is_pattern = matches!(
(
&query.clauses[i],
&query.clauses[i + 1],
&query.clauses[i + 2]
),
(Clause::Return(_), Clause::OrderBy(_), Clause::Limit(_))
);
if !is_pattern {
i += 1;
continue;
}
let (score_idx, sort_expression) = if let Clause::Return(r) = &query.clauses[i] {
if r.distinct {
i += 1;
continue;
}
if r.items.iter().any(|item| {
crate::graph::languages::cypher::ast::is_aggregate_expression(&item.expression)
}) {
i += 1;
continue;
}
if r.items
.iter()
.any(|item| matches!(item.expression, Expression::WindowFunction { .. }))
{
i += 1;
continue;
}
let order_info = if let Clause::OrderBy(o) = &query.clauses[i + 1] {
if o.items.len() != 1 {
i += 1;
continue;
}
let order_alias = match &o.items[0].expression {
Expression::Variable(v) => v.clone(),
other => expression_to_column_name(other),
};
let found = r
.items
.iter()
.enumerate()
.find(|(_, item)| return_item_column_name(item) == order_alias);
match found {
Some((idx, _)) => (idx, None), None => {
(0, Some(o.items[0].expression.clone()))
}
}
} else {
i += 1;
continue;
};
order_info
} else {
i += 1;
continue;
};
let descending = if let Clause::OrderBy(o) = &query.clauses[i + 1] {
!o.items[0].ascending
} else {
i += 1;
continue;
};
let limit = if let Clause::Limit(l) = &query.clauses[i + 2] {
match &l.count {
Expression::Literal(Value::Int64(n)) if *n > 0 => *n as usize,
_ => {
i += 1;
continue;
}
}
} else {
i += 1;
continue;
};
query.clauses.remove(i + 2); query.clauses.remove(i + 1); let return_clause = if let Clause::Return(r) = query.clauses.remove(i) {
r
} else {
unreachable!()
};
query.clauses.insert(
i,
Clause::FusedOrderByTopK {
return_clause,
score_item_index: score_idx,
descending,
limit,
sort_expression,
},
);
i += 1;
}
}