use super::*;
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct ForgettingResult {
pub records_decayed: usize,
pub records_archived: usize,
pub records_purged: usize,
pub edges_pruned: usize,
pub execution_time_ms: f64,
}
#[async_trait]
trait ForgettingGraphMaintenance: Send + Sync {
async fn maintenance_all_edges(&self) -> HirnResult<Vec<crate::graph::GraphEdge>>;
async fn maintenance_update_edge_weight(
&self,
edge_id: crate::graph::EdgeId,
new_weight: f32,
co_retrieval_count: Option<u64>,
) -> HirnResult<()>;
async fn maintenance_remove_edge(&self, edge_id: crate::graph::EdgeId) -> HirnResult<()>;
}
#[async_trait]
impl<T> ForgettingGraphMaintenance for T
where
T: crate::graph_store::GraphStore + ?Sized,
{
async fn maintenance_all_edges(&self) -> HirnResult<Vec<crate::graph::GraphEdge>> {
crate::graph_store::GraphStore::all_edges(self).await
}
async fn maintenance_update_edge_weight(
&self,
edge_id: crate::graph::EdgeId,
new_weight: f32,
co_retrieval_count: Option<u64>,
) -> HirnResult<()> {
crate::graph_store::GraphStore::update_edge_weight(
self,
edge_id,
new_weight,
co_retrieval_count,
)
.await
}
async fn maintenance_remove_edge(&self, edge_id: crate::graph::EdgeId) -> HirnResult<()> {
crate::graph_store::GraphStore::remove_edge(self, edge_id).await
}
}
pub fn retention_score(hours_since_access: f64, stability: f32, rehearsal_count: u64) -> f32 {
let effective_stability = stability as f64 * (1.0 + 0.5 * (rehearsal_count.max(1) as f64).ln());
let effective_stability = effective_stability.max(f64::EPSILON);
(-hours_since_access / effective_stability).exp() as f32
}
async fn decay_and_prune_hebbian_edges(
graph: &(impl ForgettingGraphMaintenance + ?Sized),
decay_lambda: f64,
prune_threshold: f32,
now_dt: chrono::DateTime<chrono::Utc>,
) -> HirnResult<usize> {
let all_edges = graph.maintenance_all_edges().await?;
let hebbian_edges: Vec<_> = all_edges
.into_iter()
.filter(|e| e.co_retrieval_count > 0)
.collect();
let mut pruned = 0;
for edge in hebbian_edges {
let hours_since_update = now_dt
.signed_duration_since(edge.updated_at.as_datetime())
.num_seconds()
.max(0) as f64
/ 3600.0;
let new_weight = if hours_since_update > 0.0 {
let time_decay = (-decay_lambda * hours_since_update).exp() as f32;
(edge.weight * time_decay).max(0.01)
} else {
edge.weight
};
if new_weight < prune_threshold {
graph.maintenance_remove_edge(edge.id).await?;
pruned += 1;
} else if (edge.weight - new_weight).abs() > 0.001 {
graph
.maintenance_update_edge_weight(edge.id, new_weight, None)
.await?;
}
}
Ok(pruned)
}
pub async fn run_forgetting(
db: &HirnDB,
config: &ConsolidationConfig,
) -> HirnResult<ForgettingResult> {
let start = Instant::now();
let hirn_config = db.config();
let decay_lambda = config
.decay_rate_override
.unwrap_or(hirn_config.decay_lambda);
let archive_threshold = hirn_config.archive_threshold;
let purge_threshold = hirn_config.purge_threshold;
let grace_period_hours = 168.0;
let filter = crate::db::EpisodicFilter {
include_archived: true,
..Default::default()
};
let episodes = db.list_episodes(&filter).await?;
let mut records_decayed = 0;
let mut records_archived = 0;
let mut records_purged = 0;
let now = Timestamp::now();
let now_dt = now.as_datetime();
for ep in &episodes {
let hours_since_access = now_dt
.signed_duration_since(ep.last_accessed.as_datetime())
.num_hours() as f64;
let hours_since_creation = now_dt
.signed_duration_since(ep.timestamp.as_datetime())
.num_hours() as f64;
if hours_since_creation < 1.0 {
continue;
}
let retention = retention_score(hours_since_access, ep.stability, ep.access_count);
let effective_importance = ep.importance * retention;
if ep.archived && effective_importance < purge_threshold {
if hours_since_access > grace_period_hours {
db.delete_episode(ep.id).await?;
records_purged += 1;
continue;
}
}
if !ep.archived && effective_importance < archive_threshold {
db.archive_episode(ep.id).await?;
records_archived += 1;
records_decayed += 1;
} else if effective_importance < ep.importance * 0.999 {
records_decayed += 1;
}
}
let now_dt = hirn_core::timestamp::Timestamp::now().as_datetime();
let edges_pruned = decay_and_prune_hebbian_edges(
db.graph_store(),
decay_lambda,
config.edge_prune_threshold,
now_dt,
)
.await?;
let execution_time_ms = start.elapsed().as_secs_f64() * 1000.0;
Ok(ForgettingResult {
records_decayed,
records_archived,
records_purged,
edges_pruned,
execution_time_ms,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph_store::GraphStore;
use std::sync::{Arc, Mutex};
use hirn_core::HirnConfig;
use hirn_core::HirnError;
use hirn_core::episodic::EpisodicRecord;
use hirn_core::metadata::Metadata;
use hirn_core::types::{AgentId, EdgeRelation, EventType, Namespace};
use hirn_storage::memory_store::MemoryStore;
struct FakeForgettingGraph {
edges: Vec<crate::graph::GraphEdge>,
fail_remove: Option<crate::graph::EdgeId>,
fail_update: Option<crate::graph::EdgeId>,
removed: Mutex<Vec<crate::graph::EdgeId>>,
updated: Mutex<Vec<crate::graph::EdgeId>>,
}
#[async_trait]
impl ForgettingGraphMaintenance for FakeForgettingGraph {
async fn maintenance_all_edges(&self) -> HirnResult<Vec<crate::graph::GraphEdge>> {
Ok(self.edges.clone())
}
async fn maintenance_update_edge_weight(
&self,
edge_id: crate::graph::EdgeId,
_new_weight: f32,
_co_retrieval_count: Option<u64>,
) -> HirnResult<()> {
if self.fail_update == Some(edge_id) {
return Err(HirnError::Unsupported(format!(
"simulated edge update failure for {edge_id}"
)));
}
self.updated.lock().unwrap().push(edge_id);
Ok(())
}
async fn maintenance_remove_edge(&self, edge_id: crate::graph::EdgeId) -> HirnResult<()> {
if self.fail_remove == Some(edge_id) {
return Err(HirnError::Unsupported(format!(
"simulated edge removal failure for {edge_id}"
)));
}
self.removed.lock().unwrap().push(edge_id);
Ok(())
}
}
fn test_hebbian_edge(weight: f32, updated_at: Timestamp) -> crate::graph::GraphEdge {
crate::graph::GraphEdge {
id: MemoryId::new(),
source: MemoryId::new(),
target: MemoryId::new(),
relation: EdgeRelation::RelatedTo,
weight,
co_retrieval_count: 1,
created_at: updated_at,
updated_at,
valid_from: None,
valid_until: None,
metadata: Metadata::new(),
resolved: false,
namespace: Namespace::default(),
causal: None,
}
}
async fn temp_db() -> (HirnDB, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("forgetting-tests");
let config = HirnConfig::builder()
.db_path(&path)
.embedding_dimensions(4)
.working_memory_token_limit(1000)
.build()
.unwrap();
let db = HirnDB::open_with_config(config, Arc::new(MemoryStore::new()))
.await
.unwrap();
(db, dir)
}
#[test]
fn retention_just_accessed_is_one() {
let r = retention_score(0.0, 24.0, 0);
assert!((r - 1.0).abs() < 1e-6);
}
#[test]
fn retention_approaches_zero() {
let r = retention_score(100_000.0, 24.0, 0);
assert!(r < 0.001);
}
#[test]
fn retention_decays_with_time() {
let r1 = retention_score(12.0, 24.0, 0);
let r2 = retention_score(48.0, 24.0, 0);
assert!(r1 > r2, "12h retention {r1} should be > 48h retention {r2}");
}
#[test]
fn rehearsal_slows_decay() {
let r_unrehearsed = retention_score(48.0, 24.0, 0);
let r_rehearsed = retention_score(48.0, 24.0, 10);
assert!(
r_rehearsed > r_unrehearsed,
"rehearsed {r_rehearsed} should retain better than unrehearsed {r_unrehearsed}"
);
}
#[test]
fn higher_stability_slows_decay() {
let r_low = retention_score(48.0, 12.0, 0);
let r_high = retention_score(48.0, 48.0, 0);
assert!(
r_high > r_low,
"high stability {r_high} should retain better than low stability {r_low}"
);
}
#[test]
fn spaced_rehearsal_more_effective() {
let _r_retrieved_5d = retention_score(120.0, 24.0, 1);
let _r_never_3d = retention_score(72.0, 24.0, 0);
let _r_retrieved_5d_grown = retention_score(120.0, 26.4, 1);
let r_10x_5d = retention_score(120.0, 62.2, 10);
let r_0x_3d = retention_score(72.0, 24.0, 0);
assert!(
r_10x_5d > r_0x_3d,
"10x rehearsed 5 days {r_10x_5d} should retain better than unretrieved 3 days {r_0x_3d}"
);
}
#[test]
fn retention_score_formula_matches_rfc() {
let r = retention_score(24.0, 24.0, 5);
let expected_s = 24.0 * (1.0 + 0.5 * 5.0_f64.ln());
let expected_r = (-24.0 / expected_s).exp() as f32;
assert!(
(r - expected_r).abs() < 1e-5,
"got {r}, expected {expected_r}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn decay_and_prune_fails_on_weight_update_error() {
let edge = test_hebbian_edge(
0.8,
Timestamp::from_datetime(chrono::Utc::now() - chrono::Duration::hours(4)),
);
let graph = FakeForgettingGraph {
fail_remove: None,
fail_update: Some(edge.id),
edges: vec![edge],
removed: Mutex::new(Vec::new()),
updated: Mutex::new(Vec::new()),
};
let error = decay_and_prune_hebbian_edges(&graph, 0.1, 0.0, chrono::Utc::now())
.await
.expect_err("edge weight update failure should abort forgetting maintenance");
assert!(matches!(error, HirnError::Unsupported(_)));
assert!(graph.updated.lock().unwrap().is_empty());
}
#[tokio::test(flavor = "multi_thread")]
async fn decay_and_prune_fails_on_remove_error() {
let edge = test_hebbian_edge(0.05, Timestamp::now());
let graph = FakeForgettingGraph {
fail_remove: Some(edge.id),
fail_update: None,
edges: vec![edge],
removed: Mutex::new(Vec::new()),
updated: Mutex::new(Vec::new()),
};
let error = decay_and_prune_hebbian_edges(&graph, 0.0, 0.1, chrono::Utc::now())
.await
.expect_err("edge removal failure should abort forgetting pruning");
assert!(matches!(error, HirnError::Unsupported(_)));
assert!(graph.removed.lock().unwrap().is_empty());
}
#[tokio::test(flavor = "multi_thread")]
async fn decay_and_prune_removes_weak_edges_from_both_tiers() {
let (db, _dir) = temp_db().await;
let source_id = db
.remember(
EpisodicRecord::builder()
.event_type(EventType::Observation)
.content("source event")
.summary("source event")
.embedding(vec![1.0, 0.0, 0.0, 0.0])
.importance(0.9)
.agent_id(AgentId::new("forgetting-test").unwrap())
.build()
.unwrap(),
)
.await
.unwrap();
let target_id = db
.remember(
EpisodicRecord::builder()
.event_type(EventType::Observation)
.content("target event")
.summary("target event")
.embedding(vec![0.0, 1.0, 0.0, 0.0])
.importance(0.8)
.agent_id(AgentId::new("forgetting-test").unwrap())
.build()
.unwrap(),
)
.await
.unwrap();
let edge_id = db
.cached_graph()
.add_edge(
source_id,
target_id,
EdgeRelation::Causes,
0.1,
Metadata::new(),
)
.await
.unwrap();
db.cached_graph()
.update_edge_weight(edge_id, 0.1, Some(1))
.await
.unwrap();
let pruned = decay_and_prune_hebbian_edges(db.graph_store(), 0.0, 0.2, chrono::Utc::now())
.await
.unwrap();
assert_eq!(pruned, 1);
assert!(
db.cached_graph()
.get_edges(source_id)
.await
.unwrap()
.iter()
.all(|edge| edge.id != edge_id)
);
assert!(
db.persistent_graph()
.get_edges(source_id)
.await
.unwrap()
.iter()
.all(|edge| edge.id != edge_id)
);
}
}