agentroot_core/llm/
llama.rs1use super::Embedder;
4use crate::error::{AgentRootError, Result};
5use async_trait::async_trait;
6use llama_cpp_2::{
7 context::params::LlamaContextParams,
8 llama_backend::LlamaBackend,
9 llama_batch::LlamaBatch,
10 model::{params::LlamaModelParams, LlamaModel},
11};
12use std::path::Path;
13use std::sync::Mutex;
14
15pub const DEFAULT_EMBED_MODEL: &str = "nomic-embed-text-v1.5.Q4_K_M.gguf";
17
18pub struct LlamaEmbedder {
20 #[allow(dead_code)]
21 backend: LlamaBackend,
22 model: LlamaModel,
23 context: Mutex<LlamaEmbedderContext>,
24 model_name: String,
25 dimensions: usize,
26}
27
28struct LlamaEmbedderContext {
29 ctx: llama_cpp_2::context::LlamaContext<'static>,
30}
31
32unsafe impl Send for LlamaEmbedderContext {}
33unsafe impl Sync for LlamaEmbedderContext {}
34
35impl LlamaEmbedder {
36 pub fn new(model_path: impl AsRef<Path>) -> Result<Self> {
38 let model_path = model_path.as_ref();
39 let model_name = model_path
40 .file_stem()
41 .and_then(|s| s.to_str())
42 .unwrap_or("unknown")
43 .to_string();
44
45 let mut backend = LlamaBackend::init()
47 .map_err(|e| AgentRootError::Llm(format!("Failed to init backend: {}", e)))?;
48 backend.void_logs();
49
50 let model_params = LlamaModelParams::default();
52 let model = LlamaModel::load_from_file(&backend, model_path, &model_params)
53 .map_err(|e| AgentRootError::Llm(format!("Failed to load model: {}", e)))?;
54
55 let dimensions = model.n_embd() as usize;
56
57 let ctx_size = std::num::NonZeroU32::new(2048).unwrap();
60 let ctx_params = LlamaContextParams::default()
61 .with_embeddings(true)
62 .with_n_ctx(Some(ctx_size))
63 .with_n_batch(ctx_size.get())
64 .with_n_ubatch(ctx_size.get());
65
66 let ctx = model
67 .new_context(&backend, ctx_params)
68 .map_err(|e| AgentRootError::Llm(format!("Failed to create context: {}", e)))?;
69
70 let ctx: llama_cpp_2::context::LlamaContext<'static> = unsafe { std::mem::transmute(ctx) };
73
74 Ok(Self {
75 backend,
76 model,
77 context: Mutex::new(LlamaEmbedderContext { ctx }),
78 model_name,
79 dimensions,
80 })
81 }
82
83 pub fn from_default() -> Result<Self> {
85 let model_dir = dirs::data_dir()
86 .unwrap_or_else(|| std::path::PathBuf::from("."))
87 .join("agentroot")
88 .join("models");
89
90 let model_path = model_dir.join(DEFAULT_EMBED_MODEL);
91
92 if !model_path.exists() {
93 return Err(AgentRootError::ModelNotFound(format!(
94 "Model not found at {}. Download an embedding model (e.g., nomic-embed-text) to this location.",
95 model_path.display()
96 )));
97 }
98
99 Self::new(model_path)
100 }
101
102 fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
103 let mut ctx_guard = self
104 .context
105 .lock()
106 .map_err(|e| AgentRootError::Llm(format!("Lock error: {}", e)))?;
107
108 let tokens = self
110 .model
111 .str_to_token(text, llama_cpp_2::model::AddBos::Always)
112 .map_err(|e| AgentRootError::Llm(format!("Tokenization error: {}", e)))?;
113
114 if tokens.is_empty() {
115 return Ok(vec![0.0; self.dimensions]);
116 }
117
118 let mut batch = LlamaBatch::new(tokens.len(), 1);
120
121 for (i, token) in tokens.iter().enumerate() {
122 batch
123 .add(*token, i as i32, &[0], i == tokens.len() - 1)
124 .map_err(|e| AgentRootError::Llm(format!("Batch error: {}", e)))?;
125 }
126
127 ctx_guard
129 .ctx
130 .encode(&mut batch)
131 .map_err(|e| AgentRootError::Llm(format!("Encode error: {}", e)))?;
132
133 let embeddings = ctx_guard
135 .ctx
136 .embeddings_seq_ith(0)
137 .map_err(|e| AgentRootError::Llm(format!("Embeddings error: {}", e)))?;
138
139 let norm: f32 = embeddings.iter().map(|x| x * x).sum::<f32>().sqrt();
141 let normalized: Vec<f32> = if norm > 0.0 {
142 embeddings.iter().map(|x| x / norm).collect()
143 } else {
144 embeddings.to_vec()
145 };
146
147 Ok(normalized)
148 }
149}
150
151#[async_trait]
152impl Embedder for LlamaEmbedder {
153 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
154 self.embed_sync(text)
156 }
157
158 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
159 let mut results = Vec::with_capacity(texts.len());
161 for text in texts {
162 results.push(self.embed_sync(text)?);
163 }
164 Ok(results)
165 }
166
167 fn dimensions(&self) -> usize {
168 self.dimensions
169 }
170
171 fn model_name(&self) -> &str {
172 &self.model_name
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_default_model_path() {
182 let model_dir = dirs::data_dir()
183 .unwrap_or_else(|| std::path::PathBuf::from("."))
184 .join("agentroot")
185 .join("models");
186 let model_path = model_dir.join(DEFAULT_EMBED_MODEL);
187 println!("Expected model path: {}", model_path.display());
188 }
189}