mod openai;
mod mock;
pub use openai::{OpenAIEmbeddings, OpenAIEmbeddingsConfig};
pub use mock::MockEmbeddings;
use async_trait::async_trait;
use std::error::Error;
#[derive(Debug)]
pub enum EmbeddingError {
HttpError(String),
ApiError(String),
ParseError(String),
EmptyInput,
}
impl std::fmt::Display for EmbeddingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EmbeddingError::HttpError(msg) => write!(f, "HTTP 错误: {}", msg),
EmbeddingError::ApiError(msg) => write!(f, "API 错误: {}", msg),
EmbeddingError::ParseError(msg) => write!(f, "解析错误: {}", msg),
EmbeddingError::EmptyInput => write!(f, "输入为空"),
}
}
}
impl Error for EmbeddingError {}
#[async_trait]
pub trait Embeddings: Send + Sync {
async fn embed_query(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
let mut embeddings = Vec::new();
for text in texts {
embeddings.push(self.embed_query(text).await?);
}
Ok(embeddings)
}
fn dimension(&self) -> usize;
fn model_name(&self) -> &str;
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.0001);
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 0.0001);
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - (-1.0)).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
}