openserve 2.0.3

A modern, high-performance, AI-enhanced file server built in Rust
Documentation
//! Embeddings utilities for semantic search

use anyhow::Result;
use async_openai::{
    types::{CreateEmbeddingRequestArgs, EmbeddingInput},
    Client,
};
use moka::future::Cache;
use std::sync::Arc;

/// A high-performance, concurrent cache for storing vector embeddings.
pub struct EmbeddingCache {
    cache: Cache<String, Arc<Vec<f32>>>,
}

impl EmbeddingCache {
    /// Creates a new `EmbeddingCache` with a specified maximum capacity.
    ///
    /// # Arguments
    ///
    /// * `max_size` - The maximum number of embeddings to store in the cache.
    pub fn new(max_size: usize) -> Self {
        Self {
            cache: Cache::new(max_size as u64),
        }
    }

    /// Retrieves an embedding from the cache.
    ///
    /// # Arguments
    ///
    /// * `key` - The key associated with the embedding.
    pub async fn get(&self, key: &str) -> Option<Vec<f32>> {
        self.cache.get(key).await.map(|v| v.as_ref().clone())
    }

    /// Inserts an embedding into the cache.
    ///
    /// # Arguments
    ///
    /// * `key` - The key to associate with the embedding.
    /// * `embedding` - The vector embedding to store.
    pub async fn set(&self, key: String, embedding: Vec<f32>) {
        self.cache.insert(key, Arc::new(embedding)).await;
    }
}

/// Generate embeddings for text
pub async fn generate_embedding(
    client: &Client<async_openai::config::OpenAIConfig>,
    text: &str,
    model: &str,
) -> Result<Vec<f32>> {
    let request = CreateEmbeddingRequestArgs::default()
        .model(model)
        .input(EmbeddingInput::String(text.to_string()))
        .build()?;

    let response = client.embeddings().create(request).await?;
    
    if let Some(embedding) = response.data.first() {
        Ok(embedding.embedding.clone())
    } else {
        Err(anyhow::anyhow!("No embedding returned"))
    }
}

/// Calculate cosine similarity between embeddings
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    if a.len() != b.len() {
        return 0.0;
    }

    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();

    if norm_a == 0.0 || norm_b == 0.0 {
        0.0
    } else {
        dot_product / (norm_a * norm_b)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cosine_similarity() {
        let a = vec![1.0, 0.0, 0.0];
        let b = vec![1.0, 0.0, 0.0];
        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);

        let a = vec![1.0, 0.0];
        let b = vec![0.0, 1.0];
        assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
    }

    #[tokio::test]
    async fn test_embedding_cache() {
        let cache = EmbeddingCache::new(10);
        let embedding = vec![1.0, 2.0, 3.0];
        cache.set("test".to_string(), embedding.clone()).await;
        
        // Add a small delay to ensure the cache has time to process the write
        tokio::time::sleep(std::time::Duration::from_millis(10)).await;

        let retrieved = cache.get("test").await;
        assert_eq!(retrieved, Some(embedding));
    }
}