cerebro 0.1.3

Blazing-fast, storage-agnostic semantic memory engine for AI Agents — written in pure Rust
use async_trait::async_trait;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use hf_hub::{api::tokio::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use std::sync::Mutex;
use crate::traits::{CerebroError, Embedder, Result};

/// Executes ML embeddings 100% locally on the CPU (or Metal/CUDA) using Candle.
pub struct LocalEmbedder {
    model: Mutex<BertModel>,
    tokenizer: Mutex<Tokenizer>,
    device: Device,
}

impl LocalEmbedder {
    pub async fn new() -> Result<Self> {
        let api = Api::new().map_err(|e| CerebroError::EmbeddingError(format!("HF API Error: {}", e)))?;
        let repo = api.repo(Repo::with_revision(
            "sentence-transformers/all-MiniLM-L6-v2".to_string(),
            RepoType::Model,
            "refs/pr/21".to_string(), // Stable ONNX/Safetensors weights
        ));

        let tokenizer_filename = repo.get("tokenizer.json").await
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
        let config_filename = repo.get("config.json").await
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
        let weights_filename = repo.get("model.safetensors").await
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;

        let tokenizer = Tokenizer::from_file(&tokenizer_filename)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;

        let config_str = std::fs::read_to_string(config_filename).unwrap();
        let config: Config = serde_json::from_str(&config_str).unwrap();

        let device = Device::Cpu; 
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) }
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;

        let model = BertModel::load(vb, &config)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;

        Ok(Self {
            model: Mutex::new(model),
            tokenizer: Mutex::new(tokenizer),
            device,
        })
    }
}

#[async_trait]
impl Embedder for LocalEmbedder {
    async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        let tokenizer = self.tokenizer.lock().unwrap();
        
        let tokens = tokenizer.encode_batch(texts.to_vec(), true)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
            
        let mut token_ids: Vec<Vec<u32>> = tokens.iter().map(|t| t.get_ids().to_vec()).collect();
        
        // Find max sequence length for padding
        let max_len = token_ids.iter().map(|v| v.len()).max().unwrap_or(0);
        
        // Apply 0-padding
        for ids in token_ids.iter_mut() {
            ids.resize(max_len, 0);
        }

        let n_sentences = texts.len();
        let token_ids_flat: Vec<u32> = token_ids.into_iter().flatten().collect();
        let token_ids_tensor = Tensor::from_vec(token_ids_flat, (n_sentences, max_len), &self.device)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;

        let token_type_ids = Tensor::zeros((n_sentences, max_len), candle_core::DType::U32, &self.device)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;

        let model = self.model.lock().unwrap();
        
        // Run forward pass
        let embeddings = model.forward(&token_ids_tensor, &token_type_ids, None)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
            
        // Embeddings block shape: (n_sentences, max_len, 384)
        // We do mean pooling across the sequence length (axis 1)
        
        let pooled_embeddings = embeddings
            .sum(1)
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?
            .broadcast_div(&Tensor::new(max_len as f32, &self.device).unwrap())
            .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
            
        let mut results = Vec::with_capacity(n_sentences);
        for i in 0..n_sentences {
            let row = pooled_embeddings
                .get(i)
                .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?
                .to_vec1::<f32>()
                .map_err(|e| CerebroError::EmbeddingError(e.to_string()))?;
            results.push(row);
        }
        
        Ok(results)
    }
}