Skip to main content

lattice_embed/service/
native.rs

1//! Native embedding service using lattice-inference (pure Rust, no C++ FFI).
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::info;
10
11/// Loaded model — either BERT-family (encoder) or Qwen (decoder).
12enum LoadedModel {
13    Bert(Arc<BertModel>),
14    Qwen(Arc<QwenModel>),
15}
16
17// SA-161/162: Both BertModel and QwenModel derive Send + Sync automatically:
18// BertModel has no interior mutability; QwenModel wraps mutable state in Mutex
19// which is itself Send + Sync. The manual unsafe impls are therefore redundant
20// and have been removed to prevent a stale "read-only" comment from misleading
21// future readers.
22
23impl LoadedModel {
24    fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25        match self {
26            LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27            // For Qwen, use per-item encode() which checks the cache.
28            LoadedModel::Qwen(m) => {
29                let mut results = Vec::with_capacity(texts.len());
30                for text in texts {
31                    results.push(m.encode(text).map_err(|e| e.to_string())?);
32                }
33                Ok(results)
34            }
35        }
36    }
37
38    fn cache_size(&self) -> usize {
39        match self {
40            LoadedModel::Qwen(m) => m.cache_size(),
41            _ => 0,
42        }
43    }
44}
45
46/// **Unstable**: model-loading API still evolving; signature may change as lattice-inference matures.
47///
48/// Pure Rust embedding service backed by lattice-inference.
49///
50/// Uses SIMD-accelerated matrix multiplication and safetensors weight loading.
51/// No ONNX Runtime, no C++ FFI, no fastembed dependency.
52///
53/// Supports both encoder (BERT/BGE) and decoder (Qwen3) architectures.
54///
55/// # Cancellation Safety
56///
57/// Model loading uses `std::sync::OnceLock` + `spawn_blocking` instead of
58/// `tokio::sync::OnceCell`. This is critical because `tokio::sync::OnceCell::
59/// get_or_try_init` resets when the calling future is dropped (e.g., client
60/// disconnect during MCP timeout). With `OnceLock`, the blocking task runs to
61/// completion and stores the result regardless of async cancellation, so the
62/// model only loads once per process lifetime.
63pub struct NativeEmbeddingService {
64    model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65    model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77    let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78        Ok(raw) if raw.trim().is_empty() => None,
79        Ok(raw) => {
80            let dim = raw.trim().parse::<usize>().map_err(|e| {
81                EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82            })?;
83            Some(dim)
84        }
85        Err(std::env::VarError::NotPresent) => None,
86        Err(e) => {
87            return Err(EmbedError::InvalidInput(format!(
88                "invalid {LATTICE_EMBED_DIM}: {e}"
89            )));
90        }
91    };
92    ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
97    pub fn new() -> Self {
98        Self {
99            model: Arc::new(OnceLock::new()),
100            model_config: ModelConfig::new(EmbeddingModel::default()),
101        }
102    }
103
104    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
105    pub fn with_model(model_type: EmbeddingModel) -> Self {
106        Self {
107            model: Arc::new(OnceLock::new()),
108            model_config: ModelConfig::new(model_type),
109        }
110    }
111
112    /// **Unstable**: create with explicit model config (model + optional MRL truncation dim).
113    pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114        model_config.validate()?;
115        Ok(Self {
116            model: Arc::new(OnceLock::new()),
117            model_config,
118        })
119    }
120
121    /// **Unstable**: create with model config read from `LATTICE_EMBED_DIM` env var.
122    pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123        let config = model_config_from_env(model_type)?;
124        Ok(Self {
125            model: Arc::new(OnceLock::new()),
126            model_config: config,
127        })
128    }
129
130    /// **Unstable**: persistence API may be moved to a separate manager type.
131    pub fn save_cache(&self) -> Result<usize> {
132        let Some(Ok(model)) = self.model.get() else {
133            return Ok(0);
134        };
135        match model {
136            LoadedModel::Qwen(m) => {
137                let model_name = self.model_config.model.to_string();
138                let path = embedding_cache_path(&model_name, m.dimensions());
139                m.cache_save(&path)
140                    .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141            }
142            _ => Ok(0),
143        }
144    }
145
146    /// **Unstable**: internal diagnostic; may be removed or moved to metrics.
147    pub fn cache_size(&self) -> usize {
148        self.model
149            .get()
150            .and_then(|r| r.as_ref().ok())
151            .map(LoadedModel::cache_size)
152            .unwrap_or(0)
153    }
154
155    /// Ensure the model is loaded (cancellation-safe).
156    ///
157    /// Uses `std::sync::OnceLock` so the model loading runs to completion
158    /// inside `spawn_blocking` even if the calling async future is dropped
159    /// (e.g., client disconnect during MCP timeout). The model loads exactly
160    /// once per process lifetime.
161    async fn ensure_model(&self) -> Result<&LoadedModel> {
162        // Fast path: already loaded.
163        if let Some(result) = self.model.get() {
164            return result
165                .as_ref()
166                .map_err(|e| EmbedError::ModelInitialization(e.clone()));
167        }
168
169        // Slow path: load model on blocking thread.
170        // Clone the Arc so spawn_blocking can store the result directly
171        // in the OnceLock, surviving async cancellation.
172        let model_lock = self.model.clone();
173        let model_config = self.model_config;
174
175        tokio::task::spawn_blocking(move || {
176            // OnceLock::get_or_init blocks until init completes.
177            // If another thread is already loading, this waits for it.
178            // This is fine because we're on the blocking thread pool.
179            model_lock.get_or_init(|| load_model_sync(model_config));
180        })
181        .await
182        .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
183
184        self.model
185            .get()
186            .expect("set by spawn_blocking")
187            .as_ref()
188            .map_err(|e| EmbedError::ModelInitialization(e.clone()))
189    }
190}
191
192/// Synchronous model loading (runs on blocking thread pool).
193fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
194    match model_config.model {
195        EmbeddingModel::BgeSmallEnV15
196        | EmbeddingModel::BgeBaseEnV15
197        | EmbeddingModel::BgeLargeEnV15
198        | EmbeddingModel::MultilingualE5Small
199        | EmbeddingModel::MultilingualE5Base
200        | EmbeddingModel::AllMiniLmL6V2
201        | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
202            let model_name = match model_config.model {
203                EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
204                EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
205                EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
206                EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
207                EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
208                EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
209                EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
210                    "paraphrase-multilingual-minilm-l12-v2"
211                }
212                _ => unreachable!(),
213            };
214            info!(model = model_name, "loading native BERT embedding model");
215            let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
216            // Route each model family through its correct pooling strategy.
217            // BGE uses CLS pooling; E5 and MiniLM use mean pooling.
218            if let Some(pooling) = model_config.model.bert_pooling() {
219                bert.set_pooling(pooling);
220            }
221            Ok(LoadedModel::Bert(Arc::new(bert)))
222        }
223        EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
224            load_qwen_model(model_config)
225        }
226        other => Err(format!("unsupported model: {other:?}")),
227    }
228}
229
230fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
231    model_config.validate().map_err(|e| e.to_string())?;
232    let model_type = model_config.model;
233    let model_name = model_type.to_string();
234    info!(
235        model = %model_name,
236        output_dim = ?model_config.output_dim,
237        "loading Qwen embedding model"
238    );
239    let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
240    let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
241    model.set_output_dim(model_config.output_dim);
242    let cache_path = embedding_cache_path(&model_name, model.dimensions());
243    match model.cache_load(&cache_path) {
244        Ok(n) if n > 0 => {
245            info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
246        }
247        _ => {}
248    }
249    Ok(LoadedModel::Qwen(Arc::new(model)))
250}
251
252/// Path for persistent embedding cache: ~/.lattice/cache/embed_{model}_{dim}d.bin
253fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
254    let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
255    std::path::PathBuf::from(home)
256        .join(".lattice")
257        .join("cache")
258        .join(format!("embed_{model}_{dim}d.bin"))
259}
260
261/// Locate Qwen3-Embedding model directory for the given model variant.
262fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
263    // Check env override first — applies to whichever Qwen model is loaded.
264    if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
265        return Ok(std::path::PathBuf::from(dir));
266    }
267
268    let slug = match model_type {
269        EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
270        EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
271        other => {
272            return Err(EmbedError::ModelInitialization(format!(
273                "not a Qwen model: {other}"
274            )));
275        }
276    };
277
278    let home = std::env::var("HOME")
279        .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
280    let dir = std::path::PathBuf::from(home)
281        .join(".lattice")
282        .join("models")
283        .join(slug);
284
285    if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
286        Ok(dir)
287    } else {
288        Err(EmbedError::ModelInitialization(format!(
289            "Qwen3 model not found at {}. Download from {}",
290            dir.display(),
291            model_type.model_id()
292        )))
293    }
294}
295
296#[async_trait]
297impl EmbeddingService for NativeEmbeddingService {
298    async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
299        if model != self.model_config.model {
300            return Err(EmbedError::InvalidInput(format!(
301                "requested model {:?} but this service is loaded with {:?}",
302                model, self.model_config.model
303            )));
304        }
305        if texts.is_empty() {
306            return Err(EmbedError::InvalidInput("no texts provided".into()));
307        }
308        if texts.len() > DEFAULT_MAX_BATCH_SIZE {
309            return Err(EmbedError::InvalidInput(format!(
310                "batch size {} exceeds maximum {}",
311                texts.len(),
312                DEFAULT_MAX_BATCH_SIZE
313            )));
314        }
315        for text in texts {
316            if text.len() > MAX_TEXT_CHARS {
317                return Err(EmbedError::TextTooLong {
318                    length: text.len(),
319                    max: MAX_TEXT_CHARS,
320                });
321            }
322        }
323
324        let loaded = self.ensure_model().await?;
325        let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
326        loaded
327            .encode_batch(&text_refs)
328            .map_err(EmbedError::InferenceFailed)
329    }
330
331    fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
332        if model == self.model_config.model {
333            self.model_config
334        } else {
335            ModelConfig::new(model)
336        }
337    }
338
339    fn supports_model(&self, model: EmbeddingModel) -> bool {
340        model == self.model_config.model
341    }
342
343    fn name(&self) -> &'static str {
344        "native-bert"
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_cache_path_contains_dim_in_filename() {
354        let path = embedding_cache_path("qwen3-embedding-4b", 1024);
355        let filename = path.file_name().unwrap().to_str().unwrap();
356        assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
357    }
358
359    #[test]
360    fn test_cache_path_different_dims_produce_different_paths() {
361        let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
362        let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
363        assert_ne!(path_1024, path_2560);
364        assert!(path_1024.to_string_lossy().contains("1024d"));
365        assert!(path_2560.to_string_lossy().contains("2560d"));
366    }
367
368    #[test]
369    fn test_cache_path_model_slug_differentiates_variants() {
370        let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
371        let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
372        assert_ne!(path_4b, path_06b);
373        assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
374        assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
375    }
376
377    #[test]
378    fn test_cache_path_same_model_same_dim_same_path() {
379        let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
380        let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
381        assert_eq!(p1, p2);
382    }
383}