use std::collections::HashMap;
use std::sync::Arc;
use zeph_common::memory::EdgeType;
use crate::graph::GraphStore;
pub struct CausalDistanceComputer {
graph_store: Arc<GraphStore>,
max_depth: u32,
neutral_distance: u32,
cache: Option<(i64, HashMap<i64, u32>)>,
}
impl CausalDistanceComputer {
#[must_use]
pub fn new(graph_store: Arc<GraphStore>, max_depth: u32, neutral_distance: u32) -> Self {
Self {
graph_store,
max_depth,
neutral_distance,
cache: None,
}
}
#[tracing::instrument(
name = "memory.five_signal.causal_distance.compute",
skip(self, entity_ids),
fields(goal_entity_id, candidate_count = entity_ids.len())
)]
pub async fn compute(
&mut self,
goal_entity_id: Option<i64>,
entity_ids: &[i64],
) -> Result<HashMap<i64, u32>, crate::error::MemoryError> {
tracing::debug!("five_signal: computing causal distances");
let Some(goal_id) = goal_entity_id else {
return Ok(HashMap::new());
};
let neutral = self.neutral_distance;
let depth_map = self.ensure_cache(goal_id).await?;
let result = entity_ids
.iter()
.map(|&eid| {
let dist = depth_map.get(&eid).copied().unwrap_or(neutral);
(eid, dist)
})
.collect();
Ok(result)
}
#[must_use]
#[inline]
pub fn distance_to_score(distance: u32) -> f64 {
if distance == 0 {
1.0
} else {
(1.0_f64 / f64::from(distance)).min(1.0)
}
}
pub fn invalidate_cache(&mut self) {
self.cache = None;
}
async fn ensure_cache(
&mut self,
goal_id: i64,
) -> Result<&HashMap<i64, u32>, crate::error::MemoryError> {
if self.cache.as_ref().map(|(id, _)| *id) != Some(goal_id) {
let (_, _, depth_map) = self
.graph_store
.bfs_typed(goal_id, self.max_depth, &[EdgeType::Causal])
.await?;
self.cache = Some((goal_id, depth_map));
}
Ok(&self.cache.as_ref().expect("just set above").1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn distance_to_score_values() {
assert!((CausalDistanceComputer::distance_to_score(0) - 1.0).abs() < 1e-9);
assert!((CausalDistanceComputer::distance_to_score(1) - 1.0).abs() < 1e-9);
assert!((CausalDistanceComputer::distance_to_score(2) - 0.5).abs() < 1e-9);
assert!((CausalDistanceComputer::distance_to_score(5) - 0.2).abs() < 1e-9);
}
#[test]
fn distance_to_score_beyond_max_depth_clamped_to_min() {
let score_at_limit = CausalDistanceComputer::distance_to_score(10);
let score_beyond = CausalDistanceComputer::distance_to_score(20);
assert!(score_at_limit <= 1.0);
assert!(score_beyond <= score_at_limit, "deeper nodes score lower");
assert!((score_at_limit - 0.1).abs() < 1e-9);
assert!((score_beyond - 0.05).abs() < 1e-9);
}
#[test]
fn neutral_distance_determines_unreachable_score() {
let neutral = 5_u32;
let score = CausalDistanceComputer::distance_to_score(neutral);
assert!((score - 0.2).abs() < 1e-9);
}
#[tokio::test]
async fn compute_none_goal_returns_empty_map() {
use std::sync::Arc;
let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
let graph_store = Arc::new(crate::graph::GraphStore::new(pool));
let mut computer = CausalDistanceComputer::new(graph_store, 10, 5);
let result = computer
.compute(None, &[1, 2, 3])
.await
.expect("None goal must not fail");
assert!(
result.is_empty(),
"goal_entity_id=None must return empty map, got: {result:?}"
);
}
}