openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! Embedding providers for generating vector representations
//!
//! This module provides a trait-based abstraction for different embedding providers,
//! including synthetic (local), OpenAI, Gemini, Ollama, and AWS Bedrock.

pub mod gemini;
pub mod openai;
pub mod ollama;
pub mod synthetic;

pub mod bedrock;

use crate::core::config::Config;
use crate::core::error::Result;
use crate::core::types::{EmbeddingKind, EmbeddingResult, Sector};
use async_trait::async_trait;

/// Trait for embedding providers
///
/// Implementations generate vector embeddings from text for semantic search.
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
    /// Generate embedding for a single text
    async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult>;

    /// Generate embeddings for multiple texts (batch)
    ///
    /// Default implementation calls embed() sequentially.
    async fn embed_batch(
        &self,
        texts: &[(&str, &Sector)],
    ) -> Result<Vec<EmbeddingResult>> {
        let mut results = Vec::with_capacity(texts.len());
        for (text, sector) in texts {
            results.push(self.embed(text, sector).await?);
        }
        Ok(results)
    }

    /// Get the vector dimensions for this provider
    fn dimensions(&self) -> usize;

    /// Get the provider name
    fn name(&self) -> &'static str;

    /// Check if the provider supports batch operations
    fn supports_batch(&self) -> bool {
        false
    }
}

/// Create an embedding provider based on configuration
pub fn create_provider(config: &Config) -> Box<dyn EmbeddingProvider> {
    match config.embedding_kind {
        EmbeddingKind::Synthetic => {
            Box::new(synthetic::SyntheticProvider::new(config.vec_dim))
        }
        EmbeddingKind::OpenAI => {
            Box::new(openai::OpenAIProvider::new(config))
        }
        EmbeddingKind::Ollama => {
            Box::new(ollama::OllamaProvider::new(config))
        }
        EmbeddingKind::Gemini => {
            Box::new(gemini::GeminiProvider::new(config))
        }
        #[cfg(feature = "aws")]
        EmbeddingKind::Bedrock => {
            Box::new(bedrock::BedrockProvider::new(config))
        }
        #[cfg(not(feature = "aws"))]
        EmbeddingKind::Bedrock => {
            panic!("AWS Bedrock support requires the 'aws' feature to be enabled");
        }
    }
}

/// Compress a vector to target dimension by averaging bins
pub fn compress_vector(v: &[f32], target_dim: usize) -> Vec<f32> {
    if v.len() <= target_dim {
        return v.to_vec();
    }

    let mut compressed = vec![0.0f32; target_dim];
    let bin_size = v.len() as f32 / target_dim as f32;

    for i in 0..target_dim {
        let start = (i as f32 * bin_size) as usize;
        let end = ((i + 1) as f32 * bin_size) as usize;
        let end = end.min(v.len());

        let mut sum = 0.0f32;
        let mut count = 0;

        for j in start..end {
            sum += v[j];
            count += 1;
        }

        compressed[i] = if count > 0 { sum / count as f32 } else { 0.0 };
    }

    // L2 normalize
    let norm: f32 = compressed.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 1e-10 {
        for x in &mut compressed {
            *x /= norm;
        }
    }

    compressed
}

/// Resize vector to target dimension (truncate or pad with zeros)
pub fn resize_vector(v: &[f32], target_dim: usize) -> Vec<f32> {
    if v.len() == target_dim {
        return v.to_vec();
    }

    if v.len() > target_dim {
        v[..target_dim].to_vec()
    } else {
        let mut result = v.to_vec();
        result.resize(target_dim, 0.0);
        result
    }
}

/// Fuse synthetic and semantic vectors
///
/// Combines vectors with weights (0.6 synthetic, 0.4 semantic) and normalizes.
pub fn fuse_vectors(synthetic: &[f32], semantic: &[f32]) -> Vec<f32> {
    let total_len = synthetic.len() + semantic.len();
    let mut fused = Vec::with_capacity(total_len);

    // Add weighted synthetic
    for &v in synthetic {
        fused.push(v * 0.6);
    }

    // Add weighted semantic
    for &v in semantic {
        fused.push(v * 0.4);
    }

    // L2 normalize
    let norm: f32 = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 1e-10 {
        for x in &mut fused {
            *x /= norm;
        }
    }

    fused
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_compress_vector() {
        let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
        let compressed = compress_vector(&v, 4);
        assert_eq!(compressed.len(), 4);

        // Check it's normalized
        let norm: f32 = compressed.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-5);
    }

    #[test]
    fn test_resize_vector_truncate() {
        let v = vec![1.0, 2.0, 3.0, 4.0];
        let resized = resize_vector(&v, 2);
        assert_eq!(resized, vec![1.0, 2.0]);
    }

    #[test]
    fn test_resize_vector_pad() {
        let v = vec![1.0, 2.0];
        let resized = resize_vector(&v, 4);
        assert_eq!(resized, vec![1.0, 2.0, 0.0, 0.0]);
    }

    #[test]
    fn test_fuse_vectors() {
        let syn = vec![1.0, 0.0];
        let sem = vec![0.0, 1.0];
        let fused = fuse_vectors(&syn, &sem);

        assert_eq!(fused.len(), 4);

        // Check normalized
        let norm: f32 = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-5);
    }
}