use super::Embedder;
use crate::error::{AgentRootError, Result};
use async_trait::async_trait;
use llama_cpp_2::{
context::params::LlamaContextParams,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, LlamaModel},
};
use std::path::Path;
use std::sync::Mutex;
pub const DEFAULT_EMBED_MODEL: &str = "nomic-embed-text-v1.5.Q4_K_M.gguf";
pub struct LlamaEmbedder {
#[allow(dead_code)]
backend: LlamaBackend,
model: LlamaModel,
context: Mutex<LlamaEmbedderContext>,
model_name: String,
dimensions: usize,
}
struct LlamaEmbedderContext {
ctx: llama_cpp_2::context::LlamaContext<'static>,
}
unsafe impl Send for LlamaEmbedderContext {}
unsafe impl Sync for LlamaEmbedderContext {}
impl LlamaEmbedder {
pub fn new(model_path: impl AsRef<Path>) -> Result<Self> {
let model_path = model_path.as_ref();
let model_name = model_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let mut backend = LlamaBackend::init()
.map_err(|e| AgentRootError::Llm(format!("Failed to init backend: {}", e)))?;
backend.void_logs();
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
.map_err(|e| AgentRootError::Llm(format!("Failed to load model: {}", e)))?;
let dimensions = model.n_embd() as usize;
let ctx_size = std::num::NonZeroU32::new(2048).unwrap();
let ctx_params = LlamaContextParams::default()
.with_embeddings(true)
.with_n_ctx(Some(ctx_size))
.with_n_batch(ctx_size.get())
.with_n_ubatch(ctx_size.get());
let ctx = model
.new_context(&backend, ctx_params)
.map_err(|e| AgentRootError::Llm(format!("Failed to create context: {}", e)))?;
let ctx: llama_cpp_2::context::LlamaContext<'static> = unsafe { std::mem::transmute(ctx) };
Ok(Self {
backend,
model,
context: Mutex::new(LlamaEmbedderContext { ctx }),
model_name,
dimensions,
})
}
pub fn from_default() -> Result<Self> {
let model_dir = dirs::data_dir()
.unwrap_or_else(|| std::path::PathBuf::from("."))
.join("agentroot")
.join("models");
let model_path = model_dir.join(DEFAULT_EMBED_MODEL);
if !model_path.exists() {
return Err(AgentRootError::ModelNotFound(format!(
"Model not found at {}. Download an embedding model (e.g., nomic-embed-text) to this location.",
model_path.display()
)));
}
Self::new(model_path)
}
fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
let mut ctx_guard = self
.context
.lock()
.map_err(|e| AgentRootError::Llm(format!("Lock error: {}", e)))?;
let tokens = self
.model
.str_to_token(text, llama_cpp_2::model::AddBos::Always)
.map_err(|e| AgentRootError::Llm(format!("Tokenization error: {}", e)))?;
if tokens.is_empty() {
return Ok(vec![0.0; self.dimensions]);
}
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, token) in tokens.iter().enumerate() {
batch
.add(*token, i as i32, &[0], i == tokens.len() - 1)
.map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
}
ctx_guard
.ctx
.encode(&mut batch)
.map_err(|e| AgentRootError::Llm(format!("Encode error: {}", e)))?;
let embeddings = ctx_guard
.ctx
.embeddings_seq_ith(0)
.map_err(|e| AgentRootError::Llm(format!("Embeddings error: {}", e)))?;
let norm: f32 = embeddings.iter().map(|x| x * x).sum::<f32>().sqrt();
let normalized: Vec<f32> = if norm > 0.0 {
embeddings.iter().map(|x| x / norm).collect()
} else {
embeddings.to_vec()
};
Ok(normalized)
}
}
#[async_trait]
impl Embedder for LlamaEmbedder {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.embed_sync(text)
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed_sync(text)?);
}
Ok(results)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn model_name(&self) -> &str {
&self.model_name
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_model_path() {
let model_dir = dirs::data_dir()
.unwrap_or_else(|| std::path::PathBuf::from("."))
.join("agentroot")
.join("models");
let model_path = model_dir.join(DEFAULT_EMBED_MODEL);
println!("Expected model path: {}", model_path.display());
}
}