use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterOp {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
Like,
ILike,
IsNull,
IsNotNull,
}
impl FilterOp {
pub fn as_sql(self) -> &'static str {
match self {
FilterOp::Eq => "=",
FilterOp::Ne => "!=",
FilterOp::Gt => ">",
FilterOp::Gte => ">=",
FilterOp::Lt => "<",
FilterOp::Lte => "<=",
FilterOp::Like => "LIKE",
FilterOp::ILike => "ILIKE",
FilterOp::IsNull => "IS NULL",
FilterOp::IsNotNull => "IS NOT NULL",
}
}
}
#[derive(Debug, Clone)]
pub struct MetadataFilter {
field: String,
op: FilterOp,
value: Option<FilterValue>,
}
#[derive(Debug, Clone)]
pub enum FilterValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
}
impl MetadataFilter {
pub fn new(field: impl Into<String>, op: FilterOp, value: FilterValue) -> Self {
Self {
field: field.into(),
op,
value: Some(value),
}
}
pub fn eq(field: impl Into<String>, value: impl Into<FilterValue>) -> Self {
Self::new(field, FilterOp::Eq, value.into())
}
pub fn ne(field: impl Into<String>, value: impl Into<FilterValue>) -> Self {
Self::new(field, FilterOp::Ne, value.into())
}
pub fn gt(field: impl Into<String>, value: impl Into<FilterValue>) -> Self {
Self::new(field, FilterOp::Gt, value.into())
}
pub fn lt(field: impl Into<String>, value: impl Into<FilterValue>) -> Self {
Self::new(field, FilterOp::Lt, value.into())
}
pub fn is_null(field: impl Into<String>) -> Self {
Self {
field: field.into(),
op: FilterOp::IsNull,
value: None,
}
}
pub fn is_not_null(field: impl Into<String>) -> Self {
Self {
field: field.into(),
op: FilterOp::IsNotNull,
value: None,
}
}
pub fn like(field: impl Into<String>, pattern: impl Into<String>) -> Self {
Self::new(field, FilterOp::Like, FilterValue::String(pattern.into()))
}
fn sanitize_field(&self) -> String {
let sanitized: String = self
.field
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-' || *c == '>')
.collect();
if sanitized.is_empty()
|| (!sanitized.starts_with(|c: char| c.is_alphabetic()) && !sanitized.starts_with('_'))
{
"invalid_field".to_string()
} else {
sanitized
}
}
pub fn to_sql_with_param(&self, param_index: usize) -> (String, Option<FilterValue>) {
let field = self.sanitize_field();
let jsonb_field = format!("metadata->>'{}'", field);
match self.op {
FilterOp::IsNull => (format!("{} IS NULL", jsonb_field), None),
FilterOp::IsNotNull => (format!("{} IS NOT NULL", jsonb_field), None),
_ => {
let sql = format!("{} {} ${}", jsonb_field, self.op.as_sql(), param_index);
(sql, self.value.clone())
}
}
}
}
impl From<String> for FilterValue {
fn from(s: String) -> Self {
FilterValue::String(s)
}
}
impl From<&str> for FilterValue {
fn from(s: &str) -> Self {
FilterValue::String(s.to_string())
}
}
impl From<i64> for FilterValue {
fn from(i: i64) -> Self {
FilterValue::Int(i)
}
}
impl From<i32> for FilterValue {
fn from(i: i32) -> Self {
FilterValue::Int(i as i64)
}
}
impl From<f64> for FilterValue {
fn from(f: f64) -> Self {
FilterValue::Float(f)
}
}
impl From<bool> for FilterValue {
fn from(b: bool) -> Self {
FilterValue::Bool(b)
}
}
#[derive(Debug, Clone, Default)]
pub struct MetadataFilters {
filters: Vec<MetadataFilter>,
use_and: bool,
}
impl MetadataFilters {
pub fn new() -> Self {
Self {
filters: Vec::new(),
use_and: true,
}
}
pub fn or() -> Self {
Self {
filters: Vec::new(),
use_and: false,
}
}
pub fn add(mut self, filter: MetadataFilter) -> Self {
self.filters.push(filter);
self
}
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
pub fn to_sql_with_params(&self, start_param_index: usize) -> (String, Vec<FilterValue>) {
if self.filters.is_empty() {
return (String::new(), Vec::new());
}
let mut conditions = Vec::new();
let mut params = Vec::new();
let mut param_idx = start_param_index;
for filter in &self.filters {
let (sql, param) = filter.to_sql_with_param(param_idx);
conditions.push(sql);
if let Some(p) = param {
params.push(p);
param_idx += 1;
}
}
let joiner = if self.use_and { " AND " } else { " OR " };
let sql = if conditions.len() > 1 {
format!("({})", conditions.join(joiner))
} else {
conditions.into_iter().next().unwrap_or_default()
};
(sql, params)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum EmbeddingModel {
#[default]
OpenAISmall,
OpenAILarge,
OpenAIAda002,
CohereEnglishV3,
CohereMultilingualV3,
VoyageAI2,
VoyageAILarge2,
GoogleGecko,
GoogleGeckoLatest,
MiniLMv2,
MPNetBase,
BGESmall,
BGEBase,
BGELarge,
Custom(usize),
}
impl EmbeddingModel {
pub fn dimension(&self) -> usize {
match self {
EmbeddingModel::OpenAISmall => 1536,
EmbeddingModel::OpenAILarge => 3072,
EmbeddingModel::OpenAIAda002 => 1536,
EmbeddingModel::CohereEnglishV3 => 1024,
EmbeddingModel::CohereMultilingualV3 => 1024,
EmbeddingModel::VoyageAI2 => 1024,
EmbeddingModel::VoyageAILarge2 => 1536,
EmbeddingModel::GoogleGecko => 768,
EmbeddingModel::GoogleGeckoLatest => 768,
EmbeddingModel::MiniLMv2 => 384,
EmbeddingModel::MPNetBase => 768,
EmbeddingModel::BGESmall => 384,
EmbeddingModel::BGEBase => 768,
EmbeddingModel::BGELarge => 1024,
EmbeddingModel::Custom(dim) => *dim,
}
}
pub fn name(&self) -> &str {
match self {
EmbeddingModel::OpenAISmall => "text-embedding-3-small",
EmbeddingModel::OpenAILarge => "text-embedding-3-large",
EmbeddingModel::OpenAIAda002 => "text-embedding-ada-002",
EmbeddingModel::CohereEnglishV3 => "embed-english-v3.0",
EmbeddingModel::CohereMultilingualV3 => "embed-multilingual-v3.0",
EmbeddingModel::VoyageAI2 => "voyage-2",
EmbeddingModel::VoyageAILarge2 => "voyage-large-2",
EmbeddingModel::GoogleGecko => "textembedding-gecko",
EmbeddingModel::GoogleGeckoLatest => "textembedding-gecko@latest",
EmbeddingModel::MiniLMv2 => "all-MiniLM-L6-v2",
EmbeddingModel::MPNetBase => "all-mpnet-base-v2",
EmbeddingModel::BGESmall => "bge-small-en-v1.5",
EmbeddingModel::BGEBase => "bge-base-en-v1.5",
EmbeddingModel::BGELarge => "bge-large-en-v1.5",
EmbeddingModel::Custom(_) => "custom",
}
}
pub fn provider(&self) -> &str {
match self {
EmbeddingModel::OpenAISmall
| EmbeddingModel::OpenAILarge
| EmbeddingModel::OpenAIAda002 => "openai",
EmbeddingModel::CohereEnglishV3 | EmbeddingModel::CohereMultilingualV3 => "cohere",
EmbeddingModel::VoyageAI2 | EmbeddingModel::VoyageAILarge2 => "voyage",
EmbeddingModel::GoogleGecko | EmbeddingModel::GoogleGeckoLatest => "google",
EmbeddingModel::MiniLMv2 | EmbeddingModel::MPNetBase => "sentence-transformers",
EmbeddingModel::BGESmall | EmbeddingModel::BGEBase | EmbeddingModel::BGELarge => "bge",
EmbeddingModel::Custom(_) => "custom",
}
}
pub fn from_name(name: &str) -> Option<Self> {
match name.to_lowercase().as_str() {
"text-embedding-3-small" | "openai-small" => Some(EmbeddingModel::OpenAISmall),
"text-embedding-3-large" | "openai-large" => Some(EmbeddingModel::OpenAILarge),
"text-embedding-ada-002" | "ada-002" | "ada002" => Some(EmbeddingModel::OpenAIAda002),
"embed-english-v3.0" | "cohere-english-v3" => Some(EmbeddingModel::CohereEnglishV3),
"embed-multilingual-v3.0" | "cohere-multilingual-v3" => {
Some(EmbeddingModel::CohereMultilingualV3)
}
"voyage-2" | "voyage2" => Some(EmbeddingModel::VoyageAI2),
"voyage-large-2" | "voyage-large2" => Some(EmbeddingModel::VoyageAILarge2),
"textembedding-gecko" | "gecko" => Some(EmbeddingModel::GoogleGecko),
"textembedding-gecko@latest" | "gecko-latest" => {
Some(EmbeddingModel::GoogleGeckoLatest)
}
"all-minilm-l6-v2" | "minilm" | "minilm-v2" => Some(EmbeddingModel::MiniLMv2),
"all-mpnet-base-v2" | "mpnet" => Some(EmbeddingModel::MPNetBase),
"bge-small-en-v1.5" | "bge-small" => Some(EmbeddingModel::BGESmall),
"bge-base-en-v1.5" | "bge-base" => Some(EmbeddingModel::BGEBase),
"bge-large-en-v1.5" | "bge-large" => Some(EmbeddingModel::BGELarge),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub vector: Vec<f32>,
pub metadata: Option<serde_json::Value>,
pub content: Option<String>,
}
impl VectorPoint {
pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
Self {
id: id.into(),
vector,
metadata: None,
content: None,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_content(mut self, content: impl Into<String>) -> Self {
self.content = Some(content.into());
self
}
pub fn dimension(&self) -> usize {
self.vector.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f32,
pub metadata: Option<serde_json::Value>,
pub content: Option<String>,
pub vector: Option<Vec<f32>>,
}
impl SearchResult {
pub fn new(id: impl Into<String>, score: f32) -> Self {
Self {
id: id.into(),
score,
metadata: None,
content: None,
vector: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TableStats {
pub total_vectors: u64,
pub dimension: usize,
pub index_type: Option<String>,
pub table_size_bytes: Option<u64>,
pub index_size_bytes: Option<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct SearchOptions {
pub limit: usize,
pub threshold: Option<f32>,
pub include_vector: bool,
pub include_metadata: bool,
pub include_content: bool,
pub metadata_filters: Option<MetadataFilters>,
}
impl SearchOptions {
pub fn new(limit: usize) -> Self {
Self {
limit,
include_metadata: true,
include_content: true,
..Default::default()
}
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = Some(threshold);
self
}
pub fn with_vector(mut self) -> Self {
self.include_vector = true;
self
}
pub fn with_filters(mut self, filters: MetadataFilters) -> Self {
self.metadata_filters = Some(filters);
self
}
pub fn with_filter_eq(
mut self,
field: impl Into<String>,
value: impl Into<FilterValue>,
) -> Self {
let filter = MetadataFilter::eq(field, value);
self.metadata_filters = Some(self.metadata_filters.unwrap_or_default().add(filter));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_model_dimensions() {
assert_eq!(EmbeddingModel::OpenAISmall.dimension(), 1536);
assert_eq!(EmbeddingModel::OpenAILarge.dimension(), 3072);
assert_eq!(EmbeddingModel::CohereEnglishV3.dimension(), 1024);
assert_eq!(EmbeddingModel::MiniLMv2.dimension(), 384);
assert_eq!(EmbeddingModel::Custom(512).dimension(), 512);
}
#[test]
fn test_embedding_model_from_name() {
assert_eq!(
EmbeddingModel::from_name("text-embedding-3-small"),
Some(EmbeddingModel::OpenAISmall)
);
assert_eq!(
EmbeddingModel::from_name("ada-002"),
Some(EmbeddingModel::OpenAIAda002)
);
assert_eq!(EmbeddingModel::from_name("unknown-model"), None);
}
#[test]
fn test_vector_point_creation() {
let point = VectorPoint::new("test-id", vec![0.1, 0.2, 0.3])
.with_metadata(serde_json::json!({"key": "value"}))
.with_content("test content");
assert_eq!(point.id, "test-id");
assert_eq!(point.dimension(), 3);
assert!(point.metadata.is_some());
assert!(point.content.is_some());
}
#[test]
fn test_search_options() {
let options = SearchOptions::new(10)
.with_threshold(0.8)
.with_vector()
.with_filter_eq("type", "document");
assert_eq!(options.limit, 10);
assert_eq!(options.threshold, Some(0.8));
assert!(options.include_vector);
assert!(options.metadata_filters.is_some());
}
#[test]
fn test_embedding_model_provider() {
assert_eq!(EmbeddingModel::OpenAISmall.provider(), "openai");
assert_eq!(EmbeddingModel::CohereEnglishV3.provider(), "cohere");
assert_eq!(EmbeddingModel::VoyageAI2.provider(), "voyage");
assert_eq!(EmbeddingModel::Custom(512).provider(), "custom");
}
}