memvid_cli/enrich/
candle_phi.rs

1//! Candle-based Phi-3 enrichment engine using quantized GGUF models.
2//!
3//! This engine uses Hugging Face Candle to run quantized Phi-3 models locally
4//! for extracting structured memory cards from text content.
5
6#![cfg(feature = "candle-llm")]
7
8use std::path::PathBuf;
9use std::sync::Mutex;
10
11use anyhow::{anyhow, Result};
12use candle_core::Device;
13use candle_transformers::generation::LogitsProcessor;
14use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
15use hf_hub::api::sync::Api;
16use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
17use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
18use tokenizers::Tokenizer;
19use tracing::{debug, info, warn};
20
21/// The prompt template for Phi-3 memory extraction.
22const EXTRACTION_PROMPT: &str = r#"<|system|>
23You are a memory extraction assistant. Your task is to extract structured facts from text.
24
25For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
26MEMORY_START
27kind: <Fact|Preference|Event|Profile|Relationship|Other>
28entity: <the main entity this memory is about>
29slot: <a short key describing what aspect of the entity>
30value: <the actual information>
31polarity: <Positive|Negative|Neutral>
32MEMORY_END
33
34Only extract information that is explicitly stated. Do not infer or guess.
35If there are no clear facts to extract, output MEMORY_NONE.
36<|end|>
37<|user|>
38Extract memories from this text:
39
40{text}
41<|end|>
42<|assistant|>
43"#;
44
45/// Maximum tokens to generate per extraction
46const MAX_OUTPUT_TOKENS: usize = 1024;
47
48/// Maximum input text length (characters) to process
49const MAX_INPUT_CHARS: usize = 8192;
50
51/// Hugging Face model repository (GGUF quantized version - much smaller)
52const PHI3_MINI_REPO: &str = "microsoft/Phi-3-mini-4k-instruct-gguf";
53/// GGUF file name (Q4 quantized, ~2.4GB)
54const PHI3_GGUF_FILE: &str = "Phi-3-mini-4k-instruct-q4.gguf";
55
56/// Loaded Phi-3 model state
57struct LoadedModel {
58    model: Phi3,
59    tokenizer: Tokenizer,
60    device: Device,
61}
62
63/// Candle-based Phi-3 enrichment engine using quantized GGUF.
64pub struct CandlePhiEngine {
65    /// Hugging Face model repo or local path
66    model_source: ModelSource,
67    /// Loaded model (lazy initialization)
68    loaded: Mutex<Option<LoadedModel>>,
69    /// Whether the engine is initialized
70    ready: bool,
71    /// Engine version
72    version: String,
73}
74
75/// Source for the model
76enum ModelSource {
77    /// Load from Hugging Face Hub (downloads to HF cache)
78    HuggingFace { repo: String, file: String },
79    /// Load from local GGUF file
80    Local { path: PathBuf },
81    /// Load from memvid models directory (downloads there if needed)
82    MemvidModels { models_dir: PathBuf },
83}
84
85impl CandlePhiEngine {
86    /// Create a new Candle Phi engine that loads from Hugging Face Hub.
87    pub fn from_hub(repo: Option<&str>) -> Self {
88        Self {
89            model_source: ModelSource::HuggingFace {
90                repo: repo.unwrap_or(PHI3_MINI_REPO).to_string(),
91                file: PHI3_GGUF_FILE.to_string(),
92            },
93            loaded: Mutex::new(None),
94            ready: false,
95            version: "1.0.0".to_string(),
96        }
97    }
98
99    /// Create a new Candle Phi engine that loads from a local GGUF file.
100    pub fn from_local(path: PathBuf) -> Self {
101        Self {
102            model_source: ModelSource::Local { path },
103            loaded: Mutex::new(None),
104            ready: false,
105            version: "1.0.0".to_string(),
106        }
107    }
108
109    /// Create a new Candle Phi engine that uses the memvid models directory.
110    /// Downloads the model to ~/.memvid/models/llm/phi-3-mini-q4/ if not present.
111    pub fn from_memvid_models(models_dir: PathBuf) -> Self {
112        Self {
113            model_source: ModelSource::MemvidModels { models_dir },
114            loaded: Mutex::new(None),
115            ready: false,
116            version: "1.0.0".to_string(),
117        }
118    }
119
120    /// Load the model from the configured source.
121    fn load_model(&self) -> Result<LoadedModel> {
122        let device = Device::Cpu;
123        info!("Loading quantized Phi-3 model on device: {:?}", device);
124
125        let (gguf_path, tokenizer_path) = match &self.model_source {
126            ModelSource::HuggingFace { repo, file } => {
127                info!(
128                    "Downloading GGUF model from Hugging Face: {}/{}",
129                    repo, file
130                );
131                let api = Api::new()?;
132                let model_repo = api.model(repo.clone());
133
134                // Download the GGUF file
135                let gguf_path = model_repo.get(file)?;
136                info!("Downloaded GGUF to: {:?}", gguf_path);
137
138                // Get tokenizer from the non-GGUF repo (GGUF repo doesn't have tokenizer.json)
139                let tokenizer_repo = api.model("microsoft/Phi-3-mini-4k-instruct".to_string());
140                let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
141
142                (gguf_path, tokenizer_path)
143            }
144            ModelSource::Local { path } => {
145                if !path.exists() {
146                    return Err(anyhow!("GGUF file not found: {}", path.display()));
147                }
148
149                // For local, assume tokenizer is in the same directory
150                let tokenizer_path = path
151                    .parent()
152                    .map(|p| p.join("tokenizer.json"))
153                    .ok_or_else(|| anyhow!("Cannot determine tokenizer path"))?;
154
155                (path.clone(), tokenizer_path)
156            }
157            ModelSource::MemvidModels { models_dir } => {
158                // Use ~/.memvid/models/llm/phi-3-mini-q4/ directory
159                let model_dir = models_dir.join("llm").join("phi-3-mini-q4");
160                let gguf_path = model_dir.join(PHI3_GGUF_FILE);
161                let tokenizer_path = model_dir.join("tokenizer.json");
162
163                if gguf_path.exists() && tokenizer_path.exists() {
164                    info!("Using existing model from: {:?}", model_dir);
165                } else {
166                    info!(
167                        "Downloading model to memvid models directory: {:?}",
168                        model_dir
169                    );
170                    std::fs::create_dir_all(&model_dir)?;
171
172                    // Download from HuggingFace and copy to our directory
173                    let api = Api::new()?;
174
175                    // Download GGUF
176                    let gguf_repo = api.model(PHI3_MINI_REPO.to_string());
177                    let hf_gguf = gguf_repo.get(PHI3_GGUF_FILE)?;
178                    if !gguf_path.exists() {
179                        info!("Copying GGUF to: {:?}", gguf_path);
180                        std::fs::copy(&hf_gguf, &gguf_path)?;
181                    }
182
183                    // Download tokenizer
184                    let tokenizer_repo = api.model("microsoft/Phi-3-mini-4k-instruct".to_string());
185                    let hf_tokenizer = tokenizer_repo.get("tokenizer.json")?;
186                    if !tokenizer_path.exists() {
187                        info!("Copying tokenizer to: {:?}", tokenizer_path);
188                        std::fs::copy(&hf_tokenizer, &tokenizer_path)?;
189                    }
190
191                    info!("Model installed to: {:?}", model_dir);
192                }
193
194                (gguf_path, tokenizer_path)
195            }
196        };
197
198        // Load tokenizer
199        let tokenizer = Tokenizer::from_file(&tokenizer_path)
200            .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
201
202        // Load GGUF model
203        info!("Loading GGUF file: {:?}", gguf_path);
204        let mut file = std::fs::File::open(&gguf_path)?;
205        let content = candle_core::quantized::gguf_file::Content::read(&mut file)
206            .map_err(|e| anyhow!("Failed to read GGUF: {}", e))?;
207
208        let model = Phi3::from_gguf(false, content, &mut file, &device)?;
209        info!("Phi-3 quantized model loaded successfully");
210
211        Ok(LoadedModel {
212            model,
213            tokenizer,
214            device,
215        })
216    }
217
218    /// Run inference on the given text and return raw output.
219    fn run_inference(&self, text: &str) -> Result<String> {
220        let mut loaded_guard = self
221            .loaded
222            .lock()
223            .map_err(|_| anyhow!("Model lock poisoned"))?;
224
225        let loaded = loaded_guard
226            .as_mut()
227            .ok_or_else(|| anyhow!("Candle Phi engine not initialized. Call init() first."))?;
228
229        // Truncate input if too long
230        let truncated_text = if text.len() > MAX_INPUT_CHARS {
231            &text[..MAX_INPUT_CHARS]
232        } else {
233            text
234        };
235
236        // Build the prompt
237        let prompt = EXTRACTION_PROMPT.replace("{text}", truncated_text);
238
239        // Tokenize
240        let encoding = loaded
241            .tokenizer
242            .encode(prompt.as_str(), true)
243            .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
244
245        let mut tokens: Vec<u32> = encoding.get_ids().to_vec();
246
247        debug!("Input tokens: {}", tokens.len());
248
249        // Run forward pass and generate tokens
250        let mut logits_processor = LogitsProcessor::new(42, None, None);
251        let mut generated_tokens = Vec::new();
252        let eos_token = loaded
253            .tokenizer
254            .token_to_id("<|end|>")
255            .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"))
256            .unwrap_or(0);
257
258        // Process the initial prompt
259        let input = candle_core::Tensor::new(&tokens[..], &loaded.device)?.unsqueeze(0)?;
260        let logits = loaded.model.forward(&input, 0)?;
261        let logits = logits.squeeze(0)?.squeeze(0)?;
262        let logits = logits.to_dtype(candle_core::DType::F32)?;
263
264        let next_token = logits_processor.sample(&logits)?;
265        generated_tokens.push(next_token);
266        tokens.push(next_token);
267
268        // Generate tokens one at a time
269        for i in 0..MAX_OUTPUT_TOKENS {
270            if next_token == eos_token {
271                break;
272            }
273
274            let input = candle_core::Tensor::new(&[tokens[tokens.len() - 1]], &loaded.device)?
275                .unsqueeze(0)?;
276
277            let logits = loaded.model.forward(&input, tokens.len() - 1)?;
278            let logits = logits.squeeze(0)?.squeeze(0)?;
279            let logits = logits.to_dtype(candle_core::DType::F32)?;
280
281            let next_token = logits_processor.sample(&logits)?;
282            generated_tokens.push(next_token);
283            tokens.push(next_token);
284
285            if next_token == eos_token || i >= MAX_OUTPUT_TOKENS - 1 {
286                break;
287            }
288        }
289
290        // Decode generated tokens
291        let output = loaded
292            .tokenizer
293            .decode(&generated_tokens, true)
294            .map_err(|e| anyhow!("Decoding failed: {}", e))?;
295
296        Ok(output.trim().to_string())
297    }
298
299    /// Parse the LLM output into memory cards.
300    fn parse_output(&self, output: &str, ctx: &EnrichmentContext) -> Vec<MemoryCard> {
301        let mut cards = Vec::new();
302
303        // Check for "no memories" signal
304        if output.contains("MEMORY_NONE") {
305            return cards;
306        }
307
308        // Parse MEMORY_START...MEMORY_END blocks
309        for block in output.split("MEMORY_START") {
310            let block = block.trim();
311            if block.is_empty() || !block.contains("MEMORY_END") {
312                continue;
313            }
314
315            let block = block.split("MEMORY_END").next().unwrap_or("").trim();
316
317            // Parse fields
318            let mut kind = None;
319            let mut entity = None;
320            let mut slot = None;
321            let mut value = None;
322            let mut polarity = Polarity::Neutral;
323
324            for line in block.lines() {
325                let line = line.trim();
326                if let Some(rest) = line.strip_prefix("kind:") {
327                    kind = parse_memory_kind(rest.trim());
328                } else if let Some(rest) = line.strip_prefix("entity:") {
329                    entity = Some(rest.trim().to_string());
330                } else if let Some(rest) = line.strip_prefix("slot:") {
331                    slot = Some(rest.trim().to_string());
332                } else if let Some(rest) = line.strip_prefix("value:") {
333                    value = Some(rest.trim().to_string());
334                } else if let Some(rest) = line.strip_prefix("polarity:") {
335                    polarity = parse_polarity(rest.trim());
336                }
337            }
338
339            // Build memory card if we have required fields
340            if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
341                if !e.is_empty() && !s.is_empty() && !v.is_empty() {
342                    match MemoryCardBuilder::new()
343                        .kind(k)
344                        .entity(&e)
345                        .slot(&s)
346                        .value(&v)
347                        .polarity(polarity)
348                        .source(ctx.frame_id, Some(ctx.uri.clone()))
349                        .document_date(ctx.timestamp)
350                        .engine("candle:phi-3-mini-q4", "1.0.0")
351                        .build(0)
352                    {
353                        Ok(card) => cards.push(card),
354                        Err(err) => {
355                            warn!("Failed to build memory card: {}", err);
356                        }
357                    }
358                }
359            }
360        }
361
362        cards
363    }
364}
365
366/// Parse a memory kind string into the enum.
367fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
368    match s.to_lowercase().as_str() {
369        "fact" => Some(MemoryKind::Fact),
370        "preference" => Some(MemoryKind::Preference),
371        "event" => Some(MemoryKind::Event),
372        "profile" => Some(MemoryKind::Profile),
373        "relationship" => Some(MemoryKind::Relationship),
374        "other" => Some(MemoryKind::Other),
375        _ => None,
376    }
377}
378
379/// Parse a polarity string into the enum.
380fn parse_polarity(s: &str) -> Polarity {
381    match s.to_lowercase().as_str() {
382        "positive" => Polarity::Positive,
383        "negative" => Polarity::Negative,
384        _ => Polarity::Neutral,
385    }
386}
387
388impl EnrichmentEngine for CandlePhiEngine {
389    fn kind(&self) -> &str {
390        "candle:phi-3-mini-q4"
391    }
392
393    fn version(&self) -> &str {
394        &self.version
395    }
396
397    fn init(&mut self) -> memvid_core::Result<()> {
398        let model = self
399            .load_model()
400            .map_err(|err| memvid_core::MemvidError::EmbeddingFailed {
401                reason: format!("{}", err).into_boxed_str(),
402            })?;
403        *self
404            .loaded
405            .lock()
406            .map_err(|_| memvid_core::MemvidError::EmbeddingFailed {
407                reason: "Model lock poisoned".into(),
408            })? = Some(model);
409        self.ready = true;
410        Ok(())
411    }
412
413    fn is_ready(&self) -> bool {
414        self.ready
415    }
416
417    fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
418        if ctx.text.is_empty() {
419            return EnrichmentResult::empty();
420        }
421
422        match self.run_inference(&ctx.text) {
423            Ok(output) => {
424                debug!("Candle Phi-3 output for frame {}: {}", ctx.frame_id, output);
425                let cards = self.parse_output(&output, ctx);
426                EnrichmentResult::success(cards)
427            }
428            Err(err) => EnrichmentResult::failed(format!("Candle inference failed: {}", err)),
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_parse_memory_kind() {
439        assert_eq!(parse_memory_kind("Fact"), Some(MemoryKind::Fact));
440        assert_eq!(
441            parse_memory_kind("PREFERENCE"),
442            Some(MemoryKind::Preference)
443        );
444        assert_eq!(parse_memory_kind("event"), Some(MemoryKind::Event));
445        assert_eq!(parse_memory_kind("invalid"), None);
446    }
447
448    #[test]
449    fn test_parse_polarity() {
450        assert_eq!(parse_polarity("Positive"), Polarity::Positive);
451        assert_eq!(parse_polarity("NEGATIVE"), Polarity::Negative);
452        assert_eq!(parse_polarity("Neutral"), Polarity::Neutral);
453        assert_eq!(parse_polarity("unknown"), Polarity::Neutral);
454    }
455}