vyctor 0.1.0

A fast CLI tool for semantic file search using vector embeddings
Documentation
//! Local embedding provider using Candle
//! Only compiled when the "local-embeddings" feature is enabled

use super::provider::{EmbeddingProvider, EmbeddingResult};
use anyhow::{Context, Result};
use async_trait::async_trait;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use std::io::Write;
use std::path::PathBuf;
use std::str::FromStr;
use tokenizers::Tokenizer;

const HF_BASE_URL: &str = "https://huggingface.co";

// Bundled config for all-MiniLM-L6-v2 (avoids network request for small files)
const MINILM_CONFIG: &str = include_str!("../../models/all-MiniLM-L6-v2/config.json");
const MINILM_TOKENIZER: &str = include_str!("../../models/all-MiniLM-L6-v2/tokenizer.json");
const MINILM_MODEL_ID: &str = "sentence-transformers/all-MiniLM-L6-v2";

// Bundled weights (only when bundled-weights feature is enabled)
#[cfg(feature = "bundled-weights")]
const MINILM_WEIGHTS: &[u8] = include_bytes!("../../models/all-MiniLM-L6-v2/model.safetensors");

/// Local embedding provider using sentence-transformers models
pub struct LocalEmbedder {
    model: BertModel,
    tokenizer: Tokenizer,
    device: Device,
    dimensions: usize,
    model_name: String,
}

impl std::fmt::Debug for LocalEmbedder {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("LocalEmbedder")
            .field("model", &"<BertModel>")
            .field("dimensions", &self.dimensions)
            .field("model_name", &self.model_name)
            .finish()
    }
}

/// Download a file from HuggingFace if not already cached
fn download_hf_file(
    model_id: &str,
    filename: &str,
    cache_dir: &std::path::Path,
) -> Result<PathBuf> {
    let safe_model_id = model_id.replace('/', "--");
    let model_cache = cache_dir.join(&safe_model_id);
    std::fs::create_dir_all(&model_cache)?;

    let file_path = model_cache.join(filename);

    if file_path.exists() {
        println!("  Using cached: {}", filename);
        return Ok(file_path);
    }

    let url = format!("{}/{}/resolve/main/{}", HF_BASE_URL, model_id, filename);
    println!("  Downloading: {}", filename);

    let response =
        reqwest::blocking::get(&url).with_context(|| format!("Failed to download {}", url))?;

    if !response.status().is_success() {
        anyhow::bail!(
            "Failed to download {}: HTTP {}",
            filename,
            response.status()
        );
    }

    let bytes = response.bytes()?;
    let mut file = std::fs::File::create(&file_path)?;
    file.write_all(&bytes)?;

    Ok(file_path)
}

impl LocalEmbedder {
    /// Create a new local embedder, downloading the model if necessary
    pub fn new(model_id: &str, cache_dir: &str, verbose: bool) -> Result<Self> {
        use std::time::Instant;

        let total_start = Instant::now();
        let device = Device::Cpu;
        let cache_path = PathBuf::from(cache_dir);
        std::fs::create_dir_all(&cache_path)?;

        if verbose {
            println!("Loading model: {}", model_id);
        }

        // Check if this is the bundled model
        let is_bundled = model_id == MINILM_MODEL_ID || model_id == "all-MiniLM-L6-v2";

        // Load config and tokenizer (always bundled for default model)
        let (config_str, tokenizer_str) = if is_bundled {
            if verbose {
                println!("  Using bundled config and tokenizer");
            }
            (MINILM_CONFIG.to_string(), MINILM_TOKENIZER.to_string())
        } else {
            let config_path = download_hf_file(model_id, "config.json", &cache_path)?;
            let tokenizer_path = download_hf_file(model_id, "tokenizer.json", &cache_path)?;
            let config_str = std::fs::read_to_string(&config_path)?;
            let tokenizer_str = std::fs::read_to_string(&tokenizer_path)?;
            (config_str, tokenizer_str)
        };

        // Load config
        let config: Config = serde_json::from_str(&config_str)?;
        let dimensions = config.hidden_size;

        // Load tokenizer
        let tokenizer_start = Instant::now();
        let tokenizer = Tokenizer::from_str(&tokenizer_str)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
        if verbose {
            println!("  Tokenizer loaded in {:?}", tokenizer_start.elapsed());
        }

        // Load model weights - use bundled if available, otherwise download
        let weights_start = Instant::now();
        #[cfg(feature = "bundled-weights")]
        let vb = if is_bundled {
            if verbose {
                println!("  Loading bundled weights into memory...");
            }
            VarBuilder::from_buffered_safetensors(
                MINILM_WEIGHTS.to_vec(),
                candle_core::DType::F32,
                &device,
            )?
        } else {
            let weights_path = download_hf_file(model_id, "model.safetensors", &cache_path)
                .or_else(|_| download_hf_file(model_id, "pytorch_model.bin", &cache_path))
                .context("Failed to download model weights")?;
            if weights_path
                .extension()
                .map(|e| e == "safetensors")
                .unwrap_or(false)
            {
                unsafe {
                    VarBuilder::from_mmaped_safetensors(
                        &[weights_path],
                        candle_core::DType::F32,
                        &device,
                    )?
                }
            } else {
                VarBuilder::from_pth(&weights_path, candle_core::DType::F32, &device)?
            }
        };

        #[cfg(not(feature = "bundled-weights"))]
        let vb = {
            let weights_path = if is_bundled {
                download_hf_file(MINILM_MODEL_ID, "model.safetensors", &cache_path)
                    .context("Failed to download model weights")?
            } else {
                download_hf_file(model_id, "model.safetensors", &cache_path)
                    .or_else(|_| download_hf_file(model_id, "pytorch_model.bin", &cache_path))
                    .context("Failed to download model weights")?
            };
            if verbose {
                println!("  Loading weights into memory...");
            }
            if weights_path
                .extension()
                .map(|e| e == "safetensors")
                .unwrap_or(false)
            {
                unsafe {
                    VarBuilder::from_mmaped_safetensors(
                        &[weights_path],
                        candle_core::DType::F32,
                        &device,
                    )?
                }
            } else {
                VarBuilder::from_pth(&weights_path, candle_core::DType::F32, &device)?
            }
        };
        if verbose {
            println!("  Weights loaded in {:?}", weights_start.elapsed());
        }

        let model_start = Instant::now();
        let model = BertModel::load(vb, &config)?;
        if verbose {
            println!("  Model initialized in {:?}", model_start.elapsed());
            println!("  Total model load time: {:?}", total_start.elapsed());
        }

        Ok(Self {
            model,
            tokenizer,
            device,
            dimensions,
            model_name: model_id.to_string(),
        })
    }

    /// Mean pooling over the sequence dimension
    fn mean_pooling(&self, embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
        let mask = attention_mask
            .unsqueeze(2)?
            .to_dtype(candle_core::DType::F32)?;
        let masked = embeddings.broadcast_mul(&mask)?;
        let sum = masked.sum(1)?;
        let count = mask.sum(1)?.clamp(1e-9, f64::MAX)?;
        Ok(sum.broadcast_div(&count)?)
    }

    /// Normalize embeddings to unit length
    fn normalize(&self, embeddings: &Tensor) -> Result<Tensor> {
        let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
        Ok(embeddings.broadcast_div(&norm)?)
    }
}

#[async_trait]
impl EmbeddingProvider for LocalEmbedder {
    fn dimensions(&self) -> usize {
        self.dimensions
    }

    fn model_name(&self) -> &str {
        &self.model_name
    }

    async fn embed(&self, text: &str) -> Result<EmbeddingResult> {
        let results = self.embed_batch(&[text.to_string()]).await?;
        results
            .into_iter()
            .next()
            .context("Empty result from local model")
    }

    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingResult>> {
        if texts.is_empty() {
            return Ok(vec![]);
        }

        // Tokenize all texts
        let encodings = self
            .tokenizer
            .encode_batch(texts.to_vec(), true)
            .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;

        // Find max length for padding
        let max_len = encodings.iter().map(|e| e.len()).max().unwrap_or(0);

        // Create input tensors
        let mut input_ids_vec = Vec::new();
        let mut attention_mask_vec = Vec::new();

        for encoding in &encodings {
            let ids = encoding.get_ids();
            let mask = encoding.get_attention_mask();

            // Pad to max length
            let mut padded_ids = ids.to_vec();
            let mut padded_mask = mask.to_vec();
            padded_ids.resize(max_len, 0);
            padded_mask.resize(max_len, 0);

            input_ids_vec.push(padded_ids);
            attention_mask_vec.push(padded_mask);
        }

        let batch_size = texts.len();

        let input_ids = Tensor::new(
            input_ids_vec
                .into_iter()
                .flatten()
                .map(|x| x as i64)
                .collect::<Vec<_>>(),
            &self.device,
        )?
        .reshape((batch_size, max_len))?;

        let attention_mask = Tensor::new(
            attention_mask_vec
                .into_iter()
                .flatten()
                .map(|x| x as i64)
                .collect::<Vec<_>>(),
            &self.device,
        )?
        .reshape((batch_size, max_len))?;

        let token_type_ids = Tensor::zeros_like(&input_ids)?;

        // Forward pass
        let embeddings = self
            .model
            .forward(&input_ids, &token_type_ids, Some(&attention_mask))?;

        // Mean pooling
        let pooled = self.mean_pooling(&embeddings, &attention_mask)?;

        // Normalize
        let normalized = self.normalize(&pooled)?;

        // Convert to results
        let mut results = Vec::with_capacity(batch_size);
        for (i, encoding) in encodings.iter().enumerate().take(batch_size) {
            let embedding = normalized.get(i)?.to_vec1::<f32>()?;
            results.push(EmbeddingResult {
                embedding,
                token_count: Some(encoding.len()),
            });
        }

        Ok(results)
    }
}

#[cfg(test)]
mod tests {
    // Tests would require downloading a model, so we skip them by default
}