use rusqlite::{Connection, params_from_iter};
use solo_core::{Cluster, Embedding, EmbeddingDtype, Error, MemoryId, Result};
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct MergeCandidateStats {
pub clusters_examined: usize,
pub merge_ops: usize,
pub clusters_would_absorb: usize,
}
pub fn count_existing_merge_candidates(
conn: &Connection,
expected_dim: usize,
config: &solo_steward::StewardConfig,
) -> Result<MergeCandidateStats> {
let sql = "SELECT c.cluster_id, c.centroid, c.centroid_dtype, c.centroid_dim,
c.coherence, ce.memory_id
FROM clusters c
JOIN cluster_episodes ce ON ce.cluster_id = c.cluster_id
WHERE c.centroid IS NOT NULL
AND c.centroid_dtype = 'f32'
AND c.centroid_dim = ?1
ORDER BY c.cluster_id, ce.memory_id";
let params: Vec<rusqlite::types::Value> =
vec![(expected_dim as i64).into()];
let mut stmt = conn.prepare(sql).map_err(|e| {
Error::storage(format!("prepare merge-candidate clusters: {e}"))
})?;
let rows = stmt
.query_map(params_from_iter(¶ms), |r| {
Ok((
r.get::<_, String>(0)?, r.get::<_, Vec<u8>>(1)?, r.get::<_, String>(2)?, r.get::<_, i64>(3)?, r.get::<_, f32>(4)?, r.get::<_, String>(5)?, ))
})
.map_err(|e| {
Error::storage(format!("query_map merge-candidate clusters: {e}"))
})?;
let mut clusters: Vec<Cluster> = Vec::new();
for row in rows {
let (cid_s, centroid_bytes, dtype_s, dim_i, coherence, memid_s) =
row.map_err(|e| {
Error::storage(format!("merge-candidate row decode: {e}"))
})?;
if dtype_s != "f32" || (dim_i as usize) != expected_dim {
continue;
}
let cluster_id = match MemoryId::from_str(&cid_s) {
Ok(id) => id,
Err(e) => {
tracing::warn!(
cluster_id = %cid_s,
error = %e,
"skipping cluster with unparseable cluster_id (merge-candidate count)"
);
continue;
}
};
let memory_id = match MemoryId::from_str(&memid_s) {
Ok(id) => id,
Err(e) => {
tracing::warn!(
memory_id = %memid_s,
error = %e,
"skipping cluster_episodes row with unparseable memory_id (merge-candidate count)"
);
continue;
}
};
if clusters.last().map(|c| c.cluster_id) == Some(cluster_id) {
clusters.last_mut().unwrap().episode_ids.push(memory_id);
} else {
let centroid = Embedding {
dtype: EmbeddingDtype::F32,
dim: dim_i as usize,
data: centroid_bytes,
};
clusters.push(Cluster {
cluster_id,
episode_ids: vec![memory_id],
centroid: Some(centroid),
coherence,
});
}
}
clusters.retain(|c| !c.episode_ids.is_empty());
if clusters.len() < 2 {
return Ok(MergeCandidateStats {
clusters_examined: clusters.len(),
merge_ops: 0,
clusters_would_absorb: 0,
});
}
let plan =
solo_steward::cluster::plan_existing_merges(&clusters, config)
.map_err(|e| {
Error::storage(format!("plan_existing_merges (count): {e}"))
})?;
Ok(MergeCandidateStats {
clusters_examined: clusters.len(),
merge_ops: plan.merges.len(),
clusters_would_absorb: plan.absorbed(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::open_test_db;
#[test]
fn empty_db_returns_zero_stats() {
let (conn, _tmp) = open_test_db();
let cfg = solo_steward::StewardConfig::default();
let stats =
count_existing_merge_candidates(&conn, 384, &cfg).unwrap();
assert_eq!(stats, MergeCandidateStats::default());
}
}