lattice_embed/service/
native.rs1use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::info;
10
11enum LoadedModel {
13 Bert(Arc<BertModel>),
14 Qwen(Arc<QwenModel>),
15}
16
17impl LoadedModel {
24 fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25 match self {
26 LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27 LoadedModel::Qwen(m) => {
29 let mut results = Vec::with_capacity(texts.len());
30 for text in texts {
31 results.push(m.encode(text).map_err(|e| e.to_string())?);
32 }
33 Ok(results)
34 }
35 }
36 }
37
38 fn cache_size(&self) -> usize {
39 match self {
40 LoadedModel::Qwen(m) => m.cache_size(),
41 _ => 0,
42 }
43 }
44}
45
46pub struct NativeEmbeddingService {
64 model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65 model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77 let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78 Ok(raw) if raw.trim().is_empty() => None,
79 Ok(raw) => {
80 let dim = raw.trim().parse::<usize>().map_err(|e| {
81 EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82 })?;
83 Some(dim)
84 }
85 Err(std::env::VarError::NotPresent) => None,
86 Err(e) => {
87 return Err(EmbedError::InvalidInput(format!(
88 "invalid {LATTICE_EMBED_DIM}: {e}"
89 )));
90 }
91 };
92 ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96 pub fn new() -> Self {
98 Self {
99 model: Arc::new(OnceLock::new()),
100 model_config: ModelConfig::new(EmbeddingModel::default()),
101 }
102 }
103
104 pub fn with_model(model_type: EmbeddingModel) -> Self {
106 Self {
107 model: Arc::new(OnceLock::new()),
108 model_config: ModelConfig::new(model_type),
109 }
110 }
111
112 pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114 model_config.validate()?;
115 Ok(Self {
116 model: Arc::new(OnceLock::new()),
117 model_config,
118 })
119 }
120
121 pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123 let config = model_config_from_env(model_type)?;
124 Ok(Self {
125 model: Arc::new(OnceLock::new()),
126 model_config: config,
127 })
128 }
129
130 pub fn save_cache(&self) -> Result<usize> {
132 let Some(Ok(model)) = self.model.get() else {
133 return Ok(0);
134 };
135 match model {
136 LoadedModel::Qwen(m) => {
137 let model_name = self.model_config.model.to_string();
138 let path = embedding_cache_path(&model_name, m.dimensions());
139 m.cache_save(&path)
140 .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141 }
142 _ => Ok(0),
143 }
144 }
145
146 pub fn cache_size(&self) -> usize {
148 self.model
149 .get()
150 .and_then(|r| r.as_ref().ok())
151 .map(LoadedModel::cache_size)
152 .unwrap_or(0)
153 }
154
155 async fn ensure_model(&self) -> Result<&LoadedModel> {
162 if let Some(result) = self.model.get() {
164 return result
165 .as_ref()
166 .map_err(|e| EmbedError::ModelInitialization(e.clone()));
167 }
168
169 let model_lock = self.model.clone();
173 let model_config = self.model_config;
174
175 tokio::task::spawn_blocking(move || {
176 model_lock.get_or_init(|| load_model_sync(model_config));
180 })
181 .await
182 .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
183
184 self.model
185 .get()
186 .expect("set by spawn_blocking")
187 .as_ref()
188 .map_err(|e| EmbedError::ModelInitialization(e.clone()))
189 }
190}
191
192fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
194 match model_config.model {
195 EmbeddingModel::BgeSmallEnV15
196 | EmbeddingModel::BgeBaseEnV15
197 | EmbeddingModel::BgeLargeEnV15
198 | EmbeddingModel::MultilingualE5Small
199 | EmbeddingModel::MultilingualE5Base
200 | EmbeddingModel::AllMiniLmL6V2
201 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
202 let model_name = match model_config.model {
203 EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
204 EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
205 EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
206 EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
207 EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
208 EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
209 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
210 "paraphrase-multilingual-minilm-l12-v2"
211 }
212 _ => unreachable!(),
213 };
214 info!(model = model_name, "loading native BERT embedding model");
215 let mut bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
216 if let Some(pooling) = model_config.model.bert_pooling() {
219 bert.set_pooling(pooling);
220 }
221 Ok(LoadedModel::Bert(Arc::new(bert)))
222 }
223 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
224 load_qwen_model(model_config)
225 }
226 other => Err(format!("unsupported model: {other:?}")),
227 }
228}
229
230fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
231 model_config.validate().map_err(|e| e.to_string())?;
232 let model_type = model_config.model;
233 let model_name = model_type.to_string();
234 info!(
235 model = %model_name,
236 output_dim = ?model_config.output_dim,
237 "loading Qwen embedding model"
238 );
239 let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
240 let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
241 model.set_output_dim(model_config.output_dim);
242 let cache_path = embedding_cache_path(&model_name, model.dimensions());
243 match model.cache_load(&cache_path) {
244 Ok(n) if n > 0 => {
245 info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
246 }
247 _ => {}
248 }
249 Ok(LoadedModel::Qwen(Arc::new(model)))
250}
251
252fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
254 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
255 std::path::PathBuf::from(home)
256 .join(".lattice")
257 .join("cache")
258 .join(format!("embed_{model}_{dim}d.bin"))
259}
260
261fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
263 if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
265 return Ok(std::path::PathBuf::from(dir));
266 }
267
268 let slug = match model_type {
269 EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
270 EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
271 other => {
272 return Err(EmbedError::ModelInitialization(format!(
273 "not a Qwen model: {other}"
274 )));
275 }
276 };
277
278 let home = std::env::var("HOME")
279 .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
280 let dir = std::path::PathBuf::from(home)
281 .join(".lattice")
282 .join("models")
283 .join(slug);
284
285 if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
286 Ok(dir)
287 } else {
288 Err(EmbedError::ModelInitialization(format!(
289 "Qwen3 model not found at {}. Download from {}",
290 dir.display(),
291 model_type.model_id()
292 )))
293 }
294}
295
296#[async_trait]
297impl EmbeddingService for NativeEmbeddingService {
298 async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
299 if model != self.model_config.model {
300 return Err(EmbedError::InvalidInput(format!(
301 "requested model {:?} but this service is loaded with {:?}",
302 model, self.model_config.model
303 )));
304 }
305 if texts.is_empty() {
306 return Err(EmbedError::InvalidInput("no texts provided".into()));
307 }
308 if texts.len() > DEFAULT_MAX_BATCH_SIZE {
309 return Err(EmbedError::InvalidInput(format!(
310 "batch size {} exceeds maximum {}",
311 texts.len(),
312 DEFAULT_MAX_BATCH_SIZE
313 )));
314 }
315 for text in texts {
316 if text.len() > MAX_TEXT_CHARS {
317 return Err(EmbedError::TextTooLong {
318 length: text.len(),
319 max: MAX_TEXT_CHARS,
320 });
321 }
322 }
323
324 let loaded = self.ensure_model().await?;
325 let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
326 loaded
327 .encode_batch(&text_refs)
328 .map_err(EmbedError::InferenceFailed)
329 }
330
331 fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
332 if model == self.model_config.model {
333 self.model_config
334 } else {
335 ModelConfig::new(model)
336 }
337 }
338
339 fn supports_model(&self, model: EmbeddingModel) -> bool {
340 model == self.model_config.model
341 }
342
343 fn name(&self) -> &'static str {
344 "native-bert"
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_cache_path_contains_dim_in_filename() {
354 let path = embedding_cache_path("qwen3-embedding-4b", 1024);
355 let filename = path.file_name().unwrap().to_str().unwrap();
356 assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
357 }
358
359 #[test]
360 fn test_cache_path_different_dims_produce_different_paths() {
361 let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
362 let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
363 assert_ne!(path_1024, path_2560);
364 assert!(path_1024.to_string_lossy().contains("1024d"));
365 assert!(path_2560.to_string_lossy().contains("2560d"));
366 }
367
368 #[test]
369 fn test_cache_path_model_slug_differentiates_variants() {
370 let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
371 let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
372 assert_ne!(path_4b, path_06b);
373 assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
374 assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
375 }
376
377 #[test]
378 fn test_cache_path_same_model_same_dim_same_path() {
379 let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
380 let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
381 assert_eq!(p1, p2);
382 }
383}