use std::collections::{HashMap, HashSet};
use super::VectorStore;
use crate::error::Result;
use crate::model::Memory;
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct ConsolidationParams {
pub max_age_days: u32,
pub decay_factor: f64,
pub merge_threshold: f64,
pub min_importance: f64,
}
impl Default for ConsolidationParams {
fn default() -> Self {
Self {
max_age_days: 90,
decay_factor: 0.95,
merge_threshold: 0.95,
min_importance: 0.1,
}
}
}
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)]
pub struct ConsolidationStats {
pub scanned: usize,
pub decayed: usize,
pub merged: usize,
pub pruned: usize,
pub duration_secs: f64,
}
pub async fn consolidate(
store: &VectorStore,
params: &ConsolidationParams,
namespace: Option<&str>,
) -> Result<ConsolidationStats> {
let pairs = store.get_all_with_vectors(namespace).await?;
if pairs.is_empty() {
return Ok(ConsolidationStats::default());
}
let t0 = std::time::Instant::now();
let (entries, vectors): (Vec<Memory>, Vec<Vec<f32>>) = pairs.into_iter().unzip();
let n = entries.len();
let mut importance: Vec<f64> = vec![1.0; n];
let mut dead = HashSet::new();
let decayed = apply_decay(&entries, params, &mut importance, &mut dead);
let merge_k = 5;
let mut ids_to_delete: Vec<uuid::Uuid> = Vec::new();
let id_to_idx: HashMap<uuid::Uuid, usize> =
entries.iter().enumerate().map(|(i, e)| (e.id, i)).collect();
let live_indices: Vec<usize> = (0..n).filter(|i| !dead.contains(i)).collect();
let max_ann_workers = 8;
let semaphore = tokio::sync::Semaphore::new(max_ann_workers);
let vectors_ref = &vectors;
let ann_futures = live_indices.iter().map(|&i| {
let sem = &semaphore;
async move {
let _permit = sem.acquire().await;
store
.semantic_search(
vectors_ref.get(i).unwrap_or(&Vec::new()),
merge_k,
namespace,
)
.await
.map(|neighbors| (i, neighbors))
}
});
let all_neighbors: Vec<(usize, Vec<Memory>)> =
futures::future::try_join_all(ann_futures).await?;
let (merge_deletes, merged) = merge_similar(
&all_neighbors,
&entries,
&vectors,
&id_to_idx,
&importance,
params.merge_threshold,
&mut dead,
);
ids_to_delete.extend(merge_deletes);
let mut pruned = 0usize;
for (i, entry) in entries.iter().enumerate() {
if !dead.contains(&i) && importance.get(i).copied().unwrap_or(0.0) < params.min_importance {
ids_to_delete.push(entry.id);
pruned += 1;
}
}
if !ids_to_delete.is_empty() {
store.delete_entries(&ids_to_delete).await?;
}
let stats = ConsolidationStats {
scanned: n,
decayed,
merged,
pruned,
duration_secs: t0.elapsed().as_secs_f64(),
};
tracing::info!(?stats, "consolidation complete");
Ok(stats)
}
fn apply_decay(
entries: &[Memory],
params: &ConsolidationParams,
importance: &mut [f64],
dead: &mut HashSet<usize>,
) -> usize {
let now = chrono::Utc::now();
let max_age_secs = f64::from(params.max_age_days) * 86400.0;
let mut decayed = 0usize;
for (i, entry) in entries.iter().enumerate() {
let Some(ts) = entry.timestamp else { continue };
let age = (now - ts).num_seconds() as f64;
if age > max_age_secs {
if let Some(imp) = importance.get_mut(i) {
*imp *= params.decay_factor;
}
if importance.get(i).copied().unwrap_or(0.0) < params.min_importance {
dead.insert(i);
}
decayed += 1;
}
}
decayed
}
fn merge_similar(
all_neighbors: &[(usize, Vec<Memory>)],
entries: &[Memory],
vectors: &[Vec<f32>],
id_to_idx: &HashMap<uuid::Uuid, usize>,
importance: &[f64],
threshold: f64,
dead: &mut HashSet<usize>,
) -> (Vec<uuid::Uuid>, usize) {
let mut ids_to_delete = Vec::new();
let mut merged = 0usize;
for (i, neighbors) in all_neighbors {
if dead.contains(i) {
continue;
}
for neighbor in neighbors {
let Some(entry_i) = entries.get(*i) else {
continue;
};
if neighbor.id == entry_i.id {
continue;
}
let Some(&j) = id_to_idx.get(&neighbor.id) else {
continue;
};
if dead.contains(&j) {
continue;
}
let (Some(vi), Some(vj)) = (vectors.get(*i), vectors.get(j)) else {
continue;
};
let sim = cosine_similarity(vi, vj);
if sim < threshold {
continue;
}
let (imp_i, imp_j) = (
importance.get(*i).copied().unwrap_or(0.0),
importance.get(j).copied().unwrap_or(0.0),
);
let loser = if imp_i >= imp_j { j } else { *i };
if !dead.insert(loser) {
continue;
}
if let Some(e) = entries.get(loser) {
ids_to_delete.push(e.id);
}
merged += 1;
}
}
(ids_to_delete, merged)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
let dot: f64 = a
.iter()
.zip(b)
.map(|(x, y)| f64::from(*x) * f64::from(*y))
.sum();
let mag_a = a.iter().map(|x| f64::from(*x).powi(2)).sum::<f64>().sqrt();
let mag_b = b.iter().map(|x| f64::from(*x).powi(2)).sum::<f64>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
0.0
} else {
dot / (mag_a * mag_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-9);
}
#[test]
fn cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-9);
}
#[test]
fn cosine_similarity_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
assert!((cosine_similarity(&a, &b) - (-1.0)).abs() < 1e-9);
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 2.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-9);
}
#[test]
fn cosine_similarity_arbitrary() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dot = 3.0f64.mul_add(6.0, 2.0f64.mul_add(5.0, 1.0 * 4.0));
let mag_a = (1.0_f64 + 4.0 + 9.0).sqrt();
let mag_b = (16.0_f64 + 25.0 + 36.0).sqrt();
let expected = dot / (mag_a * mag_b);
assert!((cosine_similarity(&a, &b) - expected).abs() < 1e-9);
}
}