Skip to main content

sc/embeddings/
model2vec.rs

1//! Model2Vec embedding provider.
2//!
3//! Uses local Model2Vec static embeddings for instant embedding generation.
4//! This is the "fast tier" provider in the 2-tier architecture - generates
5//! embeddings in < 1ms for immediate semantic search.
6//!
7//! Model2Vec uses pre-computed word vectors with averaging, not neural inference,
8//! which is why it's 200-800x faster than transformer-based providers.
9
10use crate::error::{Error, Result};
11use model2vec_rs::model::StaticModel;
12use std::sync::Arc;
13
14use super::provider::EmbeddingProvider;
15use super::types::{model2vec_models, ProviderInfo};
16
17/// Model2Vec embedding provider for fast embeddings.
18///
19/// Loads the model into memory on creation for instant inference.
20/// Typical latency: < 1ms per embedding.
21pub struct Model2VecProvider {
22    /// The loaded Model2Vec model (Arc for thread-safety)
23    model: Arc<StaticModel>,
24    /// Model name (e.g., "minishlab/potion-base-8M")
25    model_name: String,
26    /// Output dimensions (256 for potion models)
27    dimensions: usize,
28    /// Maximum input characters
29    max_chars: usize,
30}
31
32impl Model2VecProvider {
33    /// Create a new Model2Vec provider with the default model (potion-base-8M).
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if the model cannot be loaded from HuggingFace Hub.
38    pub fn new() -> Result<Self> {
39        Self::with_model(None)
40    }
41
42    /// Create a new Model2Vec provider with a custom model.
43    ///
44    /// # Arguments
45    ///
46    /// * `model_name` - Optional model name. Defaults to `minishlab/potion-base-8M`.
47    ///
48    /// # Errors
49    ///
50    /// Returns an error if the model cannot be loaded.
51    pub fn with_model(model_name: Option<String>) -> Result<Self> {
52        let model_name = model_name.unwrap_or_else(|| "minishlab/potion-base-8M".to_string());
53        let config = model2vec_models::get_config(&model_name);
54
55        let model = StaticModel::from_pretrained(
56            &model_name,
57            None, // No HF token needed for public models
58            None, // Use default normalization
59            None, // No subfolder
60        )
61        .map_err(|e| Error::Embedding(format!("Failed to load Model2Vec model '{}': {}", model_name, e)))?;
62
63        Ok(Self {
64            model: Arc::new(model),
65            model_name,
66            dimensions: config.dimensions,
67            max_chars: config.max_chars,
68        })
69    }
70
71    /// Try to create a provider, returning None if model loading fails.
72    ///
73    /// Useful for graceful fallback when Model2Vec isn't available.
74    pub fn try_new() -> Option<Self> {
75        Self::new().ok()
76    }
77}
78
79impl EmbeddingProvider for Model2VecProvider {
80    fn info(&self) -> ProviderInfo {
81        ProviderInfo {
82            name: "model2vec".to_string(),
83            model: self.model_name.clone(),
84            dimensions: self.dimensions,
85            max_chars: self.max_chars,
86            available: true, // If constructed, it's available
87        }
88    }
89
90    async fn is_available(&self) -> bool {
91        // Model2Vec is local - if we have the model loaded, it's available
92        true
93    }
94
95    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
96        // Model2Vec encode expects Vec<String>
97        let sentences = vec![text.to_string()];
98        let embeddings = self.model.encode(&sentences);
99
100        embeddings
101            .into_iter()
102            .next()
103            .ok_or_else(|| Error::Embedding("Model2Vec returned no embeddings".into()))
104    }
105
106    async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
107        // Convert to owned strings for Model2Vec
108        let sentences: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
109        Ok(self.model.encode(&sentences))
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_model2vec_config() {
119        let config = model2vec_models::get_config("minishlab/potion-base-8M");
120        assert_eq!(config.dimensions, 256);
121        assert!(config.max_chars > 0);
122    }
123
124    // Note: This test requires network access to download the model
125    // #[tokio::test]
126    // async fn test_model2vec_embedding() {
127    //     let provider = Model2VecProvider::new().expect("Failed to load model");
128    //     let embedding = provider.generate_embedding("Hello world").await.unwrap();
129    //     assert_eq!(embedding.len(), 256);
130    // }
131}