use crate::causal::CausalGraph;
use crate::long_term::LongTermStore;
use crate::short_term::ShortTermBuffer;
use crate::types::{SubstrateTime, TemporalPattern};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone)]
pub struct ConsolidationConfig {
pub salience_threshold: f32,
pub w_frequency: f32,
pub w_recency: f32,
pub w_causal: f32,
pub w_surprise: f32,
}
impl Default for ConsolidationConfig {
fn default() -> Self {
Self {
salience_threshold: 0.5,
w_frequency: 0.3,
w_recency: 0.2,
w_causal: 0.3,
w_surprise: 0.2,
}
}
}
pub fn compute_salience(
temporal_pattern: &TemporalPattern,
causal_graph: &CausalGraph,
long_term: &LongTermStore,
config: &ConsolidationConfig,
) -> f32 {
let now = SubstrateTime::now();
let access_freq = (temporal_pattern.access_count as f32).ln_1p() / 10.0;
let time_diff = (now - temporal_pattern.last_accessed).abs();
let seconds_since = (time_diff.0 / 1_000_000_000).max(1) as f32; let recency = 1.0 / (1.0 + seconds_since / 3600.0);
let causal_importance = causal_graph.out_degree(temporal_pattern.pattern.id) as f32;
let causal_score = (causal_importance.ln_1p()) / 5.0;
let surprise = compute_surprise(&temporal_pattern.pattern, long_term);
let salience = config.w_frequency * access_freq
+ config.w_recency * recency
+ config.w_causal * causal_score
+ config.w_surprise * surprise;
salience.max(0.0).min(1.0)
}
fn compute_surprise(pattern: &exo_core::Pattern, long_term: &LongTermStore) -> f32 {
const SAMPLE_SIZE: usize = 50;
if long_term.is_empty() {
return 1.0; }
let all_patterns = long_term.all();
let total = all_patterns.len();
if total <= SAMPLE_SIZE {
let mut max_similarity = 0.0f32;
for existing in all_patterns {
let sim = cosine_similarity_simd(&pattern.embedding, &existing.pattern.embedding);
max_similarity = max_similarity.max(sim);
}
return (1.0 - max_similarity).max(0.0);
}
let step = total / SAMPLE_SIZE;
let mut max_similarity = 0.0f32;
for i in (0..total).step_by(step.max(1)) {
let existing = &all_patterns[i];
let sim = cosine_similarity_simd(&pattern.embedding, &existing.pattern.embedding);
max_similarity = max_similarity.max(sim);
if max_similarity > 0.95 {
return 0.05; }
}
(1.0 - max_similarity).max(0.0)
}
pub fn compute_salience_batch(
patterns: &[TemporalPattern],
causal_graph: &CausalGraph,
long_term: &LongTermStore,
config: &ConsolidationConfig,
) -> Vec<f32> {
patterns
.iter()
.map(|tp| compute_salience(tp, causal_graph, long_term, config))
.collect()
}
pub fn consolidate(
short_term: &ShortTermBuffer,
long_term: &LongTermStore,
causal_graph: &CausalGraph,
config: &ConsolidationConfig,
) -> ConsolidationResult {
let mut num_consolidated = 0;
let mut num_forgotten = 0;
let patterns = short_term.drain();
for mut temporal_pattern in patterns {
let salience = compute_salience(&temporal_pattern, causal_graph, long_term, config);
temporal_pattern.pattern.salience = salience;
if salience >= config.salience_threshold {
long_term.integrate(temporal_pattern);
num_consolidated += 1;
} else {
num_forgotten += 1;
}
}
ConsolidationResult {
num_consolidated,
num_forgotten,
}
}
#[derive(Debug, Clone)]
pub struct ConsolidationResult {
pub num_consolidated: usize,
pub num_forgotten: usize,
}
#[inline]
fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let len = a.len();
let chunks = len / 4;
let _remainder = len % 4;
let mut dot = 0.0f32;
let mut mag_a = 0.0f32;
let mut mag_b = 0.0f32;
for i in 0..chunks {
let base = i * 4;
unsafe {
let a0 = *a.get_unchecked(base);
let a1 = *a.get_unchecked(base + 1);
let a2 = *a.get_unchecked(base + 2);
let a3 = *a.get_unchecked(base + 3);
let b0 = *b.get_unchecked(base);
let b1 = *b.get_unchecked(base + 1);
let b2 = *b.get_unchecked(base + 2);
let b3 = *b.get_unchecked(base + 3);
dot += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3;
mag_a += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3;
mag_b += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3;
}
}
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
mag_a += ai * ai;
mag_b += bi * bi;
}
let mag = (mag_a * mag_b).sqrt();
if mag == 0.0 {
return 0.0;
}
dot / mag
}
#[allow(dead_code)]
#[inline]
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
cosine_similarity_simd(a, b)
}
#[derive(Debug, Default)]
pub struct ConsolidationStats {
pub total_processed: AtomicUsize,
pub total_consolidated: AtomicUsize,
pub total_forgotten: AtomicUsize,
}
impl Clone for ConsolidationStats {
fn clone(&self) -> Self {
Self {
total_processed: AtomicUsize::new(self.total_processed.load(Ordering::Relaxed)),
total_consolidated: AtomicUsize::new(self.total_consolidated.load(Ordering::Relaxed)),
total_forgotten: AtomicUsize::new(self.total_forgotten.load(Ordering::Relaxed)),
}
}
}
impl ConsolidationStats {
pub fn new() -> Self {
Self::default()
}
pub fn record(&self, result: &ConsolidationResult) {
self.total_processed.fetch_add(
result.num_consolidated + result.num_forgotten,
Ordering::Relaxed,
);
self.total_consolidated
.fetch_add(result.num_consolidated, Ordering::Relaxed);
self.total_forgotten
.fetch_add(result.num_forgotten, Ordering::Relaxed);
}
pub fn consolidation_rate(&self) -> f32 {
let total = self.total_processed.load(Ordering::Relaxed);
let consolidated = self.total_consolidated.load(Ordering::Relaxed);
if total == 0 {
return 0.0;
}
consolidated as f32 / total as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Metadata;
#[test]
fn test_compute_salience() {
let causal_graph = CausalGraph::new();
let long_term = LongTermStore::default();
let config = ConsolidationConfig::default();
let mut temporal_pattern =
TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
temporal_pattern.access_count = 10;
let salience = compute_salience(&temporal_pattern, &causal_graph, &long_term, &config);
assert!(salience >= 0.0 && salience <= 1.0);
}
#[test]
fn test_consolidation() {
let short_term = ShortTermBuffer::default();
let long_term = LongTermStore::default();
let causal_graph = CausalGraph::new();
let config = ConsolidationConfig::default();
let mut p1 = TemporalPattern::from_embedding(vec![1.0, 0.0, 0.0], Metadata::new());
p1.access_count = 100; short_term.insert(p1);
let p2 = TemporalPattern::from_embedding(vec![0.0, 1.0, 0.0], Metadata::new());
short_term.insert(p2);
let result = consolidate(&short_term, &long_term, &causal_graph, &config);
assert!(result.num_consolidated > 0);
assert!(short_term.is_empty());
}
}