use crate::compete::inhibition::LateralInhibition;
#[derive(Debug, Clone)]
pub struct WTALayer {
membranes: Vec<f32>,
threshold: f32,
inhibition_strength: f32,
refractory_period: u32,
refractory_counters: Vec<u32>,
inhibition: LateralInhibition,
}
impl WTALayer {
pub fn new(size: usize, threshold: f32, inhibition: f32) -> Self {
assert!(size > 0, "size must be > 0");
Self {
membranes: vec![0.0; size],
threshold,
inhibition_strength: inhibition.clamp(0.0, 1.0),
refractory_period: 10,
refractory_counters: vec![0; size],
inhibition: LateralInhibition::new(size, inhibition, 0.9),
}
}
pub fn compete(&mut self, inputs: &[f32]) -> Option<usize> {
assert_eq!(inputs.len(), self.membranes.len(), "Input size mismatch");
let mut best_idx = None;
let mut best_val = f32::NEG_INFINITY;
for (i, &input) in inputs.iter().enumerate() {
if self.refractory_counters[i] == 0 {
self.membranes[i] = input;
if input > best_val {
best_val = input;
best_idx = Some(i);
}
} else {
self.refractory_counters[i] = self.refractory_counters[i].saturating_sub(1);
}
}
let winner_idx = best_idx?;
if best_val < self.threshold {
return None;
}
self.inhibition.apply(&mut self.membranes, winner_idx);
self.refractory_counters[winner_idx] = self.refractory_period;
Some(winner_idx)
}
pub fn compete_soft(&mut self, inputs: &[f32]) -> Vec<f32> {
assert_eq!(inputs.len(), self.membranes.len(), "Input size mismatch");
for (i, &input) in inputs.iter().enumerate() {
self.membranes[i] = input;
}
let max_val = self
.membranes
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let temperature = 1.0 / (1.0 + self.inhibition_strength);
let mut activations: Vec<f32> = self
.membranes
.iter()
.map(|&x| ((x - max_val) / temperature).exp())
.collect();
let sum: f32 = activations.iter().sum();
if sum > 0.0 {
for a in &mut activations {
*a /= sum;
}
}
activations
}
pub fn reset(&mut self) {
self.membranes.fill(0.0);
self.refractory_counters.fill(0);
}
pub fn membranes(&self) -> &[f32] {
&self.membranes
}
pub fn set_refractory_period(&mut self, period: u32) {
self.refractory_period = period;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wta_basic() {
let mut wta = WTALayer::new(5, 0.5, 0.8);
let inputs = vec![0.1, 0.3, 0.9, 0.2, 0.4];
let winner = wta.compete(&inputs);
assert_eq!(winner, Some(2), "Highest activation should win");
}
#[test]
fn test_wta_threshold() {
let mut wta = WTALayer::new(5, 0.95, 0.8);
let inputs = vec![0.1, 0.3, 0.9, 0.2, 0.4];
let winner = wta.compete(&inputs);
assert_eq!(winner, None, "No neuron exceeds threshold");
}
#[test]
fn test_wta_soft_competition() {
let mut wta = WTALayer::new(5, 0.5, 0.8);
let inputs = vec![0.1, 0.3, 0.9, 0.2, 0.4];
let activations = wta.compete_soft(&inputs);
let sum: f32 = activations.iter().sum();
assert!((sum - 1.0).abs() < 0.001, "Activations should sum to 1.0");
let max_idx = activations
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
assert_eq!(max_idx, 2, "Highest input should have highest activation");
}
#[test]
fn test_wta_refractory_period() {
let mut wta = WTALayer::new(3, 0.5, 0.8);
wta.set_refractory_period(2);
let inputs = vec![0.6, 0.7, 0.8];
let winner1 = wta.compete(&inputs);
assert_eq!(winner1, Some(2));
let inputs = vec![0.6, 0.7, 0.8];
let winner2 = wta.compete(&inputs);
assert_ne!(winner2, Some(2), "Winner should be in refractory period");
}
#[test]
fn test_wta_determinism() {
let mut wta1 = WTALayer::new(10, 0.5, 0.8);
let mut wta2 = WTALayer::new(10, 0.5, 0.8);
let inputs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
let winner1 = wta1.compete(&inputs);
let winner2 = wta2.compete(&inputs);
assert_eq!(winner1, winner2, "WTA should be deterministic");
}
#[test]
fn test_wta_reset() {
let mut wta = WTALayer::new(5, 0.5, 0.8);
let inputs = vec![0.1, 0.3, 0.9, 0.2, 0.4];
wta.compete(&inputs);
wta.reset();
assert!(
wta.membranes().iter().all(|&x| x == 0.0),
"Membranes should be reset"
);
}
#[test]
fn test_wta_performance() {
use std::time::Instant;
let mut wta = WTALayer::new(1000, 0.5, 0.8);
let inputs: Vec<f32> = (0..1000).map(|i| (i as f32) / 1000.0).collect();
let start = Instant::now();
for _ in 0..1000 {
wta.reset();
let _ = wta.compete(&inputs);
}
let elapsed = start.elapsed();
let avg_micros = elapsed.as_micros() as f64 / 1000.0;
println!("Average WTA competition time: {:.2}μs", avg_micros);
assert!(
avg_micros < 100.0,
"WTA should be fast (got {:.2}μs)",
avg_micros
);
}
}