Skip to main content

lattice_embed/service/
native.rs

1//! Native embedding service using lattice-inference (pure Rust, no C++ FFI).
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::info;
10
11/// Loaded model — either BERT-family (encoder) or Qwen (decoder).
12enum LoadedModel {
13    Bert(Arc<BertModel>),
14    Qwen(Arc<QwenModel>),
15}
16
17// SA-161/162: Both BertModel and QwenModel derive Send + Sync automatically:
18// BertModel has no interior mutability; QwenModel wraps mutable state in Mutex
19// which is itself Send + Sync. The manual unsafe impls are therefore redundant
20// and have been removed to prevent a stale "read-only" comment from misleading
21// future readers.
22
23impl LoadedModel {
24    fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25        match self {
26            LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27            // For Qwen, use per-item encode() which checks the cache.
28            LoadedModel::Qwen(m) => {
29                let mut results = Vec::with_capacity(texts.len());
30                for text in texts {
31                    results.push(m.encode(text).map_err(|e| e.to_string())?);
32                }
33                Ok(results)
34            }
35        }
36    }
37
38    fn cache_size(&self) -> usize {
39        match self {
40            LoadedModel::Qwen(m) => m.cache_size(),
41            _ => 0,
42        }
43    }
44}
45
46/// **Unstable**: model-loading API still evolving; signature may change as lattice-inference matures.
47///
48/// Pure Rust embedding service backed by lattice-inference.
49///
50/// Uses SIMD-accelerated matrix multiplication and safetensors weight loading.
51/// No ONNX Runtime, no C++ FFI, no fastembed dependency.
52///
53/// Supports both encoder (BERT/BGE) and decoder (Qwen3) architectures.
54///
55/// # Cancellation Safety
56///
57/// Model loading uses `std::sync::OnceLock` + `spawn_blocking` instead of
58/// `tokio::sync::OnceCell`. This is critical because `tokio::sync::OnceCell::
59/// get_or_try_init` resets when the calling future is dropped (e.g., client
60/// disconnect during MCP timeout). With `OnceLock`, the blocking task runs to
61/// completion and stores the result regardless of async cancellation, so the
62/// model only loads once per process lifetime.
63pub struct NativeEmbeddingService {
64    model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65    model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77    let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78        Ok(raw) if raw.trim().is_empty() => None,
79        Ok(raw) => {
80            let dim = raw.trim().parse::<usize>().map_err(|e| {
81                EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82            })?;
83            Some(dim)
84        }
85        Err(std::env::VarError::NotPresent) => None,
86        Err(e) => {
87            return Err(EmbedError::InvalidInput(format!(
88                "invalid {LATTICE_EMBED_DIM}: {e}"
89            )));
90        }
91    };
92    ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
97    pub fn new() -> Self {
98        Self {
99            model: Arc::new(OnceLock::new()),
100            model_config: ModelConfig::new(EmbeddingModel::default()),
101        }
102    }
103
104    /// **Unstable**: constructor signature may change; use `EmbeddingService` trait for stable API.
105    pub fn with_model(model_type: EmbeddingModel) -> Self {
106        Self {
107            model: Arc::new(OnceLock::new()),
108            model_config: ModelConfig::new(model_type),
109        }
110    }
111
112    /// **Unstable**: create with explicit model config (model + optional MRL truncation dim).
113    pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114        model_config.validate()?;
115        Ok(Self {
116            model: Arc::new(OnceLock::new()),
117            model_config,
118        })
119    }
120
121    /// **Unstable**: create with model config read from `LATTICE_EMBED_DIM` env var.
122    pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123        let config = model_config_from_env(model_type)?;
124        Ok(Self {
125            model: Arc::new(OnceLock::new()),
126            model_config: config,
127        })
128    }
129
130    /// **Unstable**: persistence API may be moved to a separate manager type.
131    pub fn save_cache(&self) -> Result<usize> {
132        let Some(Ok(model)) = self.model.get() else {
133            return Ok(0);
134        };
135        match model {
136            LoadedModel::Qwen(m) => {
137                let model_name = self.model_config.model.to_string();
138                let path = embedding_cache_path(&model_name, m.dimensions());
139                m.cache_save(&path)
140                    .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141            }
142            _ => Ok(0),
143        }
144    }
145
146    /// **Unstable**: internal diagnostic; may be removed or moved to metrics.
147    pub fn cache_size(&self) -> usize {
148        self.model
149            .get()
150            .and_then(|r| r.as_ref().ok())
151            .map(LoadedModel::cache_size)
152            .unwrap_or(0)
153    }
154
155    /// Ensure the model is loaded (cancellation-safe).
156    ///
157    /// Uses `std::sync::OnceLock` so the model loading runs to completion
158    /// inside `spawn_blocking` even if the calling async future is dropped
159    /// (e.g., client disconnect during MCP timeout). The model loads exactly
160    /// once per process lifetime.
161    async fn ensure_model(&self) -> Result<&LoadedModel> {
162        // Fast path: already loaded.
163        if let Some(result) = self.model.get() {
164            return result
165                .as_ref()
166                .map_err(|e| EmbedError::ModelInitialization(e.clone()));
167        }
168
169        // Slow path: load model on blocking thread.
170        // Clone the Arc so spawn_blocking can store the result directly
171        // in the OnceLock, surviving async cancellation.
172        let model_lock = self.model.clone();
173        let model_config = self.model_config;
174
175        tokio::task::spawn_blocking(move || {
176            // OnceLock::get_or_init blocks until init completes.
177            // If another thread is already loading, this waits for it.
178            // This is fine because we're on the blocking thread pool.
179            model_lock.get_or_init(|| load_model_sync(model_config));
180        })
181        .await
182        .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
183
184        self.model
185            .get()
186            .expect("set by spawn_blocking")
187            .as_ref()
188            .map_err(|e| EmbedError::ModelInitialization(e.clone()))
189    }
190}
191
192/// Synchronous model loading (runs on blocking thread pool).
193fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
194    match model_config.model {
195        EmbeddingModel::BgeSmallEnV15
196        | EmbeddingModel::BgeBaseEnV15
197        | EmbeddingModel::BgeLargeEnV15
198        | EmbeddingModel::MultilingualE5Small
199        | EmbeddingModel::MultilingualE5Base
200        | EmbeddingModel::AllMiniLmL6V2
201        | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
202            let model_name = match model_config.model {
203                EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
204                EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
205                EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
206                EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
207                EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
208                EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
209                EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
210                    "paraphrase-multilingual-minilm-l12-v2"
211                }
212                _ => unreachable!(),
213            };
214            info!(model = model_name, "loading native BERT embedding model");
215            let bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
216            Ok(LoadedModel::Bert(Arc::new(bert)))
217        }
218        EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
219            load_qwen_model(model_config)
220        }
221        other => Err(format!("unsupported model: {other:?}")),
222    }
223}
224
225fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
226    model_config.validate().map_err(|e| e.to_string())?;
227    let model_type = model_config.model;
228    let model_name = model_type.to_string();
229    info!(
230        model = %model_name,
231        output_dim = ?model_config.output_dim,
232        "loading Qwen embedding model"
233    );
234    let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
235    let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
236    model.set_output_dim(model_config.output_dim);
237    let cache_path = embedding_cache_path(&model_name, model.dimensions());
238    match model.cache_load(&cache_path) {
239        Ok(n) if n > 0 => {
240            info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
241        }
242        _ => {}
243    }
244    Ok(LoadedModel::Qwen(Arc::new(model)))
245}
246
247/// Path for persistent embedding cache: ~/.lattice/cache/embed_{model}_{dim}d.bin
248fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
249    let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
250    std::path::PathBuf::from(home)
251        .join(".lattice")
252        .join("cache")
253        .join(format!("embed_{model}_{dim}d.bin"))
254}
255
256/// Locate Qwen3-Embedding model directory for the given model variant.
257fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
258    // Check env override first — applies to whichever Qwen model is loaded.
259    if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
260        return Ok(std::path::PathBuf::from(dir));
261    }
262
263    let slug = match model_type {
264        EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
265        EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
266        other => {
267            return Err(EmbedError::ModelInitialization(format!(
268                "not a Qwen model: {other}"
269            )));
270        }
271    };
272
273    let home = std::env::var("HOME")
274        .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
275    let dir = std::path::PathBuf::from(home)
276        .join(".lattice")
277        .join("models")
278        .join(slug);
279
280    if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
281        Ok(dir)
282    } else {
283        Err(EmbedError::ModelInitialization(format!(
284            "Qwen3 model not found at {}. Download from {}",
285            dir.display(),
286            model_type.model_id()
287        )))
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_cache_path_contains_dim_in_filename() {
297        let path = embedding_cache_path("qwen3-embedding-4b", 1024);
298        let filename = path.file_name().unwrap().to_str().unwrap();
299        assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
300    }
301
302    #[test]
303    fn test_cache_path_different_dims_produce_different_paths() {
304        let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
305        let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
306        assert_ne!(path_1024, path_2560);
307        assert!(path_1024.to_string_lossy().contains("1024d"));
308        assert!(path_2560.to_string_lossy().contains("2560d"));
309    }
310
311    #[test]
312    fn test_cache_path_model_slug_differentiates_variants() {
313        let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
314        let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
315        assert_ne!(path_4b, path_06b);
316        assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
317        assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
318    }
319
320    #[test]
321    fn test_cache_path_same_model_same_dim_same_path() {
322        let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
323        let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
324        assert_eq!(p1, p2);
325    }
326}
327
328#[async_trait]
329impl EmbeddingService for NativeEmbeddingService {
330    async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
331        if model != self.model_config.model {
332            return Err(EmbedError::InvalidInput(format!(
333                "requested model {:?} but this service is loaded with {:?}",
334                model, self.model_config.model
335            )));
336        }
337        if texts.is_empty() {
338            return Err(EmbedError::InvalidInput("no texts provided".into()));
339        }
340        if texts.len() > DEFAULT_MAX_BATCH_SIZE {
341            return Err(EmbedError::InvalidInput(format!(
342                "batch size {} exceeds maximum {}",
343                texts.len(),
344                DEFAULT_MAX_BATCH_SIZE
345            )));
346        }
347        for text in texts {
348            if text.len() > MAX_TEXT_CHARS {
349                return Err(EmbedError::TextTooLong {
350                    length: text.len(),
351                    max: MAX_TEXT_CHARS,
352                });
353            }
354        }
355
356        let loaded = self.ensure_model().await?;
357        let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
358        loaded
359            .encode_batch(&text_refs)
360            .map_err(EmbedError::InferenceFailed)
361    }
362
363    fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
364        if model == self.model_config.model {
365            self.model_config
366        } else {
367            ModelConfig::new(model)
368        }
369    }
370
371    fn supports_model(&self, model: EmbeddingModel) -> bool {
372        model == self.model_config.model
373    }
374
375    fn name(&self) -> &'static str {
376        "native-bert"
377    }
378}