1use std::sync::atomic::{AtomicU64, Ordering};
14
15use crossbeam_queue::SegQueue;
16
17use hirn_core::id::MemoryId;
18use hirn_core::timestamp::Timestamp;
19
20use crate::graph::PropertyGraph;
21
22#[derive(Debug, Clone)]
24pub struct HebbianConfig {
25 pub learning_rate: f64,
27 pub decay_rate: f64,
29 pub min_weight: f32,
31}
32
33impl Default for HebbianConfig {
34 fn default() -> Self {
35 Self {
36 learning_rate: 0.05,
37 decay_rate: 0.01,
38 min_weight: 0.01,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct HebbianUpdateResult {
46 pub strengthened: usize,
48 pub decayed: usize,
50}
51
52pub fn hebbian_update(
57 graph: &mut PropertyGraph,
58 retrieved_ids: &[MemoryId],
59 config: &HebbianConfig,
60) -> HebbianUpdateResult {
61 let mut strengthened = 0;
62 let mut decayed = 0;
63
64 let retrieved_set: std::collections::HashSet<MemoryId> =
65 retrieved_ids.iter().copied().collect();
66
67 let now = Timestamp::now();
68
69 let mut co_retrieval_edges = Vec::new();
71 let mut decay_edges = Vec::new();
72
73 for &node_id in retrieved_ids {
74 for edge in graph.get_edges(node_id) {
75 let partner = if edge.source == node_id {
76 edge.target
77 } else {
78 edge.source
79 };
80
81 if retrieved_set.contains(&partner) {
82 co_retrieval_edges.push(edge.id);
84 } else {
85 decay_edges.push(edge.id);
87 }
88 }
89 }
90
91 co_retrieval_edges.sort();
93 co_retrieval_edges.dedup();
94 decay_edges.sort();
95 decay_edges.dedup();
96
97 let co_retrieval_set: std::collections::HashSet<crate::graph::EdgeId> =
100 co_retrieval_edges.iter().copied().collect();
101 decay_edges.retain(|eid| !co_retrieval_set.contains(eid));
102
103 let eta = config.learning_rate;
105 for eid in co_retrieval_edges {
106 if let Some(edge) = graph.edge_mut(eid) {
107 let delta = 1.0; let new_weight = eta.mul_add(delta, f64::from(edge.weight)).min(1.0);
109 edge.weight = new_weight as f32;
110 edge.co_retrieval_count += 1;
111 edge.updated_at = now;
112 strengthened += 1;
113 }
114 }
115
116 let base_lambda = config.decay_rate;
120 let min_w = config.min_weight;
121 for eid in decay_edges {
122 if let Some(edge) = graph.edge_mut(eid) {
123 let relation_multiplier = decay_multiplier_for_relation(edge.relation);
124 let lambda = base_lambda * relation_multiplier;
125 let new_weight = (f64::from(edge.weight) * (1.0 - lambda)).max(f64::from(min_w));
126 edge.weight = new_weight as f32;
127 edge.updated_at = now;
128 decayed += 1;
129 }
130 }
131
132 HebbianUpdateResult {
133 strengthened,
134 decayed,
135 }
136}
137
138const fn decay_multiplier_for_relation(relation: hirn_core::types::EdgeRelation) -> f64 {
141 use hirn_core::types::EdgeRelation;
142 match relation {
143 EdgeRelation::Causes | EdgeRelation::CausedBy | EdgeRelation::DerivedFrom => 0.2,
145 EdgeRelation::TemporalNext => 0.3,
147 EdgeRelation::SimilarTo => 0.5,
149 EdgeRelation::Contradicts => 0.1,
151 EdgeRelation::Supports
153 | EdgeRelation::PartOf
154 | EdgeRelation::InstanceOf
155 | EdgeRelation::ParticipatesIn => 0.4,
156 EdgeRelation::Inhibits => 0.6,
158 EdgeRelation::RelatedTo => 1.0,
160 }
161}
162
163const DEFAULT_FLUSH_THRESHOLD: u64 = 16;
167
168pub struct HebbianBuffer {
174 queue: SegQueue<Vec<MemoryId>>,
175 push_count: AtomicU64,
176 flush_threshold: u64,
177}
178
179impl HebbianBuffer {
180 #[must_use]
182 pub fn new() -> Self {
183 Self {
184 queue: SegQueue::new(),
185 push_count: AtomicU64::new(0),
186 flush_threshold: DEFAULT_FLUSH_THRESHOLD,
187 }
188 }
189
190 #[must_use]
192 pub fn with_threshold(threshold: u64) -> Self {
193 Self {
194 queue: SegQueue::new(),
195 push_count: AtomicU64::new(0),
196 flush_threshold: threshold,
197 }
198 }
199
200 pub fn push(&self, retrieved_ids: Vec<MemoryId>) -> bool {
205 self.queue.push(retrieved_ids);
206 let count = self.push_count.fetch_add(1, Ordering::Relaxed) + 1;
207 count >= self.flush_threshold
208 }
209
210 pub fn flush(&self, graph: &mut PropertyGraph, config: &HebbianConfig) -> HebbianUpdateResult {
214 self.push_count.store(0, Ordering::Relaxed);
215
216 let mut total = HebbianUpdateResult {
217 strengthened: 0,
218 decayed: 0,
219 };
220
221 while let Some(ids) = self.queue.pop() {
222 let result = hebbian_update(graph, &ids, config);
223 total.strengthened += result.strengthened;
224 total.decayed += result.decayed;
225 }
226
227 total
228 }
229
230 pub fn pending_count(&self) -> u64 {
232 self.push_count.load(Ordering::Relaxed)
233 }
234
235 pub fn pop(&self) -> Option<Vec<MemoryId>> {
237 self.queue.pop()
238 }
239
240 pub fn reset_counter(&self) {
242 self.push_count.store(0, Ordering::Relaxed);
243 }
244}
245
246impl Default for HebbianBuffer {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252#[cfg(test)]
255mod tests {
256 use super::*;
257 use hirn_core::metadata::Metadata;
258 use hirn_core::timestamp::Timestamp;
259 use hirn_core::types::{EdgeRelation, Layer};
260
261 fn make_node(pg: &mut PropertyGraph) -> MemoryId {
262 let id = MemoryId::new();
263 pg.add_node(id, Layer::Episodic, 0.5, Timestamp::now());
264 id
265 }
266
267 #[test]
268 fn co_retrieval_strengthens_edge() {
269 let mut pg = PropertyGraph::new();
270 let a = make_node(&mut pg);
271 let b = make_node(&mut pg);
272 pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
273 .unwrap();
274
275 let initial_weight = pg.get_edges(a)[0].weight;
276
277 for _ in 0..10 {
278 hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
279 }
280
281 let final_weight = pg.get_edges(a)[0].weight;
282 assert!(
283 final_weight > initial_weight,
284 "co-retrieval should strengthen: initial={initial_weight}, final={final_weight}"
285 );
286 }
287
288 #[test]
289 fn solo_retrieval_decays_edge() {
290 let mut pg = PropertyGraph::new();
291 let a = make_node(&mut pg);
292 let b = make_node(&mut pg);
293 pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
294 .unwrap();
295
296 let initial_weight = pg.get_edges(a)[0].weight;
297
298 for _ in 0..100 {
300 hebbian_update(&mut pg, &[a], &HebbianConfig::default());
301 }
302
303 let final_weight = pg.get_edges(a)[0].weight;
304 assert!(
305 final_weight < initial_weight,
306 "solo retrieval should decay: initial={initial_weight}, final={final_weight}"
307 );
308 }
309
310 #[test]
311 fn co_retrieval_count_incremented() {
312 let mut pg = PropertyGraph::new();
313 let a = make_node(&mut pg);
314 let b = make_node(&mut pg);
315 pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
316 .unwrap();
317
318 hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
319 hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
320 hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
321
322 let count = pg.get_edges(a)[0].co_retrieval_count;
323 assert_eq!(count, 3, "co_retrieval_count should be 3, got {count}");
324 }
325
326 #[test]
327 fn weight_never_exceeds_one() {
328 let mut pg = PropertyGraph::new();
329 let a = make_node(&mut pg);
330 let b = make_node(&mut pg);
331 pg.add_edge(a, b, EdgeRelation::Causes, 0.95, Metadata::new())
332 .unwrap();
333
334 let cfg = HebbianConfig {
335 learning_rate: 0.5, ..Default::default()
337 };
338
339 for _ in 0..1000 {
340 hebbian_update(&mut pg, &[a, b], &cfg);
341 }
342
343 let w = pg.get_edges(a)[0].weight;
344 assert!(w <= 1.0, "weight exceeded 1.0: {w}");
345 }
346
347 #[test]
348 fn weight_never_below_min() {
349 let mut pg = PropertyGraph::new();
350 let a = make_node(&mut pg);
351 let b = make_node(&mut pg);
352 pg.add_edge(a, b, EdgeRelation::Causes, 0.1, Metadata::new())
353 .unwrap();
354
355 let cfg = HebbianConfig {
356 decay_rate: 0.5, min_weight: 0.01,
358 ..Default::default()
359 };
360
361 for _ in 0..1000 {
362 hebbian_update(&mut pg, &[a], &cfg);
363 }
364
365 let w = pg.get_edges(a)[0].weight;
366 assert!(w >= 0.01, "weight fell below min_weight 0.01: {w}");
367 }
368
369 #[test]
370 #[allow(clippy::similar_names)]
371 fn self_organizing_clusters() {
372 let mut pg = PropertyGraph::new();
373
374 let cluster_a: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
376 let cluster_b: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
377 let cluster_c: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
378 let cluster_d: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
379
380 for &a_node in &cluster_a {
382 for &b_node in &cluster_b {
383 let _ = pg.add_edge(a_node, b_node, EdgeRelation::Causes, 0.5, Metadata::new());
384 }
385 }
386 for &c_node in &cluster_c {
387 for &d_node in &cluster_d {
388 let _ = pg.add_edge(c_node, d_node, EdgeRelation::Causes, 0.5, Metadata::new());
389 }
390 }
391 for &a_node in &cluster_a {
393 for &c_node in &cluster_c {
394 let _ = pg.add_edge(a_node, c_node, EdgeRelation::Causes, 0.5, Metadata::new());
395 }
396 }
397
398 let cfg = HebbianConfig {
399 learning_rate: 0.05,
400 decay_rate: 0.01,
401 ..Default::default()
402 };
403
404 for _ in 0..100 {
406 let ab_ids: Vec<MemoryId> = cluster_a.iter().chain(&cluster_b).copied().collect();
407 hebbian_update(&mut pg, &ab_ids, &cfg);
408
409 let cd_ids: Vec<MemoryId> = cluster_c.iter().chain(&cluster_d).copied().collect();
410 hebbian_update(&mut pg, &cd_ids, &cfg);
411 }
412
413 let edges_between_ab = pg.get_edges_between(cluster_a[0], cluster_b[0]);
415 assert!(
416 !edges_between_ab.is_empty(),
417 "cluster A↔B edges should exist"
418 );
419 let weight_ab = edges_between_ab[0].weight;
420 assert!(
421 weight_ab > 0.7,
422 "A↔B edges should be strong after co-retrieval: {weight_ab}"
423 );
424
425 let edges_between_ac = pg.get_edges_between(cluster_a[0], cluster_c[0]);
429 assert!(
430 !edges_between_ac.is_empty(),
431 "cluster A↔C edges should exist"
432 );
433 let weight_ac = edges_between_ac[0].weight;
434 assert!(
435 weight_ac < weight_ab,
436 "A↔C edges should be weaker than A↔B: ac={weight_ac}, ab={weight_ab}"
437 );
438 assert!(
439 weight_ac < 0.4,
440 "A↔C edges should have decayed from 0.5: {weight_ac}"
441 );
442 }
443
444 #[test]
445 fn no_new_edges_from_co_retrieval() {
446 let mut pg = PropertyGraph::new();
447 let a = make_node(&mut pg);
448 let b = make_node(&mut pg);
449 let result = hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
452 assert_eq!(result.strengthened, 0);
453 assert_eq!(result.decayed, 0);
454 assert_eq!(pg.edge_count(), 0, "no new edges created");
455 }
456
457 #[test]
458 fn update_result_counts() {
459 let mut pg = PropertyGraph::new();
460 let a = make_node(&mut pg);
461 let b = make_node(&mut pg);
462 let c = make_node(&mut pg);
463 pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
464 .unwrap();
465 pg.add_edge(a, c, EdgeRelation::Causes, 0.5, Metadata::new())
466 .unwrap();
467
468 let result = hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
470 assert_eq!(result.strengthened, 1, "A-B edge strengthened");
471 assert_eq!(result.decayed, 1, "A-C edge decayed (A retrieved, C not)");
472 }
473
474 #[test]
477 fn buffer_push_signals_threshold() {
478 let buf = HebbianBuffer::with_threshold(3);
479 assert!(!buf.push(vec![MemoryId::new()]));
480 assert!(!buf.push(vec![MemoryId::new()]));
481 assert!(
482 buf.push(vec![MemoryId::new()]),
483 "third push should signal flush"
484 );
485 assert_eq!(buf.pending_count(), 3);
486 }
487
488 #[test]
489 fn buffer_flush_applies_updates() {
490 let mut pg = PropertyGraph::new();
491 let a = make_node(&mut pg);
492 let b = make_node(&mut pg);
493 pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
494 .unwrap();
495
496 let initial_weight = pg.get_edges(a)[0].weight;
497
498 let buf = HebbianBuffer::with_threshold(100);
499 for _ in 0..10 {
500 buf.push(vec![a, b]);
501 }
502
503 let result = buf.flush(&mut pg, &HebbianConfig::default());
504 assert_eq!(result.strengthened, 10);
505 assert_eq!(buf.pending_count(), 0);
506
507 let final_weight = pg.get_edges(a)[0].weight;
508 assert!(
509 final_weight > initial_weight,
510 "flush should strengthen: initial={initial_weight}, final={final_weight}"
511 );
512 }
513
514 #[test]
515 fn buffer_flush_empty_is_noop() {
516 let mut pg = PropertyGraph::new();
517 let buf = HebbianBuffer::new();
518 let result = buf.flush(&mut pg, &HebbianConfig::default());
519 assert_eq!(result.strengthened, 0);
520 assert_eq!(result.decayed, 0);
521 }
522
523 #[test]
524 fn buffer_concurrent_push() {
525 use std::sync::Arc;
526 use std::thread;
527
528 let buf = Arc::new(HebbianBuffer::with_threshold(u64::MAX));
529
530 let handles: Vec<_> = (0..4)
531 .map(|_| {
532 let buf = Arc::clone(&buf);
533 thread::spawn(move || {
534 for _ in 0..250 {
535 buf.push(vec![MemoryId::new(), MemoryId::new()]);
536 }
537 })
538 })
539 .collect();
540
541 for h in handles {
542 h.join().unwrap();
543 }
544
545 assert_eq!(buf.pending_count(), 1000);
546
547 let mut pg = PropertyGraph::new();
549 let result = buf.flush(&mut pg, &HebbianConfig::default());
550 assert_eq!(result.strengthened, 0);
552 assert_eq!(result.decayed, 0);
553 assert_eq!(buf.pending_count(), 0);
554 }
555
556 #[test]
557 fn buffer_default_threshold_is_16() {
558 let buf = HebbianBuffer::new();
559 assert_eq!(buf.flush_threshold, DEFAULT_FLUSH_THRESHOLD);
560 assert_eq!(buf.flush_threshold, 16);
561 }
562}