1#![cfg(feature = "candle-llm")]
7
8use std::path::PathBuf;
9use std::sync::Mutex;
10
11use anyhow::{anyhow, Result};
12use candle_core::Device;
13use candle_transformers::generation::LogitsProcessor;
14use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
15use hf_hub::api::sync::Api;
16use memvid_core::enrich::{EnrichmentContext, EnrichmentEngine, EnrichmentResult};
17use memvid_core::types::{MemoryCard, MemoryCardBuilder, MemoryKind, Polarity};
18use tokenizers::Tokenizer;
19use tracing::{debug, info, warn};
20
21const EXTRACTION_PROMPT: &str = r#"<|system|>
23You are a memory extraction assistant. Your task is to extract structured facts from text.
24
25For each distinct fact, preference, event, or relationship mentioned, output a memory card in this exact format:
26MEMORY_START
27kind: <Fact|Preference|Event|Profile|Relationship|Other>
28entity: <the main entity this memory is about>
29slot: <a short key describing what aspect of the entity>
30value: <the actual information>
31polarity: <Positive|Negative|Neutral>
32MEMORY_END
33
34Only extract information that is explicitly stated. Do not infer or guess.
35If there are no clear facts to extract, output MEMORY_NONE.
36<|end|>
37<|user|>
38Extract memories from this text:
39
40{text}
41<|end|>
42<|assistant|>
43"#;
44
45const MAX_OUTPUT_TOKENS: usize = 1024;
47
48const MAX_INPUT_CHARS: usize = 8192;
50
51const PHI3_MINI_REPO: &str = "microsoft/Phi-3-mini-4k-instruct-gguf";
53const PHI3_GGUF_FILE: &str = "Phi-3-mini-4k-instruct-q4.gguf";
55
56struct LoadedModel {
58 model: Phi3,
59 tokenizer: Tokenizer,
60 device: Device,
61}
62
63pub struct CandlePhiEngine {
65 model_source: ModelSource,
67 loaded: Mutex<Option<LoadedModel>>,
69 ready: bool,
71 version: String,
73}
74
75enum ModelSource {
77 HuggingFace { repo: String, file: String },
79 Local { path: PathBuf },
81 MemvidModels { models_dir: PathBuf },
83}
84
85impl CandlePhiEngine {
86 pub fn from_hub(repo: Option<&str>) -> Self {
88 Self {
89 model_source: ModelSource::HuggingFace {
90 repo: repo.unwrap_or(PHI3_MINI_REPO).to_string(),
91 file: PHI3_GGUF_FILE.to_string(),
92 },
93 loaded: Mutex::new(None),
94 ready: false,
95 version: "1.0.0".to_string(),
96 }
97 }
98
99 pub fn from_local(path: PathBuf) -> Self {
101 Self {
102 model_source: ModelSource::Local { path },
103 loaded: Mutex::new(None),
104 ready: false,
105 version: "1.0.0".to_string(),
106 }
107 }
108
109 pub fn from_memvid_models(models_dir: PathBuf) -> Self {
112 Self {
113 model_source: ModelSource::MemvidModels { models_dir },
114 loaded: Mutex::new(None),
115 ready: false,
116 version: "1.0.0".to_string(),
117 }
118 }
119
120 fn load_model(&self) -> Result<LoadedModel> {
122 let device = Device::Cpu;
123 info!("Loading quantized Phi-3 model on device: {:?}", device);
124
125 let (gguf_path, tokenizer_path) = match &self.model_source {
126 ModelSource::HuggingFace { repo, file } => {
127 info!(
128 "Downloading GGUF model from Hugging Face: {}/{}",
129 repo, file
130 );
131 let api = Api::new()?;
132 let model_repo = api.model(repo.clone());
133
134 let gguf_path = model_repo.get(file)?;
136 info!("Downloaded GGUF to: {:?}", gguf_path);
137
138 let tokenizer_repo = api.model("microsoft/Phi-3-mini-4k-instruct".to_string());
140 let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
141
142 (gguf_path, tokenizer_path)
143 }
144 ModelSource::Local { path } => {
145 if !path.exists() {
146 return Err(anyhow!("GGUF file not found: {}", path.display()));
147 }
148
149 let tokenizer_path = path
151 .parent()
152 .map(|p| p.join("tokenizer.json"))
153 .ok_or_else(|| anyhow!("Cannot determine tokenizer path"))?;
154
155 (path.clone(), tokenizer_path)
156 }
157 ModelSource::MemvidModels { models_dir } => {
158 let model_dir = models_dir.join("llm").join("phi-3-mini-q4");
160 let gguf_path = model_dir.join(PHI3_GGUF_FILE);
161 let tokenizer_path = model_dir.join("tokenizer.json");
162
163 if gguf_path.exists() && tokenizer_path.exists() {
164 info!("Using existing model from: {:?}", model_dir);
165 } else {
166 info!(
167 "Downloading model to memvid models directory: {:?}",
168 model_dir
169 );
170 std::fs::create_dir_all(&model_dir)?;
171
172 let api = Api::new()?;
174
175 let gguf_repo = api.model(PHI3_MINI_REPO.to_string());
177 let hf_gguf = gguf_repo.get(PHI3_GGUF_FILE)?;
178 if !gguf_path.exists() {
179 info!("Copying GGUF to: {:?}", gguf_path);
180 std::fs::copy(&hf_gguf, &gguf_path)?;
181 }
182
183 let tokenizer_repo = api.model("microsoft/Phi-3-mini-4k-instruct".to_string());
185 let hf_tokenizer = tokenizer_repo.get("tokenizer.json")?;
186 if !tokenizer_path.exists() {
187 info!("Copying tokenizer to: {:?}", tokenizer_path);
188 std::fs::copy(&hf_tokenizer, &tokenizer_path)?;
189 }
190
191 info!("Model installed to: {:?}", model_dir);
192 }
193
194 (gguf_path, tokenizer_path)
195 }
196 };
197
198 let tokenizer = Tokenizer::from_file(&tokenizer_path)
200 .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
201
202 info!("Loading GGUF file: {:?}", gguf_path);
204 let mut file = std::fs::File::open(&gguf_path)?;
205 let content = candle_core::quantized::gguf_file::Content::read(&mut file)
206 .map_err(|e| anyhow!("Failed to read GGUF: {}", e))?;
207
208 let model = Phi3::from_gguf(false, content, &mut file, &device)?;
209 info!("Phi-3 quantized model loaded successfully");
210
211 Ok(LoadedModel {
212 model,
213 tokenizer,
214 device,
215 })
216 }
217
218 fn run_inference(&self, text: &str) -> Result<String> {
220 let mut loaded_guard = self
221 .loaded
222 .lock()
223 .map_err(|_| anyhow!("Model lock poisoned"))?;
224
225 let loaded = loaded_guard
226 .as_mut()
227 .ok_or_else(|| anyhow!("Candle Phi engine not initialized. Call init() first."))?;
228
229 let truncated_text = if text.len() > MAX_INPUT_CHARS {
231 &text[..MAX_INPUT_CHARS]
232 } else {
233 text
234 };
235
236 let prompt = EXTRACTION_PROMPT.replace("{text}", truncated_text);
238
239 let encoding = loaded
241 .tokenizer
242 .encode(prompt.as_str(), true)
243 .map_err(|e| anyhow!("Tokenization failed: {}", e))?;
244
245 let mut tokens: Vec<u32> = encoding.get_ids().to_vec();
246
247 debug!("Input tokens: {}", tokens.len());
248
249 let mut logits_processor = LogitsProcessor::new(42, None, None);
251 let mut generated_tokens = Vec::new();
252 let eos_token = loaded
253 .tokenizer
254 .token_to_id("<|end|>")
255 .or_else(|| loaded.tokenizer.token_to_id("<|endoftext|>"))
256 .unwrap_or(0);
257
258 let input = candle_core::Tensor::new(&tokens[..], &loaded.device)?.unsqueeze(0)?;
260 let logits = loaded.model.forward(&input, 0)?;
261 let logits = logits.squeeze(0)?.squeeze(0)?;
262 let logits = logits.to_dtype(candle_core::DType::F32)?;
263
264 let next_token = logits_processor.sample(&logits)?;
265 generated_tokens.push(next_token);
266 tokens.push(next_token);
267
268 for i in 0..MAX_OUTPUT_TOKENS {
270 if next_token == eos_token {
271 break;
272 }
273
274 let input = candle_core::Tensor::new(&[tokens[tokens.len() - 1]], &loaded.device)?
275 .unsqueeze(0)?;
276
277 let logits = loaded.model.forward(&input, tokens.len() - 1)?;
278 let logits = logits.squeeze(0)?.squeeze(0)?;
279 let logits = logits.to_dtype(candle_core::DType::F32)?;
280
281 let next_token = logits_processor.sample(&logits)?;
282 generated_tokens.push(next_token);
283 tokens.push(next_token);
284
285 if next_token == eos_token || i >= MAX_OUTPUT_TOKENS - 1 {
286 break;
287 }
288 }
289
290 let output = loaded
292 .tokenizer
293 .decode(&generated_tokens, true)
294 .map_err(|e| anyhow!("Decoding failed: {}", e))?;
295
296 Ok(output.trim().to_string())
297 }
298
299 fn parse_output(&self, output: &str, ctx: &EnrichmentContext) -> Vec<MemoryCard> {
301 let mut cards = Vec::new();
302
303 if output.contains("MEMORY_NONE") {
305 return cards;
306 }
307
308 for block in output.split("MEMORY_START") {
310 let block = block.trim();
311 if block.is_empty() || !block.contains("MEMORY_END") {
312 continue;
313 }
314
315 let block = block.split("MEMORY_END").next().unwrap_or("").trim();
316
317 let mut kind = None;
319 let mut entity = None;
320 let mut slot = None;
321 let mut value = None;
322 let mut polarity = Polarity::Neutral;
323
324 for line in block.lines() {
325 let line = line.trim();
326 if let Some(rest) = line.strip_prefix("kind:") {
327 kind = parse_memory_kind(rest.trim());
328 } else if let Some(rest) = line.strip_prefix("entity:") {
329 entity = Some(rest.trim().to_string());
330 } else if let Some(rest) = line.strip_prefix("slot:") {
331 slot = Some(rest.trim().to_string());
332 } else if let Some(rest) = line.strip_prefix("value:") {
333 value = Some(rest.trim().to_string());
334 } else if let Some(rest) = line.strip_prefix("polarity:") {
335 polarity = parse_polarity(rest.trim());
336 }
337 }
338
339 if let (Some(k), Some(e), Some(s), Some(v)) = (kind, entity, slot, value) {
341 if !e.is_empty() && !s.is_empty() && !v.is_empty() {
342 match MemoryCardBuilder::new()
343 .kind(k)
344 .entity(&e)
345 .slot(&s)
346 .value(&v)
347 .polarity(polarity)
348 .source(ctx.frame_id, Some(ctx.uri.clone()))
349 .document_date(ctx.timestamp)
350 .engine("candle:phi-3-mini-q4", "1.0.0")
351 .build(0)
352 {
353 Ok(card) => cards.push(card),
354 Err(err) => {
355 warn!("Failed to build memory card: {}", err);
356 }
357 }
358 }
359 }
360 }
361
362 cards
363 }
364}
365
366fn parse_memory_kind(s: &str) -> Option<MemoryKind> {
368 match s.to_lowercase().as_str() {
369 "fact" => Some(MemoryKind::Fact),
370 "preference" => Some(MemoryKind::Preference),
371 "event" => Some(MemoryKind::Event),
372 "profile" => Some(MemoryKind::Profile),
373 "relationship" => Some(MemoryKind::Relationship),
374 "other" => Some(MemoryKind::Other),
375 _ => None,
376 }
377}
378
379fn parse_polarity(s: &str) -> Polarity {
381 match s.to_lowercase().as_str() {
382 "positive" => Polarity::Positive,
383 "negative" => Polarity::Negative,
384 _ => Polarity::Neutral,
385 }
386}
387
388impl EnrichmentEngine for CandlePhiEngine {
389 fn kind(&self) -> &str {
390 "candle:phi-3-mini-q4"
391 }
392
393 fn version(&self) -> &str {
394 &self.version
395 }
396
397 fn init(&mut self) -> memvid_core::Result<()> {
398 let model = self
399 .load_model()
400 .map_err(|err| memvid_core::MemvidError::EmbeddingFailed {
401 reason: format!("{}", err).into_boxed_str(),
402 })?;
403 *self
404 .loaded
405 .lock()
406 .map_err(|_| memvid_core::MemvidError::EmbeddingFailed {
407 reason: "Model lock poisoned".into(),
408 })? = Some(model);
409 self.ready = true;
410 Ok(())
411 }
412
413 fn is_ready(&self) -> bool {
414 self.ready
415 }
416
417 fn enrich(&self, ctx: &EnrichmentContext) -> EnrichmentResult {
418 if ctx.text.is_empty() {
419 return EnrichmentResult::empty();
420 }
421
422 match self.run_inference(&ctx.text) {
423 Ok(output) => {
424 debug!("Candle Phi-3 output for frame {}: {}", ctx.frame_id, output);
425 let cards = self.parse_output(&output, ctx);
426 EnrichmentResult::success(cards)
427 }
428 Err(err) => EnrichmentResult::failed(format!("Candle inference failed: {}", err)),
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_parse_memory_kind() {
439 assert_eq!(parse_memory_kind("Fact"), Some(MemoryKind::Fact));
440 assert_eq!(
441 parse_memory_kind("PREFERENCE"),
442 Some(MemoryKind::Preference)
443 );
444 assert_eq!(parse_memory_kind("event"), Some(MemoryKind::Event));
445 assert_eq!(parse_memory_kind("invalid"), None);
446 }
447
448 #[test]
449 fn test_parse_polarity() {
450 assert_eq!(parse_polarity("Positive"), Polarity::Positive);
451 assert_eq!(parse_polarity("NEGATIVE"), Polarity::Negative);
452 assert_eq!(parse_polarity("Neutral"), Polarity::Neutral);
453 assert_eq!(parse_polarity("unknown"), Polarity::Neutral);
454 }
455}