Skip to main content

lattice_embed/service/
native.rs

1//! Native embedding service using lattice-inference (pure Rust, no C++ FFI).
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::{info, warn};
10
11/// Loaded model — either BERT-family (encoder) or Qwen (decoder).
12enum LoadedModel {
13    Bert(Arc<BertModel>),
14    Qwen(Arc<QwenModel>),
15}
16
17// SA-161/162: Both BertModel and QwenModel derive Send + Sync automatically:
18// BertModel has no interior mutability; QwenModel wraps mutable state in Mutex
19// which is itself Send + Sync. The manual unsafe impls are therefore redundant
20// and have been removed to prevent a stale "read-only" comment from misleading
21// future readers.
22
23impl LoadedModel {
24    fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25        match self {
26            LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27            // For Qwen, use per-item encode() which checks the cache.
28            LoadedModel::Qwen(m) => {
29                let mut results = Vec::with_capacity(texts.len());
30                for text in texts {
31                    results.push(m.encode(text).map_err(|e| e.to_string())?);
32                }
33                Ok(results)
34            }
35        }
36    }
37
38    fn cache_size(&self) -> usize {
39        match self {
40            LoadedModel::Qwen(m) => m.cache_size(),
41            _ => 0,
42        }
43    }
44}
45
46/// **Unstable**: model-loading API still evolving; signature may change as lattice-inference matures.
47///
48/// Pure Rust embedding service backed by lattice-inference.
49///
50/// Uses SIMD-accelerated matrix multiplication and safetensors weight loading.
51/// No ONNX Runtime, no C++ FFI, no fastembed dependency.
52///
53/// Supports both encoder (BERT/BGE) and decoder (Qwen3) architectures.
54///
55/// # Cancellation Safety
56///
57/// Model loading uses `std::sync::OnceLock` + `spawn_blocking` instead of
58/// `tokio::sync::OnceCell`. This is critical because `tokio::sync::OnceCell::
59/// get_or_try_init` resets when the calling future is dropped (e.g., client
60/// disconnect during MCP timeout). With `OnceLock`, the blocking task runs to
61/// completion and stores the result regardless of async cancellation, so the
62/// model only loads once per process lifetime.
63pub struct NativeEmbeddingService {
64    model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65    model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77    let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78        Ok(raw) if raw.trim().is_empty() => None,
79        Ok(raw) => {
80            let dim = raw.trim().parse::<usize>().map_err(|e| {
81                EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82            })?;
83            Some(dim)
84        }
85        Err(std::env::VarError::NotPresent) => None,
86        Err(e) => {
87            return Err(EmbedError::InvalidInput(format!(
88                "invalid {LATTICE_EMBED_DIM}: {e}"
89            )));
90        }
91    };
92    ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
97    pub fn new() -> Self {
98        Self {
99            model: Arc::new(OnceLock::new()),
100            model_config: ModelConfig::new(EmbeddingModel::default()),
101        }
102    }
103
104    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
105    pub fn with_model(model_type: EmbeddingModel) -> Self {
106        Self {
107            model: Arc::new(OnceLock::new()),
108            model_config: ModelConfig::new(model_type),
109        }
110    }
111
112    /// **Unstable**: create with explicit model config (model + optional MRL truncation dim).
113    pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114        model_config.validate()?;
115        Ok(Self {
116            model: Arc::new(OnceLock::new()),
117            model_config,
118        })
119    }
120
121    /// **Unstable**: create with model config read from `LATTICE_EMBED_DIM` env var.
122    pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123        let config = model_config_from_env(model_type)?;
124        Ok(Self {
125            model: Arc::new(OnceLock::new()),
126            model_config: config,
127        })
128    }
129
130    /// **Unstable**: persistence API may be moved to a separate manager type.
131    pub fn save_cache(&self) -> Result<usize> {
132        let Some(Ok(model)) = self.model.get() else {
133            return Ok(0);
134        };
135        match model {
136            LoadedModel::Qwen(m) => {
137                let model_name = self.model_config.model.to_string();
138                let path = embedding_cache_path(&model_name, m.dimensions());
139                m.cache_save(&path)
140                    .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141            }
142            _ => Ok(0),
143        }
144    }
145
146    /// **Unstable**: internal diagnostic; may be removed or moved to metrics.
147    pub fn cache_size(&self) -> usize {
148        self.model
149            .get()
150            .and_then(|r| r.as_ref().ok())
151            .map(LoadedModel::cache_size)
152            .unwrap_or(0)
153    }
154
155    /// **Unstable**: download and load the model without producing any embeddings.
156    ///
157    /// Performs the same download + checksum-verify + model-load sequence as the
158    /// first call to `embed`, then returns `Ok(())` without running an encode pass.
159    /// Intended for use by the `--download-only` CLI flag so that the model is
160    /// warmed into the file-system cache without wasting a forward pass.
161    ///
162    /// Errors are the same as those from `embed`: network failures, checksum
163    /// mismatches, and unsupported model variants surface as `EmbedError`.
164    pub async fn ensure_loaded(&self) -> Result<()> {
165        self.ensure_model().await.map(|_| ())
166    }
167
168    /// Ensure the model is loaded (cancellation-safe).
169    ///
170    /// Uses `std::sync::OnceLock` so the model loading runs to completion
171    /// inside `spawn_blocking` even if the calling async future is dropped
172    /// (e.g., client disconnect during MCP timeout). The model loads exactly
173    /// once per process lifetime.
174    async fn ensure_model(&self) -> Result<&LoadedModel> {
175        // Fast path: already loaded.
176        if let Some(result) = self.model.get() {
177            return result
178                .as_ref()
179                .map_err(|e| EmbedError::ModelInitialization(e.clone()));
180        }
181
182        // Slow path: load model on blocking thread.
183        // Clone the Arc so spawn_blocking can store the result directly
184        // in the OnceLock, surviving async cancellation.
185        let model_lock = self.model.clone();
186        let model_config = self.model_config;
187
188        tokio::task::spawn_blocking(move || {
189            // OnceLock::get_or_init blocks until init completes.
190            // If another thread is already loading, this waits for it.
191            // This is fine because we're on the blocking thread pool.
192            model_lock.get_or_init(|| load_model_sync(model_config));
193        })
194        .await
195        .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
196
197        self.model
198            .get()
199            .expect("set by spawn_blocking")
200            .as_ref()
201            .map_err(|e| EmbedError::ModelInitialization(e.clone()))
202    }
203}
204
205/// Synchronous model loading (runs on blocking thread pool).
206fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
207    match model_config.model {
208        EmbeddingModel::BgeSmallEnV15
209        | EmbeddingModel::BgeBaseEnV15
210        | EmbeddingModel::BgeLargeEnV15
211        | EmbeddingModel::MultilingualE5Small
212        | EmbeddingModel::MultilingualE5Base
213        | EmbeddingModel::AllMiniLmL6V2
214        | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
215            let model_name = match model_config.model {
216                EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
217                EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
218                EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
219                EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
220                EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
221                EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
222                EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
223                    "paraphrase-multilingual-minilm-l12-v2"
224                }
225                _ => unreachable!(),
226            };
227            info!(model = model_name, "loading native BERT embedding model");
228            let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
229            // Route each model family through its correct pooling strategy.
230            // BGE uses CLS pooling; E5 and MiniLM use mean pooling.
231            if let Some(pooling) = model_config.model.bert_pooling() {
232                bert.set_pooling(pooling);
233            }
234            Ok(LoadedModel::Bert(Arc::new(bert)))
235        }
236        EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
237            load_qwen_model(model_config)
238        }
239        other => Err(format!("unsupported model: {other:?}")),
240    }
241}
242
243fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
244    model_config.validate().map_err(|e| e.to_string())?;
245    let model_type = model_config.model;
246    let model_name = model_type.to_string();
247    info!(
248        model = %model_name,
249        output_dim = ?model_config.output_dim,
250        "loading Qwen embedding model"
251    );
252    let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
253    let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
254    model.set_output_dim(model_config.output_dim);
255    let cache_path = embedding_cache_path(&model_name, model.dimensions());
256    match model.cache_load(&cache_path) {
257        Ok(n) if n > 0 => {
258            info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
259        }
260        Ok(_) => {}
261        Err(e) => {
262            warn!(
263                path = %cache_path.display(),
264                error = %e,
265                "embedding cache failed integrity check, ignoring (will regenerate on next save)"
266            )
267        }
268    }
269    Ok(LoadedModel::Qwen(Arc::new(model)))
270}
271
272/// Path for persistent embedding cache: ~/.lattice/cache/embed_{model}_{dim}d.bin
273fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
274    let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
275    std::path::PathBuf::from(home)
276        .join(".lattice")
277        .join("cache")
278        .join(format!("embed_{model}_{dim}d.bin"))
279}
280
281/// Locate Qwen3-Embedding model directory for the given model variant.
282fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
283    // Check env override first — applies to whichever Qwen model is loaded.
284    if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
285        return Ok(std::path::PathBuf::from(dir));
286    }
287
288    let slug = match model_type {
289        EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
290        EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
291        other => {
292            return Err(EmbedError::ModelInitialization(format!(
293                "not a Qwen model: {other}"
294            )));
295        }
296    };
297
298    let home = std::env::var("HOME")
299        .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
300    let dir = std::path::PathBuf::from(home)
301        .join(".lattice")
302        .join("models")
303        .join(slug);
304
305    if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
306        Ok(dir)
307    } else {
308        Err(EmbedError::ModelInitialization(format!(
309            "Qwen3 model not found at {}. Download from {}",
310            dir.display(),
311            model_type.model_id()
312        )))
313    }
314}
315
316#[async_trait]
317impl EmbeddingService for NativeEmbeddingService {
318    async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
319        if model != self.model_config.model {
320            return Err(EmbedError::InvalidInput(format!(
321                "requested model {:?} but this service is loaded with {:?}",
322                model, self.model_config.model
323            )));
324        }
325        if texts.is_empty() {
326            return Err(EmbedError::InvalidInput("no texts provided".into()));
327        }
328        if texts.len() > DEFAULT_MAX_BATCH_SIZE {
329            return Err(EmbedError::InvalidInput(format!(
330                "batch size {} exceeds maximum {}",
331                texts.len(),
332                DEFAULT_MAX_BATCH_SIZE
333            )));
334        }
335        for text in texts {
336            if text.len() > MAX_TEXT_CHARS {
337                return Err(EmbedError::TextTooLong {
338                    length: text.len(),
339                    max: MAX_TEXT_CHARS,
340                });
341            }
342        }
343
344        let loaded = self.ensure_model().await?;
345        let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
346        loaded
347            .encode_batch(&text_refs)
348            .map_err(EmbedError::InferenceFailed)
349    }
350
351    fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
352        if model == self.model_config.model {
353            self.model_config
354        } else {
355            ModelConfig::new(model)
356        }
357    }
358
359    fn supports_model(&self, model: EmbeddingModel) -> bool {
360        model == self.model_config.model
361    }
362
363    fn name(&self) -> &'static str {
364        "native-bert"
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_cache_path_contains_dim_in_filename() {
374        let path = embedding_cache_path("qwen3-embedding-4b", 1024);
375        let filename = path.file_name().unwrap().to_str().unwrap();
376        assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
377    }
378
379    #[test]
380    fn test_cache_path_different_dims_produce_different_paths() {
381        let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
382        let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
383        assert_ne!(path_1024, path_2560);
384        assert!(path_1024.to_string_lossy().contains("1024d"));
385        assert!(path_2560.to_string_lossy().contains("2560d"));
386    }
387
388    #[test]
389    fn test_cache_path_model_slug_differentiates_variants() {
390        let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
391        let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
392        assert_ne!(path_4b, path_06b);
393        assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
394        assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
395    }
396
397    #[test]
398    fn test_cache_path_same_model_same_dim_same_path() {
399        let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
400        let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
401        assert_eq!(p1, p2);
402    }
403}