use super::expr::{Expr, QualifiedName};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub struct VectorSearch {
pub vector_expr: Box<Expr>,
pub query_expr: Box<Expr>,
pub metric: DistanceMetric,
}
impl VectorSearch {
#[must_use]
pub fn new(vector_expr: Expr, query_expr: Expr, metric: DistanceMetric) -> Self {
Self { vector_expr: Box::new(vector_expr), query_expr: Box::new(query_expr), metric }
}
#[must_use]
pub fn euclidean(vector_expr: Expr, query_expr: Expr) -> Self {
Self::new(vector_expr, query_expr, DistanceMetric::Euclidean)
}
#[must_use]
pub fn cosine(vector_expr: Expr, query_expr: Expr) -> Self {
Self::new(vector_expr, query_expr, DistanceMetric::Cosine)
}
#[must_use]
pub fn inner_product(vector_expr: Expr, query_expr: Expr) -> Self {
Self::new(vector_expr, query_expr, DistanceMetric::InnerProduct)
}
}
impl fmt::Display for VectorSearch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "vector_search({}, ...)", self.metric)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DistanceMetric {
Euclidean,
Cosine,
InnerProduct,
Manhattan,
Hamming,
}
impl DistanceMetric {
#[must_use]
pub const fn operator(&self) -> &'static str {
match self {
Self::Euclidean => "<->",
Self::Cosine => "<=>",
Self::InnerProduct => "<#>",
Self::Manhattan => "<~>", Self::Hamming => "<%>", }
}
#[must_use]
pub const fn function_name(&self) -> &'static str {
match self {
Self::Euclidean => "euclidean_distance",
Self::Cosine => "cosine_distance",
Self::InnerProduct => "inner_product",
Self::Manhattan => "manhattan_distance",
Self::Hamming => "hamming_distance",
}
}
}
impl fmt::Display for DistanceMetric {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.operator())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct VectorIndexHint {
pub index_name: QualifiedName,
pub params: VectorSearchParams,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct VectorSearchParams {
pub limit: Option<u32>,
pub ef_search: Option<u32>,
pub n_probe: Option<u32>,
pub distance_threshold: Option<f32>,
}
impl VectorSearchParams {
#[must_use]
pub const fn new() -> Self {
Self { limit: None, ef_search: None, n_probe: None, distance_threshold: None }
}
#[must_use]
pub const fn with_limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
#[must_use]
pub const fn with_ef_search(mut self, ef: u32) -> Self {
self.ef_search = Some(ef);
self
}
#[must_use]
pub const fn with_n_probe(mut self, n: u32) -> Self {
self.n_probe = Some(n);
self
}
#[must_use]
pub const fn with_distance_threshold(mut self, threshold: f32) -> Self {
self.distance_threshold = Some(threshold);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VectorAggregateOp {
Avg,
Sum,
Centroid,
}
impl fmt::Display for VectorAggregateOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Avg => "VECTOR_AVG",
Self::Sum => "VECTOR_SUM",
Self::Centroid => "VECTOR_CENTROID",
};
write!(f, "{name}")
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct VectorAggregate {
pub op: VectorAggregateOp,
pub expr: Box<Expr>,
}
impl VectorAggregate {
#[must_use]
pub fn new(op: VectorAggregateOp, expr: Expr) -> Self {
Self { op, expr: Box::new(expr) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distance_metric_operators() {
assert_eq!(DistanceMetric::Euclidean.operator(), "<->");
assert_eq!(DistanceMetric::Cosine.operator(), "<=>");
assert_eq!(DistanceMetric::InnerProduct.operator(), "<#>");
}
#[test]
fn distance_metric_function_names() {
assert_eq!(DistanceMetric::Euclidean.function_name(), "euclidean_distance");
assert_eq!(DistanceMetric::Cosine.function_name(), "cosine_distance");
assert_eq!(DistanceMetric::InnerProduct.function_name(), "inner_product");
}
#[test]
fn vector_search_params_builder() {
let params = VectorSearchParams::new()
.with_limit(10)
.with_ef_search(100)
.with_distance_threshold(0.5);
assert_eq!(params.limit, Some(10));
assert_eq!(params.ef_search, Some(100));
assert_eq!(params.distance_threshold, Some(0.5));
}
#[test]
fn vector_aggregate_display() {
assert_eq!(VectorAggregateOp::Avg.to_string(), "VECTOR_AVG");
assert_eq!(VectorAggregateOp::Centroid.to_string(), "VECTOR_CENTROID");
}
}