Skip to main content

memvid_cli/enrich/
llm.rs

1//! LLM-based enrichment engine using Phi-3.5 Mini.
2//!
3//! This engine uses a local Phi-3.5 model to extract structured memory cards
4//! from text content through prompted inference.
5
6use std::path::PathBuf;
7use std::sync::Mutex;
8
9use anyhow::{anyhow, Result};
10use llama_cpp::standard_sampler::StandardSampler;
11use llama_cpp::{LlamaModel, LlamaParams, SessionParams};
12use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
13use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
14use tokio::runtime::Runtime;
15use tracing::{debug, warn};
16
17/// The prompt template for Phi-3.5 memory extraction.
18const EXTRACTION_PROMPT: &str = r#"<|system|>
19You are a memory extraction assistant. Your task is to extract structured facts from text.
20
21For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
22MEMORY_START
23kind: <Fact|Preference|Event|Profile|Relationship|Other>
24entity: <the main entity this memory is about>
25slot: <a short key describing what aspect of the entity>
26value: <the actual information>
27polarity: <Positive|Negative|Neutral>
28MEMORY_END
29
30Only extract information that is explicitly stated. Do not infer or guess.
31If there are no clear facts to extract, output MEMORY_NONE.
32<|end|>
33<|user|>
34Extract memories from this text:
35
36{text}
37<|end|>
38<|assistant|>
39"#;
40
41/// Maximum context window size for Phi-3.5
42const MAX_CONTEXT_TOKENS: u32 = 4096;
43
44/// Maximum tokens to generate per extraction
45const MAX_OUTPUT_TOKENS: usize = 1024;
46
47/// Maximum input text length (characters) to process
48const MAX_INPUT_CHARS: usize = 8192;
49
50/// LLM enrichment engine using Phi-3.5 Mini.
51pub struct LlmEngine {
52    /// Path to the GGUF model file.
53    model_path: PathBuf,
54    /// Loaded model (lazy initialization).
55    model: Mutex<Option<LlamaModel>>,
56    /// Whether the engine is initialized.
57    ready: bool,
58    /// Engine version.
59    version: String,
60}
61
62impl LlmEngine {
63    /// Create a new LLM engine with the specified model path.
64    pub fn new(model_path: PathBuf) -> Self {
65        Self {
66            model_path,
67            model: Mutex::new(None),
68            ready: false,
69            version: "1.0.0".to_string(),
70        }
71    }
72
73    /// Load the model from disk.
74    fn load_model(&self) -> Result<LlamaModel> {
75        if !self.model_path.exists() {
76            return Err(anyhow!(
77                "Model file not found: {}",
78                self.model_path.display()
79            ));
80        }
81
82        // Suppress llama.cpp logging
83        unsafe {
84            std::env::set_var("GGML_LOG_LEVEL", "ERROR");
85            std::env::set_var("LLAMA_LOG_LEVEL", "ERROR");
86        }
87
88        debug!("Loading model from {}", self.model_path.display());
89
90        LlamaModel::load_from_file(&self.model_path, LlamaParams::default())
91            .map_err(|err| anyhow!("Failed to load Phi-3.5 model: {}", err))
92    }
93
94    /// Run inference on the given text and return raw output.
95    fn run_inference(&self, text: &str) -> Result<String> {
96        let model_guard = self
97            .model
98            .lock()
99            .map_err(|_| anyhow!("Model lock poisoned"))?;
100
101        let model = model_guard
102            .as_ref()
103            .ok_or_else(|| anyhow!("LLM engine not initialized. Call init() first."))?;
104
105        // Truncate input if too long
106        let truncated_text = if text.len() > MAX_INPUT_CHARS {
107            &text[..MAX_INPUT_CHARS]
108        } else {
109            text
110        };
111
112        // Build the prompt
113        let prompt = EXTRACTION_PROMPT.replace("{text}", truncated_text);
114
115        // Create session
116        let mut session_params = SessionParams::default();
117        session_params.n_ctx = MAX_CONTEXT_TOKENS;
118        session_params.n_batch = 512.min(MAX_CONTEXT_TOKENS);
119        if session_params.n_ubatch == 0 {
120            session_params.n_ubatch = 512;
121        }
122
123        let mut session = model
124            .create_session(session_params)
125            .map_err(|err| anyhow!("Failed to create LLM session: {}", err))?;
126
127        // Tokenize the prompt
128        let mut tokens = model
129            .tokenize_bytes(prompt.as_bytes(), true, true)
130            .map_err(|err| anyhow!("Failed to tokenize prompt: {}", err))?;
131
132        // Ensure we have room for output
133        let max_tokens = MAX_CONTEXT_TOKENS as usize;
134        let reserved = MAX_OUTPUT_TOKENS + 64;
135        if tokens.len() >= max_tokens.saturating_sub(reserved) {
136            let target = max_tokens.saturating_sub(reserved).max(1);
137            let tail_start = tokens.len().saturating_sub(target);
138            tokens = tokens.split_off(tail_start);
139        }
140
141        // Prime the context
142        session
143            .advance_context_with_tokens(&tokens)
144            .map_err(|err| anyhow!("Failed to prime LLM context: {}", err))?;
145
146        // Generate completion
147        let handle = session
148            .start_completing_with(StandardSampler::default(), MAX_OUTPUT_TOKENS)
149            .map_err(|err| anyhow!("Failed to start LLM completion: {}", err))?;
150
151        // Run async generation synchronously
152        let runtime =
153            Runtime::new().map_err(|err| anyhow!("Failed to create tokio runtime: {}", err))?;
154
155        let generated = runtime.block_on(async { handle.into_string_async().await });
156
157        Ok(generated.trim().to_string())
158    }
159
160    /// Parse the LLM output into memory cards.
161    fn parse_output(&self, output: &str, ctx: &EnrichmentContext) -> Vec<MemoryCard> {
162        let mut cards = Vec::new();
163
164        // Check for "no memories" signal
165        if output.contains("MEMORY_NONE") {
166            return cards;
167        }
168
169        // Parse MEMORY_START...MEMORY_END blocks
170        for block in output.split("MEMORY_START") {
171            let block = block.trim();
172            if block.is_empty() || !block.contains("MEMORY_END") {
173                continue;
174            }
175
176            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
177
178            // Parse fields
179            let mut kind = None;
180            let mut entity = None;
181            let mut slot = None;
182            let mut value = None;
183            let mut polarity = Polarity::Neutral;
184
185            for line in block.lines() {
186                let line = line.trim();
187                if let Some(rest) = line.strip_prefix("kind:") {
188                    kind = parse_memory_kind(rest.trim());
189                } else if let Some(rest) = line.strip_prefix("entity:") {
190                    entity = Some(rest.trim().to_string());
191                } else if let Some(rest) = line.strip_prefix("slot:") {
192                    slot = Some(rest.trim().to_string());
193                } else if let Some(rest) = line.strip_prefix("value:") {
194                    value = Some(rest.trim().to_string());
195                } else if let Some(rest) = line.strip_prefix("polarity:") {
196                    polarity = parse_polarity(rest.trim());
197                }
198            }
199
200            // Build memory card if we have required fields
201            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
202                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
203                    // ID 0 will be reassigned when added to MemoriesTrack
204                    match MemoryCardBuilder::new()
205                        .kind(k)
206                        .entity(&e)
207                        .slot(&s)
208                        .value(&v)
209                        .polarity(polarity)
210                        .source(ctx.frame_id, Some(ctx.uri.clone()))
211                        .document_date(ctx.timestamp)
212                        .engine("llm:phi-3.5-mini", "1.0.0")
213                        .build(0)
214                    {
215                        Ok(card) => cards.push(card),
216                        Err(err) => {
217                            warn!("Failed to build memory card: {}", err);
218                        }
219                    }
220                }
221            }
222        }
223
224        cards
225    }
226}
227
228/// Parse a memory kind string into the enum.
229fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
230    match s.to_lowercase().as_str() {
231        "fact" => Some(MemoryKind::Fact),
232        "preference" => Some(MemoryKind::Preference),
233        "event" => Some(MemoryKind::Event),
234        "profile" => Some(MemoryKind::Profile),
235        "relationship" => Some(MemoryKind::Relationship),
236        "other" => Some(MemoryKind::Other),
237        _ => None,
238    }
239}
240
241/// Parse a polarity string into the enum.
242fn parse_polarity(s: &str) -> Polarity {
243    match s.to_lowercase().as_str() {
244        "positive" => Polarity::Positive,
245        "negative" => Polarity::Negative,
246        _ => Polarity::Neutral,
247    }
248}
249
250impl EnrichmentEngine for LlmEngine {
251    fn kind(&self) -> &str {
252        "llm:phi-3.5-mini"
253    }
254
255    fn version(&self) -> &str {
256        &self.version
257    }
258
259    fn init(&mut self) -> memvid_core::Result<()> {
260        let model = self
261            .load_model()
262            .map_err(|err| memvid_core::MemvidError::EmbeddingFailed {
263                reason: format!("{}", err).into_boxed_str(),
264            })?;
265        *self
266            .model
267            .lock()
268            .map_err(|_| memvid_core::MemvidError::EmbeddingFailed {
269                reason: "Model lock poisoned".into(),
270            })? = Some(model);
271        self.ready = true;
272        Ok(())
273    }
274
275    fn is_ready(&self) -> bool {
276        self.ready
277    }
278
279    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
280        if ctx.text.is_empty() {
281            return EnrichmentResult::empty();
282        }
283
284        match self.run_inference(&ctx.text) {
285            Ok(output) => {
286                debug!("LLM output for frame {}: {}", ctx.frame_id, output);
287                let cards = self.parse_output(&output, ctx);
288                EnrichmentResult::success(cards)
289            }
290            Err(err) => EnrichmentResult::failed(format!("LLM inference failed: {}", err)),
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn test_parse_memory_kind() {
301        assert_eq!(parse_memory_kind("Fact"), Some(MemoryKind::Fact));
302        assert_eq!(
303            parse_memory_kind("PREFERENCE"),
304            Some(MemoryKind::Preference)
305        );
306        assert_eq!(parse_memory_kind("event"), Some(MemoryKind::Event));
307        assert_eq!(parse_memory_kind("invalid"), None);
308    }
309
310    #[test]
311    fn test_parse_polarity() {
312        assert_eq!(parse_polarity("Positive"), Polarity::Positive);
313        assert_eq!(parse_polarity("NEGATIVE"), Polarity::Negative);
314        assert_eq!(parse_polarity("Neutral"), Polarity::Neutral);
315        assert_eq!(parse_polarity("unknown"), Polarity::Neutral);
316    }
317
318    #[test]
319    fn test_parse_output() {
320        let engine = LlmEngine::new(PathBuf::from("/tmp/test.gguf"));
321        let ctx = EnrichmentContext::new(
322            1,
323            "mv2://test/1".to_string(),
324            "Test text".to_string(),
325            None,
326            1700000000,
327            None,
328        );
329
330        // Test parsing a valid memory block
331        let output = r#"
332MEMORY_START
333kind: Fact
334entity: John
335slot: employer
336value: Anthropic
337polarity: Positive
338MEMORY_END
339"#;
340        let cards = engine.parse_output(output, &ctx);
341        assert_eq!(cards.len(), 1);
342        assert_eq!(cards[0].entity, "John");
343        assert_eq!(cards[0].slot, "employer");
344        assert_eq!(cards[0].value, "Anthropic");
345        assert_eq!(cards[0].kind, MemoryKind::Fact);
346    }
347
348    #[test]
349    fn test_parse_output_none() {
350        let engine = LlmEngine::new(PathBuf::from("/tmp/test.gguf"));
351        let ctx = EnrichmentContext::new(
352            1,
353            "mv2://test/1".to_string(),
354            "Test text".to_string(),
355            None,
356            1700000000,
357            None,
358        );
359
360        let output = "MEMORY_NONE";
361        let cards = engine.parse_output(output, &ctx);
362        assert!(cards.is_empty());
363    }
364
365    #[test]
366    fn test_parse_output_multiple() {
367        let engine = LlmEngine::new(PathBuf::from("/tmp/test.gguf"));
368        let ctx = EnrichmentContext::new(
369            1,
370            "mv2://test/1".to_string(),
371            "Test text".to_string(),
372            None,
373            1700000000,
374            None,
375        );
376
377        let output = r#"
378MEMORY_START
379kind: Fact
380entity: Alice
381slot: role
382value: Engineer
383polarity: Neutral
384MEMORY_END
385
386MEMORY_START
387kind: Preference
388entity: Bob
389slot: drink
390value: Coffee
391polarity: Positive
392MEMORY_END
393"#;
394        let cards = engine.parse_output(output, &ctx);
395        assert_eq!(cards.len(), 2);
396        assert_eq!(cards[0].entity, "Alice");
397        assert_eq!(cards[1].entity, "Bob");
398    }
399}