use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AttentionMode {
Stable,
Cautious,
Freeze,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyConfig {
pub stable_threshold: f32,
pub freeze_threshold: f32,
pub cautious_width_factor: f32,
pub cautious_sparsity_factor: f32,
pub update_period: usize,
pub hysteresis: f32,
}
impl Default for PolicyConfig {
fn default() -> Self {
Self {
stable_threshold: 0.7,
freeze_threshold: 0.3,
cautious_width_factor: 0.5,
cautious_sparsity_factor: 2.0,
update_period: 4,
hysteresis: 0.05,
}
}
}
#[derive(Debug, Clone)]
pub struct AttentionPolicy {
config: PolicyConfig,
current_mode: AttentionMode,
mode_history: Vec<AttentionMode>,
}
impl AttentionPolicy {
pub fn new(config: PolicyConfig) -> Self {
Self {
config,
current_mode: AttentionMode::Stable,
mode_history: Vec::new(),
}
}
pub fn determine_mode(&mut self, coherence: f32) -> AttentionMode {
let new_mode = self.compute_mode(coherence);
let mode = self.apply_hysteresis(new_mode, coherence);
self.mode_history.push(mode);
if self.mode_history.len() > 16 {
self.mode_history.remove(0);
}
self.current_mode = mode;
mode
}
fn compute_mode(&self, coherence: f32) -> AttentionMode {
if coherence >= self.config.stable_threshold {
AttentionMode::Stable
} else if coherence <= self.config.freeze_threshold {
AttentionMode::Freeze
} else {
AttentionMode::Cautious
}
}
fn apply_hysteresis(&self, new_mode: AttentionMode, coherence: f32) -> AttentionMode {
let h = self.config.hysteresis;
match (self.current_mode, new_mode) {
(AttentionMode::Stable, AttentionMode::Cautious) => {
if coherence < self.config.stable_threshold - h {
AttentionMode::Cautious
} else {
AttentionMode::Stable
}
}
(AttentionMode::Cautious, AttentionMode::Stable) => {
if coherence > self.config.stable_threshold + h {
AttentionMode::Stable
} else {
AttentionMode::Cautious
}
}
(AttentionMode::Cautious, AttentionMode::Freeze) => {
if coherence < self.config.freeze_threshold - h {
AttentionMode::Freeze
} else {
AttentionMode::Cautious
}
}
(AttentionMode::Freeze, AttentionMode::Cautious) => {
if coherence > self.config.freeze_threshold + h {
AttentionMode::Cautious
} else {
AttentionMode::Freeze
}
}
_ => new_mode,
}
}
pub fn current_mode(&self) -> AttentionMode {
self.current_mode
}
pub fn get_attention_width(&self, base_width: usize) -> usize {
match self.current_mode {
AttentionMode::Stable => base_width,
AttentionMode::Cautious => {
((base_width as f32 * self.config.cautious_width_factor) as usize).max(1)
}
AttentionMode::Freeze => 0, }
}
pub fn get_sparsity_factor(&self) -> f32 {
match self.current_mode {
AttentionMode::Stable => 1.0,
AttentionMode::Cautious => self.config.cautious_sparsity_factor,
AttentionMode::Freeze => f32::INFINITY, }
}
pub fn allows_updates(&self) -> bool {
self.current_mode != AttentionMode::Freeze
}
pub fn allows_writes(&self) -> bool {
self.current_mode != AttentionMode::Freeze
}
pub fn mode_stability(&self) -> f32 {
if self.mode_history.is_empty() {
return 1.0;
}
let current = self.current_mode;
let matches = self.mode_history.iter().filter(|&&m| m == current).count();
matches as f32 / self.mode_history.len() as f32
}
pub fn reset(&mut self) {
self.current_mode = AttentionMode::Stable;
self.mode_history.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_modes() {
let mut policy = AttentionPolicy::new(PolicyConfig::default());
assert_eq!(policy.determine_mode(0.9), AttentionMode::Stable);
assert_eq!(policy.determine_mode(0.5), AttentionMode::Cautious);
assert_eq!(policy.determine_mode(0.1), AttentionMode::Freeze);
}
#[test]
fn test_attention_width() {
let mut policy = AttentionPolicy::new(PolicyConfig::default());
policy.determine_mode(0.9);
assert_eq!(policy.get_attention_width(100), 100);
policy.determine_mode(0.5);
assert_eq!(policy.get_attention_width(100), 50);
policy.determine_mode(0.1);
assert_eq!(policy.get_attention_width(100), 0);
}
#[test]
fn test_hysteresis() {
let mut policy = AttentionPolicy::new(PolicyConfig {
stable_threshold: 0.7,
freeze_threshold: 0.3,
hysteresis: 0.1,
..Default::default()
});
policy.determine_mode(0.8);
assert_eq!(policy.current_mode(), AttentionMode::Stable);
policy.determine_mode(0.65);
assert_eq!(policy.current_mode(), AttentionMode::Stable);
policy.determine_mode(0.55);
assert_eq!(policy.current_mode(), AttentionMode::Cautious);
}
#[test]
fn test_update_permissions() {
let mut policy = AttentionPolicy::new(PolicyConfig::default());
policy.determine_mode(0.8);
assert!(policy.allows_updates());
assert!(policy.allows_writes());
policy.determine_mode(0.1);
assert!(!policy.allows_updates());
assert!(!policy.allows_writes());
}
}