use nodedb_types::error::{NodeDbError, NodeDbResult};
use nodedb_types::filter::MetadataFilter;
use nodedb_types::value::Value;
use crate::remote_parse::{format_vector_array, quote_identifier};
pub(super) fn build_vector_search_sql(
collection: &str,
query: &[f32],
k: usize,
filter: Option<&MetadataFilter>,
) -> NodeDbResult<String> {
let collection = quote_identifier(collection);
let where_clause = match filter {
Some(f) => format!(" WHERE {}", render_metadata_filter(f)?),
None => String::new(),
};
Ok(format!(
"SELECT * FROM {collection}{where_clause} ORDER BY vector_distance({}) LIMIT {k}",
format_vector_array(query),
))
}
pub(super) fn render_metadata_filter_public(filter: &MetadataFilter) -> NodeDbResult<String> {
render_metadata_filter(filter)
}
fn render_metadata_filter(filter: &MetadataFilter) -> NodeDbResult<String> {
match filter {
MetadataFilter::Eq { field, value } => Ok(format!(
"{} = {}",
quote_identifier(field),
render_sql_literal(value)?
)),
MetadataFilter::Ne { field, value } => Ok(format!(
"{} <> {}",
quote_identifier(field),
render_sql_literal(value)?
)),
MetadataFilter::Gt { field, value } => Ok(format!(
"{} > {}",
quote_identifier(field),
render_sql_literal(value)?
)),
MetadataFilter::Gte { field, value } => Ok(format!(
"{} >= {}",
quote_identifier(field),
render_sql_literal(value)?
)),
MetadataFilter::Lt { field, value } => Ok(format!(
"{} < {}",
quote_identifier(field),
render_sql_literal(value)?
)),
MetadataFilter::Lte { field, value } => Ok(format!(
"{} <= {}",
quote_identifier(field),
render_sql_literal(value)?
)),
MetadataFilter::In { field, values } => {
let rendered: NodeDbResult<Vec<_>> = values.iter().map(render_sql_literal).collect();
Ok(format!(
"{} IN ({})",
quote_identifier(field),
rendered?.join(", ")
))
}
MetadataFilter::NotIn { field, values } => {
let rendered: NodeDbResult<Vec<_>> = values.iter().map(render_sql_literal).collect();
Ok(format!(
"{} NOT IN ({})",
quote_identifier(field),
rendered?.join(", ")
))
}
MetadataFilter::And(parts) => {
if parts.is_empty() {
return Ok("TRUE".into());
}
let rendered: NodeDbResult<Vec<_>> = parts.iter().map(render_metadata_filter).collect();
Ok(format!("({})", rendered?.join(" AND ")))
}
MetadataFilter::Or(parts) => {
if parts.is_empty() {
return Ok("FALSE".into());
}
let rendered: NodeDbResult<Vec<_>> = parts.iter().map(render_metadata_filter).collect();
Ok(format!("({})", rendered?.join(" OR ")))
}
MetadataFilter::Not(inner) => Ok(format!("NOT ({})", render_metadata_filter(inner)?)),
_ => Err(NodeDbError::storage(
"metadata filter variant not yet supported by SQL renderer",
)),
}
}
fn render_sql_literal(v: &Value) -> NodeDbResult<String> {
match v {
Value::Null => Ok("NULL".into()),
Value::Bool(b) => Ok(if *b { "TRUE".into() } else { "FALSE".into() }),
Value::Integer(i) => Ok(i.to_string()),
Value::Float(f) => Ok(f.to_string()),
Value::String(s) => Ok(format!("'{}'", s.replace('\'', "''"))),
other => Err(NodeDbError::storage(format!(
"metadata filter value cannot be rendered as a SQL literal: {other:?}"
))),
}
}
pub(super) fn translate_params(
params: &[Value],
) -> NodeDbResult<Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>>> {
params.iter().map(translate_value).collect()
}
fn translate_value(v: &Value) -> NodeDbResult<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> {
match v {
Value::Null => Ok(Box::new(None::<i64>)),
Value::Bool(b) => Ok(Box::new(*b)),
Value::Integer(i) => Ok(Box::new(*i)),
Value::Float(f) => Ok(Box::new(*f)),
Value::String(s) => Ok(Box::new(s.clone())),
Value::Bytes(b) => Ok(Box::new(b.clone())),
other => Err(NodeDbError::storage(format!(
"execute_sql parameter cannot be translated to a pgwire bind: {other:?}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vector_search_sql_without_filter_renders_basic_form() {
let sql =
build_vector_search_sql("docs", &[0.1, 0.2, 0.3], 5, None).expect("no-filter is fine");
assert!(sql.contains("SELECT"));
assert!(sql.contains("docs"));
assert!(sql.contains("vector_distance"));
assert!(sql.contains("ARRAY[0.1,0.2,0.3]"));
assert!(sql.contains("LIMIT 5"));
assert!(
!sql.contains(" WHERE "),
"no-filter SQL must not have WHERE; got: {sql}"
);
}
#[test]
fn vector_search_sql_renders_eq_metadata_filter() {
let filter = MetadataFilter::eq("category", Value::String("ai".into()));
let sql = build_vector_search_sql("docs", &[0.1, 0.2], 5, Some(&filter))
.expect("non-None metadata filter must be accepted client-side, not rejected");
assert!(
sql.contains("category"),
"rendered SQL must reference the filtered field name; got: {sql}"
);
assert!(
sql.contains("'ai'"),
"rendered SQL must reference the filtered value; got: {sql}"
);
assert!(
!sql.contains("not yet supported"),
"rendered SQL must not be a rejection message; got: {sql}"
);
let where_idx = sql.find(" WHERE ").expect("WHERE clause emitted");
let order_idx = sql.find(" ORDER BY ").expect("ORDER BY clause emitted");
assert!(
where_idx < order_idx,
"WHERE must precede ORDER BY; got: {sql}"
);
}
#[test]
fn vector_search_sql_renders_compound_metadata_filter() {
let filter = MetadataFilter::and(vec![
MetadataFilter::eq("category", Value::String("ai".into())),
MetadataFilter::Gt {
field: "score".into(),
value: Value::Float(0.5),
},
]);
let sql = build_vector_search_sql("docs", &[0.1], 3, Some(&filter))
.expect("compound metadata filter must be rendered, not rejected");
assert!(
sql.contains("category"),
"missing AND-leaf field; got: {sql}"
);
assert!(sql.contains("score"), "missing AND-leaf field; got: {sql}");
assert!(
sql.contains(" AND "),
"compound must render with AND; got: {sql}"
);
}
#[test]
fn vector_search_sql_renders_in_filter() {
let filter = MetadataFilter::In {
field: "tag".into(),
values: vec![
Value::String("rust".into()),
Value::String("databases".into()),
],
};
let sql = build_vector_search_sql("docs", &[0.0], 1, Some(&filter)).unwrap();
assert!(sql.contains(" IN ("));
assert!(sql.contains("'rust'"));
assert!(sql.contains("'databases'"));
}
#[test]
fn render_sql_literal_escapes_single_quotes() {
let sql = render_sql_literal(&Value::String("o'reilly".into())).unwrap();
assert_eq!(sql, "'o''reilly'");
}
#[test]
fn render_sql_literal_rejects_unsupported_variants() {
let array = Value::Array(vec![Value::Integer(1), Value::Integer(2)]);
assert!(render_sql_literal(&array).is_err());
}
#[test]
fn translate_params_passes_through_empty() {
let translated = translate_params(&[]).expect("empty translate is fine");
assert!(translated.is_empty());
}
#[test]
fn translate_params_accepts_bound_parameters() {
let params = vec![Value::String("alice".into()), Value::Integer(42)];
let translated = translate_params(¶ms)
.expect("non-empty params must translate for the pgwire driver");
assert_eq!(translated.len(), 2);
}
#[test]
fn translate_params_supports_common_value_variants() {
let params = vec![
Value::Null,
Value::Bool(true),
Value::Integer(7),
Value::Float(2.5),
Value::String("hi".into()),
Value::Bytes(vec![1, 2, 3]),
];
let translated = translate_params(¶ms).expect("common variants translate cleanly");
assert_eq!(translated.len(), 6);
}
#[test]
fn translate_params_rejects_unsupported_variants() {
let params = vec![Value::Object(Default::default())];
assert!(translate_params(¶ms).is_err());
}
}