use crate::error::Result;
use crate::graph::links;
use crate::store::{categories, embeddings, implicit};
use crate::types::*;
use rusqlite::Connection;
use std::collections::HashMap;
const MAX_IMPRESSION_AGE_SECS: i64 = 90 * 24 * 3600;
const PREFERENCE_HALF_LIFE_SECS: i64 = 30 * 24 * 3600;
const LINK_PRUNE_THRESHOLD: f32 = 0.02;
const MIN_PREFERENCE_CONFIDENCE: f32 = 0.05;
const DEDUP_SIMILARITY_THRESHOLD: f32 = 0.95;
const CATEGORY_CLUSTER_THRESHOLD: f32 = 0.7;
const MIN_CLUSTER_SIZE: usize = 3;
const CATEGORY_MERGE_THRESHOLD: f32 = 0.85;
const CATEGORY_DISSOLVE_THRESHOLD: f32 = 0.1;
const SPLIT_MIN_MEMBERS: usize = 8;
const SPLIT_COHERENCE_THRESHOLD: f32 = 0.6;
pub fn transform(conn: &Connection) -> Result<TransformationReport> {
let mut report = TransformationReport {
duplicates_merged: dedup_semantic_nodes(conn)?,
links_decayed: links::decay_links(conn, 0.95)? as u32,
links_pruned: links::prune_weak_links(conn, LINK_PRUNE_THRESHOLD)? as u32,
..Default::default()
};
let now = crate::db::now();
report.preferences_decayed =
implicit::decay_preferences(conn, now, PREFERENCE_HALF_LIFE_SECS)? as u32;
report.preferences_decayed +=
implicit::prune_weak_preferences(conn, MIN_PREFERENCE_CONFIDENCE)? as u32;
report.impressions_pruned =
implicit::prune_old_impressions(conn, MAX_IMPRESSION_AGE_SECS)? as u32;
report.categories_discovered = discover_categories(conn)?;
let (merged, dissolved) = maintain_categories(conn)?;
report.categories_merged = merged;
report.categories_dissolved = dissolved;
report.categories_split = split_large_categories(conn)?;
Ok(report)
}
fn dedup_semantic_nodes(conn: &Connection) -> Result<u32> {
let mut stmt =
conn.prepare("SELECT node_id, embedding FROM embeddings WHERE node_type = 'semantic'")?;
let nodes: Vec<(i64, Vec<f32>)> = stmt
.query_map([], |row| {
let id: i64 = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
Ok((id, embeddings::deserialize_embedding(&blob)))
})?
.filter_map(|r| r.ok())
.collect();
let mut merged = 0u32;
let mut deleted_ids: std::collections::HashSet<i64> = std::collections::HashSet::new();
for i in 0..nodes.len() {
if deleted_ids.contains(&nodes[i].0) {
continue;
}
for j in (i + 1)..nodes.len() {
if deleted_ids.contains(&nodes[j].0) {
continue;
}
let sim = embeddings::cosine_similarity(&nodes[i].1, &nodes[j].1);
if sim >= DEDUP_SIMILARITY_THRESHOLD {
conn.execute(
"UPDATE OR IGNORE links SET source_id = ?1 WHERE source_type = 'semantic' AND source_id = ?2",
[nodes[i].0, nodes[j].0],
)?;
conn.execute(
"UPDATE OR IGNORE links SET target_id = ?1 WHERE target_type = 'semantic' AND target_id = ?2",
[nodes[i].0, nodes[j].0],
)?;
conn.execute(
"DELETE FROM links WHERE source_type = 'semantic' AND source_id = ?1",
[nodes[j].0],
)?;
conn.execute(
"DELETE FROM links WHERE target_type = 'semantic' AND target_id = ?1",
[nodes[j].0],
)?;
conn.execute(
"UPDATE semantic_nodes SET corroboration_count = corroboration_count + 1 WHERE id = ?1",
[nodes[i].0],
)?;
crate::store::semantic::delete_node(conn, NodeId(nodes[j].0))?;
deleted_ids.insert(nodes[j].0);
merged += 1;
}
}
}
Ok(merged)
}
fn discover_categories(conn: &Connection) -> Result<u32> {
let uncategorized = categories::get_uncategorized_node_ids(conn)?;
if uncategorized.len() < MIN_CLUSTER_SIZE {
return Ok(0);
}
let mut nodes_with_emb: Vec<(NodeId, Vec<f32>)> = Vec::new();
for node_id in &uncategorized {
if let Some(emb) = embeddings::get_embedding(conn, "semantic", node_id.0)? {
nodes_with_emb.push((*node_id, emb));
}
}
if nodes_with_emb.len() < MIN_CLUSTER_SIZE {
return Ok(0);
}
let n = nodes_with_emb.len();
let mut parent: Vec<usize> = (0..n).collect();
fn find(parent: &mut Vec<usize>, i: usize) -> usize {
if parent[i] != i {
parent[i] = find(parent, parent[i]);
}
parent[i]
}
fn union(parent: &mut Vec<usize>, a: usize, b: usize) {
let ra = find(parent, a);
let rb = find(parent, b);
if ra != rb {
parent[rb] = ra;
}
}
for i in 0..n {
for j in (i + 1)..n {
let sim = embeddings::cosine_similarity(&nodes_with_emb[i].1, &nodes_with_emb[j].1);
if sim >= CATEGORY_CLUSTER_THRESHOLD {
union(&mut parent, i, j);
}
}
}
let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = find(&mut parent, i);
clusters.entry(root).or_default().push(i);
}
let mut categories_created = 0u32;
for members in clusters.values() {
if members.len() < MIN_CLUSTER_SIZE {
continue;
}
let mut best_idx = members[0];
let mut best_corr: i64 = 0;
for &idx in members {
let corr: i64 = conn
.query_row(
"SELECT COALESCE(corroboration_count, 0) FROM semantic_nodes WHERE id = ?1",
[nodes_with_emb[idx].0 .0],
|row| row.get(0),
)
.unwrap_or(0);
if corr > best_corr {
best_corr = corr;
best_idx = idx;
}
}
let prototype_id = nodes_with_emb[best_idx].0;
let dim = nodes_with_emb[members[0]].1.len();
let mut centroid = vec![0.0f32; dim];
for &idx in members {
for (d, val) in nodes_with_emb[idx].1.iter().enumerate() {
centroid[d] += val;
}
}
let count = members.len() as f32;
for val in &mut centroid {
*val /= count;
}
let label: String = conn
.query_row(
"SELECT content FROM semantic_nodes WHERE id = ?1",
[prototype_id.0],
|row| row.get::<_, String>(0),
)
.unwrap_or_default()
.split_whitespace()
.take(3)
.collect::<Vec<&str>>()
.join(" ");
let label = if label.is_empty() {
format!("cluster-{categories_created}")
} else {
label
};
let cat_id = categories::store_category(conn, &label, prototype_id, Some(¢roid), None)?;
for &idx in members {
let member_id = nodes_with_emb[idx].0;
categories::assign_node_to_category(conn, member_id, cat_id)?;
}
categories_created += 1;
}
Ok(categories_created)
}
fn maintain_categories(conn: &Connection) -> Result<(u32, u32)> {
let mut merged_count = 0u32;
let mut dissolved_count = 0u32;
let all_cats = categories::list_categories(conn, None)?;
for cat in &all_cats {
if cat.member_count > 0 {
categories::increment_stability(conn, cat.id)?;
}
}
let all_cats = categories::list_categories(conn, None)?;
for cat in &all_cats {
if cat.member_count == 0 {
categories::delete_category(conn, cat.id)?;
dissolved_count += 1;
}
}
let cats = categories::list_categories(conn, None)?;
let mut deleted: std::collections::HashSet<i64> = std::collections::HashSet::new();
let len = cats.len();
for i in 0..len {
if deleted.contains(&cats[i].id.0) {
continue;
}
for j in (i + 1)..len {
if !deleted.contains(&cats[j].id.0) {
if let (Some(ref ci), Some(ref cj)) =
(&cats[i].centroid_embedding, &cats[j].centroid_embedding)
{
let sim = embeddings::cosine_similarity(ci, cj);
if sim > CATEGORY_MERGE_THRESHOLD {
let keep_id = cats[i].id;
let lose_id = cats[j].id;
conn.execute(
"UPDATE semantic_nodes SET category_id = ?1 WHERE category_id = ?2",
[keep_id.0, lose_id.0],
)?;
let total: i64 = conn.query_row(
"SELECT COUNT(*) FROM semantic_nodes WHERE category_id = ?1",
[keep_id.0],
|row| row.get(0),
)?;
conn.execute(
"UPDATE categories SET member_count = ?1 WHERE id = ?2",
rusqlite::params![total, keep_id.0],
)?;
let mut stmt = conn.prepare(
"SELECT e.embedding FROM embeddings e
INNER JOIN semantic_nodes sn ON sn.id = e.node_id AND e.node_type = 'semantic'
WHERE sn.category_id = ?1",
)?;
let embs: Vec<Vec<f32>> = stmt
.query_map([keep_id.0], |row| {
let blob: Vec<u8> = row.get(0)?;
Ok(embeddings::deserialize_embedding(&blob))
})?
.filter_map(|r| r.ok())
.collect();
if !embs.is_empty() {
let dim = embs[0].len();
let mut new_centroid = vec![0.0f32; dim];
for emb in &embs {
for (d, val) in emb.iter().enumerate() {
new_centroid[d] += val;
}
}
let c = embs.len() as f32;
for val in &mut new_centroid {
*val /= c;
}
categories::update_centroid(conn, keep_id, &new_centroid)?;
}
conn.execute(
"UPDATE OR IGNORE links SET target_id = ?1 WHERE target_type = 'category' AND target_id = ?2 AND link_type = 'member_of'",
[keep_id.0, lose_id.0],
)?;
conn.execute(
"UPDATE OR IGNORE links SET source_id = ?1 WHERE source_type = 'category' AND source_id = ?2 AND link_type = 'member_of'",
[keep_id.0, lose_id.0],
)?;
conn.execute(
"DELETE FROM links WHERE link_type = 'member_of' AND ((target_type = 'category' AND target_id = ?1) OR (source_type = 'category' AND source_id = ?1))",
[lose_id.0],
)?;
conn.execute("DELETE FROM categories WHERE id = ?1", [lose_id.0])?;
deleted.insert(lose_id.0);
merged_count += 1;
}
}
}
}
}
let cats = categories::list_categories(conn, None)?;
for cat in &cats {
if !deleted.contains(&cat.id.0)
&& cat.stability < CATEGORY_DISSOLVE_THRESHOLD
&& cat.stability > 0.0
{
categories::delete_category(conn, cat.id)?;
dissolved_count += 1;
}
}
Ok((merged_count, dissolved_count))
}
fn split_large_categories(conn: &Connection) -> Result<u32> {
let all_cats = categories::list_categories(conn, None)?;
let mut splits = 0u32;
for cat in &all_cats {
if (cat.member_count as usize) < SPLIT_MIN_MEMBERS {
continue;
}
let centroid = match &cat.centroid_embedding {
Some(c) => c,
None => continue,
};
let member_ids: Vec<NodeId> = conn
.prepare("SELECT id FROM semantic_nodes WHERE category_id = ?1")?
.query_map([cat.id.0], |row| Ok(NodeId(row.get(0)?)))?
.filter_map(|r| r.ok())
.collect();
let mut members_with_emb: Vec<(NodeId, Vec<f32>)> = Vec::new();
for nid in &member_ids {
if let Some(emb) = embeddings::get_embedding(conn, "semantic", nid.0)? {
members_with_emb.push((*nid, emb));
}
}
if members_with_emb.len() < SPLIT_MIN_MEMBERS {
continue;
}
let total_sim: f32 = members_with_emb
.iter()
.map(|(_, emb)| embeddings::cosine_similarity(emb, centroid))
.sum();
let coherence = total_sim / members_with_emb.len() as f32;
if coherence >= SPLIT_COHERENCE_THRESHOLD {
continue; }
let n = members_with_emb.len();
let mut parent_uf: Vec<usize> = (0..n).collect();
fn find(parent: &mut Vec<usize>, i: usize) -> usize {
if parent[i] != i {
parent[i] = find(parent, parent[i]);
}
parent[i]
}
fn union(parent: &mut Vec<usize>, a: usize, b: usize) {
let ra = find(parent, a);
let rb = find(parent, b);
if ra != rb {
parent[rb] = ra;
}
}
for i in 0..n {
for j in (i + 1)..n {
let sim =
embeddings::cosine_similarity(&members_with_emb[i].1, &members_with_emb[j].1);
if sim >= CATEGORY_CLUSTER_THRESHOLD {
union(&mut parent_uf, i, j);
}
}
}
let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
let root = find(&mut parent_uf, i);
clusters.entry(root).or_default().push(i);
}
let valid_clusters: Vec<&Vec<usize>> = clusters
.values()
.filter(|c| c.len() >= MIN_CLUSTER_SIZE)
.collect();
if valid_clusters.len() < 2 {
continue; }
for cluster in valid_clusters {
let proto_idx = cluster[0];
let proto_node = members_with_emb[proto_idx].0;
let label_content: String = conn
.query_row(
"SELECT content FROM semantic_nodes WHERE id = ?1",
[proto_node.0],
|row| row.get(0),
)
.unwrap_or_else(|_| "sub-category".to_string());
let sub_label = if label_content.len() > 40 {
&label_content[..40]
} else {
&label_content
};
let dim = members_with_emb[0].1.len();
let mut sub_centroid = vec![0.0f32; dim];
for &idx in cluster {
for (d, val) in members_with_emb[idx].1.iter().enumerate() {
sub_centroid[d] += val;
}
}
for val in sub_centroid.iter_mut().take(dim) {
*val /= cluster.len() as f32;
}
let sub_id = categories::store_category(
conn,
sub_label,
proto_node,
Some(&sub_centroid),
Some(cat.id),
)?;
for &idx in cluster {
let nid = members_with_emb[idx].0;
categories::assign_node_to_category(conn, nid, sub_id)?;
}
}
splits += 1;
}
Ok(splits)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::open_memory_db;
#[test]
fn test_transform_empty_db() {
let conn = open_memory_db().unwrap();
let report = transform(&conn).unwrap();
assert_eq!(report.duplicates_merged, 0);
assert_eq!(report.links_pruned, 0);
}
#[test]
fn test_transform_decays_link_weights() {
let conn = open_memory_db().unwrap();
links::create_link(
&conn,
NodeRef::Episode(EpisodeId(1)),
NodeRef::Episode(EpisodeId(2)),
LinkType::CoRetrieval,
0.5,
)
.unwrap();
let report = transform(&conn).unwrap();
assert!(report.links_decayed > 0, "should report decayed links");
let remaining = links::get_links_from(&conn, NodeRef::Episode(EpisodeId(1))).unwrap();
assert_eq!(remaining.len(), 1);
assert!(
remaining[0].forward_weight < 0.5,
"weight should have decreased from 0.5, got {}",
remaining[0].forward_weight
);
}
#[test]
fn test_transform_prunes_weak_links() {
let conn = open_memory_db().unwrap();
links::create_link(
&conn,
NodeRef::Episode(EpisodeId(1)),
NodeRef::Episode(EpisodeId(2)),
LinkType::Temporal,
0.01,
)
.unwrap();
let report = transform(&conn).unwrap();
assert_eq!(report.links_pruned, 1);
}
#[test]
fn test_transform_discovers_categories() {
let conn = open_memory_db().unwrap();
let test_embs: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.8, 0.5, 0.0, 0.0],
vec![0.7, 0.3, 0.5, 0.0],
vec![0.9, 0.2, 0.1, 0.3],
];
for (i, emb) in test_embs.iter().enumerate() {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("cooking recipe {i}")],
).unwrap();
let node_id: i64 = conn
.query_row("SELECT last_insert_rowid()", [], |r| r.get(0))
.unwrap();
embeddings::store_embedding(&conn, "semantic", node_id, emb, "").unwrap();
}
let report = transform(&conn).unwrap();
assert!(
report.categories_discovered >= 1,
"should discover at least 1 category from 4 similar nodes, got {}",
report.categories_discovered
);
let cats = categories::list_categories(&conn, None).unwrap();
assert!(!cats.is_empty(), "should have created categories");
assert!(
cats[0].member_count >= 3,
"category should have at least 3 members, got {}",
cats[0].member_count
);
}
#[test]
fn test_transform_no_categories_with_few_nodes() {
let conn = open_memory_db().unwrap();
for i in 0..2 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated)
VALUES (?1, 'fact', 0.8, 1000, 1000)",
[format!("node {i}")],
).unwrap();
let node_id: i64 = conn
.query_row("SELECT last_insert_rowid()", [], |r| r.get(0))
.unwrap();
embeddings::store_embedding(&conn, "semantic", node_id, &[0.9, 0.1, 0.0], "").unwrap();
}
let report = transform(&conn).unwrap();
assert_eq!(report.categories_discovered, 0);
}
#[test]
fn test_transform_gc_empty_categories() {
let conn = open_memory_db().unwrap();
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated)
VALUES ('dummy', 'fact', 0.5, 1000, 1000)",
[],
).unwrap();
categories::store_category(&conn, "empty-cat", NodeId(1), None, None).unwrap();
assert_eq!(categories::count_categories(&conn).unwrap(), 1);
let report = transform(&conn).unwrap();
assert_eq!(
categories::count_categories(&conn).unwrap(),
0,
"empty category should have been garbage-collected"
);
assert!(
report.categories_dissolved >= 1,
"should report at least 1 dissolved category"
);
}
#[test]
fn test_discover_categories_creates_member_of_links() {
let conn = open_memory_db().unwrap();
for i in 0..3 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("topic alpha {i}")],
).unwrap();
let node_id: i64 = conn
.query_row("SELECT last_insert_rowid()", [], |r| r.get(0))
.unwrap();
embeddings::store_embedding(&conn, "semantic", node_id, &[1.0, 0.0, 0.0], "").unwrap();
}
let created = discover_categories(&conn).unwrap();
assert_eq!(created, 1);
let link_count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM links WHERE link_type = 'member_of'",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(
link_count, 6,
"should have 6 bidirectional MemberOf links (2 per node)"
);
}
#[test]
fn test_split_triggers_when_large_and_incoherent() {
let conn = open_memory_db().unwrap();
let test_embs: Vec<Vec<f32>> = vec![
vec![
0.45, 0.45, 0.45, 0.45, 0.45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
vec![
0.45, 0.45, 0.45, 0.45, 0.0, 0.45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
vec![
0.45, 0.45, 0.45, 0.45, 0.0, 0.0, 0.45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
vec![
0.45, 0.45, 0.45, 0.45, 0.0, 0.0, 0.0, 0.45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
vec![
0.45, 0.45, 0.45, 0.45, 0.31, 0.31, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0,
],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.45, 0.45, 0.45, 0.45, 0.45, 0.0, 0.0, 0.0,
],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.45, 0.45, 0.45, 0.45, 0.0, 0.45, 0.0, 0.0,
],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.45, 0.45, 0.45, 0.45, 0.0, 0.0, 0.45, 0.0,
],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.45, 0.45, 0.45, 0.45, 0.0, 0.0, 0.0, 0.45,
],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.45, 0.45, 0.45, 0.45, 0.31, 0.31, 0.0,
0.0,
],
];
let mut node_ids = Vec::new();
for (i, emb) in test_embs.iter().enumerate() {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
rusqlite::params![format!("split node {i}")],
)
.unwrap();
let nid = NodeId(conn.last_insert_rowid());
embeddings::store_embedding(&conn, "semantic", nid.0, emb, "").unwrap();
node_ids.push(nid);
}
let cat_id = categories::store_category(&conn, "broad", node_ids[0], None, None).unwrap();
let centroid = vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.7,
];
categories::update_centroid(&conn, cat_id, ¢roid).unwrap();
for &nid in &node_ids {
categories::assign_node_to_category(&conn, nid, cat_id).unwrap();
}
let report = transform(&conn).unwrap();
assert!(
report.categories_split > 0,
"should have split the broad category"
);
let subs = categories::get_subcategories(&conn, cat_id).unwrap();
assert!(
subs.len() >= 2,
"should have at least 2 sub-categories, got {}",
subs.len()
);
}
#[test]
fn test_no_split_when_coherent() {
let conn = open_memory_db().unwrap();
let coherent_embs: Vec<Vec<f32>> = vec![
vec![0.9, 0.3, 0.1, 0.0],
vec![0.85, 0.35, 0.15, 0.0],
vec![0.88, 0.28, 0.12, 0.05],
vec![0.92, 0.25, 0.08, 0.0],
vec![0.87, 0.32, 0.14, 0.02],
vec![0.91, 0.27, 0.11, 0.03],
vec![0.86, 0.34, 0.13, 0.01],
vec![0.89, 0.30, 0.10, 0.04],
];
let mut node_ids = Vec::new();
for (i, emb) in coherent_embs.iter().enumerate() {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
rusqlite::params![format!("coherent node {i}")],
)
.unwrap();
let nid = NodeId(conn.last_insert_rowid());
embeddings::store_embedding(&conn, "semantic", nid.0, emb, "").unwrap();
node_ids.push(nid);
}
let cat_id = categories::store_category(&conn, "tight", node_ids[0], None, None).unwrap();
let centroid = vec![0.885, 0.305, 0.116, 0.019]; categories::update_centroid(&conn, cat_id, ¢roid).unwrap();
for &nid in &node_ids {
categories::assign_node_to_category(&conn, nid, cat_id).unwrap();
}
let report = transform(&conn).unwrap();
assert_eq!(
report.categories_split, 0,
"should NOT split coherent category"
);
}
#[test]
fn test_no_split_when_small() {
let conn = open_memory_db().unwrap();
let mut node_ids = Vec::new();
for i in 0..5 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
rusqlite::params![format!("small node {}", i)],
)
.unwrap();
let nid = NodeId(conn.last_insert_rowid());
let emb = if i < 3 {
vec![1.0, 0.0, 0.0, 0.0]
} else {
vec![0.0, 0.0, 1.0, 0.0]
};
embeddings::store_embedding(&conn, "semantic", nid.0, &emb, "").unwrap();
node_ids.push(nid);
}
let cat_id = categories::store_category(&conn, "small", node_ids[0], None, None).unwrap();
let centroid = vec![0.5, 0.0, 0.5, 0.0];
categories::update_centroid(&conn, cat_id, ¢roid).unwrap();
for &nid in &node_ids {
categories::assign_node_to_category(&conn, nid, cat_id).unwrap();
}
let report = transform(&conn).unwrap();
assert_eq!(
report.categories_split, 0,
"should NOT split small category"
);
}
#[test]
fn test_dedup_merges_near_identical_nodes() {
let conn = open_memory_db().unwrap();
for i in 0..2 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("duplicate node {i}")],
)
.unwrap();
let node_id: i64 = conn
.query_row("SELECT last_insert_rowid()", [], |r| r.get(0))
.unwrap();
embeddings::store_embedding(&conn, "semantic", node_id, &[1.0, 0.0, 0.0], "").unwrap();
}
let merged = dedup_semantic_nodes(&conn).unwrap();
assert_eq!(merged, 1, "should have merged the duplicate");
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM semantic_nodes", [], |r| r.get(0))
.unwrap();
assert_eq!(count, 1, "should have 1 node after dedup");
let corr: i64 = conn
.query_row(
"SELECT corroboration_count FROM semantic_nodes WHERE id = 1",
[],
|r| r.get(0),
)
.unwrap();
assert_eq!(corr, 2, "kept node should have corroboration_count = 2");
}
#[test]
fn test_dedup_skips_already_deleted() {
let conn = open_memory_db().unwrap();
for i in 0..3 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("triple dup {i}")],
)
.unwrap();
let node_id: i64 = conn
.query_row("SELECT last_insert_rowid()", [], |r| r.get(0))
.unwrap();
embeddings::store_embedding(&conn, "semantic", node_id, &[1.0, 0.0, 0.0], "").unwrap();
}
let merged = dedup_semantic_nodes(&conn).unwrap();
assert_eq!(merged, 2, "should have merged both duplicates");
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM semantic_nodes", [], |r| r.get(0))
.unwrap();
assert_eq!(
count, 1,
"should have 1 node after dedup of 3 identical nodes"
);
}
#[test]
fn test_maintain_categories_merges_converging() {
let conn = open_memory_db().unwrap();
for i in 0..2 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated)
VALUES (?1, 'fact', 0.8, 1000, 1000)",
[format!("merge-node {i}")],
).unwrap();
}
let c1 =
categories::store_category(&conn, "cat-a", NodeId(1), Some(&[1.0, 0.0, 0.0]), None)
.unwrap();
let c2 =
categories::store_category(&conn, "cat-b", NodeId(2), Some(&[0.99, 0.01, 0.0]), None)
.unwrap();
categories::assign_node_to_category(&conn, NodeId(1), c1).unwrap();
categories::assign_node_to_category(&conn, NodeId(2), c2).unwrap();
embeddings::store_embedding(&conn, "semantic", 1, &[1.0, 0.0, 0.0], "").unwrap();
embeddings::store_embedding(&conn, "semantic", 2, &[0.99, 0.01, 0.0], "").unwrap();
let (merged, _dissolved) = maintain_categories(&conn).unwrap();
assert!(
merged >= 1,
"should have merged converging categories, got {merged}",
);
assert_eq!(
categories::count_categories(&conn).unwrap(),
1,
"should have 1 category after merge"
);
}
#[test]
fn test_discover_categories_few_embeddings() {
let conn = open_memory_db().unwrap();
for i in 0..5 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("sparse embed node {i}")],
).unwrap();
}
embeddings::store_embedding(&conn, "semantic", 1, &[1.0, 0.0, 0.0], "").unwrap();
embeddings::store_embedding(&conn, "semantic", 2, &[0.9, 0.1, 0.0], "").unwrap();
let discovered = discover_categories(&conn).unwrap();
assert_eq!(
discovered, 0,
"should not discover categories with < 3 embedded nodes"
);
}
#[test]
fn test_discover_categories_small_cluster_skipped() {
let conn = open_memory_db().unwrap();
for i in 0..3 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("cluster node {i}")],
).unwrap();
let nid = conn.last_insert_rowid();
embeddings::store_embedding(
&conn,
"semantic",
nid,
&[1.0, 0.0 + (i as f32) * 0.01, 0.0],
"",
)
.unwrap();
}
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES ('outlier a', 'fact', 0.8, 1000, 1000, 1)",
[],
).unwrap();
embeddings::store_embedding(
&conn,
"semantic",
conn.last_insert_rowid(),
&[0.0, 1.0, 0.0],
"",
)
.unwrap();
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES ('outlier b', 'fact', 0.8, 1000, 1000, 1)",
[],
).unwrap();
embeddings::store_embedding(
&conn,
"semantic",
conn.last_insert_rowid(),
&[0.0, 0.0, 1.0],
"",
)
.unwrap();
let discovered = discover_categories(&conn).unwrap();
assert_eq!(
discovered, 1,
"should create 1 category from tight cluster, skip outliers"
);
}
#[test]
fn test_discover_categories_empty_label_fallback() {
let conn = open_memory_db().unwrap();
for i in 0..3 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES ('', 'fact', 0.8, 1000, 1000, 1)",
[],
).unwrap();
let nid = conn.last_insert_rowid();
embeddings::store_embedding(
&conn,
"semantic",
nid,
&[1.0, 0.0 + (i as f32) * 0.01, 0.0],
"",
)
.unwrap();
}
let discovered = discover_categories(&conn).unwrap();
assert_eq!(discovered, 1);
let cats = categories::list_categories(&conn, None).unwrap();
assert_eq!(cats.len(), 1);
assert!(
cats[0].label.starts_with("cluster-"),
"empty content should produce fallback label, got: {}",
cats[0].label
);
}
#[test]
fn test_maintain_categories_merge_lower_stability_wins() {
let conn = open_memory_db().unwrap();
for i in 0..2 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated)
VALUES (?1, 'fact', 0.8, 1000, 1000)",
[format!("stability-node {i}")],
).unwrap();
}
let c1 = categories::store_category(
&conn,
"low-stability",
NodeId(1),
Some(&[1.0, 0.0, 0.0]),
None,
)
.unwrap();
let c2 = categories::store_category(
&conn,
"high-stability",
NodeId(2),
Some(&[0.99, 0.01, 0.0]),
None,
)
.unwrap();
conn.execute(
"UPDATE categories SET stability = 0.3 WHERE id = ?1",
[c1.0],
)
.unwrap();
conn.execute(
"UPDATE categories SET stability = 0.9 WHERE id = ?1",
[c2.0],
)
.unwrap();
categories::assign_node_to_category(&conn, NodeId(1), c1).unwrap();
categories::assign_node_to_category(&conn, NodeId(2), c2).unwrap();
embeddings::store_embedding(&conn, "semantic", 1, &[1.0, 0.0, 0.0], "").unwrap();
embeddings::store_embedding(&conn, "semantic", 2, &[0.99, 0.01, 0.0], "").unwrap();
let (merged, _dissolved) = maintain_categories(&conn).unwrap();
assert!(merged >= 1, "should merge converging categories");
let remaining = categories::list_categories(&conn, None).unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(
remaining[0].label, "high-stability",
"should keep the higher-stability category"
);
}
#[test]
fn test_split_no_centroid_skips() {
let conn = open_memory_db().unwrap();
for i in 0..10 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("no-centroid node {i}")],
).unwrap();
}
let cat_id =
categories::store_category(&conn, "no-centroid-cat", NodeId(1), None, None).unwrap();
for i in 1..=10 {
categories::assign_node_to_category(&conn, NodeId(i), cat_id).unwrap();
embeddings::store_embedding(&conn, "semantic", i, &[1.0, 0.0, 0.0], "").unwrap();
}
let splits = split_large_categories(&conn).unwrap();
assert_eq!(splits, 0, "should skip category without centroid");
}
#[test]
fn test_dissolve_unstable_category_negative_initial_stability() {
let conn = open_memory_db().unwrap();
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES ('node', 'fact', 0.8, 1000, 1000, 1)",
[],
)
.unwrap();
let cat_id =
categories::store_category(&conn, "will-dissolve", NodeId(1), None, None).unwrap();
conn.execute(
"UPDATE categories SET stability = -0.1 WHERE id = ?1",
[cat_id.0],
)
.unwrap();
categories::assign_node_to_category(&conn, NodeId(1), cat_id).unwrap();
let (_, dissolved) = maintain_categories(&conn).unwrap();
assert_eq!(
dissolved, 1,
"category with post-increment stability 0.01 should be dissolved"
);
let remaining = categories::list_categories(&conn, None).unwrap();
assert!(remaining.is_empty(), "dissolved category should be removed");
}
#[test]
fn test_merge_skips_already_deleted_inner() {
let conn = open_memory_db().unwrap();
for i in 1..=3 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("triple-merge node {i}")],
)
.unwrap();
}
let emb = vec![1.0f32, 0.0, 0.0];
let c1 = categories::store_category(&conn, "cat-1", NodeId(1), Some(&emb), None).unwrap();
let c2 = categories::store_category(&conn, "cat-2", NodeId(2), Some(&emb), None).unwrap();
let c3 = categories::store_category(&conn, "cat-3", NodeId(3), Some(&emb), None).unwrap();
conn.execute(
"UPDATE categories SET stability = 0.5 WHERE id IN (?1, ?2, ?3)",
rusqlite::params![c1.0, c2.0, c3.0],
)
.unwrap();
categories::assign_node_to_category(&conn, NodeId(1), c1).unwrap();
categories::assign_node_to_category(&conn, NodeId(2), c2).unwrap();
categories::assign_node_to_category(&conn, NodeId(3), c3).unwrap();
embeddings::store_embedding(&conn, "semantic", 1, &[1.0, 0.0, 0.0], "").unwrap();
embeddings::store_embedding(&conn, "semantic", 2, &[1.0, 0.0, 0.0], "").unwrap();
embeddings::store_embedding(&conn, "semantic", 3, &[1.0, 0.0, 0.0], "").unwrap();
let (merged, _dissolved) = maintain_categories(&conn).unwrap();
assert!(
merged >= 2,
"should merge at least 2 pairs: merged={merged}"
);
let remaining = categories::list_categories(&conn, None).unwrap();
assert_eq!(
remaining.len(),
1,
"only 1 category should survive after triple merge"
);
}
#[test]
fn test_dissolve_skips_merged_category() {
let conn = open_memory_db().unwrap();
for i in 1..=2 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("merge-dissolve node {i}")],
)
.unwrap();
}
let emb = vec![1.0f32, 0.0, 0.0];
let c1 =
categories::store_category(&conn, "will-merge-low-stab", NodeId(1), Some(&emb), None)
.unwrap();
let c2 = categories::store_category(&conn, "keep-high-stab", NodeId(2), Some(&emb), None)
.unwrap();
conn.execute(
"UPDATE categories SET stability = 0.05 WHERE id = ?1",
[c1.0],
)
.unwrap();
conn.execute(
"UPDATE categories SET stability = 0.9 WHERE id = ?1",
[c2.0],
)
.unwrap();
categories::assign_node_to_category(&conn, NodeId(1), c1).unwrap();
categories::assign_node_to_category(&conn, NodeId(2), c2).unwrap();
embeddings::store_embedding(&conn, "semantic", 1, &[1.0, 0.0, 0.0], "").unwrap();
embeddings::store_embedding(&conn, "semantic", 2, &[1.0, 0.0, 0.0], "").unwrap();
let (merged, _dissolved) = maintain_categories(&conn).unwrap();
assert!(merged >= 1, "should merge the similar categories");
let remaining = categories::list_categories(&conn, None).unwrap();
assert_eq!(remaining.len(), 1);
}
#[test]
fn test_split_coherent_category_skipped() {
let conn = open_memory_db().unwrap();
for i in 1..=10 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("coherent node {i}")],
)
.unwrap();
}
let centroid = vec![1.0f32, 0.0, 0.0];
let cat_id =
categories::store_category(&conn, "coherent-cat", NodeId(1), Some(¢roid), None)
.unwrap();
for i in 1..=10 {
categories::assign_node_to_category(&conn, NodeId(i), cat_id).unwrap();
embeddings::store_embedding(&conn, "semantic", i, &[1.0, 0.0, 0.0], "").unwrap();
}
let splits = split_large_categories(&conn).unwrap();
assert_eq!(splits, 0, "coherent category should not be split");
}
#[test]
fn test_split_no_meaningful_subclusters() {
let conn = open_memory_db().unwrap();
for i in 1..=10 {
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[format!("cluster node {i}")],
)
.unwrap();
}
let centroid = vec![0.5f32, 0.5, 0.5];
let cat_id =
categories::store_category(&conn, "no-split-cat", NodeId(1), Some(¢roid), None)
.unwrap();
for i in 1..=8 {
categories::assign_node_to_category(&conn, NodeId(i), cat_id).unwrap();
embeddings::store_embedding(&conn, "semantic", i, &[1.0, 0.0, 0.0], "").unwrap();
}
for i in 9..=10 {
categories::assign_node_to_category(&conn, NodeId(i), cat_id).unwrap();
embeddings::store_embedding(&conn, "semantic", i, &[0.0, 0.0, 1.0], "").unwrap();
}
let splits = split_large_categories(&conn).unwrap();
assert_eq!(
splits, 0,
"should not split when only 1 valid cluster exists"
);
}
#[test]
fn test_split_label_truncation() {
let conn = open_memory_db().unwrap();
for i in 1..=16 {
let long_content = format!(
"This is a very long content string that exceeds forty characters for node number {i}"
);
conn.execute(
"INSERT INTO semantic_nodes (content, node_type, confidence, created_at, last_corroborated, corroboration_count)
VALUES (?1, 'fact', 0.8, 1000, 1000, 1)",
[long_content],
)
.unwrap();
}
let centroid = vec![0.0f32, 0.0, 0.0, 1.0];
let cat_id =
categories::store_category(&conn, "split-me", NodeId(1), Some(¢roid), None)
.unwrap();
for i in 1..=8 {
categories::assign_node_to_category(&conn, NodeId(i), cat_id).unwrap();
embeddings::store_embedding(&conn, "semantic", i, &[0.95, 0.05, 0.0, 0.0], "").unwrap();
}
for i in 9..=16 {
categories::assign_node_to_category(&conn, NodeId(i), cat_id).unwrap();
embeddings::store_embedding(&conn, "semantic", i, &[0.0, 0.0, 0.95, 0.05], "").unwrap();
}
let splits = split_large_categories(&conn).unwrap();
assert!(
splits >= 1,
"should split incoherent category: splits={splits}"
);
let all_cats = categories::list_categories(&conn, None).unwrap();
let sub_cats: Vec<_> = all_cats.iter().filter(|c| c.id != cat_id).collect();
assert!(!sub_cats.is_empty(), "should have created sub-categories");
for cat in &sub_cats {
assert!(
cat.label.len() <= 40,
"sub-category label should be <= 40 chars: '{}'",
cat.label
);
}
}
}