openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! Synthetic (local) embedding provider
//!
//! Generates embeddings locally without external API calls using
//! TF-IDF, n-grams, and positional encoding.

use crate::core::error::Result;
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::EmbeddingProvider;
use crate::utils::text::{add_synonym_tokens, canonical_tokens_from_text};
use async_trait::async_trait;
use std::collections::HashMap;

/// Synthetic embedding provider using local computation
pub struct SyntheticProvider {
    dim: usize,
}

impl SyntheticProvider {
    /// Create a new synthetic provider with the specified dimension
    pub fn new(dim: usize) -> Self {
        Self { dim }
    }

    /// Generate a synthetic embedding for text
    pub fn generate(&self, text: &str, sector: &Sector) -> Vec<f32> {
        gen_synthetic_embedding(text, sector, self.dim)
    }
}

#[async_trait]
impl EmbeddingProvider for SyntheticProvider {
    async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
        let vector = self.generate(text, sector);
        Ok(EmbeddingResult {
            sector: *sector,
            vector: vector.clone(),
            dim: vector.len(),
        })
    }

    fn dimensions(&self) -> usize {
        self.dim
    }

    fn name(&self) -> &'static str {
        "synthetic"
    }

    fn supports_batch(&self) -> bool {
        true // We can process batches efficiently locally
    }

    async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
        let results = texts
            .iter()
            .map(|(text, sector)| {
                let vector = self.generate(text, sector);
                EmbeddingResult {
                    sector: **sector,
                    vector: vector.clone(),
                    dim: vector.len(),
                }
            })
            .collect();
        Ok(results)
    }
}

/// Sector weights for embedding generation
fn sector_weight(sector: &Sector) -> f32 {
    match sector {
        Sector::Episodic => 1.3,
        Sector::Semantic => 1.0,
        Sector::Procedural => 1.2,
        Sector::Emotional => 1.4,
        Sector::Reflective => 0.9,
    }
}

/// FNV-1a hash for feature hashing
fn hash_fnv1a(s: &str) -> u32 {
    let mut h: u32 = 0x811c9dc5;
    for byte in s.bytes() {
        h ^= byte as u32;
        h = h.wrapping_mul(16777619);
    }
    h
}

/// Secondary hash for double hashing
fn hash_secondary(s: &str, seed: u32) -> u32 {
    let mut h = seed;
    for byte in s.bytes() {
        h ^= byte as u32;
        h = h.wrapping_mul(0x5bd1e995);
        h = (h >> 13) ^ h;
    }
    h
}

/// Add a feature to the vector using feature hashing
fn add_feature(vec: &mut [f32], key: &str, weight: f32) {
    let dim = vec.len();
    if dim == 0 {
        return;
    }

    let h1 = hash_fnv1a(key);
    let h2 = hash_secondary(key, 0xdeadbeef);
    let sign = if h1 & 1 == 0 { 1.0 } else { -1.0 };
    let val = weight * sign;

    // Use power-of-two optimization if possible
    if dim > 0 && (dim & (dim - 1)) == 0 {
        vec[(h1 as usize) & (dim - 1)] += val;
        vec[(h2 as usize) & (dim - 1)] += val * 0.5;
    } else {
        vec[(h1 as usize) % dim] += val;
        vec[(h2 as usize) % dim] += val * 0.5;
    }
}

/// Add positional encoding features
fn add_positional_feature(vec: &mut [f32], pos: usize, weight: f32) {
    let dim = vec.len();
    if dim == 0 {
        return;
    }

    let idx = pos % dim;
    let angle = pos as f32 / 10000.0_f32.powf((2 * idx) as f32 / dim as f32);

    vec[idx] += weight * angle.sin();
    vec[(idx + 1) % dim] += weight * angle.cos();
}

/// L2 normalize a vector
fn normalize(vec: &mut [f32]) {
    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 1e-10 {
        let inv_norm = 1.0 / norm;
        for x in vec.iter_mut() {
            *x *= inv_norm;
        }
    }
}

/// Generate a synthetic embedding
///
/// Uses multiple feature types:
/// - Token TF-IDF features
/// - Character 3-grams and 4-grams
/// - Word bigrams and trigrams
/// - Skip-bigrams
/// - Positional encoding
/// - Length and density features
pub fn gen_synthetic_embedding(text: &str, sector: &Sector, dim: usize) -> Vec<f32> {
    let mut vec = vec![0.0f32; dim];

    // Get canonical tokens
    let canonical_tokens = canonical_tokens_from_text(text);

    // Handle empty text
    if canonical_tokens.is_empty() {
        let default_val = 1.0 / (dim as f32).sqrt();
        return vec![default_val; dim];
    }

    // Expand with synonyms
    let expanded_tokens: Vec<String> = add_synonym_tokens(
        canonical_tokens.iter().map(|s| s.as_str()),
    )
    .into_iter()
    .collect();

    // Build term frequency map
    let mut term_freq: HashMap<&str, usize> = HashMap::new();
    for tok in &expanded_tokens {
        *term_freq.entry(tok.as_str()).or_insert(0) += 1;
    }

    let sector_w = sector_weight(sector);
    let doc_length = expanded_tokens.len() as f32;
    let doc_log = (1.0 + doc_length).ln();
    let sector_str = sector.as_str();

    // Token features with TF-IDF weighting
    for (tok, &count) in &term_freq {
        let tf = count as f32 / doc_length;
        let idf = (1.0 + doc_length / count as f32).ln();
        let weight = (tf * idf + 1.0) * sector_w;

        // Token feature
        add_feature(&mut vec, &format!("{}|tok|{}", sector_str, tok), weight);

        // Character trigrams
        if tok.len() >= 3 {
            let chars: Vec<char> = tok.chars().collect();
            for i in 0..chars.len().saturating_sub(2) {
                let trigram: String = chars[i..i + 3].iter().collect();
                add_feature(
                    &mut vec,
                    &format!("{}|c3|{}", sector_str, trigram),
                    weight * 0.4,
                );
            }
        }

        // Character 4-grams
        if tok.len() >= 4 {
            let chars: Vec<char> = tok.chars().collect();
            for i in 0..chars.len().saturating_sub(3) {
                let fourgram: String = chars[i..i + 4].iter().collect();
                add_feature(
                    &mut vec,
                    &format!("{}|c4|{}", sector_str, fourgram),
                    weight * 0.3,
                );
            }
        }
    }

    // Word bigrams
    for i in 0..canonical_tokens.len().saturating_sub(1) {
        let a = &canonical_tokens[i];
        let b = &canonical_tokens[i + 1];
        let position_weight = 1.0 / (1.0 + i as f32 * 0.1);
        add_feature(
            &mut vec,
            &format!("{}|bi|{}_{}", sector_str, a, b),
            1.4 * sector_w * position_weight,
        );
    }

    // Word trigrams
    for i in 0..canonical_tokens.len().saturating_sub(2) {
        let a = &canonical_tokens[i];
        let b = &canonical_tokens[i + 1];
        let c = &canonical_tokens[i + 2];
        add_feature(
            &mut vec,
            &format!("{}|tri|{}_{}_{}", sector_str, a, b, c),
            1.0 * sector_w,
        );
    }

    // Skip-bigrams (tokens with one word gap)
    for i in 0..canonical_tokens.len().saturating_sub(2).min(20) {
        let a = &canonical_tokens[i];
        let c = &canonical_tokens[i + 2];
        add_feature(
            &mut vec,
            &format!("{}|skip|{}_{}", sector_str, a, c),
            0.7 * sector_w,
        );
    }

    // Positional encoding for first 50 tokens
    for i in 0..canonical_tokens.len().min(50) {
        add_positional_feature(&mut vec, i, 0.5 * sector_w / doc_log);
    }

    // Length bucket feature
    let length_bucket = ((doc_length + 1.0).log2() as usize).min(10);
    add_feature(
        &mut vec,
        &format!("{}|len|{}", sector_str, length_bucket),
        0.6 * sector_w,
    );

    // Density feature (unique tokens / total tokens)
    let density = term_freq.len() as f32 / doc_length;
    let density_bucket = (density * 10.0) as usize;
    add_feature(
        &mut vec,
        &format!("{}|dens|{}", sector_str, density_bucket),
        0.5 * sector_w,
    );

    // Normalize the vector
    normalize(&mut vec);

    vec
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_synthetic_provider() {
        let provider = SyntheticProvider::new(256);
        assert_eq!(provider.dimensions(), 256);
        assert_eq!(provider.name(), "synthetic");
    }

    #[test]
    fn test_gen_embedding() {
        let embedding = gen_synthetic_embedding("Hello world", &Sector::Semantic, 256);
        assert_eq!(embedding.len(), 256);

        // Check it's normalized
        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 1e-5);
    }

    #[test]
    fn test_empty_text() {
        let embedding = gen_synthetic_embedding("", &Sector::Semantic, 256);
        assert_eq!(embedding.len(), 256);

        // Should be uniform distribution
        let first = embedding[0];
        for &val in &embedding {
            assert!((val - first).abs() < 1e-6);
        }
    }

    #[test]
    fn test_different_sectors() {
        let text = "This is a test sentence.";
        let e1 = gen_synthetic_embedding(text, &Sector::Episodic, 256);
        let e2 = gen_synthetic_embedding(text, &Sector::Semantic, 256);

        // Embeddings should be different for different sectors
        let dot: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
        assert!(dot < 0.99); // Not identical
    }

    #[test]
    fn test_similar_texts() {
        let e1 = gen_synthetic_embedding("I love programming", &Sector::Semantic, 256);
        let e2 = gen_synthetic_embedding("I enjoy coding", &Sector::Semantic, 256);
        let e3 = gen_synthetic_embedding("The weather is nice", &Sector::Semantic, 256);

        // Similar texts should have higher similarity
        let sim_12: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
        let sim_13: f32 = e1.iter().zip(e3.iter()).map(|(a, b)| a * b).sum();

        // "love programming" should be more similar to "enjoy coding" than to "weather is nice"
        // Note: This depends on synonym mappings
        assert!(sim_12.abs() > 0.0 || sim_13.abs() > 0.0);
    }

    #[test]
    fn test_hash_functions() {
        let h1 = hash_fnv1a("test");
        let h2 = hash_fnv1a("test");
        assert_eq!(h1, h2);

        let h3 = hash_fnv1a("different");
        assert_ne!(h1, h3);
    }
}