lattice_embed/service/
native.rs1use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::{EmbedError, Result};
5use crate::model::{EmbeddingModel, ModelConfig};
6use async_trait::async_trait;
7use lattice_inference::{BertModel, QwenModel};
8use std::sync::{Arc, OnceLock};
9use tracing::info;
10
11enum LoadedModel {
13 Bert(Arc<BertModel>),
14 Qwen(Arc<QwenModel>),
15}
16
17impl LoadedModel {
24 fn encode_batch(&self, texts: &[&str]) -> std::result::Result<Vec<Vec<f32>>, String> {
25 match self {
26 LoadedModel::Bert(m) => m.encode_batch(texts).map_err(|e| e.to_string()),
27 LoadedModel::Qwen(m) => {
29 let mut results = Vec::with_capacity(texts.len());
30 for text in texts {
31 results.push(m.encode(text).map_err(|e| e.to_string())?);
32 }
33 Ok(results)
34 }
35 }
36 }
37
38 fn cache_size(&self) -> usize {
39 match self {
40 LoadedModel::Qwen(m) => m.cache_size(),
41 _ => 0,
42 }
43 }
44}
45
46pub struct NativeEmbeddingService {
64 model: Arc<OnceLock<std::result::Result<LoadedModel, String>>>,
65 model_config: ModelConfig,
66}
67
68impl Default for NativeEmbeddingService {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74const LATTICE_EMBED_DIM: &str = "LATTICE_EMBED_DIM";
75
76fn model_config_from_env(model: EmbeddingModel) -> Result<ModelConfig> {
77 let output_dim = match std::env::var(LATTICE_EMBED_DIM) {
78 Ok(raw) if raw.trim().is_empty() => None,
79 Ok(raw) => {
80 let dim = raw.trim().parse::<usize>().map_err(|e| {
81 EmbedError::InvalidInput(format!("invalid {LATTICE_EMBED_DIM}={raw:?}: {e}"))
82 })?;
83 Some(dim)
84 }
85 Err(std::env::VarError::NotPresent) => None,
86 Err(e) => {
87 return Err(EmbedError::InvalidInput(format!(
88 "invalid {LATTICE_EMBED_DIM}: {e}"
89 )));
90 }
91 };
92 ModelConfig::try_new(model, output_dim)
93}
94
95impl NativeEmbeddingService {
96 pub fn new() -> Self {
98 Self {
99 model: Arc::new(OnceLock::new()),
100 model_config: ModelConfig::new(EmbeddingModel::default()),
101 }
102 }
103
104 pub fn with_model(model_type: EmbeddingModel) -> Self {
106 Self {
107 model: Arc::new(OnceLock::new()),
108 model_config: ModelConfig::new(model_type),
109 }
110 }
111
112 pub fn with_model_config(model_config: ModelConfig) -> Result<Self> {
114 model_config.validate()?;
115 Ok(Self {
116 model: Arc::new(OnceLock::new()),
117 model_config,
118 })
119 }
120
121 pub fn with_model_from_env(model_type: EmbeddingModel) -> Result<Self> {
123 let config = model_config_from_env(model_type)?;
124 Ok(Self {
125 model: Arc::new(OnceLock::new()),
126 model_config: config,
127 })
128 }
129
130 pub fn save_cache(&self) -> Result<usize> {
132 let Some(Ok(model)) = self.model.get() else {
133 return Ok(0);
134 };
135 match model {
136 LoadedModel::Qwen(m) => {
137 let model_name = self.model_config.model.to_string();
138 let path = embedding_cache_path(&model_name, m.dimensions());
139 m.cache_save(&path)
140 .map_err(|e| EmbedError::InferenceFailed(e.to_string()))
141 }
142 _ => Ok(0),
143 }
144 }
145
146 pub fn cache_size(&self) -> usize {
148 self.model
149 .get()
150 .and_then(|r| r.as_ref().ok())
151 .map(LoadedModel::cache_size)
152 .unwrap_or(0)
153 }
154
155 async fn ensure_model(&self) -> Result<&LoadedModel> {
162 if let Some(result) = self.model.get() {
164 return result
165 .as_ref()
166 .map_err(|e| EmbedError::ModelInitialization(e.clone()));
167 }
168
169 let model_lock = self.model.clone();
173 let model_config = self.model_config;
174
175 tokio::task::spawn_blocking(move || {
176 model_lock.get_or_init(|| load_model_sync(model_config));
180 })
181 .await
182 .map_err(|e| EmbedError::ModelInitialization(e.to_string()))?;
183
184 self.model
185 .get()
186 .expect("set by spawn_blocking")
187 .as_ref()
188 .map_err(|e| EmbedError::ModelInitialization(e.clone()))
189 }
190}
191
192fn load_model_sync(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
194 match model_config.model {
195 EmbeddingModel::BgeSmallEnV15
196 | EmbeddingModel::BgeBaseEnV15
197 | EmbeddingModel::BgeLargeEnV15
198 | EmbeddingModel::MultilingualE5Small
199 | EmbeddingModel::MultilingualE5Base
200 | EmbeddingModel::AllMiniLmL6V2
201 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
202 let model_name = match model_config.model {
203 EmbeddingModel::BgeSmallEnV15 => "bge-small-en-v1.5",
204 EmbeddingModel::BgeBaseEnV15 => "bge-base-en-v1.5",
205 EmbeddingModel::BgeLargeEnV15 => "bge-large-en-v1.5",
206 EmbeddingModel::MultilingualE5Small => "multilingual-e5-small",
207 EmbeddingModel::MultilingualE5Base => "multilingual-e5-base",
208 EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
209 EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
210 "paraphrase-multilingual-minilm-l12-v2"
211 }
212 _ => unreachable!(),
213 };
214 info!(model = model_name, "loading native BERT embedding model");
215 let bert = BertModel::from_pretrained(model_name).map_err(|e| e.to_string())?;
216 Ok(LoadedModel::Bert(Arc::new(bert)))
217 }
218 EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => {
219 load_qwen_model(model_config)
220 }
221 other => Err(format!("unsupported model: {other:?}")),
222 }
223}
224
225fn load_qwen_model(model_config: ModelConfig) -> std::result::Result<LoadedModel, String> {
226 model_config.validate().map_err(|e| e.to_string())?;
227 let model_type = model_config.model;
228 let model_name = model_type.to_string();
229 info!(
230 model = %model_name,
231 output_dim = ?model_config.output_dim,
232 "loading Qwen embedding model"
233 );
234 let model_dir = qwen_model_dir(model_type).map_err(|e| e.to_string())?;
235 let mut model = QwenModel::from_directory(&model_dir).map_err(|e| e.to_string())?;
236 model.set_output_dim(model_config.output_dim);
237 let cache_path = embedding_cache_path(&model_name, model.dimensions());
238 match model.cache_load(&cache_path) {
239 Ok(n) if n > 0 => {
240 info!(entries = n, path = %cache_path.display(), "loaded embedding cache")
241 }
242 _ => {}
243 }
244 Ok(LoadedModel::Qwen(Arc::new(model)))
245}
246
247fn embedding_cache_path(model: &str, dim: usize) -> std::path::PathBuf {
249 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into());
250 std::path::PathBuf::from(home)
251 .join(".lattice")
252 .join("cache")
253 .join(format!("embed_{model}_{dim}d.bin"))
254}
255
256fn qwen_model_dir(model_type: EmbeddingModel) -> Result<std::path::PathBuf> {
258 if let Ok(dir) = std::env::var("LATTICE_QWEN_MODEL_DIR") {
260 return Ok(std::path::PathBuf::from(dir));
261 }
262
263 let slug = match model_type {
264 EmbeddingModel::Qwen3Embedding0_6B => "qwen3-embedding-0.6b",
265 EmbeddingModel::Qwen3Embedding4B => "qwen3-embedding-4b",
266 other => {
267 return Err(EmbedError::ModelInitialization(format!(
268 "not a Qwen model: {other}"
269 )));
270 }
271 };
272
273 let home = std::env::var("HOME")
274 .map_err(|_| EmbedError::ModelInitialization("HOME not set".into()))?;
275 let dir = std::path::PathBuf::from(home)
276 .join(".lattice")
277 .join("models")
278 .join(slug);
279
280 if dir.join("model.safetensors").exists() || dir.join("model.safetensors.index.json").exists() {
281 Ok(dir)
282 } else {
283 Err(EmbedError::ModelInitialization(format!(
284 "Qwen3 model not found at {}. Download from {}",
285 dir.display(),
286 model_type.model_id()
287 )))
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_cache_path_contains_dim_in_filename() {
297 let path = embedding_cache_path("qwen3-embedding-4b", 1024);
298 let filename = path.file_name().unwrap().to_str().unwrap();
299 assert_eq!(filename, "embed_qwen3-embedding-4b_1024d.bin");
300 }
301
302 #[test]
303 fn test_cache_path_different_dims_produce_different_paths() {
304 let path_1024 = embedding_cache_path("qwen3-embedding-4b", 1024);
305 let path_2560 = embedding_cache_path("qwen3-embedding-4b", 2560);
306 assert_ne!(path_1024, path_2560);
307 assert!(path_1024.to_string_lossy().contains("1024d"));
308 assert!(path_2560.to_string_lossy().contains("2560d"));
309 }
310
311 #[test]
312 fn test_cache_path_model_slug_differentiates_variants() {
313 let path_4b = embedding_cache_path("qwen3-embedding-4b", 2560);
314 let path_06b = embedding_cache_path("qwen3-embedding-0.6b", 1024);
315 assert_ne!(path_4b, path_06b);
316 assert!(path_4b.to_string_lossy().contains("qwen3-embedding-4b"));
317 assert!(path_06b.to_string_lossy().contains("qwen3-embedding-0.6b"));
318 }
319
320 #[test]
321 fn test_cache_path_same_model_same_dim_same_path() {
322 let p1 = embedding_cache_path("qwen3-embedding-4b", 1024);
323 let p2 = embedding_cache_path("qwen3-embedding-4b", 1024);
324 assert_eq!(p1, p2);
325 }
326}
327
328#[async_trait]
329impl EmbeddingService for NativeEmbeddingService {
330 async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
331 if model != self.model_config.model {
332 return Err(EmbedError::InvalidInput(format!(
333 "requested model {:?} but this service is loaded with {:?}",
334 model, self.model_config.model
335 )));
336 }
337 if texts.is_empty() {
338 return Err(EmbedError::InvalidInput("no texts provided".into()));
339 }
340 if texts.len() > DEFAULT_MAX_BATCH_SIZE {
341 return Err(EmbedError::InvalidInput(format!(
342 "batch size {} exceeds maximum {}",
343 texts.len(),
344 DEFAULT_MAX_BATCH_SIZE
345 )));
346 }
347 for text in texts {
348 if text.len() > MAX_TEXT_CHARS {
349 return Err(EmbedError::TextTooLong {
350 length: text.len(),
351 max: MAX_TEXT_CHARS,
352 });
353 }
354 }
355
356 let loaded = self.ensure_model().await?;
357 let text_refs = texts.iter().map(String::as_str).collect::<Vec<_>>();
358 loaded
359 .encode_batch(&text_refs)
360 .map_err(EmbedError::InferenceFailed)
361 }
362
363 fn model_config(&self, model: EmbeddingModel) -> ModelConfig {
364 if model == self.model_config.model {
365 self.model_config
366 } else {
367 ModelConfig::new(model)
368 }
369 }
370
371 fn supports_model(&self, model: EmbeddingModel) -> bool {
372 model == self.model_config.model
373 }
374
375 fn name(&self) -> &'static str {
376 "native-bert"
377 }
378}