use serde::{Deserialize, Serialize};
use crate::ops::DistanceMetric;
use crate::types::Embedding;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorFilter {
pub column: String,
pub query_vector: Embedding,
pub metric: DistanceMetric,
pub filter_type: VectorFilterType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum VectorFilterType {
Nearest {
limit: usize,
},
WithinDistance {
max_distance: f64,
limit: Option<usize>,
},
DistanceRange {
min_distance: f64,
max_distance: f64,
limit: Option<usize>,
},
}
impl VectorFilter {
pub fn nearest(
column: impl Into<String>,
query_vector: Embedding,
metric: DistanceMetric,
limit: usize,
) -> Self {
Self {
column: column.into(),
query_vector,
metric,
filter_type: VectorFilterType::Nearest { limit },
}
}
pub fn within_distance(
column: impl Into<String>,
query_vector: Embedding,
metric: DistanceMetric,
max_distance: f64,
) -> Self {
Self {
column: column.into(),
query_vector,
metric,
filter_type: VectorFilterType::WithinDistance {
max_distance,
limit: None,
},
}
}
pub fn distance_range(
column: impl Into<String>,
query_vector: Embedding,
metric: DistanceMetric,
min_distance: f64,
max_distance: f64,
) -> Self {
Self {
column: column.into(),
query_vector,
metric,
filter_type: VectorFilterType::DistanceRange {
min_distance,
max_distance,
limit: None,
},
}
}
pub fn with_limit(mut self, limit: usize) -> Self {
match &mut self.filter_type {
VectorFilterType::Nearest { limit: l } => *l = limit,
VectorFilterType::WithinDistance { limit: l, .. } => *l = Some(limit),
VectorFilterType::DistanceRange { limit: l, .. } => *l = Some(limit),
}
self
}
pub fn distance_expr_sql(&self, param_index: usize) -> String {
format!(
"{} {} ${}",
self.column,
self.metric.operator(),
param_index
)
}
pub fn where_sql(&self, param_index: usize) -> Option<String> {
let distance_expr = self.distance_expr_sql(param_index);
match &self.filter_type {
VectorFilterType::Nearest { .. } => None,
VectorFilterType::WithinDistance { max_distance, .. } => {
Some(format!("{distance_expr} < {max_distance}"))
}
VectorFilterType::DistanceRange {
min_distance,
max_distance,
..
} => Some(format!(
"{distance_expr} BETWEEN {min_distance} AND {max_distance}"
)),
}
}
pub fn order_by_sql(&self, param_index: usize) -> String {
self.distance_expr_sql(param_index)
}
pub fn limit_sql(&self) -> Option<String> {
let limit = match &self.filter_type {
VectorFilterType::Nearest { limit } => Some(*limit),
VectorFilterType::WithinDistance { limit, .. } => *limit,
VectorFilterType::DistanceRange { limit, .. } => *limit,
};
limit.map(|l| format!("LIMIT {l}"))
}
pub fn to_select_sql(
&self,
table: &str,
param_index: usize,
extra_where: Option<&str>,
select_columns: &str,
) -> String {
let distance_expr = self.distance_expr_sql(param_index);
let mut sql = format!(
"SELECT {}, {} AS distance FROM {}",
select_columns, distance_expr, table
);
let mut where_parts = Vec::new();
if let Some(vec_where) = self.where_sql(param_index) {
where_parts.push(vec_where);
}
if let Some(extra) = extra_where {
where_parts.push(extra.to_string());
}
if !where_parts.is_empty() {
sql.push_str(&format!(" WHERE {}", where_parts.join(" AND ")));
}
sql.push_str(&format!(" ORDER BY {}", self.order_by_sql(param_index)));
if let Some(limit) = self.limit_sql() {
sql.push_str(&format!(" {limit}"));
}
sql
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorOrderBy {
pub column: String,
pub query_vector: Embedding,
pub metric: DistanceMetric,
pub include_distance: bool,
pub distance_alias: String,
}
impl VectorOrderBy {
pub fn new(column: impl Into<String>, query_vector: Embedding, metric: DistanceMetric) -> Self {
Self {
column: column.into(),
query_vector,
metric,
include_distance: true,
distance_alias: "distance".to_string(),
}
}
pub fn alias(mut self, alias: impl Into<String>) -> Self {
self.distance_alias = alias.into();
self
}
pub fn without_distance(mut self) -> Self {
self.include_distance = false;
self
}
pub fn select_distance_sql(&self, param_index: usize) -> Option<String> {
if self.include_distance {
Some(format!(
"{} {} ${} AS {}",
self.column,
self.metric.operator(),
param_index,
self.distance_alias
))
} else {
None
}
}
pub fn order_by_sql(&self, param_index: usize) -> String {
if self.include_distance {
self.distance_alias.clone()
} else {
format!(
"{} {} ${}",
self.column,
self.metric.operator(),
param_index
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_embedding() -> Embedding {
Embedding::new(vec![0.1, 0.2, 0.3])
}
#[test]
fn test_nearest_filter() {
let filter =
VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 10);
assert!(filter.where_sql(1).is_none());
assert_eq!(filter.order_by_sql(1), "embedding <=> $1");
assert_eq!(filter.limit_sql(), Some("LIMIT 10".to_string()));
}
#[test]
fn test_within_distance_filter() {
let filter =
VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5);
let where_sql = filter.where_sql(1).unwrap();
assert!(where_sql.contains("<->"));
assert!(where_sql.contains("< 0.5"));
}
#[test]
fn test_distance_range_filter() {
let filter = VectorFilter::distance_range(
"embedding",
test_embedding(),
DistanceMetric::L2,
0.1,
0.5,
);
let where_sql = filter.where_sql(1).unwrap();
assert!(where_sql.contains("BETWEEN"));
assert!(where_sql.contains("0.1"));
assert!(where_sql.contains("0.5"));
}
#[test]
fn test_filter_with_limit() {
let filter =
VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
.with_limit(50);
assert_eq!(filter.limit_sql(), Some("LIMIT 50".to_string()));
}
#[test]
fn test_to_select_sql_nearest() {
let filter =
VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 5);
let sql = filter.to_select_sql("documents", 1, None, "*");
assert!(sql.contains("SELECT *, embedding <=> $1 AS distance"));
assert!(sql.contains("FROM documents"));
assert!(sql.contains("ORDER BY"));
assert!(sql.contains("LIMIT 5"));
assert!(!sql.contains("WHERE")); }
#[test]
fn test_to_select_sql_with_extra_where() {
let filter =
VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
.with_limit(20);
let sql = filter.to_select_sql("documents", 1, Some("category = 'tech'"), "*");
assert!(sql.contains("WHERE"));
assert!(sql.contains("< 0.5"));
assert!(sql.contains("category = 'tech'"));
assert!(sql.contains("AND"));
}
#[test]
fn test_vector_order_by() {
let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine);
assert!(order.include_distance);
let select = order.select_distance_sql(1).unwrap();
assert!(select.contains("<=>"));
assert!(select.contains("AS distance"));
let order_by = order.order_by_sql(1);
assert_eq!(order_by, "distance");
}
#[test]
fn test_vector_order_by_without_distance() {
let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::L2)
.without_distance();
assert!(order.select_distance_sql(1).is_none());
let order_by = order.order_by_sql(1);
assert!(order_by.contains("<->"));
}
#[test]
fn test_vector_order_by_custom_alias() {
let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine)
.alias("similarity");
let select = order.select_distance_sql(1).unwrap();
assert!(select.contains("AS similarity"));
}
#[test]
fn test_distance_expr_sql() {
let filter =
VectorFilter::nearest("emb", test_embedding(), DistanceMetric::InnerProduct, 5);
let expr = filter.distance_expr_sql(2);
assert_eq!(expr, "emb <#> $2");
}
}