Skip to main content

hirn_graph/
hebbian.rs

1//! Hebbian learning: co-retrieval strengthens edge weights, solo retrieval decays them.
2//!
3//! Implements CONCEPT.md §6.5:
4//! - Co-retrieval: `weight = min(1.0, weight + η × Δ)`
5//! - Solo retrieval: `weight = max(0.01, weight × (1 - λ_decay))`
6//!
7//! ## HebbianBuffer
8//!
9//! A lock-free buffer that collects co-retrieval pairs via [`crossbeam_queue::SegQueue`].
10//! Push operations never block; the flush operation drains the queue and applies
11//! batch weight updates to the graph.
12
13use std::sync::atomic::{AtomicU64, Ordering};
14
15use crossbeam_queue::SegQueue;
16
17use hirn_core::id::MemoryId;
18use hirn_core::timestamp::Timestamp;
19
20use crate::graph::PropertyGraph;
21
22/// Configuration for Hebbian weight updates.
23#[derive(Debug, Clone)]
24pub struct HebbianConfig {
25    /// Learning rate η (default 0.05). How much co-retrieval strengthens edges.
26    pub learning_rate: f64,
27    /// Decay rate `λ_decay` (default 0.01). How much solo retrieval weakens edges.
28    pub decay_rate: f64,
29    /// Minimum edge weight (default 0.01). Edges never decay below this.
30    pub min_weight: f32,
31}
32
33impl Default for HebbianConfig {
34    fn default() -> Self {
35        Self {
36            learning_rate: 0.05,
37            decay_rate: 0.01,
38            min_weight: 0.01,
39        }
40    }
41}
42
43/// Result of a Hebbian update step.
44#[derive(Debug, Clone)]
45pub struct HebbianUpdateResult {
46    /// Number of edges strengthened (co-retrieval).
47    pub strengthened: usize,
48    /// Number of edges decayed (solo retrieval).
49    pub decayed: usize,
50}
51
52/// Apply Hebbian learning updates to the graph based on co-retrieved node IDs.
53///
54/// - Edges between co-retrieved nodes are **strengthened**.
55/// - Edges from co-retrieved nodes to non-retrieved neighbors are **decayed**.
56pub fn hebbian_update(
57    graph: &mut PropertyGraph,
58    retrieved_ids: &[MemoryId],
59    config: &HebbianConfig,
60) -> HebbianUpdateResult {
61    let mut strengthened = 0;
62    let mut decayed = 0;
63
64    let retrieved_set: std::collections::HashSet<MemoryId> =
65        retrieved_ids.iter().copied().collect();
66
67    let now = Timestamp::now();
68
69    // Collect all edges for retrieved nodes (we need the IDs before mutating).
70    let mut co_retrieval_edges = Vec::new();
71    let mut decay_edges = Vec::new();
72
73    for &node_id in retrieved_ids {
74        for edge in graph.get_edges(node_id) {
75            let partner = if edge.source == node_id {
76                edge.target
77            } else {
78                edge.source
79            };
80
81            if retrieved_set.contains(&partner) {
82                // Both endpoints retrieved → co-retrieval.
83                co_retrieval_edges.push(edge.id);
84            } else {
85                // Only one endpoint retrieved → decay.
86                decay_edges.push(edge.id);
87            }
88        }
89    }
90
91    // Deduplicate (edges may be seen from both endpoints).
92    co_retrieval_edges.sort();
93    co_retrieval_edges.dedup();
94    decay_edges.sort();
95    decay_edges.dedup();
96
97    // Remove edges from decay list that are also in co-retrieval list.
98    // Use a HashSet for O(1) lookups instead of O(N) linear scan.
99    let co_retrieval_set: std::collections::HashSet<crate::graph::EdgeId> =
100        co_retrieval_edges.iter().copied().collect();
101    decay_edges.retain(|eid| !co_retrieval_set.contains(eid));
102
103    // Strengthen co-retrieved edges.
104    let eta = config.learning_rate;
105    for eid in co_retrieval_edges {
106        if let Some(edge) = graph.edge_mut(eid) {
107            let delta = 1.0; // Δ = 1.0 per co-retrieval event.
108            let new_weight = eta.mul_add(delta, f64::from(edge.weight)).min(1.0);
109            edge.weight = new_weight as f32;
110            edge.co_retrieval_count += 1;
111            edge.updated_at = now;
112            strengthened += 1;
113        }
114    }
115
116    // Decay solo-retrieved edges.
117    // F-35: Per-relation decay multipliers — causal/provenance edges decay
118    // slower than generic associations, reflecting their structural importance.
119    let base_lambda = config.decay_rate;
120    let min_w = config.min_weight;
121    for eid in decay_edges {
122        if let Some(edge) = graph.edge_mut(eid) {
123            let relation_multiplier = decay_multiplier_for_relation(edge.relation);
124            let lambda = base_lambda * relation_multiplier;
125            let new_weight = (f64::from(edge.weight) * (1.0 - lambda)).max(f64::from(min_w));
126            edge.weight = new_weight as f32;
127            edge.updated_at = now;
128            decayed += 1;
129        }
130    }
131
132    HebbianUpdateResult {
133        strengthened,
134        decayed,
135    }
136}
137
138/// F-35: Relation-type-specific decay multipliers.
139/// Structural/causal edges decay slower than generic associations.
140const fn decay_multiplier_for_relation(relation: hirn_core::types::EdgeRelation) -> f64 {
141    use hirn_core::types::EdgeRelation;
142    match relation {
143        // Causal and provenance edges are structurally important — decay very slowly.
144        EdgeRelation::Causes | EdgeRelation::CausedBy | EdgeRelation::DerivedFrom => 0.2,
145        // Temporal adjacency is important for episode chains — decay slowly.
146        EdgeRelation::TemporalNext => 0.3,
147        // Similarity edges are the backbone — moderate decay.
148        EdgeRelation::SimilarTo => 0.5,
149        // Contradiction edges should persist — very slow decay.
150        EdgeRelation::Contradicts => 0.1,
151        // Evidential/structural edges — slow decay.
152        EdgeRelation::Supports
153        | EdgeRelation::PartOf
154        | EdgeRelation::InstanceOf
155        | EdgeRelation::ParticipatesIn => 0.4,
156        // Inhibition edges — moderate decay.
157        EdgeRelation::Inhibits => 0.6,
158        // Generic associations — full decay rate.
159        EdgeRelation::RelatedTo => 1.0,
160    }
161}
162
163// ── Lock-free Hebbian buffer ─────────────────────────────────────────────
164
165/// Default flush threshold: every 16 recall operations.
166const DEFAULT_FLUSH_THRESHOLD: u64 = 16;
167
168/// Lock-free buffer for co-retrieval events.
169///
170/// Push operations use [`SegQueue`] and never block. The [`flush`](Self::flush)
171/// method drains the queue and applies all accumulated co-retrieval + decay
172/// updates to the graph in a single batch.
173pub struct HebbianBuffer {
174    queue: SegQueue<Vec<MemoryId>>,
175    push_count: AtomicU64,
176    flush_threshold: u64,
177}
178
179impl HebbianBuffer {
180    /// Create a new buffer with the default flush threshold (16).
181    #[must_use]
182    pub fn new() -> Self {
183        Self {
184            queue: SegQueue::new(),
185            push_count: AtomicU64::new(0),
186            flush_threshold: DEFAULT_FLUSH_THRESHOLD,
187        }
188    }
189
190    /// Create a new buffer with a custom flush threshold.
191    #[must_use]
192    pub fn with_threshold(threshold: u64) -> Self {
193        Self {
194            queue: SegQueue::new(),
195            push_count: AtomicU64::new(0),
196            flush_threshold: threshold,
197        }
198    }
199
200    /// Push a set of co-retrieved IDs into the buffer. Never blocks.
201    ///
202    /// Returns `true` if the push count has reached the flush threshold,
203    /// signaling that the caller should call [`flush`](Self::flush).
204    pub fn push(&self, retrieved_ids: Vec<MemoryId>) -> bool {
205        self.queue.push(retrieved_ids);
206        let count = self.push_count.fetch_add(1, Ordering::Relaxed) + 1;
207        count >= self.flush_threshold
208    }
209
210    /// Drain all buffered events and apply Hebbian updates to the graph.
211    ///
212    /// Returns the aggregate update result. Resets the push counter.
213    pub fn flush(&self, graph: &mut PropertyGraph, config: &HebbianConfig) -> HebbianUpdateResult {
214        self.push_count.store(0, Ordering::Relaxed);
215
216        let mut total = HebbianUpdateResult {
217            strengthened: 0,
218            decayed: 0,
219        };
220
221        while let Some(ids) = self.queue.pop() {
222            let result = hebbian_update(graph, &ids, config);
223            total.strengthened += result.strengthened;
224            total.decayed += result.decayed;
225        }
226
227        total
228    }
229
230    /// Number of pushes since last flush. Approximate under concurrency.
231    pub fn pending_count(&self) -> u64 {
232        self.push_count.load(Ordering::Relaxed)
233    }
234
235    /// Pop a single event from the queue, for callers that drain manually.
236    pub fn pop(&self) -> Option<Vec<MemoryId>> {
237        self.queue.pop()
238    }
239
240    /// Reset the push counter to zero (e.g. before manual drain).
241    pub fn reset_counter(&self) {
242        self.push_count.store(0, Ordering::Relaxed);
243    }
244}
245
246impl Default for HebbianBuffer {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252// ── Tests ────────────────────────────────────────────────────────────────
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use hirn_core::metadata::Metadata;
258    use hirn_core::timestamp::Timestamp;
259    use hirn_core::types::{EdgeRelation, Layer};
260
261    fn make_node(pg: &mut PropertyGraph) -> MemoryId {
262        let id = MemoryId::new();
263        pg.add_node(id, Layer::Episodic, 0.5, Timestamp::now());
264        id
265    }
266
267    #[test]
268    fn co_retrieval_strengthens_edge() {
269        let mut pg = PropertyGraph::new();
270        let a = make_node(&mut pg);
271        let b = make_node(&mut pg);
272        pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
273            .unwrap();
274
275        let initial_weight = pg.get_edges(a)[0].weight;
276
277        for _ in 0..10 {
278            hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
279        }
280
281        let final_weight = pg.get_edges(a)[0].weight;
282        assert!(
283            final_weight > initial_weight,
284            "co-retrieval should strengthen: initial={initial_weight}, final={final_weight}"
285        );
286    }
287
288    #[test]
289    fn solo_retrieval_decays_edge() {
290        let mut pg = PropertyGraph::new();
291        let a = make_node(&mut pg);
292        let b = make_node(&mut pg);
293        pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
294            .unwrap();
295
296        let initial_weight = pg.get_edges(a)[0].weight;
297
298        // Retrieve A alone 100 times.
299        for _ in 0..100 {
300            hebbian_update(&mut pg, &[a], &HebbianConfig::default());
301        }
302
303        let final_weight = pg.get_edges(a)[0].weight;
304        assert!(
305            final_weight < initial_weight,
306            "solo retrieval should decay: initial={initial_weight}, final={final_weight}"
307        );
308    }
309
310    #[test]
311    fn co_retrieval_count_incremented() {
312        let mut pg = PropertyGraph::new();
313        let a = make_node(&mut pg);
314        let b = make_node(&mut pg);
315        pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
316            .unwrap();
317
318        hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
319        hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
320        hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
321
322        let count = pg.get_edges(a)[0].co_retrieval_count;
323        assert_eq!(count, 3, "co_retrieval_count should be 3, got {count}");
324    }
325
326    #[test]
327    fn weight_never_exceeds_one() {
328        let mut pg = PropertyGraph::new();
329        let a = make_node(&mut pg);
330        let b = make_node(&mut pg);
331        pg.add_edge(a, b, EdgeRelation::Causes, 0.95, Metadata::new())
332            .unwrap();
333
334        let cfg = HebbianConfig {
335            learning_rate: 0.5, // Aggressive.
336            ..Default::default()
337        };
338
339        for _ in 0..1000 {
340            hebbian_update(&mut pg, &[a, b], &cfg);
341        }
342
343        let w = pg.get_edges(a)[0].weight;
344        assert!(w <= 1.0, "weight exceeded 1.0: {w}");
345    }
346
347    #[test]
348    fn weight_never_below_min() {
349        let mut pg = PropertyGraph::new();
350        let a = make_node(&mut pg);
351        let b = make_node(&mut pg);
352        pg.add_edge(a, b, EdgeRelation::Causes, 0.1, Metadata::new())
353            .unwrap();
354
355        let cfg = HebbianConfig {
356            decay_rate: 0.5, // Aggressive decay.
357            min_weight: 0.01,
358            ..Default::default()
359        };
360
361        for _ in 0..1000 {
362            hebbian_update(&mut pg, &[a], &cfg);
363        }
364
365        let w = pg.get_edges(a)[0].weight;
366        assert!(w >= 0.01, "weight fell below min_weight 0.01: {w}");
367    }
368
369    #[test]
370    #[allow(clippy::similar_names)]
371    fn self_organizing_clusters() {
372        let mut pg = PropertyGraph::new();
373
374        // Create 4 clusters of 3 nodes each.
375        let cluster_a: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
376        let cluster_b: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
377        let cluster_c: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
378        let cluster_d: Vec<MemoryId> = (0..3).map(|_| make_node(&mut pg)).collect();
379
380        // Cross-cluster edges (initial weight 0.5).
381        for &a_node in &cluster_a {
382            for &b_node in &cluster_b {
383                let _ = pg.add_edge(a_node, b_node, EdgeRelation::Causes, 0.5, Metadata::new());
384            }
385        }
386        for &c_node in &cluster_c {
387            for &d_node in &cluster_d {
388                let _ = pg.add_edge(c_node, d_node, EdgeRelation::Causes, 0.5, Metadata::new());
389            }
390        }
391        // Cross-group edges (A↔C).
392        for &a_node in &cluster_a {
393            for &c_node in &cluster_c {
394                let _ = pg.add_edge(a_node, c_node, EdgeRelation::Causes, 0.5, Metadata::new());
395            }
396        }
397
398        let cfg = HebbianConfig {
399            learning_rate: 0.05,
400            decay_rate: 0.01,
401            ..Default::default()
402        };
403
404        // Run 100 queries: co-retrieve within {A,B} and {C,D}.
405        for _ in 0..100 {
406            let ab_ids: Vec<MemoryId> = cluster_a.iter().chain(&cluster_b).copied().collect();
407            hebbian_update(&mut pg, &ab_ids, &cfg);
408
409            let cd_ids: Vec<MemoryId> = cluster_c.iter().chain(&cluster_d).copied().collect();
410            hebbian_update(&mut pg, &cd_ids, &cfg);
411        }
412
413        // Check: A↔B edges should be strong.
414        let edges_between_ab = pg.get_edges_between(cluster_a[0], cluster_b[0]);
415        assert!(
416            !edges_between_ab.is_empty(),
417            "cluster A↔B edges should exist"
418        );
419        let weight_ab = edges_between_ab[0].weight;
420        assert!(
421            weight_ab > 0.7,
422            "A↔B edges should be strong after co-retrieval: {weight_ab}"
423        );
424
425        // Check: A↔C edges should be weaker than AB (only decayed, never co-retrieved).
426        // F-35: Causes edges decay at 0.2× base rate, so after 200 decay events
427        // from start=0.5: 0.5 * (1 - 0.01*0.2)^200 ≈ 0.335
428        let edges_between_ac = pg.get_edges_between(cluster_a[0], cluster_c[0]);
429        assert!(
430            !edges_between_ac.is_empty(),
431            "cluster A↔C edges should exist"
432        );
433        let weight_ac = edges_between_ac[0].weight;
434        assert!(
435            weight_ac < weight_ab,
436            "A↔C edges should be weaker than A↔B: ac={weight_ac}, ab={weight_ab}"
437        );
438        assert!(
439            weight_ac < 0.4,
440            "A↔C edges should have decayed from 0.5: {weight_ac}"
441        );
442    }
443
444    #[test]
445    fn no_new_edges_from_co_retrieval() {
446        let mut pg = PropertyGraph::new();
447        let a = make_node(&mut pg);
448        let b = make_node(&mut pg);
449        // No edge between A and B.
450
451        let result = hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
452        assert_eq!(result.strengthened, 0);
453        assert_eq!(result.decayed, 0);
454        assert_eq!(pg.edge_count(), 0, "no new edges created");
455    }
456
457    #[test]
458    fn update_result_counts() {
459        let mut pg = PropertyGraph::new();
460        let a = make_node(&mut pg);
461        let b = make_node(&mut pg);
462        let c = make_node(&mut pg);
463        pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
464            .unwrap();
465        pg.add_edge(a, c, EdgeRelation::Causes, 0.5, Metadata::new())
466            .unwrap();
467
468        // Co-retrieve A and B (not C).
469        let result = hebbian_update(&mut pg, &[a, b], &HebbianConfig::default());
470        assert_eq!(result.strengthened, 1, "A-B edge strengthened");
471        assert_eq!(result.decayed, 1, "A-C edge decayed (A retrieved, C not)");
472    }
473
474    // ── HebbianBuffer tests ──────────────────────────────────────────
475
476    #[test]
477    fn buffer_push_signals_threshold() {
478        let buf = HebbianBuffer::with_threshold(3);
479        assert!(!buf.push(vec![MemoryId::new()]));
480        assert!(!buf.push(vec![MemoryId::new()]));
481        assert!(
482            buf.push(vec![MemoryId::new()]),
483            "third push should signal flush"
484        );
485        assert_eq!(buf.pending_count(), 3);
486    }
487
488    #[test]
489    fn buffer_flush_applies_updates() {
490        let mut pg = PropertyGraph::new();
491        let a = make_node(&mut pg);
492        let b = make_node(&mut pg);
493        pg.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
494            .unwrap();
495
496        let initial_weight = pg.get_edges(a)[0].weight;
497
498        let buf = HebbianBuffer::with_threshold(100);
499        for _ in 0..10 {
500            buf.push(vec![a, b]);
501        }
502
503        let result = buf.flush(&mut pg, &HebbianConfig::default());
504        assert_eq!(result.strengthened, 10);
505        assert_eq!(buf.pending_count(), 0);
506
507        let final_weight = pg.get_edges(a)[0].weight;
508        assert!(
509            final_weight > initial_weight,
510            "flush should strengthen: initial={initial_weight}, final={final_weight}"
511        );
512    }
513
514    #[test]
515    fn buffer_flush_empty_is_noop() {
516        let mut pg = PropertyGraph::new();
517        let buf = HebbianBuffer::new();
518        let result = buf.flush(&mut pg, &HebbianConfig::default());
519        assert_eq!(result.strengthened, 0);
520        assert_eq!(result.decayed, 0);
521    }
522
523    #[test]
524    fn buffer_concurrent_push() {
525        use std::sync::Arc;
526        use std::thread;
527
528        let buf = Arc::new(HebbianBuffer::with_threshold(u64::MAX));
529
530        let handles: Vec<_> = (0..4)
531            .map(|_| {
532                let buf = Arc::clone(&buf);
533                thread::spawn(move || {
534                    for _ in 0..250 {
535                        buf.push(vec![MemoryId::new(), MemoryId::new()]);
536                    }
537                })
538            })
539            .collect();
540
541        for h in handles {
542            h.join().unwrap();
543        }
544
545        assert_eq!(buf.pending_count(), 1000);
546
547        // Drain and count.
548        let mut pg = PropertyGraph::new();
549        let result = buf.flush(&mut pg, &HebbianConfig::default());
550        // No edges in graph → nothing to strengthen or decay.
551        assert_eq!(result.strengthened, 0);
552        assert_eq!(result.decayed, 0);
553        assert_eq!(buf.pending_count(), 0);
554    }
555
556    #[test]
557    fn buffer_default_threshold_is_16() {
558        let buf = HebbianBuffer::new();
559        assert_eq!(buf.flush_threshold, DEFAULT_FLUSH_THRESHOLD);
560        assert_eq!(buf.flush_threshold, 16);
561    }
562}