Skip to main content

lattice_embed/service/
native.rs

1//! Native embedding service using lattice-inference (pure Rust, no C++ FFI).
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::info;
10
11/// Loaded model — either BERT-family (encoder) or Qwen (decoder).
12enum LoadedModel {
13    Bert(Arc<BertModel>),
14    Qwen(Arc<QwenModel>),
15}
16
17// SA-161/162: Both BertModel and QwenModel derive Send + Sync automatically:
18// BertModel has no interior mutability; QwenModel wraps mutable state in Mutex
19// which is itself Send + Sync. The manual unsafe impls are therefore redundant
20// and have been removed to prevent a stale "read-only" comment from misleading
21// future readers.
22
23impl LoadedModel {
24    fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25        match self {
26            LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27            // For Qwen, use per-item encode() which checks the cache.
28            LoadedModel::Qwen(m) => {
29                let mut results = Vec::with_capacity(texts.len());
30                for text in texts {
31                    results.push(m.encode(text).map_err(|e| e.to_string())?);
32                }
33                Ok(results)
34            }
35        }
36    }
37
38    fn cache_size(&self) -> usize {
39        match self {
40            LoadedModel::Qwen(m) => m.cache_size(),
41            _ => 0,
42        }
43    }
44}
45
46/// **Unstable**: model-loading API still evolving; signature may change as lattice-inference matures.
47///
48/// Pure Rust embedding service backed by lattice-inference.
49///
50/// Uses SIMD-accelerated matrix multiplication and safetensors weight loading.
51/// No ONNX Runtime, no C++ FFI, no fastembed dependency.
52///
53/// Supports both encoder (BERT/BGE) and decoder (Qwen3) architectures.
54///
55/// # Cancellation Safety
56///
57/// Model loading uses `std::sync::OnceLock` + `spawn_blocking` instead of
58/// `tokio::sync::OnceCell`. This is critical because `tokio::sync::OnceCell::
59/// get_or_try_init` resets when the calling future is dropped (e.g., client
60/// disconnect during MCP timeout). With `OnceLock`, the blocking task runs to
61/// completion and stores the result regardless of async cancellation, so the
62/// model only loads once per process lifetime.
63pub struct NativeEmbeddingService {
64    model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65    model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77    let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78        Ok(raw) if raw.trim().is_empty() => None,
79        Ok(raw) => {
80            let dim = raw.trim().parse::<usize>().map_err(|e| {
81                EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82            })?;
83            Some(dim)
84        }
85        Err(std::env::VarError::NotPresent) => None,
86        Err(e) => {
87            return Err(EmbedError::InvalidInput(format!(
88                "invalid {LATTICE_EMBED_DIM}: {e}"
89            )));
90        }
91    };
92    ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
97    pub fn new() -> Self {
98        Self {
99            model: Arc::new(OnceLock::new()),
100            model_config: ModelConfig::new(EmbeddingModel::default()),
101        }
102    }
103
104    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
105    pub fn with_model(model_type: EmbeddingModel) -> Self {
106        Self {
107            model: Arc::new(OnceLock::new()),
108            model_config: ModelConfig::new(model_type),
109        }
110    }
111
112    /// **Unstable**: create with explicit model config (model + optional MRL truncation dim).
113    pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114        model_config.validate()?;
115        Ok(Self {
116            model: Arc::new(OnceLock::new()),
117            model_config,
118        })
119    }
120
121    /// **Unstable**: create with model config read from `LATTICE_EMBED_DIM` env var.
122    pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123        let config = model_config_from_env(model_type)?;
124        Ok(Self {
125            model: Arc::new(OnceLock::new()),
126            model_config: config,
127        })
128    }
129
130    /// **Unstable**: persistence API may be moved to a separate manager type.
131    pub fn save_cache(&self) -> Result<usize> {
132        let Some(Ok(model)) = self.model.get() else {
133            return Ok(0);
134        };
135        match model {
136            LoadedModel::Qwen(m) => {
137                let model_name = self.model_config.model.to_string();
138                let path = embedding_cache_path(&model_name, m.dimensions());
139                m.cache_save(&path)
140                    .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141            }
142            _ => Ok(0),
143        }
144    }
145
146    /// **Unstable**: internal diagnostic; may be removed or moved to metrics.
147    pub fn cache_size(&self) -> usize {
148        self.model
149            .get()
150            .and_then(|r| r.as_ref().ok())
151            .map(LoadedModel::cache_size)
152            .unwrap_or(0)
153    }
154
155    /// **Unstable**: download and load the model without producing any embeddings.
156    ///
157    /// Performs the same download + checksum-verify + model-load sequence as the
158    /// first call to `embed`, then returns `Ok(())` without running an encode pass.
159    /// Intended for use by the `--download-only` CLI flag so that the model is
160    /// warmed into the file-system cache without wasting a forward pass.
161    ///
162    /// Errors are the same as those from `embed`: network failures, checksum
163    /// mismatches, and unsupported model variants surface as `EmbedError`.
164    pub async fn ensure_loaded(&self) -> Result<()> {
165        self.ensure_model().await.map(|_| ())
166    }
167
168    /// Ensure the model is loaded (cancellation-safe).
169    ///
170    /// Uses `std::sync::OnceLock` so the model loading runs to completion
171    /// inside `spawn_blocking` even if the calling async future is dropped
172    /// (e.g., client disconnect during MCP timeout). The model loads exactly
173    /// once per process lifetime.
174    async fn ensure_model(&self) -> Result<&LoadedModel> {
175        // Fast path: already loaded.
176        if let Some(result) = self.model.get() {
177            return result
178                .as_ref()
179                .map_err(|e| EmbedError::ModelInitialization(e.clone()));
180        }
181
182        // Slow path: load model on blocking thread.
183        // Clone the Arc so spawn_blocking can store the result directly
184        // in the OnceLock, surviving async cancellation.
185        let model_lock = self.model.clone();
186        let model_config = self.model_config;
187
188        tokio::task::spawn_blocking(move || {
189            // OnceLock::get_or_init blocks until init completes.
190            // If another thread is already loading, this waits for it.
191            // This is fine because we're on the blocking thread pool.
192            model_lock.get_or_init(|| load_model_sync(model_config));
193        })
194        .await
195        .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
196
197        self.model
198            .get()
199            .expect("set by spawn_blocking")
200            .as_ref()
201            .map_err(|e| EmbedError::ModelInitialization(e.clone()))
202    }
203}
204
205/// Synchronous model loading (runs on blocking thread pool).
206fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
207    match model_config.model {
208        EmbeddingModel::BgeSmallEnV15
209        | EmbeddingModel::BgeBaseEnV15
210        | EmbeddingModel::BgeLargeEnV15
211        | EmbeddingModel::MultilingualE5Small
212        | EmbeddingModel::MultilingualE5Base
213        | EmbeddingModel::AllMiniLmL6V2
214        | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
215            let model_name = match model_config.model {
216                EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
217                EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
218                EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
219                EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
220                EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
221                EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
222                EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
223                    "paraphrase-multilingual-minilm-l12-v2"
224                }
225                _ => unreachable!(),
226            };
227            info!(model = model_name, "loading native BERT embedding model");
228            let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
229            // Route each model family through its correct pooling strategy.
230            // BGE uses CLS pooling; E5 and MiniLM use mean pooling.
231            if let Some(pooling) = model_config.model.bert_pooling() {
232                bert.set_pooling(pooling);
233            }
234            Ok(LoadedModel::Bert(Arc::new(bert)))
235        }
236        EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
237            load_qwen_model(model_config)
238        }
239        other => Err(format!("unsupported model: {other:?}")),
240    }
241}
242
243fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
244    model_config.validate().map_err(|e| e.to_string())?;
245    let model_type = model_config.model;
246    let model_name = model_type.to_string();
247    info!(
248        model = %model_name,
249        output_dim = ?model_config.output_dim,
250        "loading Qwen embedding model"
251    );
252    let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
253    let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
254    model.set_output_dim(model_config.output_dim);
255    let cache_path = embedding_cache_path(&model_name, model.dimensions());
256    match model.cache_load(&cache_path) {
257        Ok(n) if n > 0 => {
258            info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
259        }
260        _ => {}
261    }
262    Ok(LoadedModel::Qwen(Arc::new(model)))
263}
264
265/// Path for persistent embedding cache: ~/.lattice/cache/embed_{model}_{dim}d.bin
266fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
267    let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
268    std::path::PathBuf::from(home)
269        .join(".lattice")
270        .join("cache")
271        .join(format!("embed_{model}_{dim}d.bin"))
272}
273
274/// Locate Qwen3-Embedding model directory for the given model variant.
275fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
276    // Check env override first — applies to whichever Qwen model is loaded.
277    if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
278        return Ok(std::path::PathBuf::from(dir));
279    }
280
281    let slug = match model_type {
282        EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
283        EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
284        other => {
285            return Err(EmbedError::ModelInitialization(format!(
286                "not a Qwen model: {other}"
287            )));
288        }
289    };
290
291    let home = std::env::var("HOME")
292        .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
293    let dir = std::path::PathBuf::from(home)
294        .join(".lattice")
295        .join("models")
296        .join(slug);
297
298    if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
299        Ok(dir)
300    } else {
301        Err(EmbedError::ModelInitialization(format!(
302            "Qwen3 model not found at {}. Download from {}",
303            dir.display(),
304            model_type.model_id()
305        )))
306    }
307}
308
309#[async_trait]
310impl EmbeddingService for NativeEmbeddingService {
311    async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
312        if model != self.model_config.model {
313            return Err(EmbedError::InvalidInput(format!(
314                "requested model {:?} but this service is loaded with {:?}",
315                model, self.model_config.model
316            )));
317        }
318        if texts.is_empty() {
319            return Err(EmbedError::InvalidInput("no texts provided".into()));
320        }
321        if texts.len() > DEFAULT_MAX_BATCH_SIZE {
322            return Err(EmbedError::InvalidInput(format!(
323                "batch size {} exceeds maximum {}",
324                texts.len(),
325                DEFAULT_MAX_BATCH_SIZE
326            )));
327        }
328        for text in texts {
329            if text.len() > MAX_TEXT_CHARS {
330                return Err(EmbedError::TextTooLong {
331                    length: text.len(),
332                    max: MAX_TEXT_CHARS,
333                });
334            }
335        }
336
337        let loaded = self.ensure_model().await?;
338        let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
339        loaded
340            .encode_batch(&text_refs)
341            .map_err(EmbedError::InferenceFailed)
342    }
343
344    fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
345        if model == self.model_config.model {
346            self.model_config
347        } else {
348            ModelConfig::new(model)
349        }
350    }
351
352    fn supports_model(&self, model: EmbeddingModel) -> bool {
353        model == self.model_config.model
354    }
355
356    fn name(&self) -> &'static str {
357        "native-bert"
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_cache_path_contains_dim_in_filename() {
367        let path = embedding_cache_path("qwen3-embedding-4b", 1024);
368        let filename = path.file_name().unwrap().to_str().unwrap();
369        assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
370    }
371
372    #[test]
373    fn test_cache_path_different_dims_produce_different_paths() {
374        let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
375        let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
376        assert_ne!(path_1024, path_2560);
377        assert!(path_1024.to_string_lossy().contains("1024d"));
378        assert!(path_2560.to_string_lossy().contains("2560d"));
379    }
380
381    #[test]
382    fn test_cache_path_model_slug_differentiates_variants() {
383        let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
384        let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
385        assert_ne!(path_4b, path_06b);
386        assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
387        assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
388    }
389
390    #[test]
391    fn test_cache_path_same_model_same_dim_same_path() {
392        let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
393        let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
394        assert_eq!(p1, p2);
395    }
396}