use std::sync::Arc;
use crate::index::vector::DistanceMetric;
#[derive(Debug, Clone, PartialEq)]
pub struct QueryAst {
pub temporal: Option<TemporalClause>,
pub source: SourceClause,
pub rank: Option<RankClause>,
pub where_clause: Option<WhereClause>,
pub return_clause: Option<ReturnClause>,
pub order: Option<OrderClause>,
pub skip: Option<usize>,
pub limit: Option<usize>,
}
impl QueryAst {
pub fn new(source: SourceClause) -> Self {
QueryAst {
temporal: None,
source,
rank: None,
where_clause: None,
return_clause: None,
order: None,
skip: None,
limit: None,
}
}
#[must_use]
pub fn with_temporal(mut self, temporal: TemporalClause) -> Self {
self.temporal = Some(temporal);
self
}
#[must_use]
pub fn with_rank(mut self, rank: RankClause) -> Self {
self.rank = Some(rank);
self
}
#[must_use]
pub fn with_where(mut self, where_clause: WhereClause) -> Self {
self.where_clause = Some(where_clause);
self
}
#[must_use]
pub fn with_return(mut self, return_clause: ReturnClause) -> Self {
self.return_clause = Some(return_clause);
self
}
#[must_use]
pub fn with_order(mut self, order: OrderClause) -> Self {
self.order = Some(order);
self
}
#[must_use]
pub fn with_skip(mut self, skip: usize) -> Self {
self.skip = Some(skip);
self
}
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn is_temporal(&self) -> bool {
self.temporal.is_some()
}
pub fn has_vector_ops(&self) -> bool {
matches!(
self.source,
SourceClause::VectorSearch { .. } | SourceClause::FindSimilar { .. }
) || self.rank.is_some()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TemporalClause {
AsOf {
valid_time: TimestampLiteral,
transaction_time: Option<TimestampLiteral>,
},
Between {
start: TimestampLiteral,
end: TimestampLiteral,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum TimestampLiteral {
String(String),
Integer(i64),
}
#[derive(Debug, Clone, PartialEq)]
pub enum SourceClause {
Match(Vec<Pattern>),
VectorSearch {
embedding: EmbeddingRef,
metric: Option<DistanceMetric>,
limit: usize,
},
FindSimilar {
node_ref: NodeRef,
limit: usize,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum EmbeddingRef {
Parameter(String),
Literal(Arc<[f32]>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum NodeRef {
Identifier(String),
Id(u64),
Parameter(String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct Pattern {
pub elements: Vec<PatternElement>,
}
impl Pattern {
pub fn node(node: NodePattern) -> Self {
Pattern {
elements: vec![PatternElement::Node(node)],
}
}
#[must_use]
pub fn then(mut self, rel: RelationshipPattern, node: NodePattern) -> Self {
self.elements.push(PatternElement::Relationship(rel));
self.elements.push(PatternElement::Node(node));
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PatternElement {
Node(NodePattern),
Relationship(RelationshipPattern),
}
#[derive(Debug, Clone, PartialEq)]
pub struct NodePattern {
pub variable: Option<String>,
pub label: Option<String>,
pub properties: Option<PropertyMap>,
}
impl NodePattern {
pub fn empty() -> Self {
NodePattern {
variable: None,
label: None,
properties: None,
}
}
pub fn var(name: impl Into<String>) -> Self {
NodePattern {
variable: Some(name.into()),
label: None,
properties: None,
}
}
pub fn with_label(name: impl Into<String>, label: impl Into<String>) -> Self {
NodePattern {
variable: Some(name.into()),
label: Some(label.into()),
properties: None,
}
}
#[must_use]
pub fn with_properties(mut self, properties: PropertyMap) -> Self {
self.properties = Some(properties);
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RelationshipPattern {
pub variable: Option<String>,
pub rel_type: Option<String>,
pub direction: RelationshipDirection,
pub depth: Option<DepthSpec>,
}
impl RelationshipPattern {
pub fn outgoing() -> Self {
RelationshipPattern {
variable: None,
rel_type: None,
direction: RelationshipDirection::Outgoing,
depth: None,
}
}
pub fn incoming() -> Self {
RelationshipPattern {
variable: None,
rel_type: None,
direction: RelationshipDirection::Incoming,
depth: None,
}
}
pub fn both() -> Self {
RelationshipPattern {
variable: None,
rel_type: None,
direction: RelationshipDirection::Both,
depth: None,
}
}
#[must_use]
pub fn with_type(mut self, rel_type: impl Into<String>) -> Self {
self.rel_type = Some(rel_type.into());
self
}
#[must_use]
pub fn with_variable(mut self, var: impl Into<String>) -> Self {
self.variable = Some(var.into());
self
}
#[must_use]
pub fn with_depth(mut self, depth: DepthSpec) -> Self {
self.depth = Some(depth);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RelationshipDirection {
Outgoing,
Incoming,
Both,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DepthSpec {
Exact(usize),
Max(usize),
Range {
min: usize,
max: usize,
},
Variable,
}
impl DepthSpec {
pub fn one() -> Self {
DepthSpec::Exact(1)
}
pub fn exact(n: usize) -> Self {
DepthSpec::Exact(n)
}
pub fn range(min: usize, max: usize) -> Result<Self, &'static str> {
if min > max {
return Err("DepthSpec range min must be <= max");
}
Ok(DepthSpec::Range { min, max })
}
}
pub type PropertyMap = Vec<(String, PropertyValue)>;
#[derive(Debug, Clone, PartialEq)]
pub enum PropertyValue {
Null,
Bool(bool),
Int(i64),
Float(f64),
String(String),
Parameter(String),
}
impl From<bool> for PropertyValue {
fn from(v: bool) -> Self {
PropertyValue::Bool(v)
}
}
impl From<i64> for PropertyValue {
fn from(v: i64) -> Self {
PropertyValue::Int(v)
}
}
impl From<f64> for PropertyValue {
fn from(v: f64) -> Self {
PropertyValue::Float(v)
}
}
impl From<String> for PropertyValue {
fn from(v: String) -> Self {
PropertyValue::String(v)
}
}
impl From<&str> for PropertyValue {
fn from(v: &str) -> Self {
PropertyValue::String(v.to_string())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RankClause {
pub embedding: EmbeddingRef,
pub top_k: Option<usize>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct WhereClause {
pub predicate: PredicateExpr,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(missing_docs)]
pub enum PredicateExpr {
Comparison {
left: Expression,
op: ComparisonOp,
right: Expression,
},
Exists(PropertyAccess),
IsNull(PropertyAccess),
IsNotNull(PropertyAccess),
Contains {
property: PropertyAccess,
substring: String,
},
StartsWith {
property: PropertyAccess,
prefix: String,
},
EndsWith {
property: PropertyAccess,
suffix: String,
},
In {
property: PropertyAccess,
values: Vec<PropertyValue>,
},
And(Box<PredicateExpr>, Box<PredicateExpr>),
Or(Box<PredicateExpr>, Box<PredicateExpr>),
Not(Box<PredicateExpr>),
Grouped(Box<PredicateExpr>),
}
impl PredicateExpr {
pub fn eq(left: Expression, right: Expression) -> Self {
PredicateExpr::Comparison {
left,
op: ComparisonOp::Eq,
right,
}
}
pub fn gt(left: Expression, right: Expression) -> Self {
PredicateExpr::Comparison {
left,
op: ComparisonOp::Gt,
right,
}
}
pub fn and(self, other: PredicateExpr) -> Self {
PredicateExpr::And(Box::new(self), Box::new(other))
}
pub fn or(self, other: PredicateExpr) -> Self {
PredicateExpr::Or(Box::new(self), Box::new(other))
}
pub fn negate(self) -> Self {
PredicateExpr::Not(Box::new(self))
}
}
impl std::ops::Not for PredicateExpr {
type Output = Self;
fn not(self) -> Self::Output {
PredicateExpr::Not(Box::new(self))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ComparisonOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(missing_docs)]
pub enum Expression {
Property(PropertyAccess),
Identifier(String),
Literal(PropertyValue),
Parameter(String),
FunctionCall { name: String, args: Vec<Expression> },
}
impl Expression {
pub fn property(var: impl Into<String>, prop: impl Into<String>) -> Self {
Expression::Property(PropertyAccess {
variable: var.into(),
property: prop.into(),
})
}
pub fn literal(value: PropertyValue) -> Self {
Expression::Literal(value)
}
pub fn int(value: i64) -> Self {
Expression::Literal(PropertyValue::Int(value))
}
pub fn string(value: impl Into<String>) -> Self {
Expression::Literal(PropertyValue::String(value.into()))
}
pub fn param(name: impl Into<String>) -> Self {
Expression::Parameter(name.into())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PropertyAccess {
pub variable: String,
pub property: String,
}
impl PropertyAccess {
pub fn new(variable: impl Into<String>, property: impl Into<String>) -> Self {
PropertyAccess {
variable: variable.into(),
property: property.into(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ReturnClause {
pub items: Vec<ReturnItem>,
pub distinct: bool,
}
impl ReturnClause {
pub fn new(items: Vec<ReturnItem>) -> Self {
ReturnClause {
items,
distinct: false,
}
}
#[must_use]
pub fn distinct(mut self) -> Self {
self.distinct = true;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ReturnItem {
pub expression: Expression,
pub alias: Option<String>,
}
impl ReturnItem {
pub fn new(expression: Expression) -> Self {
ReturnItem {
expression,
alias: None,
}
}
#[must_use]
pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
self.alias = Some(alias.into());
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct OrderClause {
pub items: Vec<OrderItem>,
}
impl OrderClause {
pub fn new(items: Vec<OrderItem>) -> Self {
OrderClause { items }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct OrderItem {
pub expression: Expression,
pub descending: bool,
}
impl OrderItem {
pub fn asc(expression: Expression) -> Self {
OrderItem {
expression,
descending: false,
}
}
pub fn desc(expression: Expression) -> Self {
OrderItem {
expression,
descending: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_ast_new() {
let source = SourceClause::Match(vec![Pattern::node(NodePattern::var("n"))]);
let query = QueryAst::new(source);
assert!(query.temporal.is_none());
assert!(query.rank.is_none());
assert!(query.where_clause.is_none());
assert!(query.return_clause.is_none());
assert!(query.order.is_none());
assert!(query.skip.is_none());
assert!(query.limit.is_none());
}
#[test]
fn test_query_ast_with_temporal() {
let source = SourceClause::Match(vec![Pattern::node(NodePattern::var("n"))]);
let query = QueryAst::new(source).with_temporal(TemporalClause::AsOf {
valid_time: TimestampLiteral::String("2024-01-15".to_string()),
transaction_time: None,
});
assert!(query.is_temporal());
}
#[test]
fn test_query_ast_has_vector_ops() {
let query = QueryAst::new(SourceClause::VectorSearch {
embedding: EmbeddingRef::Parameter("emb".to_string()),
metric: None,
limit: 10,
});
assert!(query.has_vector_ops());
let query = QueryAst::new(SourceClause::FindSimilar {
node_ref: NodeRef::Identifier("n".to_string()),
limit: 10,
});
assert!(query.has_vector_ops());
let query = QueryAst::new(SourceClause::Match(vec![Pattern::node(NodePattern::var(
"n",
))]))
.with_rank(RankClause {
embedding: EmbeddingRef::Parameter("emb".to_string()),
top_k: Some(10),
});
assert!(query.has_vector_ops());
let query = QueryAst::new(SourceClause::Match(vec![Pattern::node(NodePattern::var(
"n",
))]));
assert!(!query.has_vector_ops());
}
#[test]
fn test_node_pattern_empty() {
let node = NodePattern::empty();
assert!(node.variable.is_none());
assert!(node.label.is_none());
assert!(node.properties.is_none());
}
#[test]
fn test_node_pattern_var() {
let node = NodePattern::var("n");
assert_eq!(node.variable, Some("n".to_string()));
assert!(node.label.is_none());
}
#[test]
fn test_node_pattern_with_label() {
let node = NodePattern::with_label("n", "Person");
assert_eq!(node.variable, Some("n".to_string()));
assert_eq!(node.label, Some("Person".to_string()));
}
#[test]
fn test_node_pattern_with_properties() {
let props = vec![(
"name".to_string(),
PropertyValue::String("Alice".to_string()),
)];
let node = NodePattern::with_label("n", "Person").with_properties(props.clone());
assert_eq!(node.properties, Some(props));
}
#[test]
fn test_relationship_pattern_outgoing() {
let rel = RelationshipPattern::outgoing().with_type("KNOWS");
assert_eq!(rel.direction, RelationshipDirection::Outgoing);
assert_eq!(rel.rel_type, Some("KNOWS".to_string()));
}
#[test]
fn test_relationship_pattern_incoming() {
let rel = RelationshipPattern::incoming().with_type("FOLLOWS");
assert_eq!(rel.direction, RelationshipDirection::Incoming);
assert_eq!(rel.rel_type, Some("FOLLOWS".to_string()));
}
#[test]
fn test_relationship_pattern_both() {
let rel = RelationshipPattern::both();
assert_eq!(rel.direction, RelationshipDirection::Both);
}
#[test]
fn test_relationship_pattern_with_depth() {
let rel = RelationshipPattern::outgoing()
.with_type("KNOWS")
.with_depth(DepthSpec::range(1, 3).unwrap());
assert_eq!(rel.depth, Some(DepthSpec::Range { min: 1, max: 3 }));
}
#[test]
fn test_pattern_chain() {
let pattern = Pattern::node(NodePattern::var("a"))
.then(
RelationshipPattern::outgoing().with_type("KNOWS"),
NodePattern::var("b"),
)
.then(
RelationshipPattern::outgoing().with_type("LIKES"),
NodePattern::var("c"),
);
assert_eq!(pattern.elements.len(), 5); }
#[test]
fn test_predicate_comparison() {
let pred = PredicateExpr::eq(Expression::property("n", "age"), Expression::int(30));
assert!(matches!(pred, PredicateExpr::Comparison { .. }));
}
#[test]
fn test_predicate_and() {
let p1 = PredicateExpr::eq(Expression::property("n", "age"), Expression::int(30));
let p2 = PredicateExpr::eq(
Expression::property("n", "name"),
Expression::string("Alice"),
);
let combined = p1.and(p2);
assert!(matches!(combined, PredicateExpr::And(_, _)));
}
#[test]
fn test_predicate_or() {
let p1 = PredicateExpr::eq(Expression::property("n", "age"), Expression::int(30));
let p2 = PredicateExpr::eq(Expression::property("n", "age"), Expression::int(40));
let combined = p1.or(p2);
assert!(matches!(combined, PredicateExpr::Or(_, _)));
}
#[test]
fn test_predicate_not() {
let pred = PredicateExpr::eq(
Expression::property("n", "active"),
Expression::literal(PropertyValue::Bool(true)),
);
let negated = !pred;
assert!(matches!(negated, PredicateExpr::Not(_)));
}
#[test]
fn test_return_clause() {
let items = vec![
ReturnItem::new(Expression::property("n", "name")),
ReturnItem::new(Expression::property("n", "age")).with_alias("years"),
];
let ret = ReturnClause::new(items);
assert!(!ret.distinct);
assert_eq!(ret.items.len(), 2);
assert_eq!(ret.items[1].alias, Some("years".to_string()));
}
#[test]
fn test_return_clause_distinct() {
let items = vec![ReturnItem::new(Expression::property("n", "name"))];
let ret = ReturnClause::new(items).distinct();
assert!(ret.distinct);
}
#[test]
fn test_order_clause() {
let items = vec![
OrderItem::desc(Expression::property("n", "age")),
OrderItem::asc(Expression::property("n", "name")),
];
let order = OrderClause::new(items);
assert_eq!(order.items.len(), 2);
assert!(order.items[0].descending);
assert!(!order.items[1].descending);
}
#[test]
fn test_property_value_from() {
let _v: PropertyValue = true.into();
let _v: PropertyValue = 42i64.into();
let _v: PropertyValue = 2.71f64.into();
let _v: PropertyValue = "hello".into();
let _v: PropertyValue = String::from("world").into();
}
#[test]
fn test_temporal_as_of() {
let temporal = TemporalClause::AsOf {
valid_time: TimestampLiteral::String("2024-01-15T10:00:00Z".to_string()),
transaction_time: Some(TimestampLiteral::Integer(1705315200000)),
};
if let TemporalClause::AsOf {
valid_time,
transaction_time,
} = temporal
{
assert!(matches!(valid_time, TimestampLiteral::String(_)));
assert!(transaction_time.is_some());
}
}
#[test]
fn test_temporal_between() {
let temporal = TemporalClause::Between {
start: TimestampLiteral::String("2024-01-01".to_string()),
end: TimestampLiteral::String("2024-12-31".to_string()),
};
assert!(matches!(temporal, TemporalClause::Between { .. }));
}
#[test]
fn test_embedding_ref_parameter() {
let emb = EmbeddingRef::Parameter("embedding".to_string());
assert!(matches!(emb, EmbeddingRef::Parameter(_)));
}
#[test]
fn test_embedding_ref_literal() {
let emb = EmbeddingRef::Literal(Arc::from([0.1f32, 0.2, 0.3].as_slice()));
if let EmbeddingRef::Literal(arr) = emb {
assert_eq!(arr.len(), 3);
}
}
#[test]
fn test_depth_spec() {
assert_eq!(DepthSpec::one(), DepthSpec::Exact(1));
assert_eq!(DepthSpec::exact(3), DepthSpec::Exact(3));
assert_eq!(
DepthSpec::range(1, 5).unwrap(),
DepthSpec::Range { min: 1, max: 5 }
);
assert!(DepthSpec::range(5, 1).is_err()); }
}