use std::collections::HashMap;
use solo_core::{Cluster, Embedding, EmbeddingDtype, Episode, Error, MemoryId, Result};
use crate::StewardConfig;
const MS_PER_DAY: i64 = 86_400_000;
pub fn cluster_episodes(
inputs: &[(Episode, Embedding)],
config: &StewardConfig,
) -> Result<Vec<Cluster>> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let dim = inputs[0].1.dim;
let dtype = inputs[0].1.dtype;
if dtype != EmbeddingDtype::F32 {
return Err(Error::steward(format!(
"cluster_episodes requires F32 embeddings; got {dtype:?}"
)));
}
for (i, (_, emb)) in inputs.iter().enumerate() {
if emb.dim != dim {
return Err(Error::steward(format!(
"cluster_episodes: embedding dim mismatch at index {i}: {} vs {dim}",
emb.dim
)));
}
if emb.dtype != EmbeddingDtype::F32 {
return Err(Error::steward(format!(
"cluster_episodes: embedding dtype mismatch at index {i}: {:?} vs F32",
emb.dtype
)));
}
emb.validate()?;
}
let f32_views: Vec<&[f32]> = inputs
.iter()
.enumerate()
.map(|(i, (_, emb))| {
emb.as_f32_slice().ok_or_else(|| {
Error::steward(format!(
"cluster_episodes: embedding at index {i} failed F32 cast"
))
})
})
.collect::<Result<Vec<_>>>()?;
let mut buckets: HashMap<i64, Vec<usize>> = HashMap::new();
for (i, (ep, _)) in inputs.iter().enumerate() {
let day = ep.ts_ms.div_euclid(MS_PER_DAY);
buckets.entry(day).or_default().push(i);
}
let mut day_keys: Vec<i64> = buckets.keys().copied().collect();
day_keys.sort();
let mut clusters: Vec<Cluster> = Vec::new();
for day in day_keys {
let indices = &buckets[&day];
if indices.len() < config.cluster_min_size {
continue;
}
let mut uf = UnionFind::new(indices.len());
for a_pos in 0..indices.len() {
for b_pos in (a_pos + 1)..indices.len() {
let a = indices[a_pos];
let b = indices[b_pos];
let sim = cosine_similarity_f32(f32_views[a], f32_views[b]);
if sim >= config.cluster_cosine_threshold {
uf.union(a_pos, b_pos);
}
}
}
let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
for pos in 0..indices.len() {
let root = uf.find(pos);
groups.entry(root).or_default().push(pos);
}
for (_root, members_pos) in groups {
if members_pos.len() < config.cluster_min_size {
continue;
}
let global_indices: Vec<usize> = members_pos.iter().map(|&p| indices[p]).collect();
clusters.push(build_cluster(&global_indices, inputs, &f32_views, dim));
}
}
clusters.sort_by_key(|c| {
let memid_set: std::collections::HashSet<MemoryId> =
c.episode_ids.iter().copied().collect();
inputs
.iter()
.filter(|(ep, _)| memid_set.contains(&ep.memory_id))
.map(|(ep, _)| ep.ts_ms)
.min()
.unwrap_or(i64::MAX)
});
Ok(clusters)
}
fn build_cluster(
global_indices: &[usize],
inputs: &[(Episode, Embedding)],
f32_views: &[&[f32]],
dim: usize,
) -> Cluster {
debug_assert!(global_indices.len() >= 2);
let mut sorted: Vec<usize> = global_indices.to_vec();
sorted.sort_by_key(|&i| {
let ep = &inputs[i].0;
(ep.ts_ms, ep.memory_id)
});
let episode_ids: Vec<MemoryId> = sorted.iter().map(|&i| inputs[i].0.memory_id).collect();
let mut sum = vec![0.0f32; dim];
for &i in &sorted {
for (j, &x) in f32_views[i].iter().enumerate() {
sum[j] += x;
}
}
let n = sorted.len() as f32;
for v in sum.iter_mut() {
*v /= n;
}
let norm = sum.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in sum.iter_mut() {
*v /= norm;
}
}
let centroid_bytes: Vec<u8> = bytemuck::cast_slice(&sum).to_vec();
let centroid = Embedding {
dtype: EmbeddingDtype::F32,
dim,
data: centroid_bytes,
};
let mut total = 0.0f32;
let mut pairs = 0u32;
for a in 0..sorted.len() {
for b in (a + 1)..sorted.len() {
total += cosine_similarity_f32(f32_views[sorted[a]], f32_views[sorted[b]]);
pairs += 1;
}
}
let coherence = if pairs > 0 { total / pairs as f32 } else { 1.0 };
Cluster {
cluster_id: MemoryId::new(),
episode_ids,
centroid: Some(centroid),
coherence,
}
}
fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
}
dot
}
#[derive(Debug, Clone)]
pub struct MergeOp {
pub survivor_id: MemoryId,
pub loser_ids: Vec<MemoryId>,
pub merged_episode_ids: Vec<MemoryId>,
pub merged_centroid: Embedding,
pub merged_coherence: f32,
}
#[derive(Debug, Clone, Default)]
pub struct MergePlan {
pub merges: Vec<MergeOp>,
}
impl MergePlan {
pub fn absorbed(&self) -> usize {
self.merges.iter().map(|m| m.loser_ids.len()).sum()
}
}
pub fn plan_existing_merges(
clusters: &[Cluster],
config: &StewardConfig,
) -> Result<MergePlan> {
let merges = compute_merge_plan(clusters, config)?;
Ok(MergePlan { merges })
}
pub fn merge_clusters_by_centroid(
clusters: &mut Vec<Cluster>,
config: &StewardConfig,
) -> Result<usize> {
let merges = compute_merge_plan(clusters, config)?;
if merges.is_empty() {
return Ok(0);
}
apply_merge_plan_in_place(clusters, &merges);
Ok(merges.iter().map(|m| m.loser_ids.len()).sum())
}
fn apply_merge_plan_in_place(clusters: &mut Vec<Cluster>, merges: &[MergeOp]) {
use std::collections::HashSet;
let losers: HashSet<MemoryId> = merges
.iter()
.flat_map(|m| m.loser_ids.iter().copied())
.collect();
let by_survivor: HashMap<MemoryId, &MergeOp> =
merges.iter().map(|m| (m.survivor_id, m)).collect();
let mut out: Vec<Cluster> = Vec::with_capacity(clusters.len());
for c in clusters.iter() {
if losers.contains(&c.cluster_id) {
continue;
}
if let Some(op) = by_survivor.get(&c.cluster_id) {
out.push(Cluster {
cluster_id: op.survivor_id,
episode_ids: op.merged_episode_ids.clone(),
centroid: Some(op.merged_centroid.clone()),
coherence: op.merged_coherence,
});
} else {
out.push(c.clone());
}
}
*clusters = out;
}
fn compute_merge_plan(
clusters: &[Cluster],
config: &StewardConfig,
) -> Result<Vec<MergeOp>> {
if clusters.len() < 2 {
return Ok(Vec::new());
}
let dim = match clusters[0].centroid.as_ref() {
Some(c) => c.dim,
None => {
return Err(Error::steward(
"compute_merge_plan: cluster[0] has no centroid".to_string(),
));
}
};
for (i, c) in clusters.iter().enumerate() {
let centroid = c.centroid.as_ref().ok_or_else(|| {
Error::steward(format!("compute_merge_plan: cluster[{i}] has no centroid"))
})?;
if centroid.dtype != EmbeddingDtype::F32 {
return Err(Error::steward(format!(
"compute_merge_plan: cluster[{i}] centroid dtype is {:?}, want F32",
centroid.dtype
)));
}
if centroid.dim != dim {
return Err(Error::steward(format!(
"compute_merge_plan: cluster[{i}] centroid dim {} != {dim}",
centroid.dim
)));
}
}
let centroid_views: Vec<&[f32]> = clusters
.iter()
.enumerate()
.map(|(i, c)| {
c.centroid.as_ref().unwrap().as_f32_slice().ok_or_else(|| {
Error::steward(format!(
"compute_merge_plan: cluster[{i}] centroid F32 cast failed"
))
})
})
.collect::<Result<Vec<_>>>()?;
let n = clusters.len();
let mut uf = UnionFind::new(n);
for a in 0..n {
for b in (a + 1)..n {
let sim = cosine_similarity_f32(centroid_views[a], centroid_views[b]);
if sim >= config.cluster_cosine_threshold {
uf.union(a, b);
}
}
}
let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = uf.find(i);
groups.entry(root).or_default().push(i);
}
let mut roots: Vec<usize> = groups.keys().copied().collect();
roots.sort();
let mut merges: Vec<MergeOp> = Vec::new();
for root in roots {
let members = &groups[&root];
if members.len() == 1 {
continue;
}
let survivor_pos = *members
.iter()
.max_by(|&&a, &&b| {
let len_a = clusters[a].episode_ids.len();
let len_b = clusters[b].episode_ids.len();
len_a
.cmp(&len_b)
.then_with(|| clusters[b].cluster_id.cmp(&clusters[a].cluster_id))
})
.expect("members non-empty");
let survivor_id = clusters[survivor_pos].cluster_id;
let loser_ids: Vec<MemoryId> = members
.iter()
.filter(|&&m| m != survivor_pos)
.map(|&m| clusters[m].cluster_id)
.collect();
let mut merged_episode_ids: Vec<MemoryId> = members
.iter()
.flat_map(|&m| clusters[m].episode_ids.iter().copied())
.collect();
merged_episode_ids.sort();
merged_episode_ids.dedup();
let mut sum = vec![0.0f32; dim];
let mut total_weight: f32 = 0.0;
for &m in members {
let w = clusters[m].episode_ids.len() as f32;
for (j, &x) in centroid_views[m].iter().enumerate() {
sum[j] += x * w;
}
total_weight += w;
}
if total_weight > 0.0 {
for v in sum.iter_mut() {
*v /= total_weight;
}
}
let norm = sum.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in sum.iter_mut() {
*v /= norm;
}
}
let centroid_bytes: Vec<u8> = bytemuck::cast_slice(&sum).to_vec();
let merged_centroid = Embedding {
dtype: EmbeddingDtype::F32,
dim,
data: centroid_bytes,
};
let mut coh_sum = 0.0f32;
let mut coh_weight = 0.0f32;
for &m in members {
let w = clusters[m].episode_ids.len() as f32;
coh_sum += clusters[m].coherence * w;
coh_weight += w;
}
let merged_coherence = if coh_weight > 0.0 {
coh_sum / coh_weight
} else {
1.0
};
merges.push(MergeOp {
survivor_id,
loser_ids,
merged_episode_ids,
merged_centroid,
merged_coherence,
});
}
Ok(merges)
}
#[derive(Debug, Clone)]
pub struct ExistingClusterSummary {
pub cluster_id: MemoryId,
pub centroid: Embedding,
pub coherence: f32,
pub episode_count: usize,
}
#[derive(Debug, Clone)]
pub struct AbsorbedCluster {
pub new_cluster_id: MemoryId,
pub existing_cluster_id: MemoryId,
pub merged_centroid: Embedding,
pub merged_coherence: f32,
pub absorbed_episode_ids: Vec<MemoryId>,
}
#[derive(Debug, Clone, Default)]
pub struct AbsorbPlan {
pub absorptions: Vec<AbsorbedCluster>,
}
impl AbsorbPlan {
pub fn was_absorbed(&self, new_cluster_id: MemoryId) -> bool {
self.absorptions
.iter()
.any(|a| a.new_cluster_id == new_cluster_id)
}
pub fn modified_existing_ids(&self) -> Vec<MemoryId> {
let mut ids: Vec<MemoryId> = self
.absorptions
.iter()
.map(|a| a.existing_cluster_id)
.collect();
ids.sort();
ids.dedup();
ids
}
}
pub fn absorb_into_existing(
new_clusters: &[Cluster],
existing: &[ExistingClusterSummary],
config: &StewardConfig,
) -> Result<AbsorbPlan> {
if new_clusters.is_empty() || existing.is_empty() {
return Ok(AbsorbPlan::default());
}
let dim = existing[0].centroid.dim;
for (i, s) in existing.iter().enumerate() {
if s.centroid.dtype != EmbeddingDtype::F32 {
return Err(Error::steward(format!(
"absorb_into_existing: existing[{i}] dtype is {:?}, want F32",
s.centroid.dtype
)));
}
if s.centroid.dim != dim {
return Err(Error::steward(format!(
"absorb_into_existing: existing[{i}] dim {} != {dim}",
s.centroid.dim
)));
}
}
if let Some(first_new) = new_clusters.first() {
let new_centroid = first_new.centroid.as_ref().ok_or_else(|| {
Error::steward("absorb_into_existing: new_clusters[0] has no centroid".to_string())
})?;
if new_centroid.dim != dim {
return Err(Error::steward(format!(
"absorb_into_existing: new_clusters[0] dim {} != existing dim {dim}",
new_centroid.dim
)));
}
}
let mut working: Vec<(MemoryId, Vec<f32>, f32, usize)> = existing
.iter()
.map(|s| {
let v = s
.centroid
.as_f32_slice()
.ok_or_else(|| {
Error::steward(format!(
"absorb_into_existing: existing[{}] centroid F32 cast failed",
s.cluster_id
))
})
.map(|sl| sl.to_vec());
v.map(|vec_f32| (s.cluster_id, vec_f32, s.coherence, s.episode_count))
})
.collect::<Result<Vec<_>>>()?;
let mut plan = AbsorbPlan::default();
for n in new_clusters {
let n_centroid = n
.centroid
.as_ref()
.ok_or_else(|| {
Error::steward(format!(
"absorb_into_existing: new cluster {} has no centroid",
n.cluster_id
))
})?
.as_f32_slice()
.ok_or_else(|| {
Error::steward(format!(
"absorb_into_existing: new cluster {} centroid F32 cast failed",
n.cluster_id
))
})?;
let mut best: Option<usize> = None;
let mut best_sim = config.cluster_cosine_threshold;
for (i, (_id, e_centroid, _coh, _count)) in working.iter().enumerate() {
let sim = cosine_similarity_f32(n_centroid, e_centroid);
if sim >= best_sim {
let take = match best {
None => true,
Some(prev) => {
let prev_count = working[prev].3;
let cur_count = working[i].3;
match cur_count.cmp(&prev_count) {
std::cmp::Ordering::Greater => true,
std::cmp::Ordering::Less => false,
std::cmp::Ordering::Equal => {
working[i].0 < working[prev].0
}
}
}
};
if take {
best = Some(i);
best_sim = sim;
}
}
}
let target_idx = match best {
Some(i) => i,
None => continue, };
let (existing_id, ref mut e_centroid, ref mut e_coh, ref mut e_count) =
working[target_idx];
let n_w = n.episode_ids.len() as f32;
let e_w = *e_count as f32;
let total_w = n_w + e_w;
let mut sum = vec![0.0f32; dim];
if total_w > 0.0 {
for j in 0..dim {
sum[j] = (e_centroid[j] * e_w + n_centroid[j] * n_w) / total_w;
}
}
let norm = sum.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in sum.iter_mut() {
*v /= norm;
}
}
let merged_coh = (*e_coh * e_w + n.coherence * n_w) / total_w.max(1.0);
*e_centroid = sum.clone();
*e_coh = merged_coh;
*e_count += n.episode_ids.len();
let centroid_bytes: Vec<u8> = bytemuck::cast_slice(&sum).to_vec();
let merged_centroid = Embedding {
dtype: EmbeddingDtype::F32,
dim,
data: centroid_bytes,
};
plan.absorptions.push(AbsorbedCluster {
new_cluster_id: n.cluster_id,
existing_cluster_id: existing_id,
merged_centroid,
merged_coherence: merged_coh,
absorbed_episode_ids: n.episode_ids.clone(),
});
}
Ok(plan)
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<u8>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, mut x: usize) -> usize {
while self.parent[x] != x {
self.parent[x] = self.parent[self.parent[x]];
x = self.parent[x];
}
x
}
fn union(&mut self, a: usize, b: usize) {
let ra = self.find(a);
let rb = self.find(b);
if ra == rb {
return;
}
let (small, big) = if self.rank[ra] < self.rank[rb] {
(ra, rb)
} else {
(rb, ra)
};
self.parent[small] = big;
if self.rank[small] == self.rank[big] {
self.rank[big] += 1;
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
use solo_core::{Confidence, EncodingContext, Tier};
fn unit_emb(dim: usize, components: &[(usize, f32)]) -> Embedding {
let mut v = vec![0.0f32; dim];
for &(i, x) in components {
v[i] = x;
}
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
Embedding {
dtype: EmbeddingDtype::F32,
dim,
data: bytemuck::cast_slice(&v).to_vec(),
}
}
fn ep(ts_ms: i64, content: &str) -> Episode {
Episode {
memory_id: MemoryId::new(),
ts_ms,
source_type: "user_message".into(),
source_id: None,
content: content.into(),
encoding_context: EncodingContext::default(),
provenance: None,
confidence: Confidence::new(0.9).unwrap(),
strength: 0.5,
salience: 0.5,
tier: Tier::Hot,
}
}
fn cfg(min: usize, threshold: f32) -> StewardConfig {
StewardConfig {
cluster_min_size: min,
cluster_cosine_threshold: threshold,
..StewardConfig::default()
}
}
#[test]
fn empty_input_returns_empty() {
let r = cluster_episodes(&[], &StewardConfig::default()).unwrap();
assert!(r.is_empty());
}
#[test]
fn rejects_non_f32_embedding() {
let bad = Embedding {
dtype: EmbeddingDtype::F16,
dim: 4,
data: vec![0u8; 8],
};
let inputs = vec![(ep(0, "x"), bad)];
let err = cluster_episodes(&inputs, &StewardConfig::default()).unwrap_err();
assert!(err.to_string().contains("F32"), "got: {err}");
}
#[test]
fn rejects_dim_mismatch() {
let inputs = vec![
(ep(0, "a"), unit_emb(4, &[(0, 1.0)])),
(ep(0, "b"), unit_emb(8, &[(0, 1.0)])),
];
let err = cluster_episodes(&inputs, &StewardConfig::default()).unwrap_err();
assert!(err.to_string().contains("dim"), "got: {err}");
}
#[test]
fn one_cluster_three_similar_one_outlier() {
let day_a = 1_700_000_000_000i64; let inputs = vec![
(ep(day_a, "a1"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 1000, "a2"), unit_emb(4, &[(0, 0.99), (1, 0.01)])),
(ep(day_a + 2000, "a3"), unit_emb(4, &[(0, 0.98), (1, 0.02)])),
(ep(day_a + 3000, "outlier"), unit_emb(4, &[(2, 1.0)])),
];
let r = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
assert_eq!(r.len(), 1);
assert_eq!(r[0].episode_ids.len(), 3);
assert!(r[0].coherence > 0.95, "coherence: {}", r[0].coherence);
let outlier_id = inputs[3].0.memory_id;
for c in &r {
assert!(!c.episode_ids.contains(&outlier_id));
}
}
#[test]
fn below_min_size_yields_no_cluster() {
let day_a = 1_700_000_000_000i64;
let inputs = vec![
(ep(day_a, "a1"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 1000, "a2"), unit_emb(4, &[(0, 1.0)])),
];
let r = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
assert!(r.is_empty(), "expected no cluster: {r:?}");
}
#[test]
fn same_content_different_days_dont_cluster() {
let day_a = 1_700_000_000_000i64;
let day_b = day_a + MS_PER_DAY * 3;
let inputs = vec![
(ep(day_a, "today"), unit_emb(4, &[(0, 1.0)])),
(ep(day_b, "three days later"), unit_emb(4, &[(0, 1.0)])),
];
let r = cluster_episodes(&inputs, &cfg(2, 0.85)).unwrap();
assert!(r.is_empty(), "should NOT cross-bucket: {r:?}");
}
#[test]
fn two_clusters_per_bucket_when_two_themes() {
let day_a = 1_700_000_000_000i64;
let inputs = vec![
(ep(day_a, "a1"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 1000, "a2"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 2000, "a3"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 3000, "b1"), unit_emb(4, &[(1, 1.0)])),
(ep(day_a + 4000, "b2"), unit_emb(4, &[(1, 1.0)])),
(ep(day_a + 5000, "b3"), unit_emb(4, &[(1, 1.0)])),
];
let r = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
assert_eq!(r.len(), 2);
let total: usize = r.iter().map(|c| c.episode_ids.len()).sum();
assert_eq!(total, 6);
assert!(r[0].episode_ids[0] == inputs[0].0.memory_id);
assert!(r[1].episode_ids[0] == inputs[3].0.memory_id);
}
#[test]
fn transitive_cluster_via_union_find() {
let day_a = 1_700_000_000_000i64;
let a = unit_emb(4, &[(0, 1.0)]);
let b = unit_emb(4, &[(0, 0.93), (1, 0.37)]); let c = unit_emb(4, &[(1, 1.0)]);
let inputs = vec![
(ep(day_a, "a"), a),
(ep(day_a + 1000, "b"), b),
(ep(day_a + 2000, "c"), c),
];
let r = cluster_episodes(&inputs, &cfg(3, 0.9)).unwrap();
assert!(r.is_empty(), "no cluster expected at threshold 0.9");
let r2 = cluster_episodes(
&[
(ep(day_a, "a"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 1000, "b"), unit_emb(4, &[(0, 0.93), (1, 0.37)])),
(ep(day_a + 2000, "c"), unit_emb(4, &[(1, 1.0)])),
],
&cfg(3, 0.3),
)
.unwrap();
assert_eq!(r2.len(), 1);
assert_eq!(r2[0].episode_ids.len(), 3);
}
#[test]
fn output_is_deterministic_modulo_cluster_id() {
let day_a = 1_700_000_000_000i64;
let inputs: Vec<(Episode, Embedding)> = vec![
(ep(day_a, "a1"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 1000, "a2"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 2000, "a3"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 3000, "b1"), unit_emb(4, &[(1, 1.0)])),
(ep(day_a + 4000, "b2"), unit_emb(4, &[(1, 1.0)])),
(ep(day_a + 5000, "b3"), unit_emb(4, &[(1, 1.0)])),
];
let r1 = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
let r2 = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
assert_eq!(r1.len(), r2.len());
for (a, b) in r1.iter().zip(r2.iter()) {
assert_eq!(a.episode_ids, b.episode_ids);
assert_eq!(a.coherence.to_bits(), b.coherence.to_bits());
}
}
fn cluster_with(
episode_ids: Vec<MemoryId>,
centroid_components: &[(usize, f32)],
coherence: f32,
dim: usize,
) -> Cluster {
Cluster {
cluster_id: MemoryId::new(),
episode_ids,
centroid: Some(unit_emb(dim, centroid_components)),
coherence,
}
}
#[test]
fn merge_empty_or_singleton_is_noop() {
let mut empty: Vec<Cluster> = Vec::new();
let n = merge_clusters_by_centroid(&mut empty, &cfg(3, 0.85)).unwrap();
assert_eq!(n, 0);
assert!(empty.is_empty());
let mut one = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
)];
let n = merge_clusters_by_centroid(&mut one, &cfg(3, 0.85)).unwrap();
assert_eq!(n, 0);
assert_eq!(one.len(), 1);
}
#[test]
fn merge_unrelated_clusters_no_op() {
let mut clusters = vec![
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
),
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(2, 1.0)],
0.92,
4,
),
];
let n = merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(n, 0);
assert_eq!(clusters.len(), 2);
}
#[test]
fn merge_two_similar_clusters_into_one() {
let big_ids = vec![MemoryId::new(), MemoryId::new(), MemoryId::new(), MemoryId::new()];
let small_ids = vec![MemoryId::new(), MemoryId::new(), MemoryId::new()];
let big_centroid = &[(0, 1.0)];
let small_centroid = &[(0, 0.99), (1, 0.01)];
let mut clusters = vec![
cluster_with(small_ids.clone(), small_centroid, 0.93, 4),
cluster_with(big_ids.clone(), big_centroid, 0.97, 4),
];
let absorbed = merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(absorbed, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].episode_ids.len(), big_ids.len() + small_ids.len());
let merged_set: std::collections::HashSet<_> =
clusters[0].episode_ids.iter().copied().collect();
for id in big_ids.iter().chain(small_ids.iter()) {
assert!(merged_set.contains(id), "missing {id}");
}
assert!(clusters[0].coherence < 0.97);
assert!(clusters[0].coherence > 0.93);
let v = clusters[0].centroid.as_ref().unwrap().as_f32_slice().unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "centroid norm: {norm}");
}
#[test]
fn merge_transitive_three_way() {
let a_ids = vec![MemoryId::new(), MemoryId::new(), MemoryId::new()];
let b_ids = vec![MemoryId::new(), MemoryId::new(), MemoryId::new()];
let c_ids = vec![MemoryId::new(), MemoryId::new(), MemoryId::new()];
let mut clusters = vec![
cluster_with(a_ids.clone(), &[(0, 1.0)], 0.95, 4),
cluster_with(b_ids.clone(), &[(0, 0.93), (1, 0.37)], 0.94, 4),
cluster_with(c_ids.clone(), &[(1, 1.0)], 0.95, 4),
];
let absorbed = merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.3)).unwrap();
assert_eq!(absorbed, 2);
assert_eq!(clusters.len(), 1);
assert_eq!(
clusters[0].episode_ids.len(),
a_ids.len() + b_ids.len() + c_ids.len()
);
}
#[test]
fn merge_below_threshold_keeps_separate() {
let mut clusters = vec![
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
),
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 0.7), (1, 0.7)],
0.92,
4,
),
];
let absorbed = merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(absorbed, 0);
assert_eq!(clusters.len(), 2);
}
#[test]
fn merge_survivor_picks_largest_cluster_id() {
let small_id_str = "00000000-0000-0000-0000-000000000001";
let big_id_str = "ffffffff-ffff-ffff-ffff-ffffffffffff";
let mut clusters = vec![
Cluster {
cluster_id: MemoryId::from_str(big_id_str).unwrap(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(0, 1.0)])),
coherence: 0.95,
},
Cluster {
cluster_id: MemoryId::from_str(small_id_str).unwrap(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(0, 0.99), (1, 0.01)])),
coherence: 0.95,
},
];
let absorbed = merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(absorbed, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].cluster_id.to_string(), small_id_str);
}
#[test]
fn merge_rejects_centroid_dim_mismatch() {
let mut clusters = vec![
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
),
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
8, ),
];
let err = merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.85)).unwrap_err();
assert!(err.to_string().contains("dim"), "got: {err}");
}
#[test]
fn merge_collapses_cross_day_clusters() {
let day_a = 1_700_000_000_000i64;
let day_b = day_a + MS_PER_DAY;
let inputs = vec![
(ep(day_a, "pa1"), unit_emb(4, &[(0, 1.0)])),
(ep(day_a + 1000, "pa2"), unit_emb(4, &[(0, 0.99), (1, 0.01)])),
(ep(day_a + 2000, "pa3"), unit_emb(4, &[(0, 0.98), (1, 0.02)])),
(ep(day_b, "pb1"), unit_emb(4, &[(0, 1.0)])),
(ep(day_b + 1000, "pb2"), unit_emb(4, &[(0, 0.99), (1, 0.01)])),
(ep(day_b + 2000, "pb3"), unit_emb(4, &[(0, 0.98), (1, 0.02)])),
];
let mut clusters = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
assert_eq!(clusters.len(), 2, "expected one cluster per day pre-merge");
let absorbed =
merge_clusters_by_centroid(&mut clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(absorbed, 1);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0].episode_ids.len(), 6);
}
#[test]
fn plan_empty_or_singleton_yields_empty_plan() {
let p = plan_existing_merges(&[], &cfg(3, 0.85)).unwrap();
assert!(p.merges.is_empty());
assert_eq!(p.absorbed(), 0);
let one = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
)];
let p = plan_existing_merges(&one, &cfg(3, 0.85)).unwrap();
assert!(p.merges.is_empty());
}
#[test]
fn plan_two_similar_yields_one_merge_op() {
let big_id = MemoryId::new();
let small_id = MemoryId::new();
let big_eps: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
let small_eps: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
let clusters = vec![
Cluster {
cluster_id: small_id,
episode_ids: small_eps.clone(),
centroid: Some(unit_emb(4, &[(0, 0.99), (1, 0.01)])),
coherence: 0.93,
},
Cluster {
cluster_id: big_id,
episode_ids: big_eps.clone(),
centroid: Some(unit_emb(4, &[(0, 1.0)])),
coherence: 0.97,
},
];
let plan = plan_existing_merges(&clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(plan.merges.len(), 1);
assert_eq!(plan.absorbed(), 1);
let op = &plan.merges[0];
assert_eq!(op.survivor_id, big_id);
assert_eq!(op.loser_ids, vec![small_id]);
assert_eq!(op.merged_episode_ids.len(), big_eps.len() + small_eps.len());
assert_eq!(clusters.len(), 2);
assert_eq!(clusters[0].cluster_id, small_id);
assert_eq!(clusters[1].cluster_id, big_id);
}
#[test]
fn plan_unrelated_yields_empty() {
let clusters = vec![
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
),
cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(2, 1.0)],
0.95,
4,
),
];
let plan = plan_existing_merges(&clusters, &cfg(3, 0.85)).unwrap();
assert!(plan.merges.is_empty());
}
#[test]
fn plan_three_way_transitive_one_op_two_losers() {
let a_eps: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
let b_eps: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
let c_eps: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
let clusters = vec![
cluster_with(a_eps.clone(), &[(0, 1.0)], 0.95, 4),
cluster_with(b_eps.clone(), &[(0, 0.93), (1, 0.37)], 0.94, 4),
cluster_with(c_eps.clone(), &[(1, 1.0)], 0.95, 4),
];
let plan = plan_existing_merges(&clusters, &cfg(3, 0.3)).unwrap();
assert_eq!(plan.merges.len(), 1, "transitive group → one op");
assert_eq!(plan.absorbed(), 2);
let op = &plan.merges[0];
assert_eq!(op.loser_ids.len(), 2);
assert_eq!(op.merged_episode_ids.len(), 9);
}
#[test]
fn plan_preserves_in_memory_mutation_equivalence() {
let big_eps: Vec<MemoryId> = (0..4).map(|_| MemoryId::new()).collect();
let small_eps: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
let big = Cluster {
cluster_id: MemoryId::new(),
episode_ids: big_eps.clone(),
centroid: Some(unit_emb(4, &[(0, 1.0)])),
coherence: 0.97,
};
let small = Cluster {
cluster_id: MemoryId::new(),
episode_ids: small_eps.clone(),
centroid: Some(unit_emb(4, &[(0, 0.99), (1, 0.01)])),
coherence: 0.93,
};
let plan = plan_existing_merges(&[small.clone(), big.clone()], &cfg(3, 0.85))
.unwrap();
assert_eq!(plan.merges.len(), 1);
let op = &plan.merges[0];
let mut mut_clusters = vec![small.clone(), big.clone()];
let absorbed =
merge_clusters_by_centroid(&mut mut_clusters, &cfg(3, 0.85)).unwrap();
assert_eq!(absorbed, 1);
assert_eq!(mut_clusters.len(), 1);
let post = &mut_clusters[0];
assert_eq!(op.survivor_id, post.cluster_id);
let mut a = op.merged_episode_ids.clone();
let mut b = post.episode_ids.clone();
a.sort();
b.sort();
assert_eq!(a, b);
let post_centroid = post.centroid.as_ref().unwrap();
assert_eq!(op.merged_centroid.data, post_centroid.data);
assert_eq!(op.merged_centroid.dim, post_centroid.dim);
assert_eq!(op.merged_coherence.to_bits(), post.coherence.to_bits());
}
fn summary(cluster_id: MemoryId, centroid: &[(usize, f32)], coherence: f32, count: usize, dim: usize) -> ExistingClusterSummary {
ExistingClusterSummary {
cluster_id,
centroid: unit_emb(dim, centroid),
coherence,
episode_count: count,
}
}
#[test]
fn absorb_empty_inputs_yield_empty_plan() {
let plan = absorb_into_existing(&[], &[], &cfg(3, 0.85)).unwrap();
assert!(plan.absorptions.is_empty());
let only_new = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
)];
let plan = absorb_into_existing(&only_new, &[], &cfg(3, 0.85)).unwrap();
assert!(plan.absorptions.is_empty());
let only_existing = vec![summary(MemoryId::new(), &[(0, 1.0)], 0.95, 3, 4)];
let plan = absorb_into_existing(&[], &only_existing, &cfg(3, 0.85)).unwrap();
assert!(plan.absorptions.is_empty());
}
#[test]
fn absorb_below_threshold_no_op() {
let new = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
)];
let existing = vec![summary(MemoryId::new(), &[(0, 0.7), (1, 0.7)], 0.92, 3, 4)];
let plan = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap();
assert!(plan.absorptions.is_empty());
}
#[test]
fn absorb_above_threshold_folds_into_existing() {
let new_id = MemoryId::new();
let existing_id = MemoryId::new();
let new_episode_ids = vec![MemoryId::new(), MemoryId::new(), MemoryId::new()];
let new = vec![Cluster {
cluster_id: new_id,
episode_ids: new_episode_ids.clone(),
centroid: Some(unit_emb(4, &[(0, 0.99), (1, 0.01)])),
coherence: 0.94,
}];
let existing = vec![summary(existing_id, &[(0, 1.0)], 0.97, 5, 4)];
let plan = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap();
assert_eq!(plan.absorptions.len(), 1);
let a = &plan.absorptions[0];
assert_eq!(a.new_cluster_id, new_id);
assert_eq!(a.existing_cluster_id, existing_id);
assert_eq!(a.absorbed_episode_ids, new_episode_ids);
let v = a.merged_centroid.as_f32_slice().unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "centroid norm: {norm}");
assert!(a.merged_coherence >= 0.94 && a.merged_coherence <= 0.97);
assert!(plan.was_absorbed(new_id));
assert_eq!(plan.modified_existing_ids(), vec![existing_id]);
}
#[test]
fn absorb_picks_largest_existing_on_tie() {
let new = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 0.99), (1, 0.01)],
0.95,
4,
)];
let small_id = MemoryId::new();
let big_id = MemoryId::new();
let existing = vec![
summary(small_id, &[(0, 1.0)], 0.95, 3, 4),
summary(big_id, &[(0, 1.0)], 0.95, 7, 4),
];
let plan = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap();
assert_eq!(plan.absorptions.len(), 1);
assert_eq!(plan.absorptions[0].existing_cluster_id, big_id);
}
#[test]
fn absorb_tie_break_smallest_cluster_id() {
let small_id = MemoryId::from_str("00000000-0000-0000-0000-000000000001").unwrap();
let big_id = MemoryId::from_str("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap();
let new = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
)];
let existing = vec![
summary(big_id, &[(0, 1.0)], 0.95, 3, 4),
summary(small_id, &[(0, 1.0)], 0.95, 3, 4),
];
let plan = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap();
assert_eq!(plan.absorptions.len(), 1);
assert_eq!(plan.absorptions[0].existing_cluster_id, small_id);
}
#[test]
fn absorb_multiple_new_into_same_existing_updates_state() {
let existing_id = MemoryId::new();
let n1 = Cluster {
cluster_id: MemoryId::new(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(0, 0.99), (1, 0.01)])),
coherence: 0.94,
};
let n2 = Cluster {
cluster_id: MemoryId::new(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(0, 0.97), (1, 0.03)])),
coherence: 0.93,
};
let new = vec![n1.clone(), n2.clone()];
let existing = vec![summary(existing_id, &[(0, 1.0)], 0.97, 5, 4)];
let plan = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap();
assert_eq!(plan.absorptions.len(), 2);
assert_eq!(plan.absorptions[0].new_cluster_id, n1.cluster_id);
assert_eq!(plan.absorptions[1].new_cluster_id, n2.cluster_id);
assert_eq!(plan.absorptions[0].existing_cluster_id, existing_id);
assert_eq!(plan.absorptions[1].existing_cluster_id, existing_id);
assert_eq!(plan.modified_existing_ids(), vec![existing_id]);
}
#[test]
fn absorb_partial_some_match_some_dont() {
let existing_id = MemoryId::new();
let n_match = Cluster {
cluster_id: MemoryId::new(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(0, 0.99), (1, 0.01)])),
coherence: 0.94,
};
let n_orth1 = Cluster {
cluster_id: MemoryId::new(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(2, 1.0)])),
coherence: 0.95,
};
let n_orth2 = Cluster {
cluster_id: MemoryId::new(),
episode_ids: vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
centroid: Some(unit_emb(4, &[(3, 1.0)])),
coherence: 0.95,
};
let new = vec![n_match.clone(), n_orth1.clone(), n_orth2.clone()];
let existing = vec![summary(existing_id, &[(0, 1.0)], 0.97, 5, 4)];
let plan = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap();
assert_eq!(plan.absorptions.len(), 1);
assert_eq!(plan.absorptions[0].new_cluster_id, n_match.cluster_id);
assert!(!plan.was_absorbed(n_orth1.cluster_id));
assert!(!plan.was_absorbed(n_orth2.cluster_id));
}
#[test]
fn absorb_rejects_dim_mismatch_in_existing() {
let new = vec![cluster_with(
vec![MemoryId::new(), MemoryId::new(), MemoryId::new()],
&[(0, 1.0)],
0.95,
4,
)];
let existing = vec![
summary(MemoryId::new(), &[(0, 1.0)], 0.95, 3, 4),
summary(MemoryId::new(), &[(0, 1.0)], 0.95, 3, 8), ];
let err = absorb_into_existing(&new, &existing, &cfg(3, 0.85)).unwrap_err();
assert!(err.to_string().contains("dim"), "got: {err}");
}
#[test]
fn centroid_is_unit_length() {
let day_a = 1_700_000_000_000i64;
let inputs = vec![
(ep(day_a, "a1"), unit_emb(8, &[(0, 1.0)])),
(ep(day_a + 1000, "a2"), unit_emb(8, &[(0, 1.0)])),
(ep(day_a + 2000, "a3"), unit_emb(8, &[(0, 1.0)])),
];
let r = cluster_episodes(&inputs, &cfg(3, 0.85)).unwrap();
assert_eq!(r.len(), 1);
let c = r[0].centroid.as_ref().unwrap();
let v = c.as_f32_slice().unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "centroid norm: {norm}");
}
}