1use std::fs;
13use std::io::Write;
14use std::path::Path;
15use std::time::{SystemTime, UNIX_EPOCH};
16
17use crate::thought_graph::{ThoughtGraphState, PATTERN_BOOST_WEIGHT};
18
19const MAX_PREDICTIONS: usize = 50; const MAX_BLOCKS_PER_PREDICTION: usize = 30;
23const HIT_REWARD: f32 = 0.3; const MISS_PENALTY: f32 = 0.05; const PREDICTION_DECAY: f32 = 0.98; const MIN_CONFIDENCE: f32 = 0.1; const MIN_PATTERN_FREQ: u32 = 3; #[derive(Clone, Debug)]
33pub struct Prediction {
34 pub predicted_query_hash: u64,
35 pub blocks: Vec<u32>,
36 pub confidence: f32,
37 pub pattern_id: u32,
38 pub created_ms: u64,
39}
40
41#[derive(Clone, Debug, Default)]
44pub struct CacheStats {
45 pub total_predictions: u32,
46 pub total_hits: u32,
47 pub total_misses: u32,
48 pub total_partial_hits: u32,
49 pub current_predictions: usize,
50 pub avg_confidence: f32,
51}
52
53impl CacheStats {
54 pub fn hit_rate(&self) -> f32 {
55 let total = self.total_hits + self.total_misses + self.total_partial_hits;
56 if total == 0 {
57 return 0.0;
58 }
59 (self.total_hits as f32 + self.total_partial_hits as f32 * 0.5) / total as f32
60 }
61}
62
63pub struct PredictiveCache {
66 pub predictions: Vec<Prediction>,
67 pub stats: CacheStats,
68}
69
70impl PredictiveCache {
71 pub fn load_or_init(output_dir: &Path) -> Self {
72 let path = output_dir.join("predictive_cache.bin");
73 if path.exists() {
74 load_cache(&path)
75 } else {
76 Self {
77 predictions: Vec::new(),
78 stats: CacheStats::default(),
79 }
80 }
81 }
82
83 pub fn check(&self, query_hash: u64) -> Option<(Vec<u32>, f32)> {
86 self.predictions
87 .iter()
88 .find(|p| p.predicted_query_hash == query_hash && p.confidence >= MIN_CONFIDENCE)
89 .map(|p| (p.blocks.clone(), p.confidence))
90 }
91
92 pub fn evaluate(
96 &mut self,
97 query_hash: u64,
98 actual_results: &[u32],
99 thought_graph: &mut ThoughtGraphState,
100 ) -> (&'static str, usize) {
101 let prediction = self
102 .predictions
103 .iter()
104 .find(|p| p.predicted_query_hash == query_hash);
105
106 let prediction = match prediction {
107 Some(p) => p.clone(),
108 None => return ("none", 0),
109 };
110
111 let overlap = prediction
113 .blocks
114 .iter()
115 .filter(|b| actual_results.contains(b))
116 .count();
117
118 let hit_type = if overlap == 0 {
119 self.stats.total_misses += 1;
121 if let Some(pattern) = thought_graph
123 .patterns
124 .iter_mut()
125 .find(|p| p.id == prediction.pattern_id)
126 {
127 pattern.strength = (pattern.strength - MISS_PENALTY).max(0.0);
128 }
129 if let Some(pred) = self
131 .predictions
132 .iter_mut()
133 .find(|p| p.predicted_query_hash == query_hash)
134 {
135 pred.confidence *= 0.5; }
137 "miss"
138 } else if overlap >= prediction.blocks.len() / 2 || overlap >= 3 {
139 self.stats.total_hits += 1;
141 if let Some(pattern) = thought_graph
143 .patterns
144 .iter_mut()
145 .find(|p| p.id == prediction.pattern_id)
146 {
147 pattern.strength = (pattern.strength + HIT_REWARD).min(5.0);
148 }
149 if let Some(pred) = self
151 .predictions
152 .iter_mut()
153 .find(|p| p.predicted_query_hash == query_hash)
154 {
155 pred.confidence = (pred.confidence + 0.2).min(1.0);
156 }
157 "hit"
158 } else {
159 self.stats.total_partial_hits += 1;
161 let reward = HIT_REWARD * (overlap as f32 / prediction.blocks.len() as f32);
162 if let Some(pattern) = thought_graph
163 .patterns
164 .iter_mut()
165 .find(|p| p.id == prediction.pattern_id)
166 {
167 pattern.strength = (pattern.strength + reward).min(5.0);
168 }
169 "partial"
170 };
171
172 (hit_type, overlap)
173 }
174
175 pub fn predict_next(&mut self, thought_graph: &ThoughtGraphState) {
178 for pred in &mut self.predictions {
180 pred.confidence *= PREDICTION_DECAY;
181 }
182 self.predictions.retain(|p| p.confidence >= MIN_CONFIDENCE);
183
184 let session_hashes: Vec<u64> = thought_graph
185 .nodes
186 .iter()
187 .filter(|n| n.session_id == thought_graph.current_session_id)
188 .map(|n| n.query_hash)
189 .collect();
190
191 if session_hashes.is_empty() {
192 return;
193 }
194
195 let now_ms = now_epoch_ms();
196
197 for pattern in &thought_graph.patterns {
198 if pattern.frequency < MIN_PATTERN_FREQ {
199 continue;
200 }
201 if pattern.result_blocks.is_empty() {
202 continue;
203 }
204
205 let seq = &pattern.sequence;
206
207 for prefix_len in 1..seq.len() {
210 if session_hashes.len() < prefix_len {
211 continue;
212 }
213
214 let trail_start = session_hashes.len() - prefix_len;
215 let trail = &session_hashes[trail_start..];
216
217 if trail == &seq[..prefix_len] {
218 let predicted_hash = seq[prefix_len];
219
220 if self
222 .predictions
223 .iter()
224 .any(|p| p.predicted_query_hash == predicted_hash)
225 {
226 continue;
227 }
228
229 let confidence = pattern.strength
230 * PATTERN_BOOST_WEIGHT
231 * (prefix_len as f32 / seq.len() as f32);
232
233 let blocks: Vec<u32> = pattern
234 .result_blocks
235 .iter()
236 .take(MAX_BLOCKS_PER_PREDICTION)
237 .copied()
238 .collect();
239
240 self.predictions.push(Prediction {
241 predicted_query_hash: predicted_hash,
242 blocks,
243 confidence: confidence.min(1.0),
244 pattern_id: pattern.id,
245 created_ms: now_ms,
246 });
247
248 self.stats.total_predictions += 1;
249 }
250 }
251 }
252
253 if self.predictions.len() > MAX_PREDICTIONS {
255 self.predictions
256 .sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
257 self.predictions.truncate(MAX_PREDICTIONS);
258 }
259
260 if !self.predictions.is_empty() {
262 self.stats.avg_confidence = self.predictions.iter().map(|p| p.confidence).sum::<f32>()
263 / self.predictions.len() as f32;
264 }
265 self.stats.current_predictions = self.predictions.len();
266 }
267
268 pub fn export_stats(&self) -> (u32, u32, u32, f32) {
270 (
271 self.stats.total_predictions,
272 self.stats.total_hits,
273 self.stats.total_misses,
274 self.stats.hit_rate(),
275 )
276 }
277
278 pub fn merge_stats(&mut self, remote_predictions: u32, remote_hits: u32, remote_misses: u32) {
280 self.stats.total_predictions += remote_predictions;
281 self.stats.total_hits += remote_hits;
282 self.stats.total_misses += remote_misses;
283 }
284
285 pub fn dream_cleanup(&mut self) {
287 self.predictions.retain(|p| p.confidence > 0.1);
288 self.stats.current_predictions = self.predictions.len();
289 }
290
291 pub fn save(&self, output_dir: &Path) -> Result<(), String> {
293 save_cache(&output_dir.join("predictive_cache.bin"), self)
294 }
295}
296
297fn now_epoch_ms() -> u64 {
300 SystemTime::now()
301 .duration_since(UNIX_EPOCH)
302 .unwrap_or_default()
303 .as_millis() as u64
304}
305
306fn save_cache(path: &Path, cache: &PredictiveCache) -> Result<(), String> {
307 let mut buf = Vec::with_capacity(256);
308
309 buf.write_all(b"PRC1").map_err(|e| e.to_string())?;
311 buf.write_all(&(cache.predictions.len() as u32).to_le_bytes())
312 .map_err(|e| e.to_string())?;
313
314 buf.write_all(&cache.stats.total_predictions.to_le_bytes())
316 .map_err(|e| e.to_string())?;
317 buf.write_all(&cache.stats.total_hits.to_le_bytes())
318 .map_err(|e| e.to_string())?;
319 buf.write_all(&cache.stats.total_misses.to_le_bytes())
320 .map_err(|e| e.to_string())?;
321 buf.write_all(&cache.stats.total_partial_hits.to_le_bytes())
322 .map_err(|e| e.to_string())?;
323
324 for p in &cache.predictions {
326 buf.write_all(&p.predicted_query_hash.to_le_bytes())
327 .map_err(|e| e.to_string())?;
328 buf.write_all(&p.confidence.to_le_bytes())
329 .map_err(|e| e.to_string())?;
330 buf.write_all(&p.pattern_id.to_le_bytes())
331 .map_err(|e| e.to_string())?;
332 buf.write_all(&p.created_ms.to_le_bytes())
333 .map_err(|e| e.to_string())?;
334 buf.write_all(&(p.blocks.len() as u16).to_le_bytes())
335 .map_err(|e| e.to_string())?;
336 for &b in &p.blocks {
337 buf.write_all(&b.to_le_bytes()).map_err(|e| e.to_string())?;
338 }
339 }
340
341 fs::write(path, &buf).map_err(|e| e.to_string())
342}
343
344fn load_cache(path: &Path) -> PredictiveCache {
345 let data = match fs::read(path) {
346 Ok(d) => d,
347 Err(_) => {
348 return PredictiveCache {
349 predictions: Vec::new(),
350 stats: CacheStats::default(),
351 }
352 }
353 };
354
355 if data.len() < 24 || &data[0..4] != b"PRC1" {
356 return PredictiveCache {
357 predictions: Vec::new(),
358 stats: CacheStats::default(),
359 };
360 }
361
362 let pred_count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
363
364 let total_predictions = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
365 let total_hits = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
366 let total_misses = u32::from_le_bytes([data[16], data[17], data[18], data[19]]);
367 let total_partial_hits = u32::from_le_bytes([data[20], data[21], data[22], data[23]]);
368
369 let mut offset = 24;
370 let mut predictions = Vec::with_capacity(pred_count);
371
372 for _ in 0..pred_count {
373 if offset + 22 > data.len() {
374 break;
375 }
376
377 let predicted_query_hash = u64::from_le_bytes([
378 data[offset],
379 data[offset + 1],
380 data[offset + 2],
381 data[offset + 3],
382 data[offset + 4],
383 data[offset + 5],
384 data[offset + 6],
385 data[offset + 7],
386 ]);
387 offset += 8;
388
389 let confidence = f32::from_le_bytes([
390 data[offset],
391 data[offset + 1],
392 data[offset + 2],
393 data[offset + 3],
394 ]);
395 offset += 4;
396
397 let pattern_id = u32::from_le_bytes([
398 data[offset],
399 data[offset + 1],
400 data[offset + 2],
401 data[offset + 3],
402 ]);
403 offset += 4;
404
405 let created_ms = u64::from_le_bytes([
406 data[offset],
407 data[offset + 1],
408 data[offset + 2],
409 data[offset + 3],
410 data[offset + 4],
411 data[offset + 5],
412 data[offset + 6],
413 data[offset + 7],
414 ]);
415 offset += 8;
416
417 if offset + 2 > data.len() {
418 break;
419 }
420 let block_count = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
421 offset += 2;
422
423 if offset + block_count * 4 > data.len() {
424 break;
425 }
426 let mut blocks = Vec::with_capacity(block_count);
427 for _ in 0..block_count {
428 let b = u32::from_le_bytes([
429 data[offset],
430 data[offset + 1],
431 data[offset + 2],
432 data[offset + 3],
433 ]);
434 blocks.push(b);
435 offset += 4;
436 }
437
438 predictions.push(Prediction {
439 predicted_query_hash,
440 blocks,
441 confidence,
442 pattern_id,
443 created_ms,
444 });
445 }
446
447 let current_predictions = predictions.len();
448 let avg_confidence = if predictions.is_empty() {
449 0.0
450 } else {
451 predictions.iter().map(|p| p.confidence).sum::<f32>() / predictions.len() as f32
452 };
453
454 PredictiveCache {
455 predictions,
456 stats: CacheStats {
457 total_predictions,
458 total_hits,
459 total_misses,
460 total_partial_hits,
461 current_predictions,
462 avg_confidence,
463 },
464 }
465}
466
467#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::thought_graph::{ThoughtGraphState, ThoughtPattern};
473
474 fn make_tg() -> ThoughtGraphState {
475 ThoughtGraphState::load_or_init(Path::new("/nonexistent"))
476 }
477
478 #[test]
479 fn test_check_empty() {
480 let cache = PredictiveCache {
481 predictions: Vec::new(),
482 stats: CacheStats::default(),
483 };
484 assert!(cache.check(0xAA).is_none());
485 }
486
487 #[test]
488 fn test_check_hit() {
489 let cache = PredictiveCache {
490 predictions: vec![Prediction {
491 predicted_query_hash: 0xAA,
492 blocks: vec![10, 20, 30],
493 confidence: 0.8,
494 pattern_id: 0,
495 created_ms: 0,
496 }],
497 stats: CacheStats::default(),
498 };
499 let result = cache.check(0xAA);
500 assert!(result.is_some());
501 let (blocks, conf) = result.unwrap();
502 assert_eq!(blocks, vec![10, 20, 30]);
503 assert!((conf - 0.8).abs() < 0.001);
504 }
505
506 #[test]
507 fn test_evaluate_hit() {
508 let mut cache = PredictiveCache {
509 predictions: vec![Prediction {
510 predicted_query_hash: 0xAA,
511 blocks: vec![10, 20, 30],
512 confidence: 0.5,
513 pattern_id: 0,
514 created_ms: 0,
515 }],
516 stats: CacheStats::default(),
517 };
518 let mut tg = make_tg();
519 tg.patterns.push(ThoughtPattern {
520 id: 0,
521 sequence: vec![0xBB, 0xAA],
522 frequency: 5,
523 strength: 1.0,
524 last_seen_ms: 0,
525 result_blocks: vec![10, 20, 30],
526 });
527
528 let actual = vec![10u32, 20, 30, 40];
529 let (hit_type, overlap) = cache.evaluate(0xAA, &actual, &mut tg);
530 assert_eq!(hit_type, "hit");
531 assert_eq!(overlap, 3);
532 assert_eq!(cache.stats.total_hits, 1);
533 assert!(tg.patterns[0].strength > 1.0);
535 }
536
537 #[test]
538 fn test_evaluate_miss() {
539 let mut cache = PredictiveCache {
540 predictions: vec![Prediction {
541 predicted_query_hash: 0xAA,
542 blocks: vec![10, 20, 30],
543 confidence: 0.5,
544 pattern_id: 0,
545 created_ms: 0,
546 }],
547 stats: CacheStats::default(),
548 };
549 let mut tg = make_tg();
550 tg.patterns.push(ThoughtPattern {
551 id: 0,
552 sequence: vec![0xBB, 0xAA],
553 frequency: 5,
554 strength: 1.0,
555 last_seen_ms: 0,
556 result_blocks: vec![10, 20, 30],
557 });
558
559 let actual = vec![100u32, 200, 300]; let (hit_type, overlap) = cache.evaluate(0xAA, &actual, &mut tg);
561 assert_eq!(hit_type, "miss");
562 assert_eq!(overlap, 0);
563 assert_eq!(cache.stats.total_misses, 1);
564 assert!(tg.patterns[0].strength < 1.0);
566 }
567
568 #[test]
569 fn test_evaluate_no_prediction() {
570 let mut cache = PredictiveCache {
571 predictions: Vec::new(),
572 stats: CacheStats::default(),
573 };
574 let mut tg = make_tg();
575 let (hit_type, _) = cache.evaluate(0xAA, &[10, 20], &mut tg);
576 assert_eq!(hit_type, "none");
577 }
578
579 #[test]
580 fn test_predict_next() {
581 let mut cache = PredictiveCache {
582 predictions: Vec::new(),
583 stats: CacheStats::default(),
584 };
585 let mut tg = make_tg();
586
587 tg.current_session_id = 1;
589 tg.nodes.push(crate::thought_graph::ThoughtNode {
590 timestamp_ms: 1000,
591 query_hash: 0xAA,
592 session_id: 1,
593 result_count: 3,
594 dominant_layer: 1,
595 centroid_hash: 0,
596 });
597 tg.patterns.push(ThoughtPattern {
598 id: 0,
599 sequence: vec![0xAA, 0xBB],
600 frequency: 5,
601 strength: 2.0,
602 last_seen_ms: 1000,
603 result_blocks: vec![10, 20, 30],
604 });
605
606 cache.predict_next(&tg);
607
608 assert_eq!(cache.predictions.len(), 1);
609 assert_eq!(cache.predictions[0].predicted_query_hash, 0xBB);
610 assert_eq!(cache.predictions[0].blocks, vec![10, 20, 30]);
611 assert!(cache.predictions[0].confidence > 0.0);
612 }
613
614 #[test]
615 fn test_predict_decay() {
616 let mut cache = PredictiveCache {
617 predictions: vec![Prediction {
618 predicted_query_hash: 0xAA,
619 blocks: vec![10],
620 confidence: MIN_CONFIDENCE + 0.01,
621 pattern_id: 0,
622 created_ms: 0,
623 }],
624 stats: CacheStats::default(),
625 };
626 let tg = make_tg();
627
628 for _ in 0..20 {
630 cache.predict_next(&tg);
631 }
632 assert!(cache.predictions.is_empty());
633 }
634
635 #[test]
636 fn test_hit_rate() {
637 let mut stats = CacheStats::default();
638 assert_eq!(stats.hit_rate(), 0.0);
639
640 stats.total_hits = 7;
641 stats.total_misses = 3;
642 assert!((stats.hit_rate() - 0.7).abs() < 0.001);
643
644 stats.total_partial_hits = 2;
645 assert!((stats.hit_rate() - 0.6667).abs() < 0.01);
647 }
648
649 #[test]
650 fn test_save_load_roundtrip() {
651 let dir = tempfile::tempdir().unwrap();
652
653 let cache = PredictiveCache {
654 predictions: vec![
655 Prediction {
656 predicted_query_hash: 0xAA,
657 blocks: vec![10, 20],
658 confidence: 0.75,
659 pattern_id: 1,
660 created_ms: 12345,
661 },
662 Prediction {
663 predicted_query_hash: 0xBB,
664 blocks: vec![30, 40, 50],
665 confidence: 0.5,
666 pattern_id: 2,
667 created_ms: 67890,
668 },
669 ],
670 stats: CacheStats {
671 total_predictions: 10,
672 total_hits: 5,
673 total_misses: 3,
674 total_partial_hits: 2,
675 current_predictions: 2,
676 avg_confidence: 0.625,
677 },
678 };
679
680 cache.save(dir.path()).unwrap();
681 let loaded = PredictiveCache::load_or_init(dir.path());
682
683 assert_eq!(loaded.predictions.len(), 2);
684 assert_eq!(loaded.predictions[0].predicted_query_hash, 0xAA);
685 assert_eq!(loaded.predictions[0].blocks, vec![10, 20]);
686 assert!((loaded.predictions[0].confidence - 0.75).abs() < 0.001);
687 assert_eq!(loaded.predictions[1].blocks, vec![30, 40, 50]);
688 assert_eq!(loaded.stats.total_hits, 5);
689 assert_eq!(loaded.stats.total_misses, 3);
690 assert_eq!(loaded.stats.total_partial_hits, 2);
691 }
692}