use serde::{Deserialize, Serialize};
use crate::schema::{DistanceMetric, VectorConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum VectorParam {
Vector(Vec<f32>),
Int(i64),
String(String),
Json(serde_json::Value),
}
impl VectorParam {
#[must_use]
pub fn to_sql_literal(&self) -> String {
match self {
VectorParam::Vector(v) => {
let values: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
format!("'[{}]'::vector", values.join(","))
},
VectorParam::Int(i) => i.to_string(),
VectorParam::String(s) => format!("'{}'", s.replace('\'', "''")),
VectorParam::Json(j) => format!("'{j}'::jsonb"),
}
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchQuery {
pub table: String,
pub embedding_column: String,
pub select_columns: Vec<String>,
pub distance_metric: DistanceMetric,
pub limit: u32,
pub where_clause: Option<String>,
pub order_by: Option<String>,
pub include_distance: bool,
pub offset: Option<u32>,
}
impl Default for VectorSearchQuery {
fn default() -> Self {
Self {
table: String::new(),
embedding_column: "embedding".to_string(),
select_columns: Vec::new(),
distance_metric: DistanceMetric::Cosine,
limit: 10,
where_clause: None,
order_by: None,
include_distance: false,
offset: None,
}
}
}
impl VectorSearchQuery {
pub fn new(table: impl Into<String>) -> Self {
Self {
table: table.into(),
..Default::default()
}
}
pub fn with_embedding_column(mut self, column: impl Into<String>) -> Self {
self.embedding_column = column.into();
self
}
pub fn with_select_columns(mut self, columns: Vec<String>) -> Self {
self.select_columns = columns;
self
}
pub const fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
self.distance_metric = metric;
self
}
pub const fn with_limit(mut self, limit: u32) -> Self {
self.limit = limit;
self
}
pub fn with_where(mut self, clause: impl Into<String>) -> Self {
self.where_clause = Some(clause.into());
self
}
pub const fn with_distance_score(mut self) -> Self {
self.include_distance = true;
self
}
pub const fn with_offset(mut self, offset: u32) -> Self {
self.offset = Some(offset);
self
}
}
#[derive(Debug, Clone)]
pub struct VectorInsertQuery {
pub table: String,
pub columns: Vec<String>,
pub vector_column: String,
pub upsert: bool,
pub conflict_columns: Vec<String>,
pub update_columns: Vec<String>,
pub returning: Option<String>,
}
impl Default for VectorInsertQuery {
fn default() -> Self {
Self {
table: String::new(),
columns: Vec::new(),
vector_column: "embedding".to_string(),
upsert: false,
conflict_columns: vec!["id".to_string()],
update_columns: Vec::new(),
returning: Some("id".to_string()),
}
}
}
impl VectorInsertQuery {
pub fn new(table: impl Into<String>) -> Self {
Self {
table: table.into(),
..Default::default()
}
}
pub fn with_columns(mut self, columns: Vec<String>) -> Self {
self.columns = columns;
self
}
pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
self.vector_column = column.into();
self
}
pub fn with_upsert(mut self, conflict_columns: Vec<String>) -> Self {
self.upsert = true;
self.conflict_columns = conflict_columns;
self
}
pub fn with_update_columns(mut self, columns: Vec<String>) -> Self {
self.update_columns = columns;
self
}
pub fn with_returning(mut self, column: impl Into<String>) -> Self {
self.returning = Some(column.into());
self
}
}
#[must_use = "call .build() to construct the final value"]
#[derive(Debug, Clone, Default)]
pub struct VectorQueryBuilder {
placeholder_style: PlaceholderStyle,
}
#[derive(Debug, Clone, Copy, Default)]
#[non_exhaustive]
pub enum PlaceholderStyle {
#[default]
Dollar,
QuestionMark,
}
impl VectorQueryBuilder {
pub fn new() -> Self {
Self::default()
}
pub const fn with_question_marks() -> Self {
Self {
placeholder_style: PlaceholderStyle::QuestionMark,
}
}
fn placeholder(&self, index: usize) -> String {
match self.placeholder_style {
PlaceholderStyle::Dollar => format!("${index}"),
PlaceholderStyle::QuestionMark => "?".to_string(),
}
}
#[must_use]
pub fn similarity_search(
&self,
query: &VectorSearchQuery,
query_embedding: &[f32],
) -> (String, Vec<VectorParam>) {
let mut params = Vec::new();
let mut param_idx = 1;
params.push(VectorParam::Vector(query_embedding.to_vec()));
let embedding_placeholder = format!("{}::vector", self.placeholder(param_idx));
param_idx += 1;
let distance_op = query.distance_metric.operator();
let select_clause = if query.select_columns.is_empty() {
if query.include_distance {
format!(
"*, ({} {} {}) AS distance",
query.embedding_column, distance_op, embedding_placeholder
)
} else {
"*".to_string()
}
} else {
let cols = query.select_columns.join(", ");
if query.include_distance {
format!(
"{}, ({} {} {}) AS distance",
cols, query.embedding_column, distance_op, embedding_placeholder
)
} else {
cols
}
};
let where_clause = if let Some(ref clause) = query.where_clause {
format!("\nWHERE {clause}")
} else {
String::new()
};
let order_clause = format!(
"\nORDER BY {} {} {}",
query.embedding_column, distance_op, embedding_placeholder
);
let limit_clause = format!("\nLIMIT {}", self.placeholder(param_idx));
params.push(VectorParam::Int(i64::from(query.limit)));
param_idx += 1;
let offset_clause = if let Some(offset) = query.offset {
let clause = format!("\nOFFSET {}", self.placeholder(param_idx));
params.push(VectorParam::Int(i64::from(offset)));
clause
} else {
String::new()
};
let sql = format!(
"SELECT {}\nFROM {}{}{}{}{}",
select_clause, query.table, where_clause, order_clause, limit_clause, offset_clause
);
(sql, params)
}
#[must_use]
pub fn insert_one(
&self,
query: &VectorInsertQuery,
values: &[VectorParam],
) -> (String, Vec<VectorParam>) {
let columns = query.columns.join(", ");
let placeholders: Vec<String> = values
.iter()
.enumerate()
.map(|(i, v)| {
let ph = self.placeholder(i + 1);
if matches!(v, VectorParam::Vector(_)) {
format!("{ph}::vector")
} else {
ph
}
})
.collect();
let values_clause = placeholders.join(", ");
let returning_clause = if let Some(ref col) = query.returning {
format!("\nRETURNING {col}")
} else {
String::new()
};
let sql = if query.upsert {
let conflict_cols = query.conflict_columns.join(", ");
let update_cols: Vec<&String> = if query.update_columns.is_empty() {
query.columns.iter().filter(|c| !query.conflict_columns.contains(c)).collect()
} else {
query.update_columns.iter().collect()
};
let update_clause: String = update_cols
.iter()
.map(|c| format!("{c} = EXCLUDED.{c}"))
.collect::<Vec<_>>()
.join(", ");
format!(
"INSERT INTO {} ({})\nVALUES ({})\nON CONFLICT ({}) DO UPDATE SET {}{}",
query.table, columns, values_clause, conflict_cols, update_clause, returning_clause
)
} else {
format!(
"INSERT INTO {} ({})\nVALUES ({}){}",
query.table, columns, values_clause, returning_clause
)
};
(sql, values.to_vec())
}
#[must_use]
pub fn insert_batch(
&self,
query: &VectorInsertQuery,
rows: &[Vec<VectorParam>],
) -> (String, Vec<VectorParam>) {
if rows.is_empty() {
return (String::new(), Vec::new());
}
let columns = query.columns.join(", ");
let cols_per_row = query.columns.len();
let mut all_params = Vec::new();
let mut values_clauses = Vec::new();
for (row_idx, row) in rows.iter().enumerate() {
let base_idx = row_idx * cols_per_row + 1;
let placeholders: Vec<String> = row
.iter()
.enumerate()
.map(|(i, v)| {
let ph = self.placeholder(base_idx + i);
if matches!(v, VectorParam::Vector(_)) {
format!("{ph}::vector")
} else {
ph
}
})
.collect();
values_clauses.push(format!("({})", placeholders.join(", ")));
all_params.extend(row.clone());
}
let returning_clause = if let Some(ref col) = query.returning {
format!("\nRETURNING {col}")
} else {
String::new()
};
let sql = format!(
"INSERT INTO {} ({})\nVALUES\n {}{}",
query.table,
columns,
values_clauses.join(",\n "),
returning_clause
);
(sql, all_params)
}
#[must_use]
pub fn create_index(&self, config: &VectorConfig, table: &str, column: &str) -> Option<String> {
config.index_type.index_sql(table, column, config.distance_metric)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_similarity_search_basic() {
let builder = VectorQueryBuilder::new();
let query = VectorSearchQuery::new("documents")
.with_embedding_column("embedding")
.with_limit(10);
let embedding = vec![0.1, 0.2, 0.3];
let (sql, params) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("SELECT *"));
assert!(sql.contains("FROM documents"));
assert!(sql.contains("ORDER BY embedding <=>"));
assert!(sql.contains("$1::vector"));
assert!(sql.contains("LIMIT $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_similarity_search_with_columns() {
let builder = VectorQueryBuilder::new();
let query = VectorSearchQuery::new("docs")
.with_select_columns(vec!["id".to_string(), "content".to_string()])
.with_distance_score()
.with_limit(5);
let embedding = vec![0.1, 0.2];
let (sql, params) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("SELECT id, content,"));
assert!(sql.contains("AS distance"));
assert!(sql.contains("LIMIT $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_similarity_search_with_where() {
let builder = VectorQueryBuilder::new();
let query = VectorSearchQuery::new("documents")
.with_where("metadata->>'type' = 'article'")
.with_limit(10);
let embedding = vec![0.1, 0.2, 0.3];
let (sql, _) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("WHERE metadata->>'type' = 'article'"));
}
#[test]
fn test_similarity_search_with_offset() {
let builder = VectorQueryBuilder::new();
let query = VectorSearchQuery::new("documents").with_limit(10).with_offset(20);
let embedding = vec![0.1, 0.2];
let (sql, params) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("OFFSET $3"));
assert_eq!(params.len(), 3);
}
#[test]
fn test_similarity_search_l2_distance() {
let builder = VectorQueryBuilder::new();
let query = VectorSearchQuery::new("docs")
.with_distance_metric(DistanceMetric::L2)
.with_limit(5);
let embedding = vec![0.1, 0.2];
let (sql, _) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("<->"));
}
#[test]
fn test_similarity_search_inner_product() {
let builder = VectorQueryBuilder::new();
let query = VectorSearchQuery::new("docs")
.with_distance_metric(DistanceMetric::InnerProduct)
.with_limit(5);
let embedding = vec![0.1, 0.2];
let (sql, _) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("<#>"));
}
#[test]
fn test_insert_one_basic() {
let builder = VectorQueryBuilder::new();
let query = VectorInsertQuery::new("documents").with_columns(vec![
"id".to_string(),
"content".to_string(),
"embedding".to_string(),
]);
let values = vec![
VectorParam::String("doc1".to_string()),
VectorParam::String("Hello world".to_string()),
VectorParam::Vector(vec![0.1, 0.2, 0.3]),
];
let (sql, params) = builder.insert_one(&query, &values);
assert!(sql.contains("INSERT INTO documents (id, content, embedding)"));
assert!(sql.contains("VALUES ($1, $2, $3::vector)"));
assert!(sql.contains("RETURNING id"));
assert_eq!(params.len(), 3);
}
#[test]
fn test_insert_upsert() {
let builder = VectorQueryBuilder::new();
let query = VectorInsertQuery::new("documents")
.with_columns(vec![
"id".to_string(),
"content".to_string(),
"embedding".to_string(),
])
.with_upsert(vec!["id".to_string()]);
let values = vec![
VectorParam::String("doc1".to_string()),
VectorParam::String("Hello world".to_string()),
VectorParam::Vector(vec![0.1, 0.2, 0.3]),
];
let (sql, _) = builder.insert_one(&query, &values);
assert!(sql.contains("ON CONFLICT (id) DO UPDATE SET"));
assert!(sql.contains("content = EXCLUDED.content"));
assert!(sql.contains("embedding = EXCLUDED.embedding"));
}
#[test]
fn test_insert_batch() {
let builder = VectorQueryBuilder::new();
let query = VectorInsertQuery::new("documents")
.with_columns(vec!["id".to_string(), "embedding".to_string()]);
let rows = vec![
vec![
VectorParam::String("doc1".to_string()),
VectorParam::Vector(vec![0.1, 0.2]),
],
vec![
VectorParam::String("doc2".to_string()),
VectorParam::Vector(vec![0.3, 0.4]),
],
];
let (sql, params) = builder.insert_batch(&query, &rows);
assert!(sql.contains("INSERT INTO documents (id, embedding)"));
assert!(sql.contains("($1, $2::vector)"));
assert!(sql.contains("($3, $4::vector)"));
assert_eq!(params.len(), 4);
}
#[test]
fn test_insert_batch_empty() {
let builder = VectorQueryBuilder::new();
let query = VectorInsertQuery::new("documents").with_columns(vec!["id".to_string()]);
let (sql, params) = builder.insert_batch(&query, &[]);
assert!(sql.is_empty());
assert!(params.is_empty());
}
#[test]
fn test_create_index_hnsw() {
let builder = VectorQueryBuilder::new();
let config = VectorConfig::openai();
let sql = builder.create_index(&config, "documents", "embedding");
assert_eq!(
sql,
Some("CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)".to_string())
);
}
#[test]
fn test_create_index_ivfflat() {
let builder = VectorQueryBuilder::new();
let config = VectorConfig::new(1536)
.with_index(crate::schema::VectorIndexType::IvfFlat)
.with_distance(DistanceMetric::L2);
let sql = builder.create_index(&config, "docs", "vec");
assert_eq!(sql, Some("CREATE INDEX ON docs USING ivfflat (vec vector_l2_ops)".to_string()));
}
#[test]
fn test_create_index_none() {
let builder = VectorQueryBuilder::new();
let config = VectorConfig::new(1536).with_index(crate::schema::VectorIndexType::None);
let sql = builder.create_index(&config, "documents", "embedding");
assert_eq!(sql, None);
}
#[test]
fn test_vector_param_to_sql_literal() {
let vec_param = VectorParam::Vector(vec![0.1, 0.2, 0.3]);
assert_eq!(vec_param.to_sql_literal(), "'[0.1,0.2,0.3]'::vector");
let int_param = VectorParam::Int(42);
assert_eq!(int_param.to_sql_literal(), "42");
let str_param = VectorParam::String("hello".to_string());
assert_eq!(str_param.to_sql_literal(), "'hello'");
let str_param_escape = VectorParam::String("it's a test".to_string());
assert_eq!(str_param_escape.to_sql_literal(), "'it''s a test'");
}
#[test]
fn test_question_mark_placeholders() {
let builder = VectorQueryBuilder::with_question_marks();
let query = VectorSearchQuery::new("docs").with_limit(10);
let embedding = vec![0.1, 0.2];
let (sql, _) = builder.similarity_search(&query, &embedding);
assert!(sql.contains("?::vector"));
assert!(!sql.contains("$1"));
}
}