Skip to main content

microscope_memory/
attention.rs

1//! Attention Mechanism — dynamic weighting of consciousness layers.
2//!
3//! Computes an attention vector over 7 layers based on query context:
4//! query length, emotional energy, session depth, pattern confidence,
5//! cache hit rate, archetype match score.
6//!
7//! Tracks which attention distributions lead to good results (user stops
8//! searching) vs bad results (user immediately re-queries). Over time,
9//! learns which layer weights work best.
10//!
11//! Binary format: attention.bin (ATT1)
12
13use std::fs;
14use std::io::Write;
15use std::path::Path;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18// ─── Constants ──────────────────────────────────────
19
20/// Layer indices: Hebbian, Mirror, Resonance, Archetype, Emotional, ThoughtGraph, PredictiveCache
21pub const NUM_LAYERS: usize = 7;
22pub const LAYER_NAMES: [&str; NUM_LAYERS] = [
23    "Hebbian",
24    "Mirror",
25    "Resonance",
26    "Archetype",
27    "Emotional",
28    "ThoughtGraph",
29    "PredictiveCache",
30];
31
32const MAX_HISTORY: usize = 200;
33const LEARN_RATE: f32 = 0.05;
34const QUALITY_GOOD_THRESHOLD: f32 = 0.7;
35const QUALITY_BAD_THRESHOLD: f32 = 0.3;
36
37/// Time thresholds for quality inference (ms).
38const SATISFIED_MS: u64 = 60_000; // >60s gap = satisfied
39const UNSATISFIED_MS: u64 = 5_000; // <5s gap = unsatisfied
40
41// ─── AttentionSignals ───────────────────────────────
42
43/// Input signals for computing attention weights.
44pub struct AttentionSignals {
45    pub query_length: usize,
46    pub emotional_energy: f32,
47    pub session_depth: usize,
48    pub pattern_confidence: f32,
49    pub cache_hit_rate: f32,
50    pub archetype_match_score: f32,
51}
52
53// ─── AttentionVector ────────────────────────────────
54
55/// Computed attention weights for a single recall.
56#[derive(Clone, Debug)]
57pub struct AttentionVector {
58    pub weights: [f32; NUM_LAYERS],
59}
60
61impl AttentionVector {
62    /// Get weight for a specific layer.
63    pub fn weight(&self, layer: usize) -> f32 {
64        if layer < NUM_LAYERS {
65            self.weights[layer]
66        } else {
67            1.0
68        }
69    }
70}
71
72// ─── AttentionOutcome ───────────────────────────────
73
74/// Recorded outcome for learning.
75#[derive(Clone, Debug)]
76pub struct AttentionOutcome {
77    pub weights: [f32; NUM_LAYERS],
78    pub timestamp_ms: u64,
79    pub quality: f32,
80}
81
82const OUTCOME_BYTES: usize = NUM_LAYERS * 4 + 8 + 4; // 28 + 8 + 4 = 40
83
84// ─── AttentionState ─────────────────────────────────
85
86pub struct AttentionState {
87    /// Learned optimal weights (running average from good outcomes).
88    pub learned_weights: [f32; NUM_LAYERS],
89    /// Recent outcome history.
90    pub history: Vec<AttentionOutcome>,
91    /// Last recall timestamp for quality inference.
92    pub last_recall_ms: u64,
93    /// Total recalls tracked.
94    pub total_recalls: u32,
95}
96
97impl AttentionState {
98    pub fn load_or_init(output_dir: &Path) -> Self {
99        let path = output_dir.join("attention.bin");
100        if path.exists() {
101            load_attention(&path)
102        } else {
103            Self {
104                learned_weights: [1.0; NUM_LAYERS],
105                history: Vec::new(),
106                last_recall_ms: 0,
107                total_recalls: 0,
108            }
109        }
110    }
111
112    /// Compute attention vector from context signals.
113    pub fn compute_attention(&self, signals: &AttentionSignals) -> AttentionVector {
114        let mut raw = [1.0f32; NUM_LAYERS];
115
116        // Factor 1: query length
117        if signals.query_length <= 10 {
118            raw[0] *= 1.5; // Hebbian: short = familiar territory
119            raw[3] *= 1.3; // Archetype: short = concept lookup
120        } else {
121            raw[5] *= 1.3; // ThoughtGraph: long = complex reasoning chain
122        }
123
124        // Factor 2: emotional energy
125        raw[4] *= 1.0 + signals.emotional_energy.min(2.0);
126
127        // Factor 3: session depth
128        let depth_factor = (signals.session_depth as f32 / 5.0).min(1.0);
129        raw[5] *= 1.0 + depth_factor * 0.5; // ThoughtGraph benefits from deep sessions
130        raw[6] *= 1.0 + depth_factor * 0.3; // PredictiveCache needs session context
131
132        // Factor 4: pattern confidence
133        raw[5] *= 1.0 + signals.pattern_confidence;
134
135        // Factor 5: cache hit rate
136        raw[6] *= 1.0 + signals.cache_hit_rate;
137
138        // Factor 6: archetype match
139        raw[3] *= 1.0 + signals.archetype_match_score.min(2.0);
140
141        // Blend with learned weights (80% computed, 20% learned)
142        for (i, w) in raw.iter_mut().enumerate() {
143            *w = *w * 0.8 + self.learned_weights[i] * 0.2;
144        }
145
146        // Normalize so average weight = 1.0
147        let sum: f32 = raw.iter().sum();
148        if sum > 0.0 {
149            let scale = NUM_LAYERS as f32 / sum;
150            for w in &mut raw {
151                *w *= scale;
152            }
153        }
154
155        AttentionVector { weights: raw }
156    }
157
158    /// Infer quality of last recall from time gap.
159    /// Returns quality score (0.0 = bad, 1.0 = good).
160    pub fn infer_quality(&self) -> f32 {
161        if self.last_recall_ms == 0 {
162            return 0.5; // no data
163        }
164        let now = now_epoch_ms();
165        let gap = now.saturating_sub(self.last_recall_ms);
166
167        if gap >= SATISFIED_MS {
168            1.0
169        } else if gap <= UNSATISFIED_MS {
170            0.2
171        } else {
172            // Linear interpolation
173            let t = (gap - UNSATISFIED_MS) as f32 / (SATISFIED_MS - UNSATISFIED_MS) as f32;
174            0.2 + t * 0.8
175        }
176    }
177
178    /// Record the outcome of a recall (the quality applies to the PREVIOUS recall).
179    pub fn record_outcome(&mut self, quality: f32, weights: &[f32; NUM_LAYERS]) {
180        self.history.push(AttentionOutcome {
181            weights: *weights,
182            timestamp_ms: now_epoch_ms(),
183            quality,
184        });
185
186        if self.history.len() > MAX_HISTORY {
187            self.history.drain(0..(self.history.len() - MAX_HISTORY));
188        }
189
190        self.update_learned_weights();
191    }
192
193    /// Mark that a recall just happened (for next quality inference).
194    pub fn mark_recall(&mut self) {
195        self.last_recall_ms = now_epoch_ms();
196        self.total_recalls += 1;
197    }
198
199    /// Update learned weights from outcome history using EMA.
200    fn update_learned_weights(&mut self) {
201        let good: Vec<&AttentionOutcome> = self
202            .history
203            .iter()
204            .filter(|o| o.quality >= QUALITY_GOOD_THRESHOLD)
205            .collect();
206        let bad: Vec<&AttentionOutcome> = self
207            .history
208            .iter()
209            .filter(|o| o.quality <= QUALITY_BAD_THRESHOLD)
210            .collect();
211
212        if good.is_empty() && bad.is_empty() {
213            return;
214        }
215
216        for i in 0..NUM_LAYERS {
217            let good_avg = if good.is_empty() {
218                self.learned_weights[i]
219            } else {
220                good.iter().map(|o| o.weights[i]).sum::<f32>() / good.len() as f32
221            };
222            let bad_avg = if bad.is_empty() {
223                self.learned_weights[i]
224            } else {
225                bad.iter().map(|o| o.weights[i]).sum::<f32>() / bad.len() as f32
226            };
227
228            let delta = good_avg - bad_avg;
229            self.learned_weights[i] += delta * LEARN_RATE;
230            self.learned_weights[i] = self.learned_weights[i].clamp(0.1, 3.0);
231        }
232    }
233
234    /// Save to binary.
235    pub fn save(&self, output_dir: &Path) -> Result<(), String> {
236        save_attention(&output_dir.join("attention.bin"), self)
237    }
238}
239
240// ─── Binary I/O ─────────────────────────────────────
241
242fn now_epoch_ms() -> u64 {
243    SystemTime::now()
244        .duration_since(UNIX_EPOCH)
245        .unwrap_or_default()
246        .as_millis() as u64
247}
248
249fn save_attention(path: &Path, state: &AttentionState) -> Result<(), String> {
250    let mut buf = Vec::with_capacity(48 + state.history.len() * OUTCOME_BYTES);
251
252    // Header
253    buf.write_all(b"ATT1").map_err(|e| e.to_string())?;
254    buf.write_all(&state.total_recalls.to_le_bytes())
255        .map_err(|e| e.to_string())?;
256    buf.write_all(&state.last_recall_ms.to_le_bytes())
257        .map_err(|e| e.to_string())?;
258
259    // Learned weights
260    for &w in &state.learned_weights {
261        buf.write_all(&w.to_le_bytes()).map_err(|e| e.to_string())?;
262    }
263
264    // History
265    buf.write_all(&(state.history.len() as u32).to_le_bytes())
266        .map_err(|e| e.to_string())?;
267
268    for outcome in &state.history {
269        for &w in &outcome.weights {
270            buf.write_all(&w.to_le_bytes()).map_err(|e| e.to_string())?;
271        }
272        buf.write_all(&outcome.timestamp_ms.to_le_bytes())
273            .map_err(|e| e.to_string())?;
274        buf.write_all(&outcome.quality.to_le_bytes())
275            .map_err(|e| e.to_string())?;
276    }
277
278    fs::write(path, &buf).map_err(|e| e.to_string())
279}
280
281fn load_attention(path: &Path) -> AttentionState {
282    let data = match fs::read(path) {
283        Ok(d) => d,
284        Err(_) => {
285            return AttentionState {
286                learned_weights: [1.0; NUM_LAYERS],
287                history: Vec::new(),
288                last_recall_ms: 0,
289                total_recalls: 0,
290            }
291        }
292    };
293
294    // Header: 4 (magic) + 4 (total_recalls) + 8 (last_recall_ms) + 28 (learned_weights) + 4 (history_count) = 48
295    if data.len() < 48 || &data[0..4] != b"ATT1" {
296        return AttentionState {
297            learned_weights: [1.0; NUM_LAYERS],
298            history: Vec::new(),
299            last_recall_ms: 0,
300            total_recalls: 0,
301        };
302    }
303
304    let total_recalls = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
305    let last_recall_ms = u64::from_le_bytes([
306        data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
307    ]);
308
309    let mut learned_weights = [0.0f32; NUM_LAYERS];
310    let mut offset = 16;
311    for w in &mut learned_weights {
312        *w = f32::from_le_bytes([
313            data[offset],
314            data[offset + 1],
315            data[offset + 2],
316            data[offset + 3],
317        ]);
318        offset += 4;
319    }
320
321    let history_count = u32::from_le_bytes([
322        data[offset],
323        data[offset + 1],
324        data[offset + 2],
325        data[offset + 3],
326    ]) as usize;
327    offset += 4;
328
329    let mut history = Vec::with_capacity(history_count);
330    for _ in 0..history_count {
331        if offset + OUTCOME_BYTES > data.len() {
332            break;
333        }
334
335        let mut weights = [0.0f32; NUM_LAYERS];
336        for w in &mut weights {
337            *w = f32::from_le_bytes([
338                data[offset],
339                data[offset + 1],
340                data[offset + 2],
341                data[offset + 3],
342            ]);
343            offset += 4;
344        }
345
346        let timestamp_ms = u64::from_le_bytes([
347            data[offset],
348            data[offset + 1],
349            data[offset + 2],
350            data[offset + 3],
351            data[offset + 4],
352            data[offset + 5],
353            data[offset + 6],
354            data[offset + 7],
355        ]);
356        offset += 8;
357
358        let quality = f32::from_le_bytes([
359            data[offset],
360            data[offset + 1],
361            data[offset + 2],
362            data[offset + 3],
363        ]);
364        offset += 4;
365
366        history.push(AttentionOutcome {
367            weights,
368            timestamp_ms,
369            quality,
370        });
371    }
372
373    AttentionState {
374        learned_weights,
375        history,
376        last_recall_ms,
377        total_recalls,
378    }
379}
380
381// ─── Tests ──────────────────────────────────────────
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    fn default_signals() -> AttentionSignals {
388        AttentionSignals {
389            query_length: 20,
390            emotional_energy: 0.0,
391            session_depth: 0,
392            pattern_confidence: 0.0,
393            cache_hit_rate: 0.0,
394            archetype_match_score: 0.0,
395        }
396    }
397
398    #[test]
399    fn test_compute_attention_short_query() {
400        let state = AttentionState {
401            learned_weights: [1.0; NUM_LAYERS],
402            history: Vec::new(),
403            last_recall_ms: 0,
404            total_recalls: 0,
405        };
406        let mut signals = default_signals();
407        signals.query_length = 5; // short
408
409        let attn = state.compute_attention(&signals);
410        // Hebbian (0) and Archetype (3) should be elevated
411        assert!(attn.weights[0] > attn.weights[1]); // Hebbian > Mirror
412        assert!(attn.weights[3] > attn.weights[1]); // Archetype > Mirror
413    }
414
415    #[test]
416    fn test_compute_attention_long_query() {
417        let state = AttentionState {
418            learned_weights: [1.0; NUM_LAYERS],
419            history: Vec::new(),
420            last_recall_ms: 0,
421            total_recalls: 0,
422        };
423        let mut signals = default_signals();
424        signals.query_length = 50; // long
425
426        let attn = state.compute_attention(&signals);
427        // ThoughtGraph (5) should be elevated
428        assert!(attn.weights[5] > attn.weights[1]); // ThoughtGraph > Mirror
429    }
430
431    #[test]
432    fn test_compute_attention_high_emotion() {
433        let state = AttentionState {
434            learned_weights: [1.0; NUM_LAYERS],
435            history: Vec::new(),
436            last_recall_ms: 0,
437            total_recalls: 0,
438        };
439        let mut signals = default_signals();
440        signals.emotional_energy = 2.0;
441
442        let attn = state.compute_attention(&signals);
443        // Emotional (4) should be highest
444        assert!(attn.weights[4] > attn.weights[0]);
445    }
446
447    #[test]
448    fn test_compute_attention_deep_session() {
449        let state = AttentionState {
450            learned_weights: [1.0; NUM_LAYERS],
451            history: Vec::new(),
452            last_recall_ms: 0,
453            total_recalls: 0,
454        };
455        let mut signals = default_signals();
456        signals.session_depth = 10;
457
458        let attn = state.compute_attention(&signals);
459        // ThoughtGraph and PredictiveCache should be elevated
460        assert!(attn.weights[5] > attn.weights[1]);
461        assert!(attn.weights[6] > attn.weights[1]);
462    }
463
464    #[test]
465    fn test_normalization() {
466        let state = AttentionState {
467            learned_weights: [1.0; NUM_LAYERS],
468            history: Vec::new(),
469            last_recall_ms: 0,
470            total_recalls: 0,
471        };
472        let signals = default_signals();
473        let attn = state.compute_attention(&signals);
474
475        let sum: f32 = attn.weights.iter().sum();
476        assert!((sum - NUM_LAYERS as f32).abs() < 0.01);
477    }
478
479    #[test]
480    fn test_quality_inference_fast_requery() {
481        let state = AttentionState {
482            learned_weights: [1.0; NUM_LAYERS],
483            history: Vec::new(),
484            last_recall_ms: now_epoch_ms() - 2_000, // 2s ago
485            total_recalls: 1,
486        };
487        let q = state.infer_quality();
488        assert!(q < 0.3); // unsatisfied
489    }
490
491    #[test]
492    fn test_quality_inference_satisfied() {
493        let state = AttentionState {
494            learned_weights: [1.0; NUM_LAYERS],
495            history: Vec::new(),
496            last_recall_ms: now_epoch_ms() - 120_000, // 2 min ago
497            total_recalls: 1,
498        };
499        let q = state.infer_quality();
500        assert!((q - 1.0).abs() < 0.01);
501    }
502
503    #[test]
504    fn test_learned_weights_update() {
505        let mut state = AttentionState {
506            learned_weights: [1.0; NUM_LAYERS],
507            history: Vec::new(),
508            last_recall_ms: 0,
509            total_recalls: 0,
510        };
511
512        // Record some good outcomes with high Hebbian weight
513        let mut good_weights = [1.0f32; NUM_LAYERS];
514        good_weights[0] = 2.0; // Hebbian high
515        for _ in 0..5 {
516            state.record_outcome(0.9, &good_weights);
517        }
518
519        // Record some bad outcomes with high Mirror weight
520        let mut bad_weights = [1.0f32; NUM_LAYERS];
521        bad_weights[1] = 2.0; // Mirror high
522        for _ in 0..5 {
523            state.record_outcome(0.1, &bad_weights);
524        }
525
526        // Learned weights should now favor Hebbian over Mirror
527        assert!(state.learned_weights[0] > state.learned_weights[1]);
528    }
529
530    #[test]
531    fn test_save_load_roundtrip() {
532        let dir = tempfile::tempdir().unwrap();
533        let mut state = AttentionState {
534            learned_weights: [1.1, 0.9, 1.0, 1.3, 0.8, 1.2, 0.7],
535            history: Vec::new(),
536            last_recall_ms: 12345678,
537            total_recalls: 42,
538        };
539        state.history.push(AttentionOutcome {
540            weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
541            timestamp_ms: 999,
542            quality: 0.75,
543        });
544
545        state.save(dir.path()).unwrap();
546        let loaded = AttentionState::load_or_init(dir.path());
547
548        assert_eq!(loaded.total_recalls, 42);
549        assert_eq!(loaded.last_recall_ms, 12345678);
550        assert!((loaded.learned_weights[0] - 1.1).abs() < 0.001);
551        assert!((loaded.learned_weights[6] - 0.7).abs() < 0.001);
552        assert_eq!(loaded.history.len(), 1);
553        assert!((loaded.history[0].quality - 0.75).abs() < 0.001);
554    }
555
556    #[test]
557    fn test_history_cap() {
558        let mut state = AttentionState {
559            learned_weights: [1.0; NUM_LAYERS],
560            history: Vec::new(),
561            last_recall_ms: 0,
562            total_recalls: 0,
563        };
564        let weights = [1.0; NUM_LAYERS];
565        for _ in 0..MAX_HISTORY + 50 {
566            state.record_outcome(0.5, &weights);
567        }
568        assert!(state.history.len() <= MAX_HISTORY);
569    }
570}