agentroot_core/llm/
llama.rs

1//! LLaMA-based embedder using llama-cpp-2
2
3use super::Embedder;
4use crate::error::{AgentRootError, Result};
5use async_trait::async_trait;
6use llama_cpp_2::{
7    context::params::LlamaContextParams,
8    llama_backend::LlamaBackend,
9    llama_batch::LlamaBatch,
10    model::{params::LlamaModelParams, LlamaModel},
11};
12use std::path::Path;
13use std::sync::Mutex;
14
15/// Default embedding model (nomic-embed-text or similar)
16pub const DEFAULT_EMBED_MODEL: &str = "nomic-embed-text-v1.5.Q4_K_M.gguf";
17
18/// LLaMA-based embedder
19pub struct LlamaEmbedder {
20    #[allow(dead_code)]
21    backend: LlamaBackend,
22    model: LlamaModel,
23    context: Mutex<LlamaEmbedderContext>,
24    model_name: String,
25    dimensions: usize,
26}
27
28struct LlamaEmbedderContext {
29    ctx: llama_cpp_2::context::LlamaContext<'static>,
30}
31
32unsafe impl Send for LlamaEmbedderContext {}
33unsafe impl Sync for LlamaEmbedderContext {}
34
35impl LlamaEmbedder {
36    /// Create a new LlamaEmbedder from a GGUF model file
37    pub fn new(model_path: impl AsRef<Path>) -> Result<Self> {
38        let model_path = model_path.as_ref();
39        let model_name = model_path
40            .file_stem()
41            .and_then(|s| s.to_str())
42            .unwrap_or("unknown")
43            .to_string();
44
45        // Initialize backend and suppress verbose output
46        let mut backend = LlamaBackend::init()
47            .map_err(|e| AgentRootError::Llm(format!("Failed to init backend: {}", e)))?;
48        backend.void_logs();
49
50        // Load model
51        let model_params = LlamaModelParams::default();
52        let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
53            .map_err(|e| AgentRootError::Llm(format!("Failed to load model: {}", e)))?;
54
55        let dimensions = model.n_embd() as usize;
56
57        // Create context with embeddings enabled
58        // n_batch and n_ubatch must be >= n_tokens for encoder models
59        let ctx_size = std::num::NonZeroU32::new(2048).unwrap();
60        let ctx_params = LlamaContextParams::default()
61            .with_embeddings(true)
62            .with_n_ctx(Some(ctx_size))
63            .with_n_batch(ctx_size.get())
64            .with_n_ubatch(ctx_size.get());
65
66        let ctx = model
67            .new_context(&backend, ctx_params)
68            .map_err(|e| AgentRootError::Llm(format!("Failed to create context: {}", e)))?;
69
70        // SAFETY: We're storing the model alongside the context and ensuring
71        // the model outlives the context through the struct layout
72        let ctx: llama_cpp_2::context::LlamaContext<'static> = unsafe { std::mem::transmute(ctx) };
73
74        Ok(Self {
75            backend,
76            model,
77            context: Mutex::new(LlamaEmbedderContext { ctx }),
78            model_name,
79            dimensions,
80        })
81    }
82
83    /// Create from default model location
84    pub fn from_default() -> Result<Self> {
85        let model_dir = dirs::data_dir()
86            .unwrap_or_else(|| std::path::PathBuf::from("."))
87            .join("agentroot")
88            .join("models");
89
90        let model_path = model_dir.join(DEFAULT_EMBED_MODEL);
91
92        if !model_path.exists() {
93            return Err(AgentRootError::ModelNotFound(format!(
94                "Model not found at {}. Download an embedding model (e.g., nomic-embed-text) to this location.",
95                model_path.display()
96            )));
97        }
98
99        Self::new(model_path)
100    }
101
102    fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
103        let mut ctx_guard = self
104            .context
105            .lock()
106            .map_err(|e| AgentRootError::Llm(format!("Lock error: {}", e)))?;
107
108        // Tokenize
109        let tokens = self
110            .model
111            .str_to_token(text, llama_cpp_2::model::AddBos::Always)
112            .map_err(|e| AgentRootError::Llm(format!("Tokenization error: {}", e)))?;
113
114        if tokens.is_empty() {
115            return Ok(vec![0.0; self.dimensions]);
116        }
117
118        // Create batch
119        let mut batch = LlamaBatch::new(tokens.len(), 1);
120
121        for (i, token) in tokens.iter().enumerate() {
122            batch
123                .add(*token, i as i32, &[0], i == tokens.len() - 1)
124                .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
125        }
126
127        // Encode (for embeddings)
128        ctx_guard
129            .ctx
130            .encode(&mut batch)
131            .map_err(|e| AgentRootError::Llm(format!("Encode error: {}", e)))?;
132
133        // Get embeddings (sequence-level pooled embedding)
134        let embeddings = ctx_guard
135            .ctx
136            .embeddings_seq_ith(0)
137            .map_err(|e| AgentRootError::Llm(format!("Embeddings error: {}", e)))?;
138
139        // Normalize the embedding
140        let norm: f32 = embeddings.iter().map(|x| x * x).sum::<f32>().sqrt();
141        let normalized: Vec<f32> = if norm > 0.0 {
142            embeddings.iter().map(|x| x / norm).collect()
143        } else {
144            embeddings.to_vec()
145        };
146
147        Ok(normalized)
148    }
149}
150
151#[async_trait]
152impl Embedder for LlamaEmbedder {
153    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
154        // Run synchronously since llama-cpp context is not async
155        self.embed_sync(text)
156    }
157
158    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
159        // Process sequentially (context is not thread-safe)
160        let mut results = Vec::with_capacity(texts.len());
161        for text in texts {
162            results.push(self.embed_sync(text)?);
163        }
164        Ok(results)
165    }
166
167    fn dimensions(&self) -> usize {
168        self.dimensions
169    }
170
171    fn model_name(&self) -> &str {
172        &self.model_name
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_default_model_path() {
182        let model_dir = dirs::data_dir()
183            .unwrap_or_else(|| std::path::PathBuf::from("."))
184            .join("agentroot")
185            .join("models");
186        let model_path = model_dir.join(DEFAULT_EMBED_MODEL);
187        println!("Expected model path: {}", model_path.display());
188    }
189}