use std::collections::{HashMap, HashSet, VecDeque};
use crate::core::types::EntityId;
#[derive(Debug, Clone)]
pub struct SpreadingParams {
pub s_max: f64,
pub w_total: f64,
pub max_depth: u8,
pub cutoff: f64,
}
impl Default for SpreadingParams {
fn default() -> Self {
Self {
s_max: 1.6,
w_total: 1.0,
max_depth: 3,
cutoff: 0.01,
}
}
}
pub fn spread_activation<F>(
sources: &[(EntityId, f64)],
get_neighbors: F,
params: &SpreadingParams,
) -> HashMap<EntityId, f64>
where
F: Fn(EntityId) -> Vec<EntityId>,
{
let mut activations: HashMap<EntityId, f64> = HashMap::new();
let mut visited: HashSet<EntityId> = HashSet::new();
let mut queue: VecDeque<(EntityId, f64, u8)> = VecDeque::new();
let n_sources = sources.len().max(1) as f64;
for &(id, weight) in sources {
*activations.entry(id).or_default() += weight;
if visited.insert(id) {
queue.push_back((id, weight, 0));
}
}
while let Some((node, incoming, depth)) = queue.pop_front() {
if depth >= params.max_depth {
continue;
}
let neighbors = get_neighbors(node);
let fan = neighbors.len();
if fan == 0 {
continue;
}
let s_ji = params.s_max - (fan as f64).ln();
let w_j = params.w_total / n_sources;
let outgoing = w_j * s_ji * incoming.signum();
if outgoing.abs() < params.cutoff {
continue;
}
for neighbor in neighbors {
*activations.entry(neighbor).or_default() += outgoing;
if visited.insert(neighbor) {
queue.push_back((neighbor, outgoing, depth + 1));
}
}
}
activations
}
#[cfg(test)]
mod tests {
use super::*;
fn make_graph(edges: &[(u64, u64)]) -> impl Fn(EntityId) -> Vec<EntityId> + '_ {
move |id: EntityId| {
edges
.iter()
.filter_map(|&(a, b)| {
if a == id.0 {
Some(EntityId(b))
} else if b == id.0 {
Some(EntityId(a))
} else {
None
}
})
.collect()
}
}
#[test]
fn test_simple_spread_a_to_b() {
let graph = make_graph(&[(1, 2)]);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams::default();
let result = spread_activation(&sources, graph, ¶ms);
assert!(result.contains_key(&EntityId(2)));
let b_act = result[&EntityId(2)];
assert!(
b_act > 0.0,
"B should have positive activation, got {b_act}"
);
}
#[test]
fn test_fan_effect_positive() {
let graph = make_graph(&[(1, 2), (1, 3), (1, 4)]);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams::default();
let result = spread_activation(&sources, graph, ¶ms);
assert!(result[&EntityId(2)] > 0.0);
assert!(result[&EntityId(3)] > 0.0);
assert!(result[&EntityId(4)] > 0.0);
}
#[test]
fn test_fan_effect_inhibition() {
let edges: Vec<(u64, u64)> = (2..=11).map(|i| (1, i)).collect();
let graph = make_graph(&edges);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams::default();
let result = spread_activation(&sources, graph, ¶ms);
for i in 2..=11 {
let act = result[&EntityId(i)];
assert!(
act < 0.0,
"Entity {i} should have negative activation (inhibition), got {act}"
);
}
}
#[test]
fn test_cycle_no_infinite_loop() {
let graph = make_graph(&[(1, 2)]);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams::default();
let result = spread_activation(&sources, graph, ¶ms);
assert!(result.contains_key(&EntityId(1)));
assert!(result.contains_key(&EntityId(2)));
}
#[test]
fn test_depth_limit() {
let graph = make_graph(&[(1, 2), (2, 3), (3, 4)]);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams {
max_depth: 2,
..Default::default()
};
let result = spread_activation(&sources, graph, ¶ms);
assert!(result.contains_key(&EntityId(1)));
assert!(result.contains_key(&EntityId(2)));
assert!(result.contains_key(&EntityId(3)));
let d_act = result.get(&EntityId(4)).copied().unwrap_or(0.0);
assert!(
d_act.abs() < f64::EPSILON,
"D should have no activation at depth 2, got {d_act}"
);
}
#[test]
fn test_no_neighbors_stops() {
let graph = make_graph(&[]);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams::default();
let result = spread_activation(&sources, graph, ¶ms);
assert_eq!(result.len(), 1);
assert!((result[&EntityId(1)] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_multiple_sources() {
let graph = make_graph(&[(1, 3), (2, 3)]);
let sources = vec![(EntityId(1), 1.0), (EntityId(2), 1.0)];
let params = SpreadingParams::default();
let result = spread_activation(&sources, graph, ¶ms);
let c_act = result[&EntityId(3)];
assert!(
c_act > 0.0,
"C should have positive activation from 2 sources"
);
}
#[test]
fn test_cutoff_stops_weak_signals() {
let graph = make_graph(&[(1, 2), (2, 3), (3, 4)]);
let sources = vec![(EntityId(1), 1.0)];
let params = SpreadingParams {
cutoff: 10.0, ..Default::default()
};
let result = spread_activation(&sources, graph, ¶ms);
assert_eq!(result.len(), 1);
}
}