use crate::error::Result;
use crate::graph::links;
use crate::types::*;
use rusqlite::Connection;
use std::collections::HashMap;
pub fn spread_activation(
conn: &Connection,
seeds: &[NodeRef],
max_depth: u32,
threshold: f32,
decay_per_hop: f32,
) -> Result<HashMap<NodeRef, f32>> {
let mut activation: HashMap<NodeRef, f32> = HashMap::new();
for seed in seeds {
*activation.entry(*seed).or_default() += 1.0;
}
for _ in 0..max_depth {
let mut delta: HashMap<NodeRef, f32> = HashMap::new();
for (node, &act) in &activation {
if act >= threshold {
let outgoing = links::get_links_from(conn, *node)?;
let total_weight: f32 = outgoing.iter().map(|l| l.forward_weight).sum();
if !outgoing.is_empty() && total_weight > 0.0 {
for link in &outgoing {
let spread = act * link.forward_weight * decay_per_hop;
if spread >= threshold * 0.1 {
*delta.entry(link.target).or_default() += spread;
}
}
}
}
}
for (node, extra) in delta {
let entry = activation.entry(node).or_default();
*entry = (*entry + extra).min(2.0); }
}
activation.retain(|_, v| *v >= threshold);
Ok(activation)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::links::create_link;
use crate::schema::open_memory_db;
#[test]
fn test_single_hop_spread() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
let c = NodeRef::Episode(EpisodeId(3));
create_link(&conn, a, b, LinkType::Topical, 0.8).unwrap();
create_link(&conn, a, c, LinkType::Topical, 0.2).unwrap();
let result = spread_activation(&conn, &[a], 1, 0.05, 0.7).unwrap();
assert!(result.contains_key(&a));
assert!(result.contains_key(&b));
assert!(result.get(&b).unwrap_or(&0.0) > result.get(&c).unwrap_or(&0.0));
}
#[test]
fn test_multi_hop_decay() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
let c = NodeRef::Episode(EpisodeId(3));
create_link(&conn, a, b, LinkType::Temporal, 0.9).unwrap();
create_link(&conn, b, c, LinkType::Temporal, 0.9).unwrap();
let result = spread_activation(&conn, &[a], 2, 0.05, 0.6).unwrap();
let act_b = result.get(&b).unwrap_or(&0.0);
let act_c = result.get(&c).unwrap_or(&0.0);
assert!(act_b > act_c, "b ({act_b}) should be > c ({act_c})");
assert!(
*act_c > 0.0,
"c should have nonzero activation from 2-hop spread"
);
}
#[test]
fn test_threshold_cutoff() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
create_link(&conn, a, b, LinkType::Topical, 0.01).unwrap();
let result = spread_activation(&conn, &[a], 1, 0.5, 0.6).unwrap();
assert!(!result.contains_key(&b));
}
#[test]
fn test_spread_activation_zero_weight_links() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
create_link(&conn, a, b, LinkType::Topical, 0.0).unwrap();
conn.execute(
"UPDATE links SET forward_weight = 0.0, backward_weight = 0.0",
[],
)
.unwrap();
let result = spread_activation(&conn, &[a], 1, 0.05, 0.6).unwrap();
assert!(
!result.contains_key(&b),
"zero-weight link should not spread activation"
);
}
#[test]
fn test_spread_activation_no_outgoing_links() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let result = spread_activation(&conn, &[a], 1, 0.05, 0.6).unwrap();
assert!(
result.contains_key(&a),
"seed should still be in activation"
);
assert_eq!(result.len(), 1, "should only contain the seed");
}
#[test]
fn test_spread_activation_empty_seeds() {
let conn = open_memory_db().unwrap();
let result = spread_activation(&conn, &[], 2, 0.05, 0.6).unwrap();
assert!(
result.is_empty(),
"empty seeds should produce empty activation"
);
}
#[test]
fn test_spread_activation_multiple_seeds() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
let result = spread_activation(&conn, &[a, b], 1, 0.05, 0.6).unwrap();
assert!(result.contains_key(&a), "seed a should be in activation");
assert!(result.contains_key(&b), "seed b should be in activation");
assert_eq!(
*result.get(&a).unwrap(),
1.0,
"seed a should start with activation 1.0"
);
assert_eq!(
*result.get(&b).unwrap(),
1.0,
"seed b should start with activation 1.0"
);
}
#[test]
fn test_spread_activation_duplicate_seeds_sum() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let result = spread_activation(&conn, &[a, a], 0, 0.05, 0.6).unwrap();
assert!(result.contains_key(&a));
assert!(
*result.get(&a).unwrap() <= 2.0,
"activation should be capped at 2.0"
);
}
#[test]
fn test_spread_activation_below_threshold_not_included() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
create_link(&conn, a, b, LinkType::Topical, 0.001).unwrap();
let result = spread_activation(&conn, &[a], 2, 0.5, 0.5).unwrap();
assert!(result.contains_key(&a));
assert!(
!result.contains_key(&b),
"very weak link should not spread above threshold"
);
}
#[test]
fn test_below_threshold_node_skipped_in_spread() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
let c = NodeRef::Episode(EpisodeId(3));
let d = NodeRef::Episode(EpisodeId(4));
create_link(&conn, a, b, LinkType::Topical, 0.9).unwrap();
create_link(&conn, a, c, LinkType::Topical, 0.01).unwrap();
conn.execute(
"UPDATE links SET forward_weight = 0.001 WHERE target_id = 3",
[],
)
.unwrap();
create_link(&conn, c, d, LinkType::Topical, 0.9).unwrap();
let result = spread_activation(&conn, &[a], 2, 0.1, 0.5).unwrap();
assert!(
!result.contains_key(&d),
"d should not be reached through below-threshold node c"
);
}
#[test]
fn test_spread_activation_zero_depth() {
let conn = open_memory_db().unwrap();
let a = NodeRef::Episode(EpisodeId(1));
let b = NodeRef::Episode(EpisodeId(2));
create_link(&conn, a, b, LinkType::Topical, 0.9).unwrap();
let result = spread_activation(&conn, &[a], 0, 0.05, 0.6).unwrap();
assert!(result.contains_key(&a));
assert!(
!result.contains_key(&b),
"zero depth: b should not receive activation"
);
}
}