use std::collections::{HashMap, HashSet};
use hirn_core::HirnResult;
use hirn_core::id::MemoryId;
use hirn_core::timestamp::Timestamp;
use hirn_core::types::EdgeRelation;
use hirn_graph::graph::EdgeId;
use hirn_graph::hebbian::{HebbianConfig, HebbianUpdateResult};
use crate::persistent_graph::PersistentGraph;
fn decay_multiplier_for_relation(relation: &EdgeRelation) -> f64 {
match relation {
EdgeRelation::Causes | EdgeRelation::CausedBy | EdgeRelation::DerivedFrom => 0.2,
EdgeRelation::TemporalNext => 0.3,
EdgeRelation::SimilarTo => 0.5,
EdgeRelation::Contradicts => 0.1,
EdgeRelation::Supports | EdgeRelation::PartOf | EdgeRelation::InstanceOf => 0.4,
EdgeRelation::Inhibits => 0.6,
EdgeRelation::ParticipatesIn => 0.4,
EdgeRelation::RelatedTo => 1.0,
}
}
pub async fn hebbian_update(
graph: &PersistentGraph,
retrieved_ids: &[MemoryId],
config: &HebbianConfig,
) -> HirnResult<HebbianUpdateResult> {
hebbian_update_batch(graph, &[retrieved_ids.to_vec()], config).await
}
pub async fn hebbian_update_batch(
graph: &PersistentGraph,
events: &[Vec<MemoryId>],
config: &HebbianConfig,
) -> HirnResult<HebbianUpdateResult> {
if events.is_empty() {
return Ok(HebbianUpdateResult {
strengthened: 0,
decayed: 0,
});
}
let unique_node_ids: Vec<MemoryId> = events
.iter()
.flat_map(|ids| ids.iter().copied())
.collect::<HashSet<_>>()
.into_iter()
.collect();
let incident_edges = graph.get_edges_for_nodes(&unique_node_ids).await?;
if incident_edges.is_empty() {
return Ok(HebbianUpdateResult {
strengthened: 0,
decayed: 0,
});
}
let mut edges_by_id = HashMap::with_capacity(incident_edges.len());
let mut adjacency = HashMap::<MemoryId, Vec<EdgeId>>::new();
for edge in incident_edges {
let edge_id = edge.id;
adjacency.entry(edge.source).or_default().push(edge_id);
if edge.target != edge.source {
adjacency.entry(edge.target).or_default().push(edge_id);
}
edges_by_id.insert(edge_id, edge);
}
let mut strengthened = 0;
let mut decayed = 0;
let mut touched_edge_ids = HashSet::new();
let eta = config.learning_rate;
let min_w = config.min_weight;
for retrieved_ids in events {
let retrieved_set: HashSet<MemoryId> = retrieved_ids.iter().copied().collect();
let mut co_retrieval_edges = Vec::new();
let mut decay_edges = Vec::new();
for &node_id in retrieved_ids {
let Some(edge_ids) = adjacency.get(&node_id) else {
continue;
};
for &edge_id in edge_ids {
let Some(edge) = edges_by_id.get(&edge_id) else {
continue;
};
let partner = if edge.source == node_id {
edge.target
} else {
edge.source
};
if retrieved_set.contains(&partner) {
co_retrieval_edges.push(edge_id);
} else {
decay_edges.push(edge_id);
}
}
}
co_retrieval_edges.sort();
co_retrieval_edges.dedup();
decay_edges.sort();
decay_edges.dedup();
let co_retrieval_set: HashSet<EdgeId> = co_retrieval_edges.iter().copied().collect();
decay_edges.retain(|edge_id| !co_retrieval_set.contains(edge_id));
let updated_at = Timestamp::now();
for edge_id in co_retrieval_edges {
let Some(edge) = edges_by_id.get_mut(&edge_id) else {
continue;
};
edge.weight = (edge.weight as f64 + eta).min(1.0) as f32;
edge.co_retrieval_count += 1;
edge.updated_at = updated_at;
touched_edge_ids.insert(edge_id);
strengthened += 1;
}
for edge_id in decay_edges {
let Some(edge) = edges_by_id.get_mut(&edge_id) else {
continue;
};
let relation_multiplier = decay_multiplier_for_relation(&edge.relation);
let lambda = config.decay_rate * relation_multiplier;
edge.weight = (edge.weight as f64 * (1.0 - lambda)).max(min_w as f64) as f32;
edge.updated_at = updated_at;
touched_edge_ids.insert(edge_id);
decayed += 1;
}
}
if touched_edge_ids.is_empty() {
return Ok(HebbianUpdateResult {
strengthened,
decayed,
});
}
let mut updated_edges = Vec::with_capacity(touched_edge_ids.len());
for edge_id in touched_edge_ids {
if let Some(edge) = edges_by_id.remove(&edge_id) {
updated_edges.push(edge);
}
}
graph.upsert_edges(&updated_edges).await?;
Ok(HebbianUpdateResult {
strengthened,
decayed,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use hirn_core::metadata::Metadata;
use hirn_core::timestamp::Timestamp;
use hirn_core::types::{Layer, Namespace};
use hirn_graph::graph::EdgeId;
use hirn_storage::PhysicalStore;
async fn temp_graph() -> (PersistentGraph, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let lance_path = dir.path().join("lance_hebb");
let config = hirn_storage::HirnDbConfig::local(lance_path.to_str().unwrap());
let backend = hirn_storage::HirnDb::open(config.clone()).await.unwrap();
let storage: Arc<dyn PhysicalStore> = backend.store_arc();
let pg = PersistentGraph::open(storage).await.unwrap();
(pg, dir)
}
fn ns() -> Namespace {
Namespace::shared()
}
async fn setup_triangle(
pg: &PersistentGraph,
) -> (MemoryId, MemoryId, MemoryId, EdgeId, EdgeId, EdgeId) {
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
for id in [a, b, c] {
pg.add_node(id, Layer::Episodic, 0.5, Timestamp::now(), ns())
.await
.unwrap();
}
let e_ab = pg
.add_edge(a, b, EdgeRelation::SimilarTo, 0.5, Metadata::new())
.await
.unwrap();
let e_bc = pg
.add_edge(b, c, EdgeRelation::SimilarTo, 0.5, Metadata::new())
.await
.unwrap();
let e_ac = pg
.add_edge(a, c, EdgeRelation::RelatedTo, 0.5, Metadata::new())
.await
.unwrap();
(a, b, c, e_ab, e_bc, e_ac)
}
#[tokio::test]
async fn co_retrieval_strengthens_edges() {
let (pg, _dir) = temp_graph().await;
let (a, b, _c, e_ab, _e_bc, _e_ac) = setup_triangle(&pg).await;
let cfg = HebbianConfig::default();
let result = hebbian_update(&pg, &[a, b], &cfg).await.unwrap();
assert!(result.strengthened > 0);
let edge = pg.get_edge(e_ab).await.unwrap().unwrap();
assert!(
edge.weight > 0.5,
"edge should be strengthened: {}",
edge.weight
);
assert_eq!(edge.co_retrieval_count, 1);
}
#[tokio::test]
async fn solo_retrieval_decays_edges() {
let (pg, _dir) = temp_graph().await;
let (a, _b, _c, _e_ab, _e_bc, e_ac) = setup_triangle(&pg).await;
let cfg = HebbianConfig::default();
let result = hebbian_update(&pg, &[a], &cfg).await.unwrap();
assert!(result.decayed > 0);
let edge = pg.get_edge(e_ac).await.unwrap().unwrap();
assert!(edge.weight < 0.5, "edge should be decayed: {}", edge.weight);
}
#[tokio::test]
async fn weight_stays_in_bounds() {
let (pg, _dir) = temp_graph().await;
let (a, b, _c, e_ab, _, _) = setup_triangle(&pg).await;
let cfg = HebbianConfig {
learning_rate: 0.5,
..Default::default()
};
for _ in 0..20 {
hebbian_update(&pg, &[a, b], &cfg).await.unwrap();
}
let edge = pg.get_edge(e_ab).await.unwrap().unwrap();
assert!(edge.weight <= 1.0);
}
#[tokio::test]
async fn co_retrieval_count_accurate() {
let (pg, _dir) = temp_graph().await;
let (a, b, _c, e_ab, _, _) = setup_triangle(&pg).await;
let cfg = HebbianConfig::default();
for _ in 0..5 {
hebbian_update(&pg, &[a, b], &cfg).await.unwrap();
}
let edge = pg.get_edge(e_ab).await.unwrap().unwrap();
assert_eq!(edge.co_retrieval_count, 5);
}
#[tokio::test]
async fn batched_events_match_sequential_updates() {
let (pg_seq, _dir_seq) = temp_graph().await;
let (a_seq, b_seq, c_seq, e_ab_seq, e_bc_seq, e_ac_seq) = setup_triangle(&pg_seq).await;
let (pg_batch, _dir_batch) = temp_graph().await;
let (a_batch, b_batch, c_batch, e_ab_batch, e_bc_batch, e_ac_batch) =
setup_triangle(&pg_batch).await;
let cfg = HebbianConfig::default();
for event in &[vec![a_seq, b_seq], vec![a_seq], vec![b_seq, c_seq]] {
hebbian_update(&pg_seq, event, &cfg).await.unwrap();
}
hebbian_update_batch(
&pg_batch,
&[
vec![a_batch, b_batch],
vec![a_batch],
vec![b_batch, c_batch],
],
&cfg,
)
.await
.unwrap();
for (seq_id, batch_id) in [
(e_ab_seq, e_ab_batch),
(e_bc_seq, e_bc_batch),
(e_ac_seq, e_ac_batch),
] {
let seq_edge = pg_seq.get_edge(seq_id).await.unwrap().unwrap();
let batch_edge = pg_batch.get_edge(batch_id).await.unwrap().unwrap();
assert_eq!(
(seq_edge.weight, seq_edge.co_retrieval_count),
(batch_edge.weight, batch_edge.co_retrieval_count)
);
}
}
#[tokio::test]
async fn relation_specific_decay_rates() {
let (pg, _dir) = temp_graph().await;
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
for id in [a, b, c] {
pg.add_node(id, Layer::Episodic, 0.5, Timestamp::now(), ns())
.await
.unwrap();
}
let e_causal = pg
.add_edge(a, b, EdgeRelation::Causes, 0.5, Metadata::new())
.await
.unwrap();
let e_generic = pg
.add_edge(a, c, EdgeRelation::RelatedTo, 0.5, Metadata::new())
.await
.unwrap();
let cfg = HebbianConfig {
decay_rate: 0.1,
..Default::default()
};
hebbian_update(&pg, &[a], &cfg).await.unwrap();
let causal = pg.get_edge(e_causal).await.unwrap().unwrap();
let generic = pg.get_edge(e_generic).await.unwrap().unwrap();
assert!(
causal.weight > generic.weight,
"causal {} should decay slower than generic {}",
causal.weight,
generic.weight
);
}
}