use std::collections::HashMap;
use std::sync::Arc;
use sqlparser::ast::{
BinaryOperator, Expr, OrderByExpr, Query as SqlQuery, SelectItem, SetExpr, Statement,
TableFactor, TableWithJoins, Value,
};
use crate::index::vector::DistanceMetric;
use crate::query::builder::Query;
use crate::query::ir::{Predicate, PredicateValue, QueryOp, SortKey};
use crate::query::plan::QueryHints;
use super::error::SqlError;
use super::match_parser;
use super::parser::SqlParser;
use super::temporal_parser;
use super::vector_parser::{self, EmbeddingRef, VectorOp};
const SIMILAR_TO_DEFAULT_K: usize = 100;
#[derive(Debug, Clone)]
pub enum SqlParameterValue {
Scalar(PredicateValue),
Embedding(Arc<[f32]>),
}
pub struct SqlConverter {
parameters: HashMap<String, SqlParameterValue>,
}
impl SqlConverter {
pub fn new() -> Self {
SqlConverter {
parameters: HashMap::new(),
}
}
pub fn with_parameters(parameters: HashMap<String, SqlParameterValue>) -> Self {
SqlConverter { parameters }
}
pub fn bind(&mut self, name: impl Into<String>, value: SqlParameterValue) -> &mut Self {
self.parameters.insert(name.into(), value);
self
}
pub fn convert_sql(&self, sql: &str) -> Result<Query, SqlError> {
let extracted_temporal = temporal_parser::extract_temporal_clauses(sql)?;
let extracted_match = match_parser::extract_match_clauses(&extracted_temporal.cleaned_sql)?;
let extracted_vector = vector_parser::extract_vector_clauses(&extracted_match.cleaned_sql)?;
let stmt = SqlParser::parse(&extracted_vector.cleaned_sql)?;
let mut query = self.convert(&stmt)?;
query.temporal_context = extracted_temporal.to_temporal_context()?;
for pattern in extracted_match.patterns.iter().rev() {
let insert_pos = query
.ops
.iter()
.position(|op| !matches!(op, QueryOp::ScanNodes { .. } | QueryOp::ScanEdges { .. }))
.unwrap_or(query.ops.len());
query.ops.insert(insert_pos, pattern.to_query_op());
}
for vector_op in &extracted_vector.vector_ops {
self.apply_vector_op(&mut query, vector_op)?;
}
Ok(query)
}
pub fn convert(&self, stmt: &Statement) -> Result<Query, SqlError> {
match stmt {
Statement::Query(query) => self.convert_query(query),
_ => Err(SqlError::UnsupportedFeature(format!(
"Only SELECT queries are supported, got: {:?}",
stmt
))),
}
}
fn convert_query(&self, query: &SqlQuery) -> Result<Query, SqlError> {
let select = match query.body.as_ref() {
SetExpr::Select(select) => select,
_ => {
return Err(SqlError::UnsupportedFeature(
"Only simple SELECT queries are supported".to_string(),
));
}
};
let mut ops = Vec::new();
self.convert_from(&select.from, &mut ops)?;
if let Some(ref selection) = select.selection {
let predicate = self.convert_expr_to_predicate(selection)?;
ops.push(QueryOp::Filter(predicate));
}
self.convert_projection(&select.projection, &mut ops)?;
for order_by in &query.order_by {
self.convert_order_by(order_by, &mut ops)?;
}
if let Some(ref offset) = query.offset {
let n = self.expr_to_usize(&offset.value)?;
ops.push(QueryOp::Skip(n));
}
if let Some(ref limit) = query.limit {
let n = self.expr_to_usize(limit)?;
ops.push(QueryOp::Limit(n));
}
Ok(Query {
ops,
temporal_context: None,
hints: QueryHints::default(),
})
}
fn convert_from(
&self,
from: &[TableWithJoins],
ops: &mut Vec<QueryOp>,
) -> Result<(), SqlError> {
if from.is_empty() {
return Err(SqlError::MissingClause(
"FROM clause is required".to_string(),
));
}
if from.len() > 1 {
return Err(SqlError::UnsupportedFeature(
"Multiple tables (joins) not yet supported".to_string(),
));
}
let table = &from[0];
if !table.joins.is_empty() {
return Err(SqlError::UnsupportedFeature(
"JOIN clauses are not yet supported. Use MATCH for graph traversal: \
MATCH (source)-[:EDGE_TYPE]->(target)"
.to_string(),
));
}
match &table.relation {
TableFactor::Table { name, alias: _, .. } => {
let table_name = name.to_string().to_lowercase();
match table_name.as_str() {
"nodes" => {
ops.push(QueryOp::ScanNodes { label: None });
}
"edges" => {
ops.push(QueryOp::ScanEdges { edge_type: None });
}
_ => {
ops.push(QueryOp::ScanNodes {
label: Some(table_name),
});
}
}
}
_ => {
return Err(SqlError::UnsupportedFeature(
"Complex table expressions not supported".to_string(),
));
}
}
Ok(())
}
fn convert_projection(
&self,
projection: &[SelectItem],
ops: &mut Vec<QueryOp>,
) -> Result<(), SqlError> {
let mut columns = Vec::new();
let mut is_star = false;
for item in projection {
match item {
SelectItem::Wildcard(_) => {
is_star = true;
}
SelectItem::UnnamedExpr(expr) => {
if let Some(col) = self.expr_to_column_name(expr) {
columns.push(col);
}
}
SelectItem::ExprWithAlias { expr, alias } => {
if self.expr_to_column_name(expr).is_some() {
columns.push(alias.value.clone());
}
}
SelectItem::QualifiedWildcard(_, _) => {
is_star = true;
}
}
}
if !is_star && !columns.is_empty() {
ops.push(QueryOp::Project(columns));
}
Ok(())
}
fn convert_order_by(
&self,
order_by: &OrderByExpr,
ops: &mut Vec<QueryOp>,
) -> Result<(), SqlError> {
let key = match &order_by.expr {
Expr::Identifier(ident) => {
let name = ident.value.to_lowercase();
match name.as_str() {
"score" => SortKey::Score,
"timestamp" => SortKey::Timestamp,
_ => SortKey::Property(ident.value.clone()),
}
}
Expr::CompoundIdentifier(parts) => {
let col = parts.last().map(|p| p.value.clone()).ok_or_else(|| {
SqlError::InvalidColumn("Empty compound identifier in ORDER BY".to_string())
})?;
SortKey::Property(col)
}
_ => {
return Err(SqlError::UnsupportedFeature(
"Complex ORDER BY expressions not yet supported. Use simple column names (e.g., ORDER BY name DESC)".to_string(),
));
}
};
let descending = order_by.asc.map(|asc| !asc).unwrap_or(false);
ops.push(QueryOp::Sort { key, descending });
Ok(())
}
fn convert_expr_to_predicate(&self, expr: &Expr) -> Result<Predicate, SqlError> {
match expr {
Expr::BinaryOp { left, op, right } => self.convert_binary_op(left, op, right),
Expr::Nested(inner) => self.convert_expr_to_predicate(inner),
Expr::IsNull(inner) => {
let key = self.expr_to_property_key(inner)?;
Ok(Predicate::Eq {
key,
value: PredicateValue::Null,
})
}
Expr::IsNotNull(inner) => {
let key = self.expr_to_property_key(inner)?;
Ok(Predicate::Ne {
key,
value: PredicateValue::Null,
})
}
Expr::InList {
expr,
list,
negated,
} => {
let key = self.expr_to_property_key(expr)?;
let values: Result<Vec<PredicateValue>, SqlError> =
list.iter().map(|e| self.expr_to_value(e)).collect();
let pred = Predicate::In {
key,
values: values?,
};
if *negated { Ok(!pred) } else { Ok(pred) }
}
Expr::Like {
expr,
pattern,
negated,
..
} => {
let key = self.expr_to_property_key(expr)?;
let pattern_str = self.expr_to_string(pattern)?;
let pred = if pattern_str.starts_with('%')
&& pattern_str.ends_with('%')
&& pattern_str.len() > 1
{
let substring = pattern_str[1..pattern_str.len() - 1].to_string();
Predicate::Contains { key, substring }
} else if pattern_str.ends_with('%') && !pattern_str.starts_with('%') {
let prefix = pattern_str[..pattern_str.len() - 1].to_string();
Predicate::StartsWith { key, prefix }
} else if pattern_str.starts_with('%') && !pattern_str.ends_with('%') {
let suffix = pattern_str[1..].to_string();
Predicate::EndsWith { key, suffix }
} else {
Predicate::Eq {
key,
value: PredicateValue::String(pattern_str),
}
};
if *negated { Ok(!pred) } else { Ok(pred) }
}
_ => Err(SqlError::UnsupportedFeature(format!(
"Expression type not supported in WHERE: {:?}",
expr
))),
}
}
fn convert_binary_op(
&self,
left: &Expr,
op: &BinaryOperator,
right: &Expr,
) -> Result<Predicate, SqlError> {
match op {
BinaryOperator::And => {
let l = self.convert_expr_to_predicate(left)?;
let r = self.convert_expr_to_predicate(right)?;
return Ok(l.and(r));
}
BinaryOperator::Or => {
let l = self.convert_expr_to_predicate(left)?;
let r = self.convert_expr_to_predicate(right)?;
return Ok(l.or(r));
}
_ => {}
}
let key = self.expr_to_property_key(left)?;
let value = self.expr_to_value(right)?;
match op {
BinaryOperator::Eq => Ok(Predicate::Eq { key, value }),
BinaryOperator::NotEq => Ok(Predicate::Ne { key, value }),
BinaryOperator::Lt => Ok(Predicate::Lt { key, value }),
BinaryOperator::LtEq => Ok(Predicate::Lte { key, value }),
BinaryOperator::Gt => Ok(Predicate::Gt { key, value }),
BinaryOperator::GtEq => Ok(Predicate::Gte { key, value }),
_ => Err(SqlError::UnsupportedFeature(format!(
"Operator not supported: {:?}",
op
))),
}
}
fn expr_to_property_key(&self, expr: &Expr) -> Result<String, SqlError> {
match expr {
Expr::Identifier(ident) => Ok(ident.value.clone()),
Expr::CompoundIdentifier(parts) => {
parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SqlError::InvalidColumn("Empty compound identifier".to_string()))
}
_ => Err(SqlError::InvalidColumn(format!(
"Cannot extract property key from: {:?}",
expr
))),
}
}
fn expr_to_value(&self, expr: &Expr) -> Result<PredicateValue, SqlError> {
match expr {
Expr::Value(value) => self.value_to_predicate_value(value),
Expr::Identifier(ident) => {
if let Some(param) = self.parameters.get(&ident.value) {
match param {
SqlParameterValue::Scalar(v) => Ok(v.clone()),
_ => Err(SqlError::TypeError(
"Expected scalar parameter value".to_string(),
)),
}
} else {
Err(SqlError::ParameterError(format!(
"Unknown parameter: {}",
ident.value
)))
}
}
Expr::UnaryOp { op, expr } => {
match op {
sqlparser::ast::UnaryOperator::Minus => {
let inner = self.expr_to_value(expr)?;
match inner {
PredicateValue::Int(n) => Ok(PredicateValue::Int(-n)),
PredicateValue::Float(n) => Ok(PredicateValue::Float(-n)),
_ => Err(SqlError::TypeError(
"Cannot negate non-numeric value".to_string(),
)),
}
}
_ => Err(SqlError::UnsupportedFeature(format!(
"Unary operator not supported: {:?}",
op
))),
}
}
_ => Err(SqlError::UnsupportedFeature(format!(
"Expression type not supported as value: {:?}",
expr
))),
}
}
fn value_to_predicate_value(&self, value: &Value) -> Result<PredicateValue, SqlError> {
match value {
Value::Null => Ok(PredicateValue::Null),
Value::Boolean(b) => Ok(PredicateValue::Bool(*b)),
Value::Number(n, _) => {
if let Ok(i) = n.parse::<i64>() {
Ok(PredicateValue::Int(i))
} else if let Ok(f) = n.parse::<f64>() {
Ok(PredicateValue::Float(f))
} else {
Err(SqlError::TypeError(format!("Invalid number: {}", n)))
}
}
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => {
Ok(PredicateValue::String(s.clone()))
}
_ => Err(SqlError::UnsupportedFeature(format!(
"Value type not supported: {:?}",
value
))),
}
}
fn expr_to_column_name(&self, expr: &Expr) -> Option<String> {
match expr {
Expr::Identifier(ident) => Some(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
_ => None,
}
}
fn expr_to_usize(&self, expr: &Expr) -> Result<usize, SqlError> {
match expr {
Expr::Value(Value::Number(n, _)) => n
.parse::<usize>()
.map_err(|_| SqlError::TypeError(format!("Expected positive integer, got: {}", n))),
_ => Err(SqlError::TypeError(format!(
"Expected integer literal, got: {:?}",
expr
))),
}
}
fn expr_to_string(&self, expr: &Expr) -> Result<String, SqlError> {
match expr {
Expr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => {
Ok(s.clone())
}
_ => Err(SqlError::TypeError(format!(
"Expected string literal, got: {:?}",
expr
))),
}
}
fn resolve_embedding(&self, r: &EmbeddingRef) -> Result<Arc<[f32]>, SqlError> {
match r {
EmbeddingRef::Literal(v) => Ok(v.clone().into()),
EmbeddingRef::Parameter(name) => {
let clean_name = name.strip_prefix('$').unwrap_or(name);
match self.parameters.get(clean_name) {
Some(SqlParameterValue::Embedding(e)) => Ok(e.clone()),
Some(_) => Err(SqlError::TypeError(format!(
"Parameter '{}' is not an embedding",
clean_name
))),
None => Err(SqlError::ParameterError(format!(
"Unbound embedding parameter: {}",
clean_name
))),
}
}
}
}
fn apply_vector_op(&self, query: &mut Query, op: &VectorOp) -> Result<(), SqlError> {
match op {
VectorOp::KnnOrderBy {
property_key,
embedding_ref,
metric,
} => {
let embedding = self.resolve_embedding(embedding_ref)?;
let has_traversal = query.ops.iter().any(|op| {
matches!(
op,
QueryOp::TraverseOut { .. }
| QueryOp::TraverseIn { .. }
| QueryOp::TraverseBoth { .. }
)
});
let k = query.ops.iter().find_map(|op| {
if let QueryOp::Limit(n) = op {
Some(*n)
} else {
None
}
});
if has_traversal {
let pos = query
.ops
.iter()
.position(|op| matches!(op, QueryOp::Limit(_)))
.unwrap_or(query.ops.len());
query.ops.insert(
pos,
QueryOp::RankBySimilarity {
embedding,
top_k: k,
property_key: Some(property_key.clone()),
},
);
} else {
let limit = k.unwrap_or(10);
let offset = query.ops.iter().find_map(|op| {
if let QueryOp::Skip(n) = op {
Some(*n)
} else {
None
}
});
let effective_k = limit + offset.unwrap_or(0);
query.ops.retain(|op| !matches!(op, QueryOp::Limit(_)));
let pos = query
.ops
.iter()
.position(|op| {
!matches!(op, QueryOp::ScanNodes { .. } | QueryOp::ScanEdges { .. })
})
.unwrap_or(query.ops.len());
query.ops.insert(
pos,
QueryOp::VectorSearch {
embedding,
k: effective_k,
metric: *metric,
property_key: Some(property_key.clone()),
},
);
if offset.is_some() {
query.ops.push(QueryOp::Limit(limit));
}
}
}
VectorOp::KnnFunction {
property_key,
embedding_ref,
k,
..
} => {
let embedding = self.resolve_embedding(embedding_ref)?;
let label = query.ops.iter().find_map(|op| {
if let QueryOp::ScanNodes { label: Some(l) } = op {
Some(l.clone())
} else {
None
}
});
query
.ops
.retain(|op| !matches!(op, QueryOp::ScanNodes { .. }));
query.ops.insert(
0,
QueryOp::VectorSearch {
embedding,
k: *k,
metric: DistanceMetric::Cosine,
property_key: Some(property_key.clone()),
},
);
if let Some(l) = label {
query.ops.insert(1, QueryOp::FilterLabel(l));
}
}
VectorOp::SimilarToFilter {
property_key,
embedding_ref,
threshold,
} => {
let embedding = self.resolve_embedding(embedding_ref)?;
let k = SIMILAR_TO_DEFAULT_K;
let pos = query
.ops
.iter()
.position(|op| {
!matches!(op, QueryOp::ScanNodes { .. } | QueryOp::ScanEdges { .. })
})
.unwrap_or(query.ops.len());
query.ops.insert(
pos,
QueryOp::VectorSearch {
embedding,
k,
metric: DistanceMetric::Cosine,
property_key: Some(property_key.clone()),
},
);
query.ops.insert(
pos + 1,
QueryOp::Filter(Predicate::Gte {
key: "score".to_string(),
value: PredicateValue::Float(*threshold),
}),
);
}
}
Ok(())
}
}
impl Default for SqlConverter {
fn default() -> Self {
Self::new()
}
}
pub fn parse_sql(sql: &str) -> Result<Query, SqlError> {
let converter = SqlConverter::new();
converter.convert_sql(sql)
}
pub fn parse_sql_with_params(
sql: &str,
params: HashMap<String, SqlParameterValue>,
) -> Result<Query, SqlError> {
let converter = SqlConverter::with_parameters(params);
converter.convert_sql(sql)
}