Skip to main content

erio_embedding/
gemma.rs

1//! Local embedding engine using the GGUF-quantized `EmbeddingGemma` model.
2//!
3//! This engine never downloads model files at runtime.
4//!
5//! `GemmaEmbedding::new` loads from a local model directory specified by `ERIO_MODEL_DIR`.
6//! The embedding crate's `build.rs` populates this automatically at build time by downloading
7//! public GitHub Release assets, or you can set `ERIO_MODEL_DIR` manually for offline builds.
8
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use crate::config::EmbeddingConfig;
13use crate::engine::EmbeddingEngine;
14use crate::error::EmbeddingError;
15use crate::model::EmbeddingGemmaModel;
16use crate::task::format_query;
17
18/// Paths to all required model files.
19#[derive(Debug, Clone)]
20pub struct ModelFiles {
21    /// Path to the GGUF backbone model file.
22    pub gguf_path: PathBuf,
23    /// Path to the first dense layer safetensors (`2_Dense/model.safetensors`).
24    pub dense1_path: PathBuf,
25    /// Path to the second dense layer safetensors (`3_Dense/model.safetensors`).
26    pub dense2_path: PathBuf,
27    /// Path to the tokenizer file (tokenizer.json).
28    pub tokenizer_path: PathBuf,
29}
30
31fn model_files_from_env() -> Result<ModelFiles, EmbeddingError> {
32    let base_dir =
33        std::env::var("ERIO_MODEL_DIR").unwrap_or_else(|_| env!("ERIO_MODEL_DIR").to_owned());
34
35    let base = PathBuf::from(base_dir);
36    let files = ModelFiles {
37        gguf_path: base.join("embeddinggemma-300M-Q8_0.gguf"),
38        dense1_path: base.join("2_Dense/model.safetensors"),
39        dense2_path: base.join("3_Dense/model.safetensors"),
40        tokenizer_path: base.join("tokenizer.json"),
41    };
42
43    for path in [
44        &files.gguf_path,
45        &files.dense1_path,
46        &files.dense2_path,
47        &files.tokenizer_path,
48    ] {
49        if !path.exists() {
50            return Err(EmbeddingError::ModelLoad(format!(
51                "required model file missing: {}",
52                path.display()
53            )));
54        }
55    }
56
57    Ok(files)
58}
59
60/// Local embedding engine using quantized `EmbeddingGemma` model via candle.
61pub struct GemmaEmbedding {
62    config: EmbeddingConfig,
63    model: Arc<EmbeddingGemmaModel>,
64}
65
66impl GemmaEmbedding {
67    /// Creates a new `GemmaEmbedding` from build-time prepared model files.
68    pub fn new(config: EmbeddingConfig) -> Result<Self, EmbeddingError> {
69        let model_files = model_files_from_env()?;
70        Self::from_files(config, &model_files)
71    }
72
73    /// Creates a `GemmaEmbedding` from pre-downloaded model files.
74    pub fn from_files(
75        config: EmbeddingConfig,
76        model_files: &ModelFiles,
77    ) -> Result<Self, EmbeddingError> {
78        let model = EmbeddingGemmaModel::load(
79            &model_files.gguf_path,
80            &model_files.dense1_path,
81            &model_files.dense2_path,
82            &model_files.tokenizer_path,
83        )?;
84        Ok(Self {
85            config,
86            model: Arc::new(model),
87        })
88    }
89}
90
91#[async_trait::async_trait]
92impl EmbeddingEngine for GemmaEmbedding {
93    fn name(&self) -> &'static str {
94        "gemma"
95    }
96
97    fn dimensions(&self) -> usize {
98        self.config.dimensions
99    }
100
101    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
102        if text.is_empty() {
103            return Err(EmbeddingError::InvalidInput(
104                "text must not be empty".into(),
105            ));
106        }
107        let prompt = format_query(text, self.config.task_type);
108        self.model.embed(&prompt)
109    }
110
111    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
112        if texts.iter().any(|t| t.is_empty()) {
113            return Err(EmbeddingError::InvalidInput(
114                "text must not be empty".into(),
115            ));
116        }
117        texts
118            .iter()
119            .map(|text| {
120                let prompt = format_query(text, self.config.task_type);
121                self.model.embed(&prompt)
122            })
123            .collect()
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use std::fs;
131    use std::sync::{Mutex, OnceLock};
132    use std::time::{SystemTime, UNIX_EPOCH};
133
134    // All tests that require model download are #[ignore].
135    // Run with: cargo test -p erio-embedding -- --ignored
136
137    #[test]
138    fn model_files_struct_holds_paths() {
139        let files = ModelFiles {
140            gguf_path: PathBuf::from("/tmp/model.gguf"),
141            dense1_path: PathBuf::from("/tmp/2_Dense/model.safetensors"),
142            dense2_path: PathBuf::from("/tmp/3_Dense/model.safetensors"),
143            tokenizer_path: PathBuf::from("/tmp/tokenizer.json"),
144        };
145        assert!(files.gguf_path.ends_with("model.gguf"));
146        assert!(files.dense1_path.ends_with("model.safetensors"));
147        assert!(files.dense2_path.ends_with("model.safetensors"));
148        assert!(files.tokenizer_path.ends_with("tokenizer.json"));
149    }
150
151    #[test]
152    fn gemma_new_is_sync_constructor_signature() {
153        let constructor: fn(EmbeddingConfig) -> Result<GemmaEmbedding, EmbeddingError> =
154            GemmaEmbedding::new;
155        let _ = constructor;
156    }
157
158    #[test]
159    #[allow(unsafe_code)]
160    fn model_files_from_env_errors_when_required_files_are_missing() {
161        static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
162        let _guard = ENV_LOCK
163            .get_or_init(|| Mutex::new(()))
164            .lock()
165            .expect("env lock poisoned");
166
167        let unique = SystemTime::now()
168            .duration_since(UNIX_EPOCH)
169            .expect("clock should be after unix epoch")
170            .as_nanos();
171        let temp_dir = std::env::temp_dir().join(format!(
172            "erio-embedding-missing-model-files-{}-{unique}",
173            std::process::id()
174        ));
175        fs::create_dir_all(&temp_dir).expect("failed to create temp model dir");
176        fs::write(temp_dir.join("tokenizer.json"), b"{}")
177            .expect("failed to create partial model file");
178
179        let previous_model_dir = std::env::var("ERIO_MODEL_DIR").ok();
180        unsafe { std::env::set_var("ERIO_MODEL_DIR", &temp_dir) };
181
182        let result = model_files_from_env();
183
184        match previous_model_dir {
185            Some(value) => unsafe { std::env::set_var("ERIO_MODEL_DIR", value) },
186            None => unsafe { std::env::remove_var("ERIO_MODEL_DIR") },
187        }
188
189        fs::remove_dir_all(&temp_dir).expect("failed to cleanup temp model dir");
190
191        match result {
192            Err(EmbeddingError::ModelLoad(message)) => {
193                assert!(
194                    message.contains("required model file missing"),
195                    "unexpected error message: {message}"
196                );
197            }
198            other => panic!("expected EmbeddingError::ModelLoad, got {other:?}"),
199        }
200    }
201
202    #[tokio::test]
203    #[ignore = "requires model download"]
204    async fn gemma_returns_name() {
205        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
206        assert_eq!(engine.name(), "gemma");
207    }
208
209    #[tokio::test]
210    #[ignore = "requires model download"]
211    async fn gemma_returns_correct_dimensions() {
212        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
213        assert_eq!(engine.dimensions(), 768);
214    }
215
216    #[tokio::test]
217    #[ignore = "requires model download"]
218    async fn gemma_embed_returns_correct_dimensions() {
219        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
220        let vec = engine.embed("hello world").await.unwrap();
221        assert_eq!(vec.len(), 768);
222    }
223
224    #[tokio::test]
225    #[ignore = "requires model download"]
226    async fn gemma_embed_rejects_empty_input() {
227        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
228        let result = engine.embed("").await;
229        assert!(matches!(
230            result.unwrap_err(),
231            EmbeddingError::InvalidInput(_)
232        ));
233    }
234
235    #[tokio::test]
236    #[ignore = "requires model download"]
237    async fn gemma_same_input_same_output() {
238        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
239        let v1 = engine.embed("test determinism").await.unwrap();
240        let v2 = engine.embed("test determinism").await.unwrap();
241        assert_eq!(v1, v2);
242    }
243
244    #[tokio::test]
245    #[ignore = "requires model download"]
246    async fn gemma_different_inputs_different_outputs() {
247        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
248        let v1 = engine.embed("hello").await.unwrap();
249        let v2 = engine.embed("world").await.unwrap();
250        assert_ne!(v1, v2);
251    }
252
253    #[tokio::test]
254    #[ignore = "requires model download"]
255    async fn gemma_normalized_unit_length() {
256        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
257        let vec = engine.embed("test normalization").await.unwrap();
258        let magnitude: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
259        assert!(
260            (magnitude - 1.0).abs() < 1e-4,
261            "Expected unit length, got {magnitude}"
262        );
263    }
264
265    #[tokio::test]
266    #[ignore = "requires model download"]
267    async fn gemma_embed_batch_correct_count() {
268        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
269        let results = engine.embed_batch(&["a", "b", "c"]).await.unwrap();
270        assert_eq!(results.len(), 3);
271        for vec in &results {
272            assert_eq!(vec.len(), 768);
273        }
274    }
275
276    #[tokio::test]
277    #[ignore = "requires model download"]
278    async fn gemma_embed_batch_preserves_order() {
279        let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
280        let v_a = engine.embed("alpha").await.unwrap();
281        let v_b = engine.embed("beta").await.unwrap();
282        let batch = engine.embed_batch(&["alpha", "beta"]).await.unwrap();
283        assert_eq!(batch[0], v_a);
284        assert_eq!(batch[1], v_b);
285    }
286}