use super::{AttentionCoherenceConfig, AttentionError, Result};
use super::config::AttentionMode;
#[derive(Debug, Clone)]
pub struct AttentionScore {
pub node_idx: usize,
pub score: f32,
pub coherence_contribution: f32,
}
#[derive(Debug, Clone)]
pub struct TopologyGateResult {
pub coherence: f32,
pub mode: AttentionMode,
pub width: usize,
pub allows_updates: bool,
pub ticks_since_update: usize,
}
impl TopologyGateResult {
pub fn stable(config: &AttentionCoherenceConfig) -> Self {
Self {
coherence: 1.0,
mode: AttentionMode::Stable,
width: config.base_width,
allows_updates: true,
ticks_since_update: 0,
}
}
}
#[derive(Debug)]
pub struct TopologyGate {
config: AttentionCoherenceConfig,
coherence: f32,
mode: AttentionMode,
ticks_since_update: usize,
cached_metrics: Option<CoherenceMetrics>,
}
impl TopologyGate {
pub fn new(config: AttentionCoherenceConfig) -> Self {
Self {
coherence: 1.0, mode: AttentionMode::Stable,
ticks_since_update: 0,
cached_metrics: None,
config,
}
}
pub fn update_coherence(&mut self, keys: &[&[f32]]) {
if keys.is_empty() {
return;
}
let metrics = self.compute_coherence_metrics(keys);
self.coherence = metrics.coherence_score;
self.mode = AttentionMode::from_coherence(self.coherence, &self.config);
self.ticks_since_update = 0;
self.cached_metrics = Some(metrics);
}
pub fn tick(&mut self) {
self.ticks_since_update += 1;
}
pub fn needs_update(&self) -> bool {
self.ticks_since_update >= self.config.coherence_update_period
|| self.cached_metrics.is_none()
}
pub fn current_mode(&self) -> AttentionMode {
self.mode
}
pub fn current_coherence(&self) -> f32 {
self.coherence
}
pub fn allows_updates(&self) -> bool {
self.mode.allows_updates()
}
pub fn attention_width(&self) -> usize {
self.config.width_for_coherence(self.coherence)
}
pub fn current_result(&self) -> TopologyGateResult {
TopologyGateResult {
coherence: self.coherence,
mode: self.mode,
width: self.attention_width(),
allows_updates: self.allows_updates(),
ticks_since_update: self.ticks_since_update,
}
}
fn compute_coherence_metrics(&self, keys: &[&[f32]]) -> CoherenceMetrics {
if keys.is_empty() {
return CoherenceMetrics::empty();
}
let n = keys.len();
let k = self.config.k_neighbors.min(n - 1);
if k == 0 {
return CoherenceMetrics::with_score(1.0);
}
let mut similarities: Vec<Vec<f32>> = Vec::with_capacity(n);
for i in 0..n {
let mut row = Vec::with_capacity(n);
for j in 0..n {
if i == j {
row.push(1.0);
} else {
row.push(self.cosine_similarity(keys[i], keys[j]));
}
}
similarities.push(row);
}
let mut total_boundary_mass = 0.0f32;
let mut total_edges = 0;
for i in 0..n {
let mut neighbor_sims: Vec<(usize, f32)> = similarities[i]
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(j, &s)| (j, s))
.collect();
neighbor_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let neighbors: Vec<usize> = neighbor_sims.iter().take(k).map(|(j, _)| *j).collect();
for j in 0..n {
if j != i && !neighbors.contains(&j) {
total_boundary_mass += similarities[i][j].max(0.0);
total_edges += 1;
}
}
}
let all_sims: Vec<f32> = similarities
.iter()
.enumerate()
.flat_map(|(i, row)| row.iter().enumerate().filter(move |(j, _)| *j > i).map(|(_, &s)| s))
.collect();
let mean_sim: f32 = all_sims.iter().sum::<f32>() / all_sims.len().max(1) as f32;
let variance: f32 = all_sims.iter().map(|s| (s - mean_sim).powi(2)).sum::<f32>()
/ all_sims.len().max(1) as f32;
let boundary_ratio = if total_edges > 0 {
total_boundary_mass / total_edges as f32
} else {
0.0
};
let coherence_score = (mean_sim * 0.5 + (1.0 - variance.sqrt()) * 0.3 + (1.0 - boundary_ratio) * 0.2)
.clamp(0.0, 1.0);
CoherenceMetrics {
coherence_score,
mean_similarity: mean_sim,
similarity_variance: variance,
boundary_mass: total_boundary_mass,
num_nodes: n,
}
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
}
#[derive(Debug, Clone)]
struct CoherenceMetrics {
coherence_score: f32,
mean_similarity: f32,
similarity_variance: f32,
boundary_mass: f32,
num_nodes: usize,
}
impl CoherenceMetrics {
fn empty() -> Self {
Self {
coherence_score: 1.0,
mean_similarity: 1.0,
similarity_variance: 0.0,
boundary_mass: 0.0,
num_nodes: 0,
}
}
fn with_score(score: f32) -> Self {
Self {
coherence_score: score,
mean_similarity: score,
similarity_variance: 0.0,
boundary_mass: 0.0,
num_nodes: 1,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topology_gate_creation() {
let config = AttentionCoherenceConfig::default();
let gate = TopologyGate::new(config);
assert_eq!(gate.current_mode(), AttentionMode::Stable);
assert!(gate.allows_updates());
}
#[test]
fn test_update_coherence_similar_keys() {
let config = AttentionCoherenceConfig::default();
let mut gate = TopologyGate::new(config);
let keys: Vec<Vec<f32>> = (0..10).map(|_| vec![1.0, 0.0, 0.0, 0.0]).collect();
let key_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
gate.update_coherence(&key_refs);
assert!(gate.current_coherence() > 0.5);
assert_eq!(gate.current_mode(), AttentionMode::Stable);
}
#[test]
fn test_update_coherence_diverse_keys() {
let config = AttentionCoherenceConfig {
stable_threshold: 0.9,
freeze_threshold: 0.5,
..Default::default()
};
let mut gate = TopologyGate::new(config);
let keys: Vec<Vec<f32>> = (0..10)
.map(|i| {
let mut v = vec![0.0f32; 16];
v[i % 16] = 1.0;
v
})
.collect();
let key_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
gate.update_coherence(&key_refs);
assert!(
gate.current_mode() == AttentionMode::Cautious
|| gate.current_mode() == AttentionMode::Freeze
);
}
#[test]
fn test_tick_and_update_period() {
let config = AttentionCoherenceConfig {
coherence_update_period: 4,
..Default::default()
};
let mut gate = TopologyGate::new(config);
assert!(gate.needs_update());
let keys: Vec<Vec<f32>> = vec![vec![1.0; 8]; 5];
let key_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
gate.update_coherence(&key_refs);
assert!(!gate.needs_update());
for _ in 0..4 {
gate.tick();
}
assert!(gate.needs_update());
}
#[test]
fn test_attention_width() {
let config = AttentionCoherenceConfig {
base_width: 64,
stable_threshold: 0.7,
freeze_threshold: 0.3,
..Default::default()
};
let mut gate = TopologyGate::new(config);
gate.coherence = 0.8;
gate.mode = AttentionMode::from_coherence(0.8, &gate.config);
assert_eq!(gate.attention_width(), 64);
gate.coherence = 0.5;
gate.mode = AttentionMode::from_coherence(0.5, &gate.config);
assert_eq!(gate.attention_width(), 32);
gate.coherence = 0.2;
gate.mode = AttentionMode::from_coherence(0.2, &gate.config);
assert_eq!(gate.attention_width(), 1);
}
}