use std::collections::HashMap;
use std::sync::OnceLock;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
#[allow(unused_imports)]
use zeph_db::sql;
use crate::embedding_store::EmbeddingStore;
use crate::error::MemoryError;
use crate::graph::store::GraphStore;
use crate::graph::types::{Edge, EdgeType, edge_type_weight, evolved_weight};
#[derive(Debug, Clone)]
pub struct ActivatedNode {
pub entity_id: i64,
pub activation: f32,
pub depth: u32,
}
#[derive(Debug, Clone)]
pub struct ActivatedFact {
pub edge: Edge,
pub activation_score: f32,
}
#[derive(Debug, Clone)]
pub struct SpreadingActivationParams {
pub decay_lambda: f32,
pub max_hops: u32,
pub activation_threshold: f32,
pub inhibition_threshold: f32,
pub max_activated_nodes: usize,
pub temporal_decay_rate: f64,
pub seed_structural_weight: f32,
pub seed_community_cap: usize,
}
#[derive(Debug, Clone)]
pub struct HelaFact {
pub edge: Edge,
pub score: f32,
pub depth: u32,
pub path_weight: f32,
pub cosine: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct HelaSpreadParams {
pub spread_depth: u32,
pub edge_types: Vec<EdgeType>,
pub max_visited: usize,
pub step_budget: Option<std::time::Duration>,
}
impl Default for HelaSpreadParams {
fn default() -> Self {
Self {
spread_depth: 2,
edge_types: Vec::new(),
max_visited: 200,
step_budget: Some(std::time::Duration::from_millis(8)),
}
}
}
static HELA_DIM_MISMATCH: OnceLock<String> = OnceLock::new();
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let denom = (norm_a * norm_b).max(f32::EPSILON);
dot / denom
}
#[tracing::instrument(
name = "memory.graph.hela_spread",
skip_all,
fields(
depth = params.spread_depth,
limit,
anchor_id = tracing::field::Empty,
visited = tracing::field::Empty,
scored = tracing::field::Empty,
fallback = tracing::field::Empty,
)
)]
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
pub async fn hela_spreading_recall(
store: &GraphStore,
embeddings: &EmbeddingStore,
provider: &zeph_llm::any::AnyProvider,
query: &str,
limit: usize,
params: &HelaSpreadParams,
hebbian_enabled: bool,
hebbian_lr: f32,
) -> Result<Vec<HelaFact>, MemoryError> {
use zeph_llm::LlmProvider as _;
const ENTITY_COLLECTION: &str = "zeph_graph_entities";
if limit == 0 {
return Ok(Vec::new());
}
if HELA_DIM_MISMATCH.get().map(String::as_str) == Some(ENTITY_COLLECTION) {
tracing::debug!("hela: dim mismatch previously detected for collection, skipping");
return Ok(Vec::new());
}
let q_vec = provider.embed(query).await?;
let t_anchor = Instant::now();
let anchor_results = match embeddings
.search_collection(ENTITY_COLLECTION, &q_vec, 1, None)
.await
{
Ok(r) => r,
Err(e) => {
let msg = e.to_string();
if msg.contains("wrong vector dimension")
|| msg.contains("InvalidArgument")
|| msg.contains("dimension")
{
let _ = HELA_DIM_MISMATCH.set(ENTITY_COLLECTION.to_owned());
tracing::warn!(
collection = ENTITY_COLLECTION,
error = %e,
"hela: vector dimension mismatch — HL-F5 disabled for this collection"
);
return Ok(Vec::new());
}
return Err(e);
}
};
if params.step_budget.is_some_and(|b| t_anchor.elapsed() > b) {
tracing::warn!(
elapsed_ms = t_anchor.elapsed().as_millis(),
"hela: anchor ANN over budget"
);
return Ok(Vec::new());
}
let Some(anchor_point) = anchor_results.first() else {
tracing::debug!("hela: no anchor found, returning empty");
return Ok(Vec::new());
};
let Some(anchor_entity_id) = anchor_point
.payload
.get("entity_id")
.and_then(serde_json::Value::as_i64)
else {
tracing::warn!("hela: anchor point missing entity_id payload");
return Ok(Vec::new());
};
let anchor_cosine = anchor_point.score;
tracing::Span::current().record("anchor_id", anchor_entity_id);
tracing::debug!(anchor_entity_id, anchor_cosine, "hela: anchor resolved");
let spread_depth = params.spread_depth.clamp(1, 6);
let mut visited: HashMap<i64, (u32, f32, Option<i64>)> = HashMap::new();
visited.insert(anchor_entity_id, (0, 1.0, None));
let mut edge_cache: HashMap<i64, Edge> = HashMap::new();
let mut frontier: Vec<i64> = vec![anchor_entity_id];
for hop in 0..spread_depth {
if frontier.is_empty() {
break;
}
tracing::debug!(hop, frontier_size = frontier.len(), "hela: starting hop");
let t_step = Instant::now();
let edges = store
.edges_for_entities(&frontier, ¶ms.edge_types)
.await?;
if params.step_budget.is_some_and(|b| t_step.elapsed() > b) {
tracing::warn!(
hop,
elapsed_ms = t_step.elapsed().as_millis(),
"hela: edge-fetch over budget"
);
return Ok(Vec::new());
}
let mut next_frontier: Vec<i64> = Vec::new();
for edge in &edges {
edge_cache.entry(edge.id).or_insert_with(|| edge.clone());
for &src_id in &frontier {
let neighbor = if edge.source_entity_id == src_id {
edge.target_entity_id
} else if edge.target_entity_id == src_id {
edge.source_entity_id
} else {
continue;
};
let parent_pw = visited.get(&src_id).map_or(1.0, |&(_, pw, _)| pw);
let new_pw = parent_pw * edge.weight;
let entry = visited
.entry(neighbor)
.or_insert((hop + 1, 0.0_f32, Some(edge.id)));
if new_pw > entry.1
|| ((new_pw - entry.1).abs() < f32::EPSILON && hop + 1 < entry.0)
{
*entry = (hop + 1, new_pw, Some(edge.id));
if !next_frontier.contains(&neighbor) {
next_frontier.push(neighbor);
}
}
if visited.len() >= params.max_visited {
break;
}
}
if visited.len() >= params.max_visited {
break;
}
}
tracing::debug!(
hop,
edges_fetched = edges.len(),
visited = visited.len(),
next_frontier = next_frontier.len(),
"hela: hop complete"
);
frontier = next_frontier;
if visited.len() >= params.max_visited {
break;
}
}
if visited.len() == 1 {
tracing::Span::current().record("fallback", true);
tracing::debug!(
anchor_entity_id,
anchor_cosine,
"hela: anchor isolated, falling back to pure ANN"
);
let fact = HelaFact {
edge: Edge::synthetic_anchor(anchor_entity_id),
score: anchor_cosine,
depth: 0,
path_weight: 1.0,
cosine: Some(anchor_cosine.clamp(0.0, 1.0)),
};
return Ok(vec![fact]);
}
let entity_ids: Vec<i64> = visited.keys().copied().collect();
let point_id_map = store.qdrant_point_ids_for_entities(&entity_ids).await?;
let point_ids: Vec<String> = point_id_map.values().cloned().collect();
let t_vec = Instant::now();
let vec_map = embeddings
.get_vectors_from_collection(ENTITY_COLLECTION, &point_ids)
.await?;
if params.step_budget.is_some_and(|b| t_vec.elapsed() > b) {
tracing::warn!(
elapsed_ms = t_vec.elapsed().as_millis(),
"hela: vectors-batch over budget"
);
return Ok(Vec::new());
}
let mut facts: Vec<HelaFact> = Vec::with_capacity(visited.len().saturating_sub(1));
for (&entity_id, &(depth, path_weight, edge_id_opt)) in &visited {
if entity_id == anchor_entity_id {
continue;
}
let Some(edge_id) = edge_id_opt else {
continue;
};
let Some(point_id) = point_id_map.get(&entity_id) else {
continue;
};
let Some(node_vec) = vec_map.get(point_id) else {
continue;
};
if node_vec.len() != q_vec.len() {
continue;
}
let cosine_clamped = cosine(&q_vec, node_vec).max(0.0);
let fact_score = path_weight * cosine_clamped;
let Some(edge) = edge_cache.get(&edge_id).cloned() else {
continue;
};
facts.push(HelaFact {
edge,
score: fact_score,
depth,
path_weight,
cosine: Some(cosine_clamped),
});
}
facts.sort_by(|a, b| b.score.total_cmp(&a.score));
facts.truncate(limit);
if hebbian_enabled {
let edge_ids: Vec<i64> = facts
.iter()
.map(|f| f.edge.id)
.filter(|&id| id != 0) .collect();
if !edge_ids.is_empty()
&& let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
{
tracing::warn!(error = %e, "hela: hebbian increment failed");
}
}
tracing::Span::current().record("visited", visited.len());
tracing::Span::current().record("scored", facts.len());
Ok(facts)
}
pub struct SpreadingActivation {
params: SpreadingActivationParams,
}
impl SpreadingActivation {
#[must_use]
pub fn new(params: SpreadingActivationParams) -> Self {
Self { params }
}
#[allow(clippy::too_many_lines)]
pub async fn spread(
&self,
store: &GraphStore,
seeds: HashMap<i64, f32>,
edge_types: &[EdgeType],
) -> Result<(Vec<ActivatedNode>, Vec<ActivatedFact>), MemoryError> {
if seeds.is_empty() {
return Ok((Vec::new(), Vec::new()));
}
let now_secs: i64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs().cast_signed());
let mut activation: HashMap<i64, (f32, u32)> = HashMap::new();
let mut seed_count = 0usize;
for (entity_id, match_score) in &seeds {
if *match_score < self.params.activation_threshold {
tracing::debug!(
entity_id,
score = match_score,
threshold = self.params.activation_threshold,
"spreading activation: seed below threshold, skipping"
);
continue;
}
activation.insert(*entity_id, (*match_score, 0));
seed_count += 1;
}
tracing::debug!(
seeds = seed_count,
"spreading activation: initialized seeds"
);
let mut activated_facts: Vec<ActivatedFact> = Vec::new();
for hop in 0..self.params.max_hops {
let active_nodes: Vec<(i64, f32)> = activation
.iter()
.filter(|(_, (score, _))| *score >= self.params.activation_threshold)
.map(|(&id, &(score, _))| (id, score))
.collect();
if active_nodes.is_empty() {
break;
}
let node_ids: Vec<i64> = active_nodes.iter().map(|(id, _)| *id).collect();
let edges = store.edges_for_entities(&node_ids, edge_types).await?;
let edge_count = edges.len();
let mut next_activation: HashMap<i64, (f32, u32)> = HashMap::new();
for edge in &edges {
for &(active_id, node_score) in &active_nodes {
let neighbor = if edge.source_entity_id == active_id {
edge.target_entity_id
} else if edge.target_entity_id == active_id {
edge.source_entity_id
} else {
continue;
};
let current_score = activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
let next_score = next_activation.get(&neighbor).map_or(0.0_f32, |&(s, _)| s);
if current_score >= self.params.inhibition_threshold
|| next_score >= self.params.inhibition_threshold
{
continue;
}
let recency = self.recency_weight(&edge.valid_from, now_secs);
let edge_weight = evolved_weight(edge.retrieval_count, edge.confidence);
let type_w = edge_type_weight(edge.edge_type);
let spread_value =
node_score * self.params.decay_lambda * edge_weight * recency * type_w;
if spread_value < self.params.activation_threshold {
continue;
}
let depth_at_max = hop + 1;
let entry = next_activation
.entry(neighbor)
.or_insert((0.0, depth_at_max));
let new_score = (entry.0 + spread_value).min(1.0);
if new_score > entry.0 {
entry.0 = new_score;
entry.1 = depth_at_max;
}
}
}
for (node_id, (new_score, new_depth)) in next_activation {
let entry = activation.entry(node_id).or_insert((0.0, new_depth));
if new_score > entry.0 {
entry.0 = new_score;
entry.1 = new_depth;
}
}
let pruned_count = if activation.len() > self.params.max_activated_nodes {
let before = activation.len();
let mut entries: Vec<(i64, (f32, u32))> = activation.drain().collect();
entries.sort_by(|(_, (a, _)), (_, (b, _))| b.total_cmp(a));
entries.truncate(self.params.max_activated_nodes);
activation = entries.into_iter().collect();
before - self.params.max_activated_nodes
} else {
0
};
tracing::debug!(
hop,
active_nodes = active_nodes.len(),
edges_fetched = edge_count,
after_merge = activation.len(),
pruned = pruned_count,
"spreading activation: hop complete"
);
for edge in edges {
let src_score = activation
.get(&edge.source_entity_id)
.map_or(0.0, |&(s, _)| s);
let tgt_score = activation
.get(&edge.target_entity_id)
.map_or(0.0, |&(s, _)| s);
if src_score >= self.params.activation_threshold
&& tgt_score >= self.params.activation_threshold
{
let activation_score = src_score.max(tgt_score);
activated_facts.push(ActivatedFact {
edge,
activation_score,
});
}
}
}
let mut result: Vec<ActivatedNode> = activation
.into_iter()
.filter(|(_, (score, _))| *score >= self.params.activation_threshold)
.map(|(entity_id, (activation, depth))| ActivatedNode {
entity_id,
activation,
depth,
})
.collect();
result.sort_by(|a, b| b.activation.total_cmp(&a.activation));
tracing::info!(
activated = result.len(),
facts = activated_facts.len(),
"spreading activation: complete"
);
Ok((result, activated_facts))
}
#[allow(clippy::cast_precision_loss)]
fn recency_weight(&self, valid_from: &str, now_secs: i64) -> f32 {
if self.params.temporal_decay_rate <= 0.0 {
return 1.0;
}
let Some(valid_from_secs) = parse_sqlite_datetime_to_unix(valid_from) else {
return 1.0;
};
let age_secs = (now_secs - valid_from_secs).max(0);
let age_days = age_secs as f64 / 86_400.0;
let weight = 1.0_f64 / (1.0 + age_days * self.params.temporal_decay_rate);
#[allow(clippy::cast_possible_truncation)]
let w = weight as f32;
w
}
}
#[must_use]
fn parse_sqlite_datetime_to_unix(s: &str) -> Option<i64> {
if s.len() < 19 {
return None;
}
let year: i64 = s[0..4].parse().ok()?;
let month: i64 = s[5..7].parse().ok()?;
let day: i64 = s[8..10].parse().ok()?;
let hour: i64 = s[11..13].parse().ok()?;
let min: i64 = s[14..16].parse().ok()?;
let sec: i64 = s[17..19].parse().ok()?;
let (y, m) = if month <= 2 {
(year - 1, month + 9)
} else {
(year, month - 3)
};
let era = y.div_euclid(400);
let yoe = y - era * 400;
let doy = (153 * m + 2) / 5 + day - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
let days = era * 146_097 + doe - 719_468;
Some(days * 86_400 + hour * 3_600 + min * 60 + sec)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::GraphStore;
use crate::graph::types::EntityType;
use crate::store::SqliteStore;
async fn setup_store() -> GraphStore {
let store = SqliteStore::new(":memory:").await.unwrap();
GraphStore::new(store.pool().clone())
}
fn default_params() -> SpreadingActivationParams {
SpreadingActivationParams {
decay_lambda: 0.85,
max_hops: 3,
activation_threshold: 0.1,
inhibition_threshold: 0.8,
max_activated_nodes: 50,
temporal_decay_rate: 0.0,
seed_structural_weight: 0.4,
seed_community_cap: 3,
}
}
#[tokio::test]
async fn spread_empty_graph_no_edges_no_facts() {
let store = setup_store().await;
let sa = SpreadingActivation::new(default_params());
let seeds = HashMap::from([(1_i64, 1.0_f32)]);
let (nodes, facts) = sa.spread(&store, seeds, &[]).await.unwrap();
assert_eq!(nodes.len(), 1, "seed must be in activated nodes");
assert_eq!(nodes[0].entity_id, 1);
assert!((nodes[0].activation - 1.0).abs() < 1e-6);
assert!(
facts.is_empty(),
"expected no activated facts on empty graph"
);
}
#[tokio::test]
async fn spread_empty_seeds_returns_empty() {
let store = setup_store().await;
let sa = SpreadingActivation::new(default_params());
let (nodes, facts) = sa.spread(&store, HashMap::new(), &[]).await.unwrap();
assert!(nodes.is_empty());
assert!(facts.is_empty());
}
#[tokio::test]
async fn spread_single_seed_no_edges_returns_seed() {
let store = setup_store().await;
let alice = store
.upsert_entity("Alice", "Alice", EntityType::Person, None)
.await
.unwrap();
let sa = SpreadingActivation::new(default_params());
let seeds = HashMap::from([(alice, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].entity_id, alice);
assert_eq!(nodes[0].depth, 0);
assert!((nodes[0].activation - 1.0).abs() < 1e-6);
}
#[tokio::test]
async fn spread_linear_chain_all_activated_with_decay() {
let store = setup_store().await;
let a = store
.upsert_entity("A", "A", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("B", "B", EntityType::Person, None)
.await
.unwrap();
let c = store
.upsert_entity("C", "C", EntityType::Person, None)
.await
.unwrap();
store
.insert_edge(a, b, "knows", "A knows B", 1.0, None)
.await
.unwrap();
store
.insert_edge(b, c, "knows", "B knows C", 1.0, None)
.await
.unwrap();
let mut cfg = default_params();
cfg.max_hops = 3;
cfg.decay_lambda = 0.9;
let sa = SpreadingActivation::new(cfg);
let seeds = HashMap::from([(a, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
assert!(ids.contains(&a), "A (seed) must be activated");
assert!(ids.contains(&b), "B (hop 1) must be activated");
assert!(ids.contains(&c), "C (hop 2) must be activated");
let score_a = nodes.iter().find(|n| n.entity_id == a).unwrap().activation;
let score_b = nodes.iter().find(|n| n.entity_id == b).unwrap().activation;
let score_c = nodes.iter().find(|n| n.entity_id == c).unwrap().activation;
assert!(
score_a > score_b,
"seed A should have higher activation than hop-1 B"
);
assert!(
score_b > score_c,
"hop-1 B should have higher activation than hop-2 C"
);
}
#[tokio::test]
async fn spread_linear_chain_max_hops_limits_reach() {
let store = setup_store().await;
let a = store
.upsert_entity("A", "A", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("B", "B", EntityType::Person, None)
.await
.unwrap();
let c = store
.upsert_entity("C", "C", EntityType::Person, None)
.await
.unwrap();
store
.insert_edge(a, b, "knows", "A knows B", 1.0, None)
.await
.unwrap();
store
.insert_edge(b, c, "knows", "B knows C", 1.0, None)
.await
.unwrap();
let mut cfg = default_params();
cfg.max_hops = 1;
let sa = SpreadingActivation::new(cfg);
let seeds = HashMap::from([(a, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
assert!(ids.contains(&a), "A must be activated (seed)");
assert!(ids.contains(&b), "B must be activated (hop 1)");
assert!(!ids.contains(&c), "C must NOT be activated with max_hops=1");
}
#[tokio::test]
async fn spread_diamond_graph_convergence() {
let store = setup_store().await;
let a = store
.upsert_entity("A", "A", EntityType::Person, None)
.await
.unwrap();
let b = store
.upsert_entity("B", "B", EntityType::Person, None)
.await
.unwrap();
let c = store
.upsert_entity("C", "C", EntityType::Person, None)
.await
.unwrap();
let d = store
.upsert_entity("D", "D", EntityType::Person, None)
.await
.unwrap();
store
.insert_edge(a, b, "rel", "A-B", 1.0, None)
.await
.unwrap();
store
.insert_edge(a, c, "rel", "A-C", 1.0, None)
.await
.unwrap();
store
.insert_edge(b, d, "rel", "B-D", 1.0, None)
.await
.unwrap();
store
.insert_edge(c, d, "rel", "C-D", 1.0, None)
.await
.unwrap();
let mut cfg = default_params();
cfg.max_hops = 3;
cfg.decay_lambda = 0.9;
cfg.inhibition_threshold = 0.95; let sa = SpreadingActivation::new(cfg);
let seeds = HashMap::from([(a, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
assert!(ids.contains(&d), "D must be activated via diamond paths");
let node_d = nodes.iter().find(|n| n.entity_id == d).unwrap();
assert_eq!(node_d.depth, 2, "D should be at depth 2");
}
#[tokio::test]
async fn spread_inhibition_prevents_runaway() {
let store = setup_store().await;
let hub = store
.upsert_entity("Hub", "Hub", EntityType::Concept, None)
.await
.unwrap();
for i in 0..5 {
let leaf = store
.upsert_entity(
&format!("Leaf{i}"),
&format!("Leaf{i}"),
EntityType::Concept,
None,
)
.await
.unwrap();
store
.insert_edge(hub, leaf, "has", &format!("Hub has Leaf{i}"), 1.0, None)
.await
.unwrap();
store
.insert_edge(
leaf,
hub,
"part_of",
&format!("Leaf{i} part_of Hub"),
1.0,
None,
)
.await
.unwrap();
}
let mut cfg = default_params();
cfg.inhibition_threshold = 0.8;
cfg.max_hops = 3;
let sa = SpreadingActivation::new(cfg);
let seeds = HashMap::from([(hub, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
let hub_node = nodes.iter().find(|n| n.entity_id == hub);
assert!(hub_node.is_some(), "hub must be in results");
assert!(
hub_node.unwrap().activation <= 1.0,
"activation must not exceed 1.0"
);
}
#[tokio::test]
async fn spread_max_activated_nodes_cap_enforced() {
let store = setup_store().await;
let root = store
.upsert_entity("Root", "Root", EntityType::Person, None)
.await
.unwrap();
for i in 0..20 {
let leaf = store
.upsert_entity(
&format!("Node{i}"),
&format!("Node{i}"),
EntityType::Concept,
None,
)
.await
.unwrap();
store
.insert_edge(root, leaf, "has", &format!("Root has Node{i}"), 0.9, None)
.await
.unwrap();
}
let max_nodes = 5;
let cfg = SpreadingActivationParams {
max_activated_nodes: max_nodes,
max_hops: 2,
..default_params()
};
let sa = SpreadingActivation::new(cfg);
let seeds = HashMap::from([(root, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
assert!(
nodes.len() <= max_nodes,
"activation must be capped at {max_nodes} nodes, got {}",
nodes.len()
);
}
#[tokio::test]
async fn spread_temporal_decay_recency_effect() {
let store = setup_store().await;
let src = store
.upsert_entity("Src", "Src", EntityType::Person, None)
.await
.unwrap();
let recent = store
.upsert_entity("Recent", "Recent", EntityType::Tool, None)
.await
.unwrap();
let old = store
.upsert_entity("Old", "Old", EntityType::Tool, None)
.await
.unwrap();
store
.insert_edge(src, recent, "uses", "Src uses Recent", 1.0, None)
.await
.unwrap();
zeph_db::query(
sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
VALUES (?1, ?2, 'uses', 'Src uses Old', 1.0, '1970-01-01 00:00:00')"),
)
.bind(src)
.bind(old)
.execute(store.pool())
.await
.unwrap();
let mut cfg = default_params();
cfg.max_hops = 2;
let sa = SpreadingActivation::new(SpreadingActivationParams {
temporal_decay_rate: 0.5,
..cfg
});
let seeds = HashMap::from([(src, 1.0_f32)]);
let (nodes, _) = sa.spread(&store, seeds, &[]).await.unwrap();
let score_recent = nodes
.iter()
.find(|n| n.entity_id == recent)
.map_or(0.0, |n| n.activation);
let score_old = nodes
.iter()
.find(|n| n.entity_id == old)
.map_or(0.0, |n| n.activation);
assert!(
score_recent > score_old,
"recent edge ({score_recent}) must produce higher activation than old edge ({score_old})"
);
}
#[tokio::test]
async fn spread_edge_type_filter_excludes_other_types() {
let store = setup_store().await;
let a = store
.upsert_entity("A", "A", EntityType::Person, None)
.await
.unwrap();
let b_semantic = store
.upsert_entity("BSemantic", "BSemantic", EntityType::Tool, None)
.await
.unwrap();
let c_causal = store
.upsert_entity("CCausal", "CCausal", EntityType::Concept, None)
.await
.unwrap();
store
.insert_edge(a, b_semantic, "uses", "A uses BSemantic", 1.0, None)
.await
.unwrap();
zeph_db::query(
sql!("INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, edge_type)
VALUES (?1, ?2, 'caused', 'A caused CCausal', 1.0, datetime('now'), 'causal')"),
)
.bind(a)
.bind(c_causal)
.execute(store.pool())
.await
.unwrap();
let cfg = default_params();
let sa = SpreadingActivation::new(cfg);
let seeds = HashMap::from([(a, 1.0_f32)]);
let (nodes, _) = sa
.spread(&store, seeds, &[EdgeType::Semantic])
.await
.unwrap();
let ids: Vec<i64> = nodes.iter().map(|n| n.entity_id).collect();
assert!(
ids.contains(&b_semantic),
"BSemantic must be activated via semantic edge"
);
assert!(
!ids.contains(&c_causal),
"CCausal must NOT be activated when filtering to semantic only"
);
}
#[tokio::test]
async fn spread_large_seed_list() {
let store = setup_store().await;
let mut seeds = HashMap::new();
for i in 0..100i64 {
let id = store
.upsert_entity(
&format!("Entity{i}"),
&format!("entity{i}"),
EntityType::Concept,
None,
)
.await
.unwrap();
seeds.insert(id, 1.0_f32);
}
let cfg = default_params();
let sa = SpreadingActivation::new(cfg);
let result = sa.spread(&store, seeds, &[]).await;
assert!(
result.is_ok(),
"large seed list must not error: {:?}",
result.err()
);
}
#[test]
fn hela_cosine_identical_vectors() {
let v = vec![1.0_f32, 0.0, 0.0];
assert!(
(cosine(&v, &v) - 1.0).abs() < 1e-6,
"identical vectors → cosine 1.0"
);
}
#[test]
fn hela_cosine_orthogonal_vectors() {
let a = vec![1.0_f32, 0.0];
let b = vec![0.0_f32, 1.0];
assert!(
cosine(&a, &b).abs() < 1e-6,
"orthogonal vectors → cosine 0.0"
);
}
#[test]
fn hela_cosine_anti_correlated() {
let a = vec![1.0_f32, 0.0];
let b = vec![-1.0_f32, 0.0];
assert!(
cosine(&a, &b) < 0.0,
"anti-correlated vectors → negative cosine"
);
}
#[test]
fn hela_cosine_zero_vector_no_panic() {
let a = vec![0.0_f32, 0.0];
let b = vec![1.0_f32, 0.0];
let result = cosine(&a, &b);
assert!(
result.is_finite(),
"zero-norm vector must yield finite cosine"
);
}
#[test]
fn hela_spread_params_default_depth_is_two() {
let p = HelaSpreadParams::default();
assert_eq!(p.spread_depth, 2);
assert!(p.step_budget.is_some());
assert!(p.edge_types.is_empty());
assert_eq!(p.max_visited, 200);
}
#[test]
fn hela_synthetic_anchor_edge_id_is_zero() {
let edge = Edge::synthetic_anchor(42);
assert_eq!(
edge.id, 0,
"synthetic anchor must have id = 0 to be excluded from Hebbian"
);
assert_eq!(edge.source_entity_id, 42);
assert_eq!(edge.target_entity_id, 42);
}
#[test]
fn hela_negative_cosine_clamped_to_zero_in_score() {
let anti = vec![-1.0_f32, 0.0];
let query = vec![1.0_f32, 0.0];
let cosine_raw = cosine(&query, &anti);
assert!(cosine_raw < 0.0);
let clamped = cosine_raw.max(0.0);
let fact_score = 0.9_f32 * clamped;
assert!(
fact_score < f32::EPSILON,
"anti-correlated score must be 0.0"
);
}
#[test]
fn hela_path_weight_multiplicative() {
let w1 = 0.8_f32;
let w2 = 0.5_f32;
let expected = w1 * w2;
assert!((expected - 0.4).abs() < 1e-6);
}
#[test]
fn hela_max_path_weight_on_multipath() {
let pw_a = 0.9_f32; let pw_b = 0.3_f32; let kept = pw_a.max(pw_b);
assert!(
(kept - 0.9).abs() < 1e-6,
"multi-path resolution must keep maximum path_weight"
);
}
#[test]
fn hela_fact_score_formula() {
let path_weight = 0.8_f32;
let cosine_clamped = 0.75_f32;
let expected = path_weight * cosine_clamped;
assert!((expected - 0.6).abs() < 1e-5);
}
}