use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::ProviderError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorRecord {
pub id: String,
pub vector: Vec<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
impl VectorRecord {
pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
assert!(!vector.is_empty(), "VectorRecord vector must not be empty");
Self {
id: id.into(),
vector,
text: None,
metadata: None,
}
}
pub fn with_text(mut self, text: impl Into<String>) -> Self {
self.text = Some(text.into());
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub vector: Option<Vec<f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Default)]
pub struct VectorQuery {
pub vector: Vec<f32>,
pub top_k: usize,
pub score_threshold: Option<f32>,
pub filter: Option<serde_json::Value>,
}
impl VectorQuery {
pub fn new(vector: Vec<f32>) -> Self {
Self {
vector,
top_k: 10,
score_threshold: None,
filter: None,
}
}
pub fn with_top_k(mut self, top_k: usize) -> Self {
assert!(top_k > 0, "top_k must be greater than 0");
assert!(top_k <= 1000, "top_k must not exceed 1000");
self.top_k = top_k;
self
}
pub fn with_score_threshold(mut self, threshold: f32) -> Self {
self.score_threshold = Some(threshold);
self
}
pub fn with_filter(mut self, filter: serde_json::Value) -> Self {
self.filter = Some(filter);
self
}
}
#[async_trait]
pub trait VectorStore: Send + Sync {
async fn upsert(&self, records: Vec<VectorRecord>) -> Result<(), ProviderError>;
async fn delete(&self, ids: Vec<String>) -> Result<(), ProviderError>;
async fn get(&self, ids: Vec<String>) -> Result<Vec<VectorRecord>, ProviderError>;
async fn search(&self, query: VectorQuery) -> Result<Vec<SearchResult>, ProviderError>;
async fn clear(&self) -> Result<(), ProviderError>;
async fn count(&self) -> Result<usize, ProviderError>;
}