use std::collections::HashMap;
use std::sync::Arc;
use super::ast::*;
use super::error::CypherError;
use super::parser::CypherParser;
use crate::core::temporal::{TimeRange, Timestamp};
use crate::query::builder::Query;
use crate::query::ir::{Predicate, PredicateValue, QueryOp, SortKey, TraversalDepth};
use crate::query::plan::{QueryHints, TemporalContext};
#[derive(Debug, Clone, PartialEq)]
pub enum CypherParameterValue {
Null,
Bool(bool),
Int(i64),
Float(f64),
String(String),
Embedding(Arc<[f32]>),
}
impl CypherParameterValue {
fn to_predicate_value(&self) -> Result<PredicateValue, CypherError> {
match self {
CypherParameterValue::Null => Ok(PredicateValue::Null),
CypherParameterValue::Bool(b) => Ok(PredicateValue::Bool(*b)),
CypherParameterValue::Int(n) => Ok(PredicateValue::Int(*n)),
CypherParameterValue::Float(f) => Ok(PredicateValue::Float(*f)),
CypherParameterValue::String(s) => Ok(PredicateValue::String(s.clone())),
CypherParameterValue::Embedding(_) => Err(CypherError::ParameterError(
"embedding parameters cannot be used as predicate values".to_string(),
)),
}
}
}
pub struct CypherConverter {
params: HashMap<String, CypherParameterValue>,
}
impl Default for CypherConverter {
fn default() -> Self {
Self::new()
}
}
impl CypherConverter {
#[must_use]
pub fn new() -> Self {
Self {
params: HashMap::new(),
}
}
#[must_use]
pub fn with_params(params: HashMap<String, CypherParameterValue>) -> Self {
Self { params }
}
pub fn convert(&self, stmt: CypherStatement) -> Result<Query, CypherError> {
match stmt {
CypherStatement::Match {
pattern,
where_clause,
return_clause,
temporal,
with_clauses,
..
} => {
let mut ops = Vec::new();
for pat in &pattern {
self.convert_pattern(pat, &mut ops)?;
}
if let Some(expr) = where_clause {
let predicate = self.convert_expr_to_predicate(&expr)?;
ops.push(QueryOp::Filter(predicate));
}
let temporal_context = if let Some(ref temporal_clause) = temporal {
Some(self.convert_temporal(temporal_clause, &mut ops)?)
} else {
None
};
for with in &with_clauses {
if let Some(ref where_expr) = with.where_clause {
let predicate = self.convert_expr_to_predicate(where_expr)?;
ops.push(QueryOp::Filter(predicate));
}
}
if return_clause.distinct {
ops.push(QueryOp::Distinct);
}
for item in &return_clause.items {
if let CypherReturnItem::Expression {
expr: CypherExpr::FunctionCall { name, .. },
..
} = item
&& name.eq_ignore_ascii_case("count")
{
ops.push(QueryOp::Count);
}
}
let vector_rank_emitted = self.try_emit_vector_rank(&return_clause, &mut ops)?;
if !vector_rank_emitted {
for order_item in &return_clause.order_by {
let sort_key = self.convert_order_item_to_sort_key(order_item)?;
ops.push(QueryOp::Sort {
key: sort_key,
descending: order_item.descending,
});
}
}
if let Some(skip) = return_clause.skip {
ops.push(QueryOp::Skip(skip));
}
if let Some(limit) = return_clause.limit {
ops.push(QueryOp::Limit(limit));
}
Ok(Query {
ops,
temporal_context,
hints: QueryHints::default(),
})
}
}
}
fn convert_pattern(
&self,
pattern: &CypherPattern,
ops: &mut Vec<QueryOp>,
) -> Result<(), CypherError> {
let mut first_node = true;
for element in &pattern.elements {
match element {
CypherPatternElement::Node(node) => {
if first_node {
let label = node.labels.first().cloned();
ops.push(QueryOp::ScanNodes { label });
first_node = false;
}
self.convert_node_properties(node, ops)?;
}
CypherPatternElement::Relationship(rel) => {
self.convert_relationship(rel, ops)?;
}
}
}
Ok(())
}
fn convert_node_properties(
&self,
node: &CypherNodePattern,
ops: &mut Vec<QueryOp>,
) -> Result<(), CypherError> {
for (key, value) in &node.properties {
let pred_value = self.convert_value_to_predicate_value(value)?;
ops.push(QueryOp::Filter(Predicate::Eq {
key: key.clone(),
value: pred_value,
}));
}
Ok(())
}
fn convert_relationship(
&self,
rel: &CypherRelPattern,
ops: &mut Vec<QueryOp>,
) -> Result<(), CypherError> {
let label = rel.rel_types.first().cloned();
let depth = self.convert_depth(&rel.depth);
let op = match rel.direction {
CypherDirection::Outgoing => QueryOp::TraverseOut { label, depth },
CypherDirection::Incoming => QueryOp::TraverseIn { label, depth },
CypherDirection::Both => QueryOp::TraverseBoth { label, depth },
};
ops.push(op);
for (key, value) in &rel.properties {
let pred_value = self.convert_value_to_predicate_value(value)?;
ops.push(QueryOp::Filter(Predicate::Eq {
key: key.clone(),
value: pred_value,
}));
}
Ok(())
}
fn convert_depth(&self, depth: &Option<CypherDepth>) -> TraversalDepth {
match depth {
None => TraversalDepth::Exact(1),
Some(CypherDepth::Unbounded) => TraversalDepth::Variable,
Some(CypherDepth::Exact(n)) => TraversalDepth::Exact(*n),
Some(CypherDepth::Max(n)) => TraversalDepth::Max(*n),
Some(CypherDepth::Min(n)) => TraversalDepth::Range {
min: *n,
max: usize::MAX,
},
Some(CypherDepth::Range { min, max }) => TraversalDepth::Range {
min: *min,
max: *max,
},
}
}
fn convert_expr_to_predicate(&self, expr: &CypherExpr) -> Result<Predicate, CypherError> {
match expr {
CypherExpr::Comparison { left, op, right } => self.convert_comparison(left, *op, right),
CypherExpr::And(left, right) => {
let left_pred = self.convert_expr_to_predicate(left)?;
let right_pred = self.convert_expr_to_predicate(right)?;
Ok(Predicate::And(vec![left_pred, right_pred]))
}
CypherExpr::Or(left, right) => {
let left_pred = self.convert_expr_to_predicate(left)?;
let right_pred = self.convert_expr_to_predicate(right)?;
Ok(Predicate::Or(vec![left_pred, right_pred]))
}
CypherExpr::Not(inner) => {
let inner_pred = self.convert_expr_to_predicate(inner)?;
Ok(Predicate::Not(Box::new(inner_pred)))
}
CypherExpr::IsNull(inner) => {
let key = self.extract_property_key(inner)?;
Ok(Predicate::Eq {
key,
value: PredicateValue::Null,
})
}
CypherExpr::IsNotNull(inner) => {
let key = self.extract_property_key(inner)?;
Ok(Predicate::Ne {
key,
value: PredicateValue::Null,
})
}
CypherExpr::In { expr, values } => {
let key = self.extract_property_key(expr)?;
let pred_values = values
.iter()
.map(|v| self.convert_expr_to_predicate_value(v))
.collect::<Result<Vec<_>, _>>()?;
Ok(Predicate::In {
key,
values: pred_values,
})
}
CypherExpr::Contains { expr, substring } => {
let key = self.extract_property_key(expr)?;
Ok(Predicate::Contains {
key,
substring: substring.clone(),
})
}
CypherExpr::StartsWith { expr, prefix } => {
let key = self.extract_property_key(expr)?;
Ok(Predicate::StartsWith {
key,
prefix: prefix.clone(),
})
}
CypherExpr::EndsWith { expr, suffix } => {
let key = self.extract_property_key(expr)?;
Ok(Predicate::EndsWith {
key,
suffix: suffix.clone(),
})
}
CypherExpr::Grouped(inner) => self.convert_expr_to_predicate(inner),
CypherExpr::Value(CypherValue::Bool(true)) => Ok(Predicate::True),
CypherExpr::Value(CypherValue::Bool(false)) => Ok(Predicate::False),
_ => Err(CypherError::UnsupportedFeature(format!(
"expression cannot be converted to predicate: {expr:?}"
))),
}
}
fn convert_comparison(
&self,
left: &CypherExpr,
op: CypherCompOp,
right: &CypherExpr,
) -> Result<Predicate, CypherError> {
let key = self.extract_property_key(left)?;
let value = self.convert_expr_to_predicate_value(right)?;
Ok(match op {
CypherCompOp::Eq => Predicate::Eq { key, value },
CypherCompOp::Ne => Predicate::Ne { key, value },
CypherCompOp::Gt => Predicate::Gt { key, value },
CypherCompOp::Ge => Predicate::Gte { key, value },
CypherCompOp::Lt => Predicate::Lt { key, value },
CypherCompOp::Le => Predicate::Lte { key, value },
})
}
fn extract_property_key(&self, expr: &CypherExpr) -> Result<String, CypherError> {
match expr {
CypherExpr::Property { property, .. } => Ok(property.clone()),
CypherExpr::Variable(name) => Ok(name.clone()),
_ => Err(CypherError::SemanticError(format!(
"expected property access or variable, got: {expr:?}"
))),
}
}
fn convert_expr_to_predicate_value(
&self,
expr: &CypherExpr,
) -> Result<PredicateValue, CypherError> {
match expr {
CypherExpr::Value(val) => self.convert_value_to_predicate_value(val),
_ => Err(CypherError::SemanticError(format!(
"expected literal value, got: {expr:?}"
))),
}
}
fn convert_value_to_predicate_value(
&self,
value: &CypherValue,
) -> Result<PredicateValue, CypherError> {
match value {
CypherValue::Null => Ok(PredicateValue::Null),
CypherValue::Bool(b) => Ok(PredicateValue::Bool(*b)),
CypherValue::Int(n) => Ok(PredicateValue::Int(*n)),
CypherValue::Float(f) => Ok(PredicateValue::Float(*f)),
CypherValue::String(s) => Ok(PredicateValue::String(s.clone())),
CypherValue::Parameter(name) => {
let param = self.params.get(name).ok_or_else(|| {
CypherError::ParameterError(format!("unbound parameter: ${name}"))
})?;
param.to_predicate_value()
}
CypherValue::Vector(_) => Err(CypherError::UnsupportedFeature(
"vector literals in predicate position".to_string(),
)),
}
}
fn convert_order_item_to_sort_key(
&self,
item: &CypherOrderItem,
) -> Result<SortKey, CypherError> {
match &item.expr {
CypherExpr::Property { property, .. } => Ok(SortKey::Property(property.clone())),
CypherExpr::Variable(name) => Ok(SortKey::Property(name.clone())),
_ => Err(CypherError::UnsupportedFeature(format!(
"unsupported expression in ORDER BY clause: {:?}",
item.expr
))),
}
}
fn convert_temporal(
&self,
temporal: &CypherTemporal,
ops: &mut Vec<QueryOp>,
) -> Result<TemporalContext, CypherError> {
match temporal {
CypherTemporal::AsOfTimestamp(ts_str) => {
let ts = parse_timestamp_string(ts_str)?;
Ok(TemporalContext::as_of(ts, ts))
}
CypherTemporal::AsOfValidTime(ts_str) => {
let ts = parse_timestamp_string(ts_str)?;
Ok(TemporalContext::as_of_valid_time(ts))
}
CypherTemporal::AsOfSystemTime(ts_str) => {
let ts = parse_timestamp_string(ts_str)?;
Ok(TemporalContext::as_of_transaction_time(ts))
}
CypherTemporal::BiTemporal {
valid_time,
system_time,
} => {
let vt = parse_timestamp_string(valid_time)?;
let st = parse_timestamp_string(system_time)?;
Ok(TemporalContext::as_of(vt, st))
}
CypherTemporal::Between { start, end } => {
let start_ts = parse_timestamp_string(start)?;
let end_ts = parse_timestamp_string(end)?;
let time_range = TimeRange::new(start_ts, end_ts).map_err(|e| {
CypherError::InvalidTemporalClause(format!("invalid time range: {e}"))
})?;
ops.push(QueryOp::Between { time_range });
Ok(TemporalContext {
valid_time_between: Some(time_range),
include_history: true,
..Default::default()
})
}
}
}
fn try_emit_vector_rank(
&self,
return_clause: &CypherReturn,
ops: &mut Vec<QueryOp>,
) -> Result<bool, CypherError> {
if return_clause.order_by.len() != 1 {
return Ok(false);
}
let order_item = &return_clause.order_by[0];
if let CypherExpr::FunctionCall { name, args } = &order_item.expr
&& is_vector_function(name)
{
let (property_key, embedding) = self.extract_vector_args(args)?;
let top_k = return_clause.limit;
ops.push(QueryOp::RankBySimilarity {
embedding,
top_k,
property_key: Some(property_key),
});
return Ok(true);
}
if let CypherExpr::Variable(ref alias_name) = order_item.expr {
for item in &return_clause.items {
if let CypherReturnItem::Expression {
expr: CypherExpr::FunctionCall { name, args },
alias: Some(alias),
} = item
&& alias == alias_name
&& is_vector_function(name)
{
let (property_key, embedding) = self.extract_vector_args(args)?;
let top_k = return_clause.limit;
ops.push(QueryOp::RankBySimilarity {
embedding,
top_k,
property_key: Some(property_key),
});
return Ok(true);
}
}
}
Ok(false)
}
fn extract_vector_args(
&self,
args: &[CypherExpr],
) -> Result<(String, Arc<[f32]>), CypherError> {
if args.len() != 2 {
return Err(CypherError::SemanticError(
"vector similarity function expects exactly 2 arguments".to_string(),
));
}
let property_key = match &args[0] {
CypherExpr::Property { property, .. } => property.clone(),
_ => return Err(CypherError::SemanticError(
"first argument to vector function must be a property access (e.g., d.embedding)"
.to_string(),
)),
};
let embedding = match &args[1] {
CypherExpr::Value(CypherValue::Parameter(param_name)) => {
let param = self.params.get(param_name).ok_or_else(|| {
CypherError::ParameterError(format!("unbound parameter: ${param_name}"))
})?;
match param {
CypherParameterValue::Embedding(emb) => Arc::clone(emb),
_ => {
return Err(CypherError::ParameterError(format!(
"parameter ${param_name} must be an Embedding, got: {param:?}"
)));
}
}
}
CypherExpr::Value(CypherValue::Vector(v)) => Arc::clone(v),
_ => {
return Err(CypherError::SemanticError(
"second argument to vector function must be a parameter or vector literal"
.to_string(),
));
}
};
Ok((property_key, embedding))
}
}
fn parse_timestamp_string(s: &str) -> Result<Timestamp, CypherError> {
let trimmed = s.trim().trim_matches('\'').trim_matches('"');
if let Ok(micros) = trimmed.parse::<i64>() {
return Ok(Timestamp::from(micros));
}
if let Ok(dt) = trimmed.parse::<chrono::DateTime<chrono::Utc>>() {
return Ok(Timestamp::from(dt.timestamp_micros()));
}
if let Ok(date) = trimmed.parse::<chrono::NaiveDate>()
&& let Some(dt) = date.and_hms_opt(0, 0, 0)
{
return Ok(Timestamp::from(dt.and_utc().timestamp_micros()));
}
Err(CypherError::InvalidTimestamp(s.to_string()))
}
fn is_vector_function(name: &str) -> bool {
matches!(
name,
"vector.similarity" | "vector.cosine" | "vector.euclidean"
)
}
pub fn parse_cypher(input: &str) -> Result<Query, CypherError> {
let stmt = CypherParser::parse(input)?;
CypherConverter::new().convert(stmt)
}
pub fn parse_cypher_with_params(
input: &str,
params: HashMap<String, CypherParameterValue>,
) -> Result<Query, CypherError> {
let stmt = CypherParser::parse(input)?;
CypherConverter::with_params(params).convert(stmt)
}