use std::path::PathBuf;
use std::sync::Arc;
use crate::config::EmbeddingConfig;
use crate::engine::EmbeddingEngine;
use crate::error::EmbeddingError;
use crate::model::EmbeddingGemmaModel;
use crate::task::format_query;
#[derive(Debug, Clone)]
pub struct ModelFiles {
pub gguf_path: PathBuf,
pub dense1_path: PathBuf,
pub dense2_path: PathBuf,
pub tokenizer_path: PathBuf,
}
fn model_files_from_env() -> Result<ModelFiles, EmbeddingError> {
let base_dir =
std::env::var("ERIO_MODEL_DIR").unwrap_or_else(|_| env!("ERIO_MODEL_DIR").to_owned());
let base = PathBuf::from(base_dir);
let files = ModelFiles {
gguf_path: base.join("embeddinggemma-300M-Q8_0.gguf"),
dense1_path: base.join("2_Dense/model.safetensors"),
dense2_path: base.join("3_Dense/model.safetensors"),
tokenizer_path: base.join("tokenizer.json"),
};
for path in [
&files.gguf_path,
&files.dense1_path,
&files.dense2_path,
&files.tokenizer_path,
] {
if !path.exists() {
return Err(EmbeddingError::ModelLoad(format!(
"required model file missing: {}",
path.display()
)));
}
}
Ok(files)
}
pub struct GemmaEmbedding {
config: EmbeddingConfig,
model: Arc<EmbeddingGemmaModel>,
}
impl GemmaEmbedding {
pub fn new(config: EmbeddingConfig) -> Result<Self, EmbeddingError> {
let model_files = model_files_from_env()?;
Self::from_files(config, &model_files)
}
pub fn from_files(
config: EmbeddingConfig,
model_files: &ModelFiles,
) -> Result<Self, EmbeddingError> {
let model = EmbeddingGemmaModel::load(
&model_files.gguf_path,
&model_files.dense1_path,
&model_files.dense2_path,
&model_files.tokenizer_path,
)?;
Ok(Self {
config,
model: Arc::new(model),
})
}
}
#[async_trait::async_trait]
impl EmbeddingEngine for GemmaEmbedding {
fn name(&self) -> &'static str {
"gemma"
}
fn dimensions(&self) -> usize {
self.config.dimensions
}
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
if text.is_empty() {
return Err(EmbeddingError::InvalidInput(
"text must not be empty".into(),
));
}
let prompt = format_query(text, self.config.task_type);
self.model.embed(&prompt)
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
if texts.iter().any(|t| t.is_empty()) {
return Err(EmbeddingError::InvalidInput(
"text must not be empty".into(),
));
}
texts
.iter()
.map(|text| {
let prompt = format_query(text, self.config.task_type);
self.model.embed(&prompt)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::sync::{Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn model_files_struct_holds_paths() {
let files = ModelFiles {
gguf_path: PathBuf::from("/tmp/model.gguf"),
dense1_path: PathBuf::from("/tmp/2_Dense/model.safetensors"),
dense2_path: PathBuf::from("/tmp/3_Dense/model.safetensors"),
tokenizer_path: PathBuf::from("/tmp/tokenizer.json"),
};
assert!(files.gguf_path.ends_with("model.gguf"));
assert!(files.dense1_path.ends_with("model.safetensors"));
assert!(files.dense2_path.ends_with("model.safetensors"));
assert!(files.tokenizer_path.ends_with("tokenizer.json"));
}
#[test]
fn gemma_new_is_sync_constructor_signature() {
let constructor: fn(EmbeddingConfig) -> Result<GemmaEmbedding, EmbeddingError> =
GemmaEmbedding::new;
let _ = constructor;
}
#[test]
#[allow(unsafe_code)]
fn model_files_from_env_errors_when_required_files_are_missing() {
static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
let _guard = ENV_LOCK
.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned");
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock should be after unix epoch")
.as_nanos();
let temp_dir = std::env::temp_dir().join(format!(
"erio-embedding-missing-model-files-{}-{unique}",
std::process::id()
));
fs::create_dir_all(&temp_dir).expect("failed to create temp model dir");
fs::write(temp_dir.join("tokenizer.json"), b"{}")
.expect("failed to create partial model file");
let previous_model_dir = std::env::var("ERIO_MODEL_DIR").ok();
unsafe { std::env::set_var("ERIO_MODEL_DIR", &temp_dir) };
let result = model_files_from_env();
match previous_model_dir {
Some(value) => unsafe { std::env::set_var("ERIO_MODEL_DIR", value) },
None => unsafe { std::env::remove_var("ERIO_MODEL_DIR") },
}
fs::remove_dir_all(&temp_dir).expect("failed to cleanup temp model dir");
match result {
Err(EmbeddingError::ModelLoad(message)) => {
assert!(
message.contains("required model file missing"),
"unexpected error message: {message}"
);
}
other => panic!("expected EmbeddingError::ModelLoad, got {other:?}"),
}
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_returns_name() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
assert_eq!(engine.name(), "gemma");
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_returns_correct_dimensions() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
assert_eq!(engine.dimensions(), 768);
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_embed_returns_correct_dimensions() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let vec = engine.embed("hello world").await.unwrap();
assert_eq!(vec.len(), 768);
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_embed_rejects_empty_input() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let result = engine.embed("").await;
assert!(matches!(
result.unwrap_err(),
EmbeddingError::InvalidInput(_)
));
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_same_input_same_output() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let v1 = engine.embed("test determinism").await.unwrap();
let v2 = engine.embed("test determinism").await.unwrap();
assert_eq!(v1, v2);
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_different_inputs_different_outputs() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let v1 = engine.embed("hello").await.unwrap();
let v2 = engine.embed("world").await.unwrap();
assert_ne!(v1, v2);
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_normalized_unit_length() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let vec = engine.embed("test normalization").await.unwrap();
let magnitude: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(magnitude - 1.0).abs() < 1e-4,
"Expected unit length, got {magnitude}"
);
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_embed_batch_correct_count() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let results = engine.embed_batch(&["a", "b", "c"]).await.unwrap();
assert_eq!(results.len(), 3);
for vec in &results {
assert_eq!(vec.len(), 768);
}
}
#[tokio::test]
#[ignore = "requires model download"]
async fn gemma_embed_batch_preserves_order() {
let engine = GemmaEmbedding::new(EmbeddingConfig::default()).unwrap();
let v_a = engine.embed("alpha").await.unwrap();
let v_b = engine.embed("beta").await.unwrap();
let batch = engine.embed_batch(&["alpha", "beta"]).await.unwrap();
assert_eq!(batch[0], v_a);
assert_eq!(batch[1], v_b);
}
}