use async_trait::async_trait;
use crate::error::ProviderError;
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>, ProviderError>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
let embedding = self.embed(text).await?;
results.push(embedding);
}
Ok(results)
}
fn embedding_dim(&self) -> Option<usize> {
None
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, ProviderError> {
if a.len() != b.len() {
return Err(ProviderError::Message(format!(
"向量维度不匹配: {} vs {}",
a.len(),
b.len()
)));
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return Ok(0.0);
}
Ok(dot_product / (norm_a * norm_b))
}
pub fn vector_search(
query: &[f32],
candidates: &[Vec<f32>],
top_k: usize,
) -> Result<Vec<(usize, f32)>, ProviderError> {
let mut scores: Vec<(usize, f32)> = candidates
.iter()
.enumerate()
.map(|(idx, vec)| {
let sim = cosine_similarity(query, vec).unwrap_or(0.0);
(idx, sim)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
Ok(scores)
}