1use std::path::PathBuf;
10use std::sync::Arc;
11
12use crate::config::EmbeddingConfig;
13use crate::engine::EmbeddingEngine;
14use crate::error::EmbeddingError;
15use crate::model::EmbeddingGemmaModel;
16use crate::task::format_query;
17
18#[derive(Debug, Clone)]
20pub struct ModelFiles {
21 pub gguf_path: PathBuf,
23 pub dense1_path: PathBuf,
25 pub dense2_path: PathBuf,
27 pub tokenizer_path: PathBuf,
29}
30
31fn model_files_from_env() -> Result<ModelFiles, EmbeddingError> {
32 let base_dir =
33 std::env::var("ERIO_MODEL_DIR").unwrap_or_else(|_| env!("ERIO_MODEL_DIR").to_owned());
34
35 let base = PathBuf::from(base_dir);
36 let files = ModelFiles {
37 gguf_path: base.join("embeddinggemma-300M-Q8_0.gguf"),
38 dense1_path: base.join("2_Dense/model.safetensors"),
39 dense2_path: base.join("3_Dense/model.safetensors"),
40 tokenizer_path: base.join("tokenizer.json"),
41 };
42
43 for path in [
44 &files.gguf_path,
45 &files.dense1_path,
46 &files.dense2_path,
47 &files.tokenizer_path,
48 ] {
49 if !path.exists() {
50 return Err(EmbeddingError::ModelLoad(format!(
51 "required model file missing: {}",
52 path.display()
53 )));
54 }
55 }
56
57 Ok(files)
58}
59
60pub struct GemmaEmbedding {
62 config: EmbeddingConfig,
63 model: Arc<EmbeddingGemmaModel>,
64}
65
66impl GemmaEmbedding {
67 pub fn new(config: EmbeddingConfig) -> Result<Self, EmbeddingError> {
69 let model_files = model_files_from_env()?;
70 Self::from_files(config, &model_files)
71 }
72
73 pub fn from_files(
75 config: EmbeddingConfig,
76 model_files: &ModelFiles,
77 ) -> Result<Self, EmbeddingError> {
78 let model = EmbeddingGemmaModel::load(
79 &model_files.gguf_path,
80 &model_files.dense1_path,
81 &model_files.dense2_path,
82 &model_files.tokenizer_path,
83 )?;
84 Ok(Self {
85 config,
86 model: Arc::new(model),
87 })
88 }
89}
90
91#[async_trait::async_trait]
92impl EmbeddingEngine for GemmaEmbedding {
93 fn name(&self) -> &'static str {
94 "gemma"
95 }
96
97 fn dimensions(&self) -> usize {
98 self.config.dimensions
99 }
100
101 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
102 if text.is_empty() {
103 return Err(EmbeddingError::InvalidInput(
104 "text must not be empty".into(),
105 ));
106 }
107 let prompt = format_query(text, self.config.task_type);
108 self.model.embed(&prompt)
109 }
110
111 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
112 if texts.iter().any(|t| t.is_empty()) {
113 return Err(EmbeddingError::InvalidInput(
114 "text must not be empty".into(),
115 ));
116 }
117 texts
118 .iter()
119 .map(|text| {
120 let prompt = format_query(text, self.config.task_type);
121 self.model.embed(&prompt)
122 })
123 .collect()
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use std::fs;
131 use std::sync::{Mutex, OnceLock};
132 use std::time::{SystemTime, UNIX_EPOCH};
133
134 #[test]
138 fn model_files_struct_holds_paths() {
139 let files = ModelFiles {
140 gguf_path: PathBuf::from("/tmp/model.gguf"),
141 dense1_path: PathBuf::from("/tmp/2_Dense/model.safetensors"),
142 dense2_path: PathBuf::from("/tmp/3_Dense/model.safetensors"),
143 tokenizer_path: PathBuf::from("/tmp/tokenizer.json"),
144 };
145 assert!(files.gguf_path.ends_with("model.gguf"));
146 assert!(files.dense1_path.ends_with("model.safetensors"));
147 assert!(files.dense2_path.ends_with("model.safetensors"));
148 assert!(files.tokenizer_path.ends_with("tokenizer.json"));
149 }
150
151 #[test]
152 fn gemma_new_is_sync_constructor_signature() {
153 let constructor: fn(EmbeddingConfig) -> Result<GemmaEmbedding, EmbeddingError> =
154 GemmaEmbedding::new;
155 let _ = constructor;
156 }
157
158 #[test]
159 #[allow(unsafe_code)]
160 fn model_files_from_env_errors_when_required_files_are_missing() {
161 static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
162 let _guard = ENV_LOCK
163 .get_or_init(|| Mutex::new(()))
164 .lock()
165 .expect("env lock poisoned");
166
167 let unique = SystemTime::now()
168 .duration_since(UNIX_EPOCH)
169 .expect("clock should be after unix epoch")
170 .as_nanos();
171 let temp_dir = std::env::temp_dir().join(format!(
172 "erio-embedding-missing-model-files-{}-{unique}",
173 std::process::id()
174 ));
175 fs::create_dir_all(&temp_dir).expect("failed to create temp model dir");
176 fs::write(temp_dir.join("tokenizer.json"), b"{}")
177 .expect("failed to create partial model file");
178
179 let previous_model_dir = std::env::var("ERIO_MODEL_DIR").ok();
180 unsafe { std::env::set_var("ERIO_MODEL_DIR", &temp_dir) };
181
182 let result = model_files_from_env();
183
184 match previous_model_dir {
185 Some(value) => unsafe { std::env::set_var("ERIO_MODEL_DIR", value) },
186 None => unsafe { std::env::remove_var("ERIO_MODEL_DIR") },
187 }
188
189 fs::remove_dir_all(&temp_dir).expect("failed to cleanup temp model dir");
190
191 match result {
192 Err(EmbeddingError::ModelLoad(message)) => {
193 assert!(
194 message.contains("required model file missing"),
195 "unexpected error message: {message}"
196 );
197 }
198 other => panic!("expected EmbeddingError::ModelLoad, got {other:?}"),
199 }
200 }
201
202 #[tokio::test]
203 #[ignore = "requires model download"]
204 async fn gemma_returns_name() {
205 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
206 assert_eq!(engine.name(), "gemma");
207 }
208
209 #[tokio::test]
210 #[ignore = "requires model download"]
211 async fn gemma_returns_correct_dimensions() {
212 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
213 assert_eq!(engine.dimensions(), 768);
214 }
215
216 #[tokio::test]
217 #[ignore = "requires model download"]
218 async fn gemma_embed_returns_correct_dimensions() {
219 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
220 let vec = engine.embed("hello world").await.unwrap();
221 assert_eq!(vec.len(), 768);
222 }
223
224 #[tokio::test]
225 #[ignore = "requires model download"]
226 async fn gemma_embed_rejects_empty_input() {
227 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
228 let result = engine.embed("").await;
229 assert!(matches!(
230 result.unwrap_err(),
231 EmbeddingError::InvalidInput(_)
232 ));
233 }
234
235 #[tokio::test]
236 #[ignore = "requires model download"]
237 async fn gemma_same_input_same_output() {
238 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
239 let v1 = engine.embed("test determinism").await.unwrap();
240 let v2 = engine.embed("test determinism").await.unwrap();
241 assert_eq!(v1, v2);
242 }
243
244 #[tokio::test]
245 #[ignore = "requires model download"]
246 async fn gemma_different_inputs_different_outputs() {
247 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
248 let v1 = engine.embed("hello").await.unwrap();
249 let v2 = engine.embed("world").await.unwrap();
250 assert_ne!(v1, v2);
251 }
252
253 #[tokio::test]
254 #[ignore = "requires model download"]
255 async fn gemma_normalized_unit_length() {
256 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
257 let vec = engine.embed("test normalization").await.unwrap();
258 let magnitude: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
259 assert!(
260 (magnitude - 1.0).abs() < 1e-4,
261 "Expected unit length, got {magnitude}"
262 );
263 }
264
265 #[tokio::test]
266 #[ignore = "requires model download"]
267 async fn gemma_embed_batch_correct_count() {
268 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
269 let results = engine.embed_batch(&["a", "b", "c"]).await.unwrap();
270 assert_eq!(results.len(), 3);
271 for vec in &results {
272 assert_eq!(vec.len(), 768);
273 }
274 }
275
276 #[tokio::test]
277 #[ignore = "requires model download"]
278 async fn gemma_embed_batch_preserves_order() {
279 let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
280 let v_a = engine.embed("alpha").await.unwrap();
281 let v_b = engine.embed("beta").await.unwrap();
282 let batch = engine.embed_batch(&["alpha", "beta"]).await.unwrap();
283 assert_eq!(batch[0], v_a);
284 assert_eq!(batch[1], v_b);
285 }
286}