1use std::path::PathBuf;
7use std::sync::Mutex;
8
9use anyhow::{anyhow, Result};
10use llama_cpp::standard_sampler::StandardSampler;
11use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
12use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
13use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
14use tokio::runtime::Runtime;
15use tracing::{debug, warn};
16
17const EXTRACTION_PROMPT: &str = r#"<|system|>
19You are a memory extraction assistant. Your task is to extract structured facts from text.
20
21For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
22MEMORY_START
23kind: <Fact|Preference|Event|Profile|Relationship|Other>
24entity: <the main entity this memory is about>
25slot: <a short key describing what aspect of the entity>
26value: <the actual information>
27polarity: <Positive|Negative|Neutral>
28MEMORY_END
29
30Only extract information that is explicitly stated. Do not infer or guess.
31If there are no clear facts to extract, output MEMORY_NONE.
32<|end|>
33<|user|>
34Extract memories from this text:
35
36{text}
37<|end|>
38<|assistant|>
39"#;
40
41const MAX_CONTEXT_TOKENS: u32 = 4096;
43
44const MAX_OUTPUT_TOKENS: usize = 1024;
46
47const MAX_INPUT_CHARS: usize = 8192;
49
50pub struct LlmEngine {
52 model_path: PathBuf,
54 model: Mutex<Option<LlamaModel>>,
56 ready: bool,
58 version: String,
60}
61
62impl LlmEngine {
63 pub fn new(model_path: PathBuf) -> Self {
65 Self {
66 model_path,
67 model: Mutex::new(None),
68 ready: false,
69 version: "1.0.0".to_string(),
70 }
71 }
72
73 fn load_model(&self) -> Result<LlamaModel> {
75 if !self.model_path.exists() {
76 return Err(anyhow!(
77 "Model file not found: {}",
78 self.model_path.display()
79 ));
80 }
81
82 unsafe {
84 std::env::set_var("GGML_LOG_LEVEL", "ERROR");
85 std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
86 }
87
88 debug!("Loading model from {}", self.model_path.display());
89
90 LlamaModel::load_from_file(&self.model_path, LlamaParams::default())
91 .map_err(|err| anyhow!("Failed to load Phi-3.5 model: {}", err))
92 }
93
94 fn run_inference(&self, text: &str) -> Result<String> {
96 let model_guard = self
97 .model
98 .lock()
99 .map_err(|_| anyhow!("Model lock poisoned"))?;
100
101 let model = model_guard
102 .as_ref()
103 .ok_or_else(|| anyhow!("LLM engine not initialized. Call init() first."))?;
104
105 let truncated_text = if text.len() > MAX_INPUT_CHARS {
107 &text[..MAX_INPUT_CHARS]
108 } else {
109 text
110 };
111
112 let prompt = EXTRACTION_PROMPT.replace("{text}", truncated_text);
114
115 let mut session_params = SessionParams::default();
117 session_params.n_ctx = MAX_CONTEXT_TOKENS;
118 session_params.n_batch = 512.min(MAX_CONTEXT_TOKENS);
119 if session_params.n_ubatch == 0 {
120 session_params.n_ubatch = 512;
121 }
122
123 let mut session = model
124 .create_session(session_params)
125 .map_err(|err| anyhow!("Failed to create LLM session: {}", err))?;
126
127 let mut tokens = model
129 .tokenize_bytes(prompt.as_bytes(), true, true)
130 .map_err(|err| anyhow!("Failed to tokenize prompt: {}", err))?;
131
132 let max_tokens = MAX_CONTEXT_TOKENS as usize;
134 let reserved = MAX_OUTPUT_TOKENS + 64;
135 if tokens.len() >= max_tokens.saturating_sub(reserved) {
136 let target = max_tokens.saturating_sub(reserved).max(1);
137 let tail_start = tokens.len().saturating_sub(target);
138 tokens = tokens.split_off(tail_start);
139 }
140
141 session
143 .advance_context_with_tokens(&tokens)
144 .map_err(|err| anyhow!("Failed to prime LLM context: {}", err))?;
145
146 let handle = session
148 .start_completing_with(StandardSampler::default(), MAX_OUTPUT_TOKENS)
149 .map_err(|err| anyhow!("Failed to start LLM completion: {}", err))?;
150
151 let runtime =
153 Runtime::new().map_err(|err| anyhow!("Failed to create tokio runtime: {}", err))?;
154
155 let generated = runtime.block_on(async { handle.into_string_async().await });
156
157 Ok(generated.trim().to_string())
158 }
159
160 fn parse_output(&self, output: &str, ctx: &EnrichmentContext) -> Vec<MemoryCard> {
162 let mut cards = Vec::new();
163
164 if output.contains("MEMORY_NONE") {
166 return cards;
167 }
168
169 for block in output.split("MEMORY_START") {
171 let block = block.trim();
172 if block.is_empty() || !block.contains("MEMORY_END") {
173 continue;
174 }
175
176 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
177
178 let mut kind = None;
180 let mut entity = None;
181 let mut slot = None;
182 let mut value = None;
183 let mut polarity = Polarity::Neutral;
184
185 for line in block.lines() {
186 let line = line.trim();
187 if let Some(rest) = line.strip_prefix("kind:") {
188 kind = parse_memory_kind(rest.trim());
189 } else if let Some(rest) = line.strip_prefix("entity:") {
190 entity = Some(rest.trim().to_string());
191 } else if let Some(rest) = line.strip_prefix("slot:") {
192 slot = Some(rest.trim().to_string());
193 } else if let Some(rest) = line.strip_prefix("value:") {
194 value = Some(rest.trim().to_string());
195 } else if let Some(rest) = line.strip_prefix("polarity:") {
196 polarity = parse_polarity(rest.trim());
197 }
198 }
199
200 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
202 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
203 match MemoryCardBuilder::new()
205 .kind(k)
206 .entity(&e)
207 .slot(&s)
208 .value(&v)
209 .polarity(polarity)
210 .source(ctx.frame_id, Some(ctx.uri.clone()))
211 .document_date(ctx.timestamp)
212 .engine("llm:phi-3.5-mini", "1.0.0")
213 .build(0)
214 {
215 Ok(card) => cards.push(card),
216 Err(err) => {
217 warn!("Failed to build memory card: {}", err);
218 }
219 }
220 }
221 }
222 }
223
224 cards
225 }
226}
227
228fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
230 match s.to_lowercase().as_str() {
231 "fact" => Some(MemoryKind::Fact),
232 "preference" => Some(MemoryKind::Preference),
233 "event" => Some(MemoryKind::Event),
234 "profile" => Some(MemoryKind::Profile),
235 "relationship" => Some(MemoryKind::Relationship),
236 "other" => Some(MemoryKind::Other),
237 _ => None,
238 }
239}
240
241fn parse_polarity(s: &str) -> Polarity {
243 match s.to_lowercase().as_str() {
244 "positive" => Polarity::Positive,
245 "negative" => Polarity::Negative,
246 _ => Polarity::Neutral,
247 }
248}
249
250impl EnrichmentEngine for LlmEngine {
251 fn kind(&self) -> &str {
252 "llm:phi-3.5-mini"
253 }
254
255 fn version(&self) -> &str {
256 &self.version
257 }
258
259 fn init(&mut self) -> memvid_core::Result<()> {
260 let model = self
261 .load_model()
262 .map_err(|err| memvid_core::MemvidError::EmbeddingFailed {
263 reason: format!("{}", err).into_boxed_str(),
264 })?;
265 *self
266 .model
267 .lock()
268 .map_err(|_| memvid_core::MemvidError::EmbeddingFailed {
269 reason: "Model lock poisoned".into(),
270 })? = Some(model);
271 self.ready = true;
272 Ok(())
273 }
274
275 fn is_ready(&self) -> bool {
276 self.ready
277 }
278
279 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
280 if ctx.text.is_empty() {
281 return EnrichmentResult::empty();
282 }
283
284 match self.run_inference(&ctx.text) {
285 Ok(output) => {
286 debug!("LLM output for frame {}: {}", ctx.frame_id, output);
287 let cards = self.parse_output(&output, ctx);
288 EnrichmentResult::success(cards)
289 }
290 Err(err) => EnrichmentResult::failed(format!("LLM inference failed: {}", err)),
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_parse_memory_kind() {
301 assert_eq!(parse_memory_kind("Fact"), Some(MemoryKind::Fact));
302 assert_eq!(
303 parse_memory_kind("PREFERENCE"),
304 Some(MemoryKind::Preference)
305 );
306 assert_eq!(parse_memory_kind("event"), Some(MemoryKind::Event));
307 assert_eq!(parse_memory_kind("invalid"), None);
308 }
309
310 #[test]
311 fn test_parse_polarity() {
312 assert_eq!(parse_polarity("Positive"), Polarity::Positive);
313 assert_eq!(parse_polarity("NEGATIVE"), Polarity::Negative);
314 assert_eq!(parse_polarity("Neutral"), Polarity::Neutral);
315 assert_eq!(parse_polarity("unknown"), Polarity::Neutral);
316 }
317
318 #[test]
319 fn test_parse_output() {
320 let engine = LlmEngine::new(PathBuf::from("/tmp/test.gguf"));
321 let ctx = EnrichmentContext::new(
322 1,
323 "mv2://test/1".to_string(),
324 "Test text".to_string(),
325 None,
326 1700000000,
327 None,
328 );
329
330 let output = r#"
332MEMORY_START
333kind: Fact
334entity: John
335slot: employer
336value: Anthropic
337polarity: Positive
338MEMORY_END
339"#;
340 let cards = engine.parse_output(output, &ctx);
341 assert_eq!(cards.len(), 1);
342 assert_eq!(cards[0].entity, "John");
343 assert_eq!(cards[0].slot, "employer");
344 assert_eq!(cards[0].value, "Anthropic");
345 assert_eq!(cards[0].kind, MemoryKind::Fact);
346 }
347
348 #[test]
349 fn test_parse_output_none() {
350 let engine = LlmEngine::new(PathBuf::from("/tmp/test.gguf"));
351 let ctx = EnrichmentContext::new(
352 1,
353 "mv2://test/1".to_string(),
354 "Test text".to_string(),
355 None,
356 1700000000,
357 None,
358 );
359
360 let output = "MEMORY_NONE";
361 let cards = engine.parse_output(output, &ctx);
362 assert!(cards.is_empty());
363 }
364
365 #[test]
366 fn test_parse_output_multiple() {
367 let engine = LlmEngine::new(PathBuf::from("/tmp/test.gguf"));
368 let ctx = EnrichmentContext::new(
369 1,
370 "mv2://test/1".to_string(),
371 "Test text".to_string(),
372 None,
373 1700000000,
374 None,
375 );
376
377 let output = r#"
378MEMORY_START
379kind: Fact
380entity: Alice
381slot: role
382value: Engineer
383polarity: Neutral
384MEMORY_END
385
386MEMORY_START
387kind: Preference
388entity: Bob
389slot: drink
390value: Coffee
391polarity: Positive
392MEMORY_END
393"#;
394 let cards = engine.parse_output(output, &ctx);
395 assert_eq!(cards.len(), 2);
396 assert_eq!(cards[0].entity, "Alice");
397 assert_eq!(cards[1].entity, "Bob");
398 }
399}