mod adapter;
mod config;
mod diffusion;
mod moe;
mod topology;
pub use adapter::AttentionAdapter;
pub use config::AttentionCoherenceConfig;
pub use diffusion::{DiffusionSmoothing, SmoothedEnergy};
pub use moe::{ExpertRouting, MoEResidualProcessor};
pub use topology::{AttentionScore, TopologyGate, TopologyGateResult};
use std::collections::HashMap;
pub type NodeId = u64;
pub type EdgeId = (NodeId, NodeId);
pub type Result<T> = std::result::Result<T, AttentionError>;
#[derive(Debug, Clone, thiserror::Error)]
pub enum AttentionError {
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Empty input: {0}")]
EmptyInput(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Computation failed: {0}")]
ComputationFailed(String),
#[error("Mode not supported in current state: {0}")]
ModeNotSupported(String),
}
#[derive(Debug)]
pub struct AttentionCoherence {
config: AttentionCoherenceConfig,
adapter: AttentionAdapter,
topo_gate: TopologyGate,
moe: MoEResidualProcessor,
diffusion: DiffusionSmoothing,
}
impl AttentionCoherence {
pub fn new(config: AttentionCoherenceConfig) -> Self {
let adapter = AttentionAdapter::new(config.clone());
let topo_gate = TopologyGate::new(config.clone());
let moe = MoEResidualProcessor::new(config.clone());
let diffusion = DiffusionSmoothing::new(config.clone());
Self {
config,
adapter,
topo_gate,
moe,
diffusion,
}
}
pub fn default_config() -> Self {
Self::new(AttentionCoherenceConfig::default())
}
pub fn compute_attention_scores(
&mut self,
node_states: &[&[f32]],
) -> Result<HashMap<usize, f32>> {
if node_states.is_empty() {
return Err(AttentionError::EmptyInput("node_states".to_string()));
}
self.topo_gate.update_coherence(node_states);
let scores = self.adapter.compute_scores(node_states)?;
Ok(scores
.into_iter()
.enumerate()
.map(|(i, s)| (i, s))
.collect())
}
pub fn weighted_residuals(
&mut self,
node_states: &[&[f32]],
edge_residuals: &[(usize, usize, Vec<f32>)], ) -> Result<Vec<WeightedEdgeResidual>> {
if node_states.is_empty() {
return Err(AttentionError::EmptyInput("node_states".to_string()));
}
let scores = self.compute_attention_scores(node_states)?;
let mut weighted = Vec::with_capacity(edge_residuals.len());
for (source, target, residual) in edge_residuals {
let source_score = scores.get(source).copied().unwrap_or(1.0);
let target_score = scores.get(target).copied().unwrap_or(1.0);
let attention_weight = (source_score + target_score) / 2.0;
let residual_norm_sq: f32 = residual.iter().map(|x| x * x).sum();
let weighted_energy = residual_norm_sq * attention_weight;
weighted.push(WeightedEdgeResidual {
source_idx: *source,
target_idx: *target,
source_attention: source_score,
target_attention: target_score,
attention_weight,
residual_norm_sq,
weighted_energy,
});
}
Ok(weighted)
}
pub fn moe_process_residual(
&self,
residual: &[f32],
context: &[f32],
) -> Result<MoEProcessedResidual> {
self.moe.process(residual, context)
}
pub fn smooth_energy(
&self,
edge_energies: &[(usize, usize, f32)], node_states: &[&[f32]],
steps: usize,
) -> Result<SmoothedEnergy> {
self.diffusion.smooth(edge_energies, node_states, steps)
}
pub fn gate_result(&self) -> TopologyGateResult {
self.topo_gate.current_result()
}
pub fn allows_updates(&self) -> bool {
self.topo_gate.allows_updates()
}
pub fn attention_width(&self) -> usize {
self.topo_gate.attention_width()
}
pub fn config(&self) -> &AttentionCoherenceConfig {
&self.config
}
pub fn full_analysis(
&mut self,
node_states: &[&[f32]],
edge_residuals: &[(usize, usize, Vec<f32>)],
) -> Result<AttentionEnergyAnalysis> {
let gate_result = self.topo_gate.current_result();
let weighted = self.weighted_residuals(node_states, edge_residuals)?;
let edge_energies: Vec<(usize, usize, f32)> = weighted
.iter()
.map(|w| (w.source_idx, w.target_idx, w.weighted_energy))
.collect();
let smoothed = if self.config.enable_diffusion {
Some(self.smooth_energy(&edge_energies, node_states, self.config.diffusion_steps)?)
} else {
None
};
let total_energy: f32 = weighted.iter().map(|w| w.weighted_energy).sum();
let avg_attention: f32 = weighted.iter().map(|w| w.attention_weight).sum::<f32>()
/ weighted.len().max(1) as f32;
Ok(AttentionEnergyAnalysis {
weighted_residuals: weighted,
smoothed_energy: smoothed,
total_energy,
avg_attention_weight: avg_attention,
gate_result,
num_edges: edge_residuals.len(),
})
}
}
#[derive(Debug, Clone)]
pub struct WeightedEdgeResidual {
pub source_idx: usize,
pub target_idx: usize,
pub source_attention: f32,
pub target_attention: f32,
pub attention_weight: f32,
pub residual_norm_sq: f32,
pub weighted_energy: f32,
}
#[derive(Debug, Clone)]
pub struct MoEProcessedResidual {
pub output: Vec<f32>,
pub expert_indices: Vec<usize>,
pub expert_weights: Vec<f32>,
pub load_balance_loss: f32,
}
#[derive(Debug, Clone)]
pub struct AttentionEnergyAnalysis {
pub weighted_residuals: Vec<WeightedEdgeResidual>,
pub smoothed_energy: Option<SmoothedEnergy>,
pub total_energy: f32,
pub avg_attention_weight: f32,
pub gate_result: TopologyGateResult,
pub num_edges: usize,
}
impl AttentionEnergyAnalysis {
pub fn is_coherent(&self, threshold: f32) -> bool {
self.total_energy < threshold
}
pub fn highest_energy_edge(&self) -> Option<&WeightedEdgeResidual> {
self.weighted_residuals
.iter()
.max_by(|a, b| a.weighted_energy.partial_cmp(&b.weighted_energy).unwrap())
}
pub fn edges_above_threshold(&self, threshold: f32) -> Vec<&WeightedEdgeResidual> {
self.weighted_residuals
.iter()
.filter(|r| r.weighted_energy > threshold)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_states(n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n)
.map(|i| vec![0.1 * (i + 1) as f32; dim])
.collect()
}
#[test]
fn test_basic_coherence() {
let config = AttentionCoherenceConfig {
dimension: 16,
..Default::default()
};
let mut coherence = AttentionCoherence::new(config);
let states = make_states(5, 16);
let state_refs: Vec<&[f32]> = states.iter().map(|s| s.as_slice()).collect();
let scores = coherence.compute_attention_scores(&state_refs).unwrap();
assert_eq!(scores.len(), 5);
for (_, &score) in &scores {
assert!(score >= 0.0 && score <= 1.0);
}
}
#[test]
fn test_weighted_residuals() {
let config = AttentionCoherenceConfig {
dimension: 8,
..Default::default()
};
let mut coherence = AttentionCoherence::new(config);
let states = make_states(4, 8);
let state_refs: Vec<&[f32]> = states.iter().map(|s| s.as_slice()).collect();
let residuals = vec![
(0, 1, vec![0.1f32; 8]),
(1, 2, vec![0.2f32; 8]),
(2, 3, vec![0.3f32; 8]),
];
let weighted = coherence.weighted_residuals(&state_refs, &residuals).unwrap();
assert_eq!(weighted.len(), 3);
for w in &weighted {
assert!(w.weighted_energy >= 0.0);
assert!(w.attention_weight > 0.0);
}
}
#[test]
fn test_full_analysis() {
let config = AttentionCoherenceConfig {
dimension: 8,
enable_diffusion: false,
..Default::default()
};
let mut coherence = AttentionCoherence::new(config);
let states = make_states(3, 8);
let state_refs: Vec<&[f32]> = states.iter().map(|s| s.as_slice()).collect();
let residuals = vec![(0, 1, vec![0.1f32; 8]), (1, 2, vec![0.2f32; 8])];
let analysis = coherence.full_analysis(&state_refs, &residuals).unwrap();
assert_eq!(analysis.num_edges, 2);
assert!(analysis.total_energy >= 0.0);
assert!(analysis.avg_attention_weight > 0.0);
}
}