#[cfg(feature = "attention")]
use ruvector_mincut_gated_transformer::{
GatePacket, MincutDepthRouter, ModRoutingConfig, RoutingStats, TokenRoute,
CoherenceEarlyExit, EarlyExitConfig, EarlyExitDecision, ExitReason,
};
use crate::tile::{GateDecision, TileReport};
#[derive(Clone, Debug)]
pub struct AttentionConfig {
pub flops_reduction: f32,
pub min_entries_per_round: u16,
pub lambda_delta_skip_threshold: i32,
pub adaptive_capacity: bool,
pub enable_early_exit: bool,
pub early_exit_threshold: f32,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
flops_reduction: 0.5,
min_entries_per_round: 4,
lambda_delta_skip_threshold: 3276, adaptive_capacity: true,
enable_early_exit: true,
early_exit_threshold: 0.95,
}
}
}
impl AttentionConfig {
pub fn realtime() -> Self {
Self {
flops_reduction: 0.6, min_entries_per_round: 2,
lambda_delta_skip_threshold: 2000, adaptive_capacity: true,
enable_early_exit: true,
early_exit_threshold: 0.9,
}
}
pub fn accurate() -> Self {
Self {
flops_reduction: 0.3,
min_entries_per_round: 8,
lambda_delta_skip_threshold: 5000, adaptive_capacity: false,
enable_early_exit: false,
early_exit_threshold: 0.99,
}
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct GatePacketBridge {
prev_lambda: u32,
smoothed_boundary: u16,
}
impl GatePacketBridge {
pub fn new() -> Self {
Self::default()
}
#[cfg(feature = "attention")]
pub fn to_gate_packet(&mut self, reports: &[TileReport]) -> GatePacket {
if reports.is_empty() {
return GatePacket::default();
}
let mut min_cut = f64::MAX;
let mut max_shift = 0.0f64;
let mut total_boundary = 0u32;
let mut max_boundary_concentration = 0u32;
for report in reports {
if report.local_cut < min_cut && report.local_cut > 0.0 {
min_cut = report.local_cut;
}
if report.shift_score > max_shift {
max_shift = report.shift_score;
}
total_boundary += report.boundary_candidates.iter()
.filter(|&&c| c != 0)
.count() as u32;
let concentration = (report.shift_score * 32767.0) as u32;
if concentration > max_boundary_concentration {
max_boundary_concentration = concentration;
}
}
let lambda = (min_cut.clamp(0.0, 1000.0) * 32.767) as u32;
let boundary_edges = ((total_boundary as u32 + self.smoothed_boundary as u32) / 2) as u16;
self.smoothed_boundary = boundary_edges;
let packet = GatePacket {
lambda,
lambda_prev: self.prev_lambda,
boundary_edges,
boundary_concentration_q15: max_boundary_concentration.min(32767) as u16,
partition_count: reports.len() as u16,
flags: 0,
};
self.prev_lambda = lambda;
packet
}
#[cfg(feature = "attention")]
pub fn from_gate_packet(packet: &GatePacket) -> (f64, f64, usize) {
let min_cut = packet.lambda as f64 / 32.767;
let shift_score = packet.boundary_concentration_q15 as f64 / 32767.0;
let partition_count = packet.partition_count as usize;
(min_cut, shift_score, partition_count)
}
}
#[cfg(feature = "attention")]
pub struct CoherenceAttention {
config: AttentionConfig,
router: MincutDepthRouter,
bridge: GatePacketBridge,
stats: AttentionStats,
}
#[cfg(feature = "attention")]
impl CoherenceAttention {
pub fn new(config: AttentionConfig) -> Self {
let mod_config = ModRoutingConfig {
lambda_delta_skip_threshold: config.lambda_delta_skip_threshold,
boundary_token_force_compute: true,
layer_capacity_ratio: 1.0 - config.flops_reduction,
min_tokens_per_layer: config.min_entries_per_round,
adaptive_capacity: config.adaptive_capacity,
};
Self {
config,
router: MincutDepthRouter::new(mod_config).unwrap_or_default(),
bridge: GatePacketBridge::new(),
stats: AttentionStats::default(),
}
}
pub fn optimize(&mut self, reports: &[TileReport]) -> (GatePacket, Vec<TokenRoute>) {
let gate = self.bridge.to_gate_packet(reports);
let positions: Vec<u16> = (0..reports.len() as u16).collect();
let routes = self.router.route_tokens(&gate, &positions);
let routing_stats = self.router.routing_stats(&routes);
self.stats.total_entries += routing_stats.total_tokens;
self.stats.computed_entries += routing_stats.compute_tokens;
self.stats.skipped_entries += routing_stats.skip_tokens;
self.stats.boundary_entries += routing_stats.boundary_tokens;
self.stats.decisions += 1;
(gate, routes)
}
pub fn check_early_exit(
&self,
gate: &GatePacket,
current_layer: usize,
max_layers: usize,
) -> EarlyExitDecision {
if !self.config.enable_early_exit {
return EarlyExitDecision {
should_exit: false,
confidence: 0.0,
reason: ExitReason::None,
};
}
let lambda_delta_abs = gate.lambda_delta().abs() as f32;
let stability = 1.0 - (lambda_delta_abs / 32768.0).min(1.0);
let progress = current_layer as f32 / max_layers as f32;
let should_exit = stability > self.config.early_exit_threshold && progress > 0.5;
EarlyExitDecision {
should_exit,
confidence: stability,
reason: if should_exit {
ExitReason::HighConfidence
} else {
ExitReason::None
},
}
}
pub fn stats(&self) -> &AttentionStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = AttentionStats::default();
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct AttentionStats {
pub total_entries: usize,
pub computed_entries: usize,
pub skipped_entries: usize,
pub boundary_entries: usize,
pub decisions: usize,
}
impl AttentionStats {
pub fn flops_reduction(&self) -> f32 {
if self.total_entries == 0 {
return 0.0;
}
self.skipped_entries as f32 / self.total_entries as f32
}
pub fn compute_ratio(&self) -> f32 {
if self.total_entries == 0 {
return 0.0;
}
self.computed_entries as f32 / self.total_entries as f32
}
}
#[cfg(not(feature = "attention"))]
pub mod fallback {
use super::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TokenRoute {
Compute,
Skip,
Boundary,
}
impl TokenRoute {
pub fn requires_compute(&self) -> bool {
!matches!(self, TokenRoute::Skip)
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct GatePacket {
pub lambda: u32,
pub lambda_prev: u32,
pub boundary_edges: u16,
pub boundary_concentration_q15: u16,
pub partition_count: u16,
pub flags: u16,
}
impl GatePacket {
pub fn lambda_delta(&self) -> i32 {
(self.lambda as i32) - (self.lambda_prev as i32)
}
}
pub struct CoherenceAttention {
#[allow(dead_code)]
config: AttentionConfig,
bridge: GatePacketBridge,
stats: AttentionStats,
}
impl CoherenceAttention {
pub fn new(config: AttentionConfig) -> Self {
Self {
config,
bridge: GatePacketBridge::new(),
stats: AttentionStats::default(),
}
}
pub fn optimize(&mut self, reports: &[TileReport]) -> (GatePacket, Vec<TokenRoute>) {
let gate = self.bridge.to_gate_packet_fallback(reports);
let routes: Vec<TokenRoute> = reports.iter().enumerate().map(|(i, report)| {
if report.boundary_candidates.iter().any(|&c| c != 0) {
return TokenRoute::Boundary;
}
if report.shift_score < 0.1 && i % 2 == 0 {
return TokenRoute::Skip;
}
TokenRoute::Compute
}).collect();
self.stats.total_entries += routes.len();
self.stats.computed_entries += routes.iter()
.filter(|r| r.requires_compute())
.count();
self.stats.skipped_entries += routes.iter()
.filter(|r| matches!(r, TokenRoute::Skip))
.count();
self.stats.boundary_entries += routes.iter()
.filter(|r| matches!(r, TokenRoute::Boundary))
.count();
self.stats.decisions += 1;
(gate, routes)
}
pub fn stats(&self) -> &AttentionStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = AttentionStats::default();
}
}
impl GatePacketBridge {
pub fn to_gate_packet_fallback(&mut self, reports: &[TileReport]) -> GatePacket {
if reports.is_empty() {
return GatePacket::default();
}
let mut min_cut = f64::MAX;
let mut max_shift = 0.0f64;
for report in reports {
if report.local_cut < min_cut && report.local_cut > 0.0 {
min_cut = report.local_cut;
}
if report.shift_score > max_shift {
max_shift = report.shift_score;
}
}
let lambda = (min_cut.clamp(0.0, 1000.0) * 32.767) as u32;
let packet = GatePacket {
lambda,
lambda_prev: self.prev_lambda,
boundary_edges: 0,
boundary_concentration_q15: (max_shift * 32767.0) as u16,
partition_count: reports.len() as u16,
flags: 0,
};
self.prev_lambda = lambda;
packet
}
}
}
#[cfg(not(feature = "attention"))]
pub use fallback::*;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_config_default() {
let config = AttentionConfig::default();
assert_eq!(config.flops_reduction, 0.5);
assert!(config.enable_early_exit);
}
#[test]
fn test_attention_config_realtime() {
let config = AttentionConfig::realtime();
assert!(config.flops_reduction > 0.5);
}
#[test]
fn test_gate_packet_bridge() {
let mut bridge = GatePacketBridge::new();
let reports = vec![
{
let mut r = TileReport::new(1);
r.local_cut = 10.0;
r.shift_score = 0.2;
r
},
{
let mut r = TileReport::new(2);
r.local_cut = 15.0;
r.shift_score = 0.1;
r
},
];
#[cfg(feature = "attention")]
{
let packet = bridge.to_gate_packet(&reports);
assert!(packet.lambda > 0);
assert_eq!(packet.partition_count, 2);
}
#[cfg(not(feature = "attention"))]
{
let packet = bridge.to_gate_packet_fallback(&reports);
assert!(packet.lambda > 0);
assert_eq!(packet.partition_count, 2);
}
}
#[test]
fn test_attention_stats() {
let mut stats = AttentionStats::default();
stats.total_entries = 100;
stats.computed_entries = 60;
stats.skipped_entries = 40;
assert_eq!(stats.flops_reduction(), 0.4);
assert_eq!(stats.compute_ratio(), 0.6);
}
}