use super::error::SqlError;
use crate::index::vector::DistanceMetric;
#[derive(Debug)]
pub struct ExtractedVector {
pub cleaned_sql: String,
pub vector_ops: Vec<VectorOp>,
}
#[derive(Debug)]
pub enum VectorOp {
KnnOrderBy {
property_key: String,
embedding_ref: EmbeddingRef,
metric: DistanceMetric,
},
KnnFunction {
#[allow(dead_code)]
label: String,
property_key: String,
embedding_ref: EmbeddingRef,
k: usize,
},
SimilarToFilter {
property_key: String,
embedding_ref: EmbeddingRef,
threshold: f64,
},
}
#[derive(Debug, Clone)]
pub enum EmbeddingRef {
Parameter(String),
Literal(Vec<f32>),
}
pub fn extract_vector_clauses(sql: &str) -> Result<ExtractedVector, SqlError> {
let mut cleaned = sql.to_string();
let mut vector_ops = Vec::new();
extract_knn_function(&mut cleaned, &mut vector_ops)?;
extract_similar_to(&mut cleaned, &mut vector_ops)?;
extract_order_by_distance(&mut cleaned, &mut vector_ops)?;
Ok(ExtractedVector {
cleaned_sql: cleaned,
vector_ops,
})
}
fn extract_knn_function(sql: &mut String, ops: &mut Vec<VectorOp>) -> Result<(), SqlError> {
let Some(knn_pos) = find_outside_strings(sql, "KNN(") else {
return Ok(());
};
let open_paren = knn_pos + 3; let close_paren = find_matching_paren(sql, open_paren).ok_or_else(|| {
SqlError::ParseError("Unmatched parenthesis in KNN() function".to_string())
})?;
let args_str = &sql[open_paren + 1..close_paren];
let args = split_args(args_str)?;
if args.len() != 4 {
return Err(SqlError::ParseError(format!(
"KNN() requires 4 arguments (label, property, embedding, k), got {}",
args.len()
)));
}
let label = unquote_string(&args[0])?;
let property_key = unquote_string(&args[1])?;
let embedding_ref = parse_embedding_ref(&args[2])?;
let k = args[3].trim().parse::<usize>().map_err(|_| {
SqlError::ParseError(format!("Invalid k value in KNN(): '{}'", args[3].trim()))
})?;
ops.push(VectorOp::KnnFunction {
label: label.clone(),
property_key,
embedding_ref,
k,
});
let replacement = label.to_lowercase();
sql.replace_range(knn_pos..close_paren + 1, &replacement);
Ok(())
}
fn extract_similar_to(sql: &mut String, ops: &mut Vec<VectorOp>) -> Result<(), SqlError> {
let Some(st_pos) = find_outside_strings(sql, "SIMILAR_TO(") else {
return Ok(());
};
let open_paren = st_pos + 10; let close_paren = find_matching_paren(sql, open_paren).ok_or_else(|| {
SqlError::ParseError("Unmatched parenthesis in SIMILAR_TO() function".to_string())
})?;
let args_str = &sql[open_paren + 1..close_paren];
let args = split_args(args_str)?;
if args.len() != 3 {
return Err(SqlError::ParseError(format!(
"SIMILAR_TO() requires 3 arguments (property, embedding, threshold), got {}",
args.len()
)));
}
let property_key = args[0].trim().to_string();
let embedding_ref = parse_embedding_ref(&args[1])?;
let threshold = args[2].trim().parse::<f64>().map_err(|_| {
SqlError::ParseError(format!(
"Invalid threshold in SIMILAR_TO(): '{}'",
args[2].trim()
))
})?;
ops.push(VectorOp::SimilarToFilter {
property_key,
embedding_ref,
threshold,
});
sql.replace_range(st_pos..close_paren + 1, "TRUE");
cleanup_where_true(sql);
Ok(())
}
fn extract_order_by_distance(sql: &mut String, ops: &mut Vec<VectorOp>) -> Result<(), SqlError> {
let (op_pos, op_len, metric) = if let Some(pos) = find_outside_strings(sql, "<=>") {
(pos, 3, DistanceMetric::Cosine)
} else if let Some(pos) = find_outside_strings(sql, "<->") {
(pos, 3, DistanceMetric::Euclidean)
} else {
return Ok(());
};
let order_by_pos = rfind_ascii_ci(&sql[..op_pos], "ORDER BY").ok_or_else(|| {
SqlError::ParseError(
"Distance operator <=> / <-> must appear in an ORDER BY clause".to_string(),
)
})?;
let between = sql[order_by_pos + 8..op_pos].trim();
let property_key = extract_property_from_order_by(between)?;
let after_op = op_pos + op_len;
let (embedding_ref, ref_end) = extract_vector_ref_from_sql(sql, after_op)?;
ops.push(VectorOp::KnnOrderBy {
property_key,
embedding_ref,
metric,
});
let remainder = sql[ref_end..].trim_start().to_string();
let prefix = sql[..order_by_pos].to_string();
*sql = format!("{}{}", prefix, remainder);
Ok(())
}
fn cleanup_where_true(sql: &mut String) {
let Some(where_pos) = find_outside_strings(sql, "WHERE") else {
return;
};
let predicate_start = where_pos + 5; let whitespace_len = sql[predicate_start..]
.len()
.saturating_sub(sql[predicate_start..].trim_start().len());
let true_pos = predicate_start + whitespace_len;
let remainder_upper = sql[true_pos..].to_uppercase();
let Some(after_true) = remainder_upper.strip_prefix("TRUE") else {
return;
};
let true_end = true_pos + 4; let after_true_trimmed = after_true.trim_start();
if after_true_trimmed.is_empty()
|| after_true_trimmed.starts_with("ORDER")
|| after_true_trimmed.starts_with("LIMIT")
|| after_true_trimmed.starts_with("OFFSET")
|| after_true_trimmed.starts_with(';')
{
sql.replace_range(where_pos..true_end, "");
let trimmed = sql.replace(" ", " ");
*sql = trimmed.trim().to_string();
} else if after_true_trimmed.starts_with("AND") {
let and_pos = true_end + (after_true.len() - after_true_trimmed.len());
let and_end = and_pos + 3; let after_and_ws = sql[and_end..].len() - sql[and_end..].trim_start().len();
sql.replace_range(true_pos..and_end + after_and_ws, "");
}
}
fn rfind_ascii_ci(haystack: &str, needle: &str) -> Option<usize> {
debug_assert!(needle.is_ascii(), "rfind_ascii_ci: needle must be ASCII");
let n = needle.len();
if n == 0 {
return Some(haystack.len());
}
let haystack_bytes = haystack.as_bytes();
let needle_bytes = needle.as_bytes();
let mut last = None;
for (byte_pos, _) in haystack.char_indices() {
if byte_pos + n <= haystack_bytes.len()
&& haystack_bytes[byte_pos..byte_pos + n].eq_ignore_ascii_case(needle_bytes)
{
last = Some(byte_pos);
}
}
last
}
fn find_outside_strings(sql: &str, needle: &str) -> Option<usize> {
let chars: Vec<char> = sql.chars().collect();
let needle_chars: Vec<char> = needle.chars().collect();
let needle_len = needle_chars.len();
let mut i = 0;
while i < chars.len() {
if chars[i] == '\'' {
i += 1;
while i < chars.len() {
if chars[i] == '\'' {
if i + 1 < chars.len() && chars[i + 1] == '\'' {
i += 2; } else {
i += 1;
break;
}
} else {
i += 1;
}
}
continue;
}
if i + needle_len <= chars.len() {
let candidate: String = chars[i..i + needle_len].iter().collect();
if candidate == needle {
let byte_offset: usize = chars[..i].iter().map(|c| c.len_utf8()).sum();
return Some(byte_offset);
}
if !needle.starts_with('<')
&& !needle.starts_with('>')
&& candidate.to_uppercase() == needle.to_uppercase()
{
let byte_offset: usize = chars[..i].iter().map(|c| c.len_utf8()).sum();
return Some(byte_offset);
}
}
i += 1;
}
None
}
fn find_matching_paren(sql: &str, open_pos: usize) -> Option<usize> {
let bytes = sql.as_bytes();
if bytes.get(open_pos) != Some(&b'(') {
return None;
}
let mut depth = 1;
let mut i = open_pos + 1;
let mut in_string = false;
while i < bytes.len() {
if in_string {
if bytes[i] == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
in_string = false;
}
} else {
match bytes[i] {
b'\'' => in_string = true,
b'(' => depth += 1,
b')' => {
depth -= 1;
if depth == 0 {
return Some(i);
}
}
_ => {}
}
}
i += 1;
}
None
}
fn split_args(args: &str) -> Result<Vec<String>, SqlError> {
let mut result = Vec::new();
let mut current = String::new();
let mut depth = 0;
let mut in_string = false;
for ch in args.chars() {
if in_string {
current.push(ch);
if ch == '\'' {
in_string = false;
}
continue;
}
match ch {
'\'' => {
in_string = true;
current.push(ch);
}
'(' => {
depth += 1;
current.push(ch);
}
')' => {
depth -= 1;
current.push(ch);
}
',' if depth == 0 => {
result.push(current.trim().to_string());
current = String::new();
}
_ => current.push(ch),
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
result.push(trimmed);
}
Ok(result)
}
fn unquote_string(s: &str) -> Result<String, SqlError> {
let s = s.trim();
if (s.starts_with('\'') && s.ends_with('\'')) || (s.starts_with('"') && s.ends_with('"')) {
Ok(s[1..s.len() - 1].to_string())
} else {
Err(SqlError::ParseError(format!(
"Expected quoted string, got: '{}'",
s
)))
}
}
fn parse_embedding_ref(s: &str) -> Result<EmbeddingRef, SqlError> {
let s = s.trim();
if s.starts_with('$') {
return Ok(EmbeddingRef::Parameter(s.to_string()));
}
if s.starts_with('\'') {
let values = parse_vector_literal(s)?;
return Ok(EmbeddingRef::Literal(values));
}
if s.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Ok(EmbeddingRef::Parameter(format!("${}", s)));
}
Err(SqlError::ParseError(format!(
"Cannot parse embedding reference: '{}'",
s
)))
}
fn parse_vector_literal(s: &str) -> Result<Vec<f32>, SqlError> {
let s = s.trim();
let s = if let Some(cast_pos) = s.find("::") {
&s[..cast_pos]
} else {
s
};
let s = if s.starts_with('\'') && s.ends_with('\'') {
&s[1..s.len() - 1]
} else {
s
};
let s = s.trim();
let s = s
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
.ok_or_else(|| {
SqlError::ParseError(format!("Vector literal must be enclosed in []: '{}'", s))
})?;
let values: Result<Vec<f32>, _> = s.split(',').map(|v| v.trim().parse::<f32>()).collect();
values.map_err(|e| SqlError::ParseError(format!("Invalid vector literal: {}", e)))
}
fn extract_property_from_order_by(text: &str) -> Result<String, SqlError> {
let text = text.trim();
if text.is_empty() {
return Err(SqlError::ParseError(
"Missing property name in ORDER BY distance clause".to_string(),
));
}
if let Some(dot_pos) = text.rfind('.') {
let col = text[dot_pos + 1..].trim();
if !col.is_empty() {
return Ok(col.to_string());
}
}
Ok(text.to_string())
}
fn extract_vector_ref_from_sql(sql: &str, start: usize) -> Result<(EmbeddingRef, usize), SqlError> {
let rest = sql[start..].trim_start();
let trimmed_offset = start + (sql.len() - start - rest.len());
if rest.starts_with('$') {
let end = rest
.find(|c: char| !c.is_alphanumeric() && c != '_' && c != '$')
.unwrap_or(rest.len());
let param = &rest[..end];
return Ok((
EmbeddingRef::Parameter(param.to_string()),
trimmed_offset + end,
));
}
if let Some(after_quote) = rest.strip_prefix('\'') {
let close_quote = after_quote.find('\'').map(|p| p + 1).ok_or_else(|| {
SqlError::ParseError("Unterminated vector literal in ORDER BY".to_string())
})?;
let mut end = close_quote + 1;
if rest[end..].starts_with("::") {
let cast_end = rest[end + 2..]
.find(|c: char| !c.is_alphanumeric() && c != '_')
.map(|p| end + 2 + p)
.unwrap_or(rest.len());
end = cast_end;
}
let literal_text = &rest[..end];
let values = parse_vector_literal(literal_text)?;
return Ok((EmbeddingRef::Literal(values), trimmed_offset + end));
}
Err(SqlError::ParseError(
"Expected vector literal ('[...]'::vector) or parameter ($name) after distance operator"
.to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_vector_literal_basic() {
let values = parse_vector_literal("'[0.1, 0.2, 0.3]'::vector").unwrap();
assert_eq!(values, vec![0.1, 0.2, 0.3]);
}
#[test]
fn test_parse_vector_literal_without_cast() {
let values = parse_vector_literal("'[1.0, 2.0]'").unwrap();
assert_eq!(values, vec![1.0, 2.0]);
}
#[test]
fn test_parse_vector_literal_invalid() {
assert!(parse_vector_literal("'not a vector'").is_err());
}
#[test]
fn test_parse_embedding_ref_parameter() {
let r = parse_embedding_ref("$query_embedding").unwrap();
assert!(matches!(r, EmbeddingRef::Parameter(ref s) if s == "$query_embedding"));
}
#[test]
fn test_parse_embedding_ref_literal() {
let r = parse_embedding_ref("'[0.1, 0.2]'::vector").unwrap();
assert!(matches!(r, EmbeddingRef::Literal(ref v) if v.len() == 2));
}
#[test]
fn test_extract_order_by_cosine_literal() {
let result = extract_vector_clauses(
"SELECT * FROM Documents ORDER BY embedding <=> '[0.1, 0.2, 0.3]'::vector LIMIT 10",
)
.unwrap();
assert_eq!(result.vector_ops.len(), 1);
assert!(matches!(
&result.vector_ops[0],
VectorOp::KnnOrderBy {
property_key,
metric: DistanceMetric::Cosine,
..
} if property_key == "embedding"
));
assert!(!result.cleaned_sql.contains("<=>"));
assert!(result.cleaned_sql.to_uppercase().contains("LIMIT"));
}
#[test]
fn test_extract_order_by_cosine_parameter() {
let result = extract_vector_clauses(
"SELECT * FROM Documents ORDER BY embedding <=> $query_embedding LIMIT 10",
)
.unwrap();
assert_eq!(result.vector_ops.len(), 1);
assert!(matches!(
&result.vector_ops[0],
VectorOp::KnnOrderBy {
embedding_ref: EmbeddingRef::Parameter(p),
..
} if p == "$query_embedding"
));
}
#[test]
fn test_extract_order_by_l2_distance() {
let result = extract_vector_clauses(
"SELECT * FROM Documents ORDER BY embedding <-> '[0.1, 0.2]'::vector LIMIT 5",
)
.unwrap();
assert!(matches!(
&result.vector_ops[0],
VectorOp::KnnOrderBy {
metric: DistanceMetric::Euclidean,
..
}
));
}
#[test]
fn test_extract_order_by_compound_property() {
let result = extract_vector_clauses(
"SELECT * FROM Documents ORDER BY doc.embedding <=> $query LIMIT 10",
)
.unwrap();
assert!(matches!(
&result.vector_ops[0],
VectorOp::KnnOrderBy {
property_key,
..
} if property_key == "embedding"
));
}
#[test]
fn test_extract_order_by_no_limit() {
let result = extract_vector_clauses(
"SELECT * FROM Documents ORDER BY embedding <=> '[0.1, 0.2]'::vector",
)
.unwrap();
assert_eq!(result.vector_ops.len(), 1);
}
#[test]
fn test_extract_knn_function() {
let result =
extract_vector_clauses("SELECT * FROM KNN('Documents', 'embedding', $query, 10)")
.unwrap();
assert_eq!(result.vector_ops.len(), 1);
assert!(matches!(
&result.vector_ops[0],
VectorOp::KnnFunction {
label,
property_key,
k: 10,
..
} if label == "Documents" && property_key == "embedding"
));
assert!(result.cleaned_sql.contains("documents"));
assert!(!result.cleaned_sql.to_uppercase().contains("KNN("));
}
#[test]
fn test_extract_knn_function_wrong_arg_count() {
let result = extract_vector_clauses("SELECT * FROM KNN('Documents', 'embedding', $query)");
assert!(result.is_err());
}
#[test]
fn test_extract_similar_to() {
let result = extract_vector_clauses(
"SELECT * FROM Documents WHERE SIMILAR_TO(embedding, $query_embedding, 0.8)",
)
.unwrap();
assert_eq!(result.vector_ops.len(), 1);
assert!(matches!(
&result.vector_ops[0],
VectorOp::SimilarToFilter {
property_key,
threshold,
..
} if property_key == "embedding" && (*threshold - 0.8).abs() < f64::EPSILON
));
assert!(!result.cleaned_sql.to_uppercase().contains("SIMILAR_TO"));
assert!(
!result.cleaned_sql.to_uppercase().contains("WHERE"),
"WHERE clause should be removed when SIMILAR_TO was the only predicate, got: '{}'",
result.cleaned_sql
);
}
#[test]
fn test_extract_similar_to_wrong_arg_count() {
let result =
extract_vector_clauses("SELECT * FROM Documents WHERE SIMILAR_TO(embedding, $query)");
assert!(result.is_err());
}
#[test]
fn test_extract_no_vector_clauses() {
let sql = "SELECT * FROM nodes WHERE label = 'Person' ORDER BY name LIMIT 10";
let result = extract_vector_clauses(sql).unwrap();
assert!(result.vector_ops.is_empty());
assert_eq!(result.cleaned_sql, sql);
}
#[test]
fn test_extract_similar_to_with_and_clause() {
let result = extract_vector_clauses(
"SELECT * FROM Documents WHERE SIMILAR_TO(embedding, $query, 0.8) AND price < 100",
)
.unwrap();
assert_eq!(result.vector_ops.len(), 1);
assert!(result.cleaned_sql.contains("price < 100"));
}
}