#![allow(missing_docs)]
pub mod loader;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use llama_gguf::{
backend::cpu::CpuBackend,
gguf::GgufFile,
model::{load_llama_model, EmbeddingConfig, EmbeddingExtractor, LlamaModel, PoolingStrategy},
tokenizer::Tokenizer,
HfClient,
};
use parking_lot::Mutex;
use crate::embedding::{EmbeddingProvider, EmbeddingVector};
pub use self::loader::GgufModelLoader;
pub use self::loader::{MODEL_DISPLAY_NAME, MODEL_SIZE_MB};
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EmbeddingDimension {
Dim128,
Dim256,
Dim512,
Dim768,
}
impl EmbeddingDimension {
pub fn size(&self) -> usize {
match self {
Self::Dim128 => 128,
Self::Dim256 => 256,
Self::Dim512 => 512,
Self::Dim768 => 768,
}
}
}
impl Default for EmbeddingDimension {
fn default() -> Self {
Self::Dim256
}
}
struct LoadedModel {
model: LlamaModel,
tokenizer: Tokenizer,
extractor: EmbeddingExtractor,
loaded_at: Instant,
}
pub struct GgufEmbeddingProvider {
model_dir: PathBuf,
dimension: EmbeddingDimension,
inner: Mutex<Option<LoadedModel>>,
model_ttl: Duration,
last_used: Mutex<Instant>,
}
impl std::fmt::Debug for GgufEmbeddingProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GgufEmbeddingProvider")
.field("model_dir", &self.model_dir)
.field("dimension", &self.dimension)
.field("model_ttl", &self.model_ttl)
.finish()
}
}
impl GgufEmbeddingProvider {
pub fn new(model_dir: PathBuf, dimension: EmbeddingDimension, model_ttl_secs: u64) -> Self {
Self {
model_dir,
dimension,
inner: Mutex::new(None),
model_ttl: Duration::from_secs(model_ttl_secs),
last_used: Mutex::new(Instant::now()),
}
}
pub fn with_defaults(model_dir: PathBuf) -> Self {
Self::new(model_dir, EmbeddingDimension::default(), 300)
}
fn ensure_loaded(&self) -> Result<()> {
{
let inner = self.inner.lock();
if inner.is_some() {
return Ok(());
}
}
let gguf_path = GgufModelLoader::ensure_model(&self.model_dir)
.context("Failed to download EmbeddingGemma GGUF model")?;
let gguf = GgufFile::open(&gguf_path)
.with_context(|| format!("Failed to open GGUF file: {}", gguf_path.display()))?;
let model = load_llama_model(&gguf_path)
.with_context(|| format!("Failed to load model from: {}", gguf_path.display()))?;
let tokenizer =
Tokenizer::from_gguf(&gguf).context("Failed to load tokenizer from GGUF")?;
let embed_config = EmbeddingConfig {
layer: -1, pooling: PoolingStrategy::Mean,
normalize: true,
max_length: 512,
..EmbeddingConfig::default()
};
let extractor = EmbeddingExtractor::new(embed_config, model.config());
let mut inner = self.inner.lock();
*inner = Some(LoadedModel {
model,
tokenizer,
extractor,
loaded_at: Instant::now(),
});
tracing::info!(
dir = %self.model_dir.display(),
dim = self.dimension.size(),
"GGUF EmbeddingGemma model loaded"
);
Ok(())
}
fn encode(&self, text: &str) -> Result<Vec<f32>> {
let mut inner = self.inner.lock();
let loaded = inner
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Model not loaded"))?;
let backend = CpuBackend::new();
let mut ctx = loaded.model.create_context(Arc::new(backend));
let embedding = loaded
.extractor
.embed_text(&loaded.model, &loaded.tokenizer, &mut ctx, text)
.map_err(|e| anyhow::anyhow!("Embedding extraction failed: {}", e))?;
let dim = self.dimension.size();
let truncated = if embedding.len() > dim {
embedding[..dim].to_vec()
} else {
embedding
};
let norm: f32 = truncated.iter().map(|x| x * x).sum::<f32>().sqrt();
let result = if norm > 1e-10 {
truncated.iter().map(|x| x / norm).collect()
} else {
truncated
};
Ok(result)
}
pub fn maybe_unload(&self) {
let mut inner = self.inner.lock();
if let Some(ref loaded) = *inner {
if loaded.loaded_at.elapsed() > self.model_ttl {
*inner = None;
tracing::debug!("GGUF embedding model unloaded (TTL expired)");
}
}
}
pub fn dimension(&self) -> usize {
self.dimension.size()
}
pub fn model_dir(&self) -> &PathBuf {
&self.model_dir
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for GgufEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<EmbeddingVector> {
self.ensure_loaded()?;
*self.last_used.lock() = Instant::now();
let vec = self.encode(text)?;
Ok(EmbeddingVector::DenseF32(vec))
}
fn name(&self) -> &str {
"gguf-embeddinggemma-300m"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_dimension_sizes() {
assert_eq!(EmbeddingDimension::Dim128.size(), 128);
assert_eq!(EmbeddingDimension::Dim256.size(), 256);
assert_eq!(EmbeddingDimension::Dim512.size(), 512);
assert_eq!(EmbeddingDimension::Dim768.size(), 768);
}
#[test]
fn test_default_dimension() {
assert_eq!(EmbeddingDimension::default().size(), 256);
}
#[test]
fn test_provider_creation() {
let provider =
GgufEmbeddingProvider::with_defaults(PathBuf::from("/tmp/test-models/embedding"));
assert_eq!(provider.dimension(), 256);
assert_eq!(provider.name(), "gguf-embeddinggemma-300m");
}
#[test]
fn test_provider_debug() {
let provider = GgufEmbeddingProvider::with_defaults(PathBuf::from("/tmp/test"));
let debug_str = format!("{:?}", provider);
assert!(debug_str.contains("GgufEmbeddingProvider"));
assert!(debug_str.contains("Dim256"));
}
#[test]
fn test_maybe_unload_noop_when_not_loaded() {
let provider = GgufEmbeddingProvider::with_defaults(PathBuf::from("/tmp/test"));
provider.maybe_unload(); }
#[tokio::test]
#[ignore = "requires model download (~329MB)"]
async fn test_embed_produces_dense_vector() {
let dir = dirs::home_dir()
.unwrap()
.join(".oxios")
.join("models")
.join("embeddinggemma-300m");
let provider = GgufEmbeddingProvider::new(dir, EmbeddingDimension::Dim256, 300);
let vec = provider.embed("Rust programming language").await.unwrap();
match vec {
EmbeddingVector::DenseF32(v) => {
assert_eq!(v.len(), 256, "Should produce 256-dim vector");
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01, "Should be L2 normalized");
}
_ => panic!("Expected DenseF32"),
}
}
#[tokio::test]
#[ignore = "requires model download (~329MB)"]
async fn test_embed_korean() {
let dir = dirs::home_dir()
.unwrap()
.join(".oxios")
.join("models")
.join("embeddinggemma-300m");
let provider = GgufEmbeddingProvider::new(dir, EmbeddingDimension::Dim256, 300);
let vec = provider.embed("한국어 임베딩 테스트").await.unwrap();
if let EmbeddingVector::DenseF32(v) = vec {
assert_eq!(v.len(), 256);
assert!(v.iter().any(|&x| x != 0.0), "Should not be all zeros");
}
}
}