mod config;
#[cfg(feature = "embed")]
mod download;
mod onnx;
pub use config::{EmbeddingModelConfig, EmbeddingOptions};
pub use onnx::OnnxEmbeddingModel;
use grafeo_common::utils::error::Result;
pub trait EmbeddingModel: Send + Sync {
fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
fn dimensions(&self) -> usize;
fn name(&self) -> &str;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct MockEmbeddingModel {
dims: usize,
}
impl MockEmbeddingModel {
fn new(dims: usize) -> Self {
Self { dims }
}
}
impl EmbeddingModel for MockEmbeddingModel {
fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Ok(texts
.iter()
.map(|t| {
let seed = t.bytes().map(|b| b as f32).sum::<f32>();
(0..self.dims)
.map(|i| ((seed + i as f32) * 0.01).sin())
.collect()
})
.collect())
}
fn dimensions(&self) -> usize {
self.dims
}
fn name(&self) -> &str {
"mock-embedding"
}
}
#[test]
fn trait_returns_correct_dimensions() {
let model = MockEmbeddingModel::new(384);
let vecs = model.embed(&["hello"]).unwrap();
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].len(), 384);
assert_eq!(model.dimensions(), 384);
}
#[test]
fn trait_batch_returns_one_vec_per_input() {
let model = MockEmbeddingModel::new(128);
let texts = &["one", "two", "three", "four", "five"];
let vecs = model.embed(texts).unwrap();
assert_eq!(vecs.len(), 5);
for v in &vecs {
assert_eq!(v.len(), 128);
}
}
#[test]
fn trait_empty_input_returns_empty() {
let model = MockEmbeddingModel::new(64);
let vecs = model.embed(&[]).unwrap();
assert!(vecs.is_empty());
}
#[test]
fn trait_different_texts_produce_different_vectors() {
let model = MockEmbeddingModel::new(32);
let vecs = model.embed(&["hello", "world"]).unwrap();
assert_ne!(vecs[0], vecs[1]);
}
#[test]
fn trait_same_text_produces_same_vector() {
let model = MockEmbeddingModel::new(32);
let v1 = model.embed(&["hello"]).unwrap();
let v2 = model.embed(&["hello"]).unwrap();
assert_eq!(v1[0], v2[0]);
}
#[test]
fn trait_is_send_sync() {
let model: Arc<dyn EmbeddingModel> = Arc::new(MockEmbeddingModel::new(64));
let model_clone = Arc::clone(&model);
let handle =
std::thread::spawn(move || model_clone.embed(&["from another thread"]).unwrap());
let result = handle.join().unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), model.dimensions());
}
#[test]
fn database_register_and_embed() {
use crate::{Config, GrafeoDB};
let db = GrafeoDB::with_config(Config::in_memory()).unwrap();
let model: Arc<dyn EmbeddingModel> = Arc::new(MockEmbeddingModel::new(128));
db.register_embedding_model("test-model", model);
let vecs = db.embed_text("test-model", &["hello", "world"]).unwrap();
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0].len(), 128);
assert_eq!(vecs[1].len(), 128);
}
#[test]
fn database_embed_model_not_found() {
use crate::{Config, GrafeoDB};
let db = GrafeoDB::with_config(Config::in_memory()).unwrap();
let result = db.embed_text("nonexistent", &["hello"]);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("not registered"),
"Error should mention model not registered: {err}"
);
}
}