1use std::fs;
14use std::io::Write;
15use std::path::Path;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18pub const NUM_LAYERS: usize = 7;
22pub const LAYER_NAMES: [&str; NUM_LAYERS] = [
23 "Hebbian",
24 "Mirror",
25 "Resonance",
26 "Archetype",
27 "Emotional",
28 "ThoughtGraph",
29 "PredictiveCache",
30];
31
32const MAX_HISTORY: usize = 200;
33const LEARN_RATE: f32 = 0.05;
34const QUALITY_GOOD_THRESHOLD: f32 = 0.7;
35const QUALITY_BAD_THRESHOLD: f32 = 0.3;
36
37const SATISFIED_MS: u64 = 60_000; const UNSATISFIED_MS: u64 = 5_000; pub struct AttentionSignals {
45 pub query_length: usize,
46 pub emotional_energy: f32,
47 pub session_depth: usize,
48 pub pattern_confidence: f32,
49 pub cache_hit_rate: f32,
50 pub archetype_match_score: f32,
51}
52
53#[derive(Clone, Debug)]
57pub struct AttentionVector {
58 pub weights: [f32; NUM_LAYERS],
59}
60
61impl AttentionVector {
62 pub fn weight(&self, layer: usize) -> f32 {
64 if layer < NUM_LAYERS {
65 self.weights[layer]
66 } else {
67 1.0
68 }
69 }
70}
71
72#[derive(Clone, Debug)]
76pub struct AttentionOutcome {
77 pub weights: [f32; NUM_LAYERS],
78 pub timestamp_ms: u64,
79 pub quality: f32,
80}
81
82const OUTCOME_BYTES: usize = NUM_LAYERS * 4 + 8 + 4; pub struct AttentionState {
87 pub learned_weights: [f32; NUM_LAYERS],
89 pub history: Vec<AttentionOutcome>,
91 pub last_recall_ms: u64,
93 pub total_recalls: u32,
95}
96
97impl AttentionState {
98 pub fn load_or_init(output_dir: &Path) -> Self {
99 let path = output_dir.join("attention.bin");
100 if path.exists() {
101 load_attention(&path)
102 } else {
103 Self {
104 learned_weights: [1.0; NUM_LAYERS],
105 history: Vec::new(),
106 last_recall_ms: 0,
107 total_recalls: 0,
108 }
109 }
110 }
111
112 pub fn compute_attention(&self, signals: &AttentionSignals) -> AttentionVector {
114 let mut raw = [1.0f32; NUM_LAYERS];
115
116 if signals.query_length <= 10 {
118 raw[0] *= 1.5; raw[3] *= 1.3; } else {
121 raw[5] *= 1.3; }
123
124 raw[4] *= 1.0 + signals.emotional_energy.min(2.0);
126
127 let depth_factor = (signals.session_depth as f32 / 5.0).min(1.0);
129 raw[5] *= 1.0 + depth_factor * 0.5; raw[6] *= 1.0 + depth_factor * 0.3; raw[5] *= 1.0 + signals.pattern_confidence;
134
135 raw[6] *= 1.0 + signals.cache_hit_rate;
137
138 raw[3] *= 1.0 + signals.archetype_match_score.min(2.0);
140
141 for (i, w) in raw.iter_mut().enumerate() {
143 *w = *w * 0.8 + self.learned_weights[i] * 0.2;
144 }
145
146 let sum: f32 = raw.iter().sum();
148 if sum > 0.0 {
149 let scale = NUM_LAYERS as f32 / sum;
150 for w in &mut raw {
151 *w *= scale;
152 }
153 }
154
155 AttentionVector { weights: raw }
156 }
157
158 pub fn infer_quality(&self) -> f32 {
161 if self.last_recall_ms == 0 {
162 return 0.5; }
164 let now = now_epoch_ms();
165 let gap = now.saturating_sub(self.last_recall_ms);
166
167 if gap >= SATISFIED_MS {
168 1.0
169 } else if gap <= UNSATISFIED_MS {
170 0.2
171 } else {
172 let t = (gap - UNSATISFIED_MS) as f32 / (SATISFIED_MS - UNSATISFIED_MS) as f32;
174 0.2 + t * 0.8
175 }
176 }
177
178 pub fn record_outcome(&mut self, quality: f32, weights: &[f32; NUM_LAYERS]) {
180 self.history.push(AttentionOutcome {
181 weights: *weights,
182 timestamp_ms: now_epoch_ms(),
183 quality,
184 });
185
186 if self.history.len() > MAX_HISTORY {
187 self.history.drain(0..(self.history.len() - MAX_HISTORY));
188 }
189
190 self.update_learned_weights();
191 }
192
193 pub fn mark_recall(&mut self) {
195 self.last_recall_ms = now_epoch_ms();
196 self.total_recalls += 1;
197 }
198
199 fn update_learned_weights(&mut self) {
201 let good: Vec<&AttentionOutcome> = self
202 .history
203 .iter()
204 .filter(|o| o.quality >= QUALITY_GOOD_THRESHOLD)
205 .collect();
206 let bad: Vec<&AttentionOutcome> = self
207 .history
208 .iter()
209 .filter(|o| o.quality <= QUALITY_BAD_THRESHOLD)
210 .collect();
211
212 if good.is_empty() && bad.is_empty() {
213 return;
214 }
215
216 for i in 0..NUM_LAYERS {
217 let good_avg = if good.is_empty() {
218 self.learned_weights[i]
219 } else {
220 good.iter().map(|o| o.weights[i]).sum::<f32>() / good.len() as f32
221 };
222 let bad_avg = if bad.is_empty() {
223 self.learned_weights[i]
224 } else {
225 bad.iter().map(|o| o.weights[i]).sum::<f32>() / bad.len() as f32
226 };
227
228 let delta = good_avg - bad_avg;
229 self.learned_weights[i] += delta * LEARN_RATE;
230 self.learned_weights[i] = self.learned_weights[i].clamp(0.1, 3.0);
231 }
232 }
233
234 pub fn save(&self, output_dir: &Path) -> Result<(), String> {
236 save_attention(&output_dir.join("attention.bin"), self)
237 }
238}
239
240fn now_epoch_ms() -> u64 {
243 SystemTime::now()
244 .duration_since(UNIX_EPOCH)
245 .unwrap_or_default()
246 .as_millis() as u64
247}
248
249fn save_attention(path: &Path, state: &AttentionState) -> Result<(), String> {
250 let mut buf = Vec::with_capacity(48 + state.history.len() * OUTCOME_BYTES);
251
252 buf.write_all(b"ATT1").map_err(|e| e.to_string())?;
254 buf.write_all(&state.total_recalls.to_le_bytes())
255 .map_err(|e| e.to_string())?;
256 buf.write_all(&state.last_recall_ms.to_le_bytes())
257 .map_err(|e| e.to_string())?;
258
259 for &w in &state.learned_weights {
261 buf.write_all(&w.to_le_bytes()).map_err(|e| e.to_string())?;
262 }
263
264 buf.write_all(&(state.history.len() as u32).to_le_bytes())
266 .map_err(|e| e.to_string())?;
267
268 for outcome in &state.history {
269 for &w in &outcome.weights {
270 buf.write_all(&w.to_le_bytes()).map_err(|e| e.to_string())?;
271 }
272 buf.write_all(&outcome.timestamp_ms.to_le_bytes())
273 .map_err(|e| e.to_string())?;
274 buf.write_all(&outcome.quality.to_le_bytes())
275 .map_err(|e| e.to_string())?;
276 }
277
278 fs::write(path, &buf).map_err(|e| e.to_string())
279}
280
281fn load_attention(path: &Path) -> AttentionState {
282 let data = match fs::read(path) {
283 Ok(d) => d,
284 Err(_) => {
285 return AttentionState {
286 learned_weights: [1.0; NUM_LAYERS],
287 history: Vec::new(),
288 last_recall_ms: 0,
289 total_recalls: 0,
290 }
291 }
292 };
293
294 if data.len() < 48 || &data[0..4] != b"ATT1" {
296 return AttentionState {
297 learned_weights: [1.0; NUM_LAYERS],
298 history: Vec::new(),
299 last_recall_ms: 0,
300 total_recalls: 0,
301 };
302 }
303
304 let total_recalls = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
305 let last_recall_ms = u64::from_le_bytes([
306 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
307 ]);
308
309 let mut learned_weights = [0.0f32; NUM_LAYERS];
310 let mut offset = 16;
311 for w in &mut learned_weights {
312 *w = f32::from_le_bytes([
313 data[offset],
314 data[offset + 1],
315 data[offset + 2],
316 data[offset + 3],
317 ]);
318 offset += 4;
319 }
320
321 let history_count = u32::from_le_bytes([
322 data[offset],
323 data[offset + 1],
324 data[offset + 2],
325 data[offset + 3],
326 ]) as usize;
327 offset += 4;
328
329 let mut history = Vec::with_capacity(history_count);
330 for _ in 0..history_count {
331 if offset + OUTCOME_BYTES > data.len() {
332 break;
333 }
334
335 let mut weights = [0.0f32; NUM_LAYERS];
336 for w in &mut weights {
337 *w = f32::from_le_bytes([
338 data[offset],
339 data[offset + 1],
340 data[offset + 2],
341 data[offset + 3],
342 ]);
343 offset += 4;
344 }
345
346 let timestamp_ms = u64::from_le_bytes([
347 data[offset],
348 data[offset + 1],
349 data[offset + 2],
350 data[offset + 3],
351 data[offset + 4],
352 data[offset + 5],
353 data[offset + 6],
354 data[offset + 7],
355 ]);
356 offset += 8;
357
358 let quality = f32::from_le_bytes([
359 data[offset],
360 data[offset + 1],
361 data[offset + 2],
362 data[offset + 3],
363 ]);
364 offset += 4;
365
366 history.push(AttentionOutcome {
367 weights,
368 timestamp_ms,
369 quality,
370 });
371 }
372
373 AttentionState {
374 learned_weights,
375 history,
376 last_recall_ms,
377 total_recalls,
378 }
379}
380
381#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn default_signals() -> AttentionSignals {
388 AttentionSignals {
389 query_length: 20,
390 emotional_energy: 0.0,
391 session_depth: 0,
392 pattern_confidence: 0.0,
393 cache_hit_rate: 0.0,
394 archetype_match_score: 0.0,
395 }
396 }
397
398 #[test]
399 fn test_compute_attention_short_query() {
400 let state = AttentionState {
401 learned_weights: [1.0; NUM_LAYERS],
402 history: Vec::new(),
403 last_recall_ms: 0,
404 total_recalls: 0,
405 };
406 let mut signals = default_signals();
407 signals.query_length = 5; let attn = state.compute_attention(&signals);
410 assert!(attn.weights[0] > attn.weights[1]); assert!(attn.weights[3] > attn.weights[1]); }
414
415 #[test]
416 fn test_compute_attention_long_query() {
417 let state = AttentionState {
418 learned_weights: [1.0; NUM_LAYERS],
419 history: Vec::new(),
420 last_recall_ms: 0,
421 total_recalls: 0,
422 };
423 let mut signals = default_signals();
424 signals.query_length = 50; let attn = state.compute_attention(&signals);
427 assert!(attn.weights[5] > attn.weights[1]); }
430
431 #[test]
432 fn test_compute_attention_high_emotion() {
433 let state = AttentionState {
434 learned_weights: [1.0; NUM_LAYERS],
435 history: Vec::new(),
436 last_recall_ms: 0,
437 total_recalls: 0,
438 };
439 let mut signals = default_signals();
440 signals.emotional_energy = 2.0;
441
442 let attn = state.compute_attention(&signals);
443 assert!(attn.weights[4] > attn.weights[0]);
445 }
446
447 #[test]
448 fn test_compute_attention_deep_session() {
449 let state = AttentionState {
450 learned_weights: [1.0; NUM_LAYERS],
451 history: Vec::new(),
452 last_recall_ms: 0,
453 total_recalls: 0,
454 };
455 let mut signals = default_signals();
456 signals.session_depth = 10;
457
458 let attn = state.compute_attention(&signals);
459 assert!(attn.weights[5] > attn.weights[1]);
461 assert!(attn.weights[6] > attn.weights[1]);
462 }
463
464 #[test]
465 fn test_normalization() {
466 let state = AttentionState {
467 learned_weights: [1.0; NUM_LAYERS],
468 history: Vec::new(),
469 last_recall_ms: 0,
470 total_recalls: 0,
471 };
472 let signals = default_signals();
473 let attn = state.compute_attention(&signals);
474
475 let sum: f32 = attn.weights.iter().sum();
476 assert!((sum - NUM_LAYERS as f32).abs() < 0.01);
477 }
478
479 #[test]
480 fn test_quality_inference_fast_requery() {
481 let state = AttentionState {
482 learned_weights: [1.0; NUM_LAYERS],
483 history: Vec::new(),
484 last_recall_ms: now_epoch_ms() - 2_000, total_recalls: 1,
486 };
487 let q = state.infer_quality();
488 assert!(q < 0.3); }
490
491 #[test]
492 fn test_quality_inference_satisfied() {
493 let state = AttentionState {
494 learned_weights: [1.0; NUM_LAYERS],
495 history: Vec::new(),
496 last_recall_ms: now_epoch_ms() - 120_000, total_recalls: 1,
498 };
499 let q = state.infer_quality();
500 assert!((q - 1.0).abs() < 0.01);
501 }
502
503 #[test]
504 fn test_learned_weights_update() {
505 let mut state = AttentionState {
506 learned_weights: [1.0; NUM_LAYERS],
507 history: Vec::new(),
508 last_recall_ms: 0,
509 total_recalls: 0,
510 };
511
512 let mut good_weights = [1.0f32; NUM_LAYERS];
514 good_weights[0] = 2.0; for _ in 0..5 {
516 state.record_outcome(0.9, &good_weights);
517 }
518
519 let mut bad_weights = [1.0f32; NUM_LAYERS];
521 bad_weights[1] = 2.0; for _ in 0..5 {
523 state.record_outcome(0.1, &bad_weights);
524 }
525
526 assert!(state.learned_weights[0] > state.learned_weights[1]);
528 }
529
530 #[test]
531 fn test_save_load_roundtrip() {
532 let dir = tempfile::tempdir().unwrap();
533 let mut state = AttentionState {
534 learned_weights: [1.1, 0.9, 1.0, 1.3, 0.8, 1.2, 0.7],
535 history: Vec::new(),
536 last_recall_ms: 12345678,
537 total_recalls: 42,
538 };
539 state.history.push(AttentionOutcome {
540 weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
541 timestamp_ms: 999,
542 quality: 0.75,
543 });
544
545 state.save(dir.path()).unwrap();
546 let loaded = AttentionState::load_or_init(dir.path());
547
548 assert_eq!(loaded.total_recalls, 42);
549 assert_eq!(loaded.last_recall_ms, 12345678);
550 assert!((loaded.learned_weights[0] - 1.1).abs() < 0.001);
551 assert!((loaded.learned_weights[6] - 0.7).abs() < 0.001);
552 assert_eq!(loaded.history.len(), 1);
553 assert!((loaded.history[0].quality - 0.75).abs() < 0.001);
554 }
555
556 #[test]
557 fn test_history_cap() {
558 let mut state = AttentionState {
559 learned_weights: [1.0; NUM_LAYERS],
560 history: Vec::new(),
561 last_recall_ms: 0,
562 total_recalls: 0,
563 };
564 let weights = [1.0; NUM_LAYERS];
565 for _ in 0..MAX_HISTORY + 50 {
566 state.record_outcome(0.5, &weights);
567 }
568 assert!(state.history.len() <= MAX_HISTORY);
569 }
570}