use fxhash::FxHashMap;
use crate::core::atom_memory::AtomMemory;
use crate::core::entangled::EntangledHVec;
use crate::core::triple_store::TripleStore;
#[derive(Clone, Debug, Default)]
pub struct RefinementReport {
pub atoms_refined: usize,
pub atoms_skipped: usize,
pub avg_context_depth: f64,
}
#[derive(Clone, Debug)]
pub struct RefinerConfig {
pub alpha: f64,
pub min_context_relations: usize,
pub min_peers_per_relation: usize,
}
impl Default for RefinerConfig {
fn default() -> Self {
Self {
alpha: 0.15,
min_context_relations: 2,
min_peers_per_relation: 2,
}
}
}
pub struct DistributionalRefiner;
impl DistributionalRefiner {
pub fn refine(
atom_memory: &AtomMemory,
triple_store: &TripleStore,
config: &RefinerConfig,
) -> RefinementReport {
let snapshot = triple_store.snapshot();
if snapshot.is_empty() {
return RefinementReport::default();
}
let mut atom_relations: FxHashMap<String, Vec<String>> = FxHashMap::default();
for t in &snapshot {
atom_relations
.entry(t.subject_id.clone())
.or_default()
.push(t.relation_id.clone());
}
for rels in atom_relations.values_mut() {
rels.sort();
rels.dedup();
}
let mut relation_subjects: FxHashMap<String, Vec<String>> = FxHashMap::default();
for t in &snapshot {
relation_subjects
.entry(t.relation_id.clone())
.or_default()
.push(t.subject_id.clone());
}
for subjects in relation_subjects.values_mut() {
subjects.sort();
subjects.dedup();
}
let mut report = RefinementReport::default();
let mut total_context_depth: usize = 0;
for (atom_id, relations) in &atom_relations {
if relations.len() < config.min_context_relations {
report.atoms_skipped += 1;
continue;
}
let original = match atom_memory.get(atom_id) {
Some(v) => v,
None => continue,
};
let mut peer_index_freq: FxHashMap<u32, f64> = FxHashMap::default();
let mut contributing_relations = 0usize;
for rel in relations {
let peers = match relation_subjects.get(rel) {
Some(p) => p,
None => continue,
};
let other_peers: Vec<&String> = peers.iter().filter(|p| *p != atom_id).collect();
if other_peers.len() < config.min_peers_per_relation {
continue;
}
contributing_relations += 1;
let weight = 1.0 / other_peers.len() as f64;
for peer_id in &other_peers {
if let Some(peer_vec) = atom_memory.get(peer_id) {
for &idx in peer_vec.indices() {
*peer_index_freq.entry(idx).or_insert(0.0) += weight;
}
}
}
}
if contributing_relations == 0 {
report.atoms_skipped += 1;
continue;
}
total_context_depth += contributing_relations;
let dim = original.dim;
let target_count = (dim / 256).max(1);
let ctx_slots = ((target_count as f64) * config.alpha).round() as usize;
let orig_slots = target_count.saturating_sub(ctx_slots);
let orig_set: fxhash::FxHashSet<u32> = original.indices().iter().copied().collect();
let mut ctx_scored: Vec<(u32, f64)> = peer_index_freq
.iter()
.filter(|(idx, _)| !orig_set.contains(idx))
.map(|(&idx, &freq)| (idx, freq))
.collect();
ctx_scored.sort_unstable_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
ctx_scored.truncate(ctx_slots);
let mut orig_indices: Vec<u32> = original.indices().to_vec();
orig_indices.truncate(orig_slots);
let mut new_indices: Vec<u32> = orig_indices;
new_indices.extend(ctx_scored.iter().map(|(idx, _)| *idx));
new_indices.sort_unstable();
new_indices.dedup();
new_indices.truncate(target_count);
let refined = EntangledHVec::from_indices(new_indices, dim);
if original.similarity(&refined) < 0.999 {
atom_memory.delete(atom_id);
atom_memory.load_atom(atom_id.clone(), refined);
report.atoms_refined += 1;
} else {
report.atoms_skipped += 1;
}
}
if report.atoms_refined > 0 {
report.avg_context_depth = total_context_depth as f64 / report.atoms_refined as f64;
atom_memory.rebuild_indices();
}
report
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_refine_converges_cooccurring_atoms() {
let dim = 16384;
let atom_mem = AtomMemory::new(dim, 3.0);
let triple_store = TripleStore::new();
atom_mem.get_or_insert("paris");
atom_mem.get_or_insert("berlin");
atom_mem.get_or_insert("tokyo");
atom_mem.get_or_insert("france");
atom_mem.get_or_insert("germany");
atom_mem.get_or_insert("japan");
atom_mem.get_or_insert("europe");
triple_store.add("paris", "capital_of", "france", "c1");
triple_store.add("berlin", "capital_of", "germany", "c2");
triple_store.add("tokyo", "capital_of", "japan", "c3");
triple_store.add("paris", "located_in", "europe", "c4");
triple_store.add("berlin", "located_in", "europe", "c5");
let paris_before = atom_mem.get("paris").unwrap();
let berlin_before = atom_mem.get("berlin").unwrap();
let sim_before = paris_before.similarity(&berlin_before);
let config = RefinerConfig {
alpha: 0.3,
min_context_relations: 2,
min_peers_per_relation: 2,
};
let report = DistributionalRefiner::refine(&atom_mem, &triple_store, &config);
assert!(
report.atoms_refined > 0,
"Should refine at least some atoms"
);
let paris_after = atom_mem.get("paris").unwrap();
let berlin_after = atom_mem.get("berlin").unwrap();
let sim_after = paris_after.similarity(&berlin_after);
assert!(
sim_after > sim_before,
"paris and berlin should become more similar after refinement: before={:.4}, after={:.4}",
sim_before,
sim_after
);
}
#[test]
fn test_refine_skips_low_context_atoms() {
let dim = 16384;
let atom_mem = AtomMemory::new(dim, 3.0);
let triple_store = TripleStore::new();
atom_mem.get_or_insert("lonely");
atom_mem.get_or_insert("x");
triple_store.add("lonely", "r", "x", "c1");
let config = RefinerConfig {
min_context_relations: 2,
..Default::default()
};
let report = DistributionalRefiner::refine(&atom_mem, &triple_store, &config);
assert_eq!(report.atoms_refined, 0);
assert!(report.atoms_skipped > 0);
}
#[test]
fn test_refine_empty_store() {
let dim = 16384;
let atom_mem = AtomMemory::new(dim, 3.0);
let triple_store = TripleStore::new();
let config = RefinerConfig::default();
let report = DistributionalRefiner::refine(&atom_mem, &triple_store, &config);
assert_eq!(report.atoms_refined, 0);
assert_eq!(report.atoms_skipped, 0);
}
#[test]
fn test_refine_preserves_vector_sparsity() {
let dim = 16384;
let atom_mem = AtomMemory::new(dim, 3.0);
let triple_store = TripleStore::new();
for city in &["a", "b", "c", "d", "e"] {
atom_mem.get_or_insert(city);
}
for country in &["x", "y", "z", "w", "v"] {
atom_mem.get_or_insert(country);
}
atom_mem.get_or_insert("continent");
for (city, country) in [("a", "x"), ("b", "y"), ("c", "z"), ("d", "w"), ("e", "v")] {
triple_store.add(city, "capital_of", country, &format!("c_{}", city));
triple_store.add(city, "in", "continent", &format!("l_{}", city));
}
let config = RefinerConfig {
alpha: 0.3,
min_context_relations: 2,
min_peers_per_relation: 2,
};
let report = DistributionalRefiner::refine(&atom_mem, &triple_store, &config);
let target = dim / 256;
for city in &["a", "b", "c", "d", "e"] {
let vec = atom_mem.get(city).unwrap();
assert!(
vec.indices().len() <= target + 1,
"Refined vector for {} has {} indices, expected <= {}",
city,
vec.indices().len(),
target + 1
);
}
assert!(report.atoms_refined >= 3);
}
#[test]
fn test_multiple_refinement_rounds() {
let dim = 16384;
let atom_mem = AtomMemory::new(dim, 3.0);
let triple_store = TripleStore::new();
atom_mem.get_or_insert("a");
atom_mem.get_or_insert("b");
atom_mem.get_or_insert("c");
atom_mem.get_or_insert("x");
atom_mem.get_or_insert("y");
triple_store.add("a", "r1", "x", "c1");
triple_store.add("b", "r1", "y", "c2");
triple_store.add("c", "r1", "x", "c3");
triple_store.add("a", "r2", "y", "c4");
triple_store.add("b", "r2", "x", "c5");
triple_store.add("c", "r2", "y", "c6");
let config = RefinerConfig {
alpha: 0.2,
min_context_relations: 2,
min_peers_per_relation: 2,
};
let sim_before = atom_mem
.get("a")
.unwrap()
.similarity(&atom_mem.get("b").unwrap());
for _ in 0..3 {
DistributionalRefiner::refine(&atom_mem, &triple_store, &config);
}
let sim_after = atom_mem
.get("a")
.unwrap()
.similarity(&atom_mem.get("b").unwrap());
assert!(
sim_after > sim_before,
"Multiple rounds should increase similarity: before={:.4}, after={:.4}",
sim_before,
sim_after
);
}
}