use anyhow::Result;
use async_openai::{
types::{CreateEmbeddingRequestArgs, EmbeddingInput},
Client,
};
use moka::future::Cache;
use std::sync::Arc;
pub struct EmbeddingCache {
cache: Cache<String, Arc<Vec<f32>>>,
}
impl EmbeddingCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: Cache::new(max_size as u64),
}
}
pub async fn get(&self, key: &str) -> Option<Vec<f32>> {
self.cache.get(key).await.map(|v| v.as_ref().clone())
}
pub async fn set(&self, key: String, embedding: Vec<f32>) {
self.cache.insert(key, Arc::new(embedding)).await;
}
}
pub async fn generate_embedding(
client: &Client<async_openai::config::OpenAIConfig>,
text: &str,
model: &str,
) -> Result<Vec<f32>> {
let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input(EmbeddingInput::String(text.to_string()))
.build()?;
let response = client.embeddings().create(request).await?;
if let Some(embedding) = response.data.first() {
Ok(embedding.embedding.clone())
} else {
Err(anyhow::anyhow!("No embedding returned"))
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot_product / (norm_a * norm_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
}
#[tokio::test]
async fn test_embedding_cache() {
let cache = EmbeddingCache::new(10);
let embedding = vec![1.0, 2.0, 3.0];
cache.set("test".to_string(), embedding.clone()).await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let retrieved = cache.get("test").await;
assert_eq!(retrieved, Some(embedding));
}
}