lattice_embed/service/
native.rs1use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::info;
10
11enum LoadedModel {
13 Bert(Arc<BertModel>),
14 Qwen(Arc<QwenModel>),
15}
16
17impl LoadedModel {
24 fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25 match self {
26 LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27 LoadedModel::Qwen(m) => {
29 let mut results = Vec::with_capacity(texts.len());
30 for text in texts {
31 results.push(m.encode(text).map_err(|e| e.to_string())?);
32 }
33 Ok(results)
34 }
35 }
36 }
37
38 fn cache_size(&self) -> usize {
39 match self {
40 LoadedModel::Qwen(m) => m.cache_size(),
41 _ => 0,
42 }
43 }
44}
45
46pub struct NativeEmbeddingService {
64 model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65 model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77 let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78 Ok(raw) if raw.trim().is_empty() => None,
79 Ok(raw) => {
80 let dim = raw.trim().parse::<usize>().map_err(|e| {
81 EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82 })?;
83 Some(dim)
84 }
85 Err(std::env::VarError::NotPresent) => None,
86 Err(e) => {
87 return Err(EmbedError::InvalidInput(format!(
88 "invalid {LATTICE_EMBED_DIM}: {e}"
89 )));
90 }
91 };
92 ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96 pub fn new() -> Self {
98 Self {
99 model: Arc::new(OnceLock::new()),
100 model_config: ModelConfig::new(EmbeddingModel::default()),
101 }
102 }
103
104 pub fn with_model(model_type: EmbeddingModel) -> Self {
106 Self {
107 model: Arc::new(OnceLock::new()),
108 model_config: ModelConfig::new(model_type),
109 }
110 }
111
112 pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114 model_config.validate()?;
115 Ok(Self {
116 model: Arc::new(OnceLock::new()),
117 model_config,
118 })
119 }
120
121 pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123 let config = model_config_from_env(model_type)?;
124 Ok(Self {
125 model: Arc::new(OnceLock::new()),
126 model_config: config,
127 })
128 }
129
130 pub fn save_cache(&self) -> Result<usize> {
132 let Some(Ok(model)) = self.model.get() else {
133 return Ok(0);
134 };
135 match model {
136 LoadedModel::Qwen(m) => {
137 let model_name = self.model_config.model.to_string();
138 let path = embedding_cache_path(&model_name, m.dimensions());
139 m.cache_save(&path)
140 .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141 }
142 _ => Ok(0),
143 }
144 }
145
146 pub fn cache_size(&self) -> usize {
148 self.model
149 .get()
150 .and_then(|r| r.as_ref().ok())
151 .map(LoadedModel::cache_size)
152 .unwrap_or(0)
153 }
154
155 pub async fn ensure_loaded(&self) -> Result<()> {
165 self.ensure_model().await.map(|_| ())
166 }
167
168 async fn ensure_model(&self) -> Result<&LoadedModel> {
175 if let Some(result) = self.model.get() {
177 return result
178 .as_ref()
179 .map_err(|e| EmbedError::ModelInitialization(e.clone()));
180 }
181
182 let model_lock = self.model.clone();
186 let model_config = self.model_config;
187
188 tokio::task::spawn_blocking(move || {
189 model_lock.get_or_init(|| load_model_sync(model_config));
193 })
194 .await
195 .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
196
197 self.model
198 .get()
199 .expect("set by spawn_blocking")
200 .as_ref()
201 .map_err(|e| EmbedError::ModelInitialization(e.clone()))
202 }
203}
204
205fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
207 match model_config.model {
208 EmbeddingModel::BgeSmallEnV15
209 | EmbeddingModel::BgeBaseEnV15
210 | EmbeddingModel::BgeLargeEnV15
211 | EmbeddingModel::MultilingualE5Small
212 | EmbeddingModel::MultilingualE5Base
213 | EmbeddingModel::AllMiniLmL6V2
214 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
215 let model_name = match model_config.model {
216 EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
217 EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
218 EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
219 EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
220 EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
221 EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
222 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
223 "paraphrase-multilingual-minilm-l12-v2"
224 }
225 _ => unreachable!(),
226 };
227 info!(model = model_name, "loading native BERT embedding model");
228 let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
229 if let Some(pooling) = model_config.model.bert_pooling() {
232 bert.set_pooling(pooling);
233 }
234 Ok(LoadedModel::Bert(Arc::new(bert)))
235 }
236 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
237 load_qwen_model(model_config)
238 }
239 other => Err(format!("unsupported model: {other:?}")),
240 }
241}
242
243fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
244 model_config.validate().map_err(|e| e.to_string())?;
245 let model_type = model_config.model;
246 let model_name = model_type.to_string();
247 info!(
248 model = %model_name,
249 output_dim = ?model_config.output_dim,
250 "loading Qwen embedding model"
251 );
252 let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
253 let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
254 model.set_output_dim(model_config.output_dim);
255 let cache_path = embedding_cache_path(&model_name, model.dimensions());
256 match model.cache_load(&cache_path) {
257 Ok(n) if n > 0 => {
258 info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
259 }
260 _ => {}
261 }
262 Ok(LoadedModel::Qwen(Arc::new(model)))
263}
264
265fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
267 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
268 std::path::PathBuf::from(home)
269 .join(".lattice")
270 .join("cache")
271 .join(format!("embed_{model}_{dim}d.bin"))
272}
273
274fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
276 if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
278 return Ok(std::path::PathBuf::from(dir));
279 }
280
281 let slug = match model_type {
282 EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
283 EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
284 other => {
285 return Err(EmbedError::ModelInitialization(format!(
286 "not a Qwen model: {other}"
287 )));
288 }
289 };
290
291 let home = std::env::var("HOME")
292 .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
293 let dir = std::path::PathBuf::from(home)
294 .join(".lattice")
295 .join("models")
296 .join(slug);
297
298 if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
299 Ok(dir)
300 } else {
301 Err(EmbedError::ModelInitialization(format!(
302 "Qwen3 model not found at {}. Download from {}",
303 dir.display(),
304 model_type.model_id()
305 )))
306 }
307}
308
309#[async_trait]
310impl EmbeddingService for NativeEmbeddingService {
311 async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
312 if model != self.model_config.model {
313 return Err(EmbedError::InvalidInput(format!(
314 "requested model {:?} but this service is loaded with {:?}",
315 model, self.model_config.model
316 )));
317 }
318 if texts.is_empty() {
319 return Err(EmbedError::InvalidInput("no texts provided".into()));
320 }
321 if texts.len() > DEFAULT_MAX_BATCH_SIZE {
322 return Err(EmbedError::InvalidInput(format!(
323 "batch size {} exceeds maximum {}",
324 texts.len(),
325 DEFAULT_MAX_BATCH_SIZE
326 )));
327 }
328 for text in texts {
329 if text.len() > MAX_TEXT_CHARS {
330 return Err(EmbedError::TextTooLong {
331 length: text.len(),
332 max: MAX_TEXT_CHARS,
333 });
334 }
335 }
336
337 let loaded = self.ensure_model().await?;
338 let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
339 loaded
340 .encode_batch(&text_refs)
341 .map_err(EmbedError::InferenceFailed)
342 }
343
344 fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
345 if model == self.model_config.model {
346 self.model_config
347 } else {
348 ModelConfig::new(model)
349 }
350 }
351
352 fn supports_model(&self, model: EmbeddingModel) -> bool {
353 model == self.model_config.model
354 }
355
356 fn name(&self) -> &'static str {
357 "native-bert"
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_cache_path_contains_dim_in_filename() {
367 let path = embedding_cache_path("qwen3-embedding-4b", 1024);
368 let filename = path.file_name().unwrap().to_str().unwrap();
369 assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
370 }
371
372 #[test]
373 fn test_cache_path_different_dims_produce_different_paths() {
374 let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
375 let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
376 assert_ne!(path_1024, path_2560);
377 assert!(path_1024.to_string_lossy().contains("1024d"));
378 assert!(path_2560.to_string_lossy().contains("2560d"));
379 }
380
381 #[test]
382 fn test_cache_path_model_slug_differentiates_variants() {
383 let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
384 let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
385 assert_ne!(path_4b, path_06b);
386 assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
387 assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
388 }
389
390 #[test]
391 fn test_cache_path_same_model_same_dim_same_path() {
392 let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
393 let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
394 assert_eq!(p1, p2);
395 }
396}