use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::documents::Document;
use crate::error::Result;
use crate::retrievers::BaseRetriever;
#[derive(Debug, Clone, Default)]
pub enum SearchType {
#[default]
Similarity,
SimilarityScoreThreshold { score_threshold: f32 },
Mmr { fetch_k: usize, lambda_mult: f32 },
}
#[async_trait]
pub trait VectorStore: Send + Sync {
async fn add_texts(
&self,
texts: &[String],
metadatas: Option<&[HashMap<String, Value>]>,
ids: Option<&[String]>,
) -> Result<Vec<String>>;
async fn add_documents(
&self,
documents: Vec<Document>,
ids: Option<Vec<String>>,
) -> Result<Vec<String>>;
async fn delete(&self, ids: Option<&[String]>) -> Result<bool>;
async fn get_by_ids(&self, ids: &[String]) -> Result<Vec<Document>>;
async fn similarity_search(&self, query: &str, k: usize) -> Result<Vec<Document>>;
async fn similarity_search_with_score(
&self,
query: &str,
k: usize,
) -> Result<Vec<(Document, f32)>>;
async fn similarity_search_by_vector(
&self,
embedding: &[f32],
k: usize,
) -> Result<Vec<Document>>;
async fn max_marginal_relevance_search(
&self,
query: &str,
k: usize,
fetch_k: usize,
lambda_mult: f32,
) -> Result<Vec<Document>>;
}
pub struct VectorStoreRetriever {
vectorstore: Arc<dyn VectorStore>,
search_type: SearchType,
k: usize,
}
impl VectorStoreRetriever {
pub fn new(vectorstore: Arc<dyn VectorStore>, search_type: SearchType, k: usize) -> Self {
Self {
vectorstore,
search_type,
k,
}
}
pub fn from_vectorstore(vectorstore: Arc<dyn VectorStore>) -> Self {
Self {
vectorstore,
search_type: SearchType::Similarity,
k: 4,
}
}
}
#[async_trait]
impl BaseRetriever for VectorStoreRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
match &self.search_type {
SearchType::Similarity => self.vectorstore.similarity_search(query, self.k).await,
SearchType::SimilarityScoreThreshold { score_threshold } => {
let threshold = *score_threshold;
let results = self
.vectorstore
.similarity_search_with_score(query, self.k)
.await?;
Ok(results
.into_iter()
.filter(|(_, score)| *score >= threshold)
.map(|(doc, _)| doc)
.collect())
}
SearchType::Mmr {
fetch_k,
lambda_mult,
} => {
self.vectorstore
.max_marginal_relevance_search(query, self.k, *fetch_k, *lambda_mult)
.await
}
}
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
pub fn euclidean_relevance_score(distance: f32) -> f32 {
1.0 - distance / 2f32.sqrt()
}
pub fn cosine_relevance_score(distance: f32) -> f32 {
1.0 - distance
}
pub fn max_inner_product_relevance_score(distance: f32) -> f32 {
if distance > 0.0 {
1.0 - distance
} else {
-distance
}
}