#[derive(Debug, Clone)]
pub struct LateralInhibition {
size: usize,
strength: f32,
decay: f32,
radius: usize,
}
impl LateralInhibition {
pub fn new(size: usize, strength: f32, decay: f32) -> Self {
Self {
size,
strength: strength.clamp(0.0, 1.0),
decay: decay.clamp(0.0, 1.0),
radius: (size as f32).sqrt() as usize, }
}
pub fn with_radius(mut self, radius: usize) -> Self {
self.radius = radius;
self
}
pub fn apply(&self, activations: &mut [f32], winner: usize) {
assert_eq!(activations.len(), self.size, "Activation size mismatch");
assert!(winner < self.size, "Winner index out of bounds");
let winner_activation = activations[winner];
for (i, activation) in activations.iter_mut().enumerate() {
if i == winner {
continue; }
let distance = if i > winner { i - winner } else { winner - i };
if distance <= self.radius {
let distance_factor = 1.0 - (distance as f32 / self.radius as f32);
let inhibition = self.strength * distance_factor * winner_activation;
*activation *= 1.0 - inhibition;
}
}
}
pub fn apply_global(&self, activations: &mut [f32]) {
let total_activation: f32 = activations.iter().sum();
let mean_activation = total_activation / activations.len() as f32;
for activation in activations.iter_mut() {
let inhibition = self.strength * mean_activation;
*activation = (*activation - inhibition).max(0.0);
}
}
pub fn weight(&self, from: usize, to: usize) -> f32 {
if from == to {
return 0.0; }
let distance = if to > from { to - from } else { from - to };
if distance > self.radius {
return 0.0;
}
let distance_factor = 1.0 - (distance as f32 / self.radius as f32);
self.strength * distance_factor
}
pub fn weight_matrix(&self) -> Vec<Vec<f32>> {
let mut matrix = vec![vec![0.0; self.size]; self.size];
for i in 0..self.size {
for j in 0..self.size {
matrix[i][j] = self.weight(i, j);
}
}
matrix
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inhibition_basic() {
let inhibition = LateralInhibition::new(10, 0.8, 0.9);
let mut activations = vec![0.0; 10];
activations[5] = 1.0; activations[4] = 0.5; activations[6] = 0.5;
inhibition.apply(&mut activations, 5);
assert_eq!(activations[5], 1.0);
assert!(activations[4] < 0.5, "Nearby neuron should be inhibited");
assert!(activations[6] < 0.5, "Nearby neuron should be inhibited");
}
#[test]
fn test_inhibition_radius() {
let inhibition = LateralInhibition::new(20, 0.8, 0.9).with_radius(2);
let mut activations = vec![0.5; 20];
activations[10] = 1.0;
inhibition.apply(&mut activations, 10);
assert!(activations[9] < 0.5);
assert!(activations[11] < 0.5);
assert!(activations[7] >= activations[9]);
assert!(activations[13] >= activations[11]);
}
#[test]
fn test_inhibition_no_self_inhibition() {
let inhibition = LateralInhibition::new(10, 1.0, 0.9);
assert_eq!(inhibition.weight(5, 5), 0.0, "No self-inhibition");
}
#[test]
fn test_inhibition_symmetric() {
let inhibition = LateralInhibition::new(10, 0.8, 0.9);
assert_eq!(
inhibition.weight(3, 7),
inhibition.weight(7, 3),
"Inhibition should be symmetric"
);
}
#[test]
fn test_global_inhibition() {
let inhibition = LateralInhibition::new(10, 0.5, 0.9);
let mut activations = vec![0.8; 10];
inhibition.apply_global(&mut activations);
assert!(activations.iter().all(|&x| x < 0.8));
assert!(activations.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-6));
}
#[test]
fn test_inhibition_strength_bounds() {
let inhibition1 = LateralInhibition::new(10, -0.5, 0.9);
let inhibition2 = LateralInhibition::new(10, 1.5, 0.9);
assert_eq!(inhibition1.strength, 0.0);
assert_eq!(inhibition2.strength, 1.0);
}
#[test]
fn test_weight_matrix_structure() {
let inhibition = LateralInhibition::new(5, 0.8, 0.9).with_radius(1);
let matrix = inhibition.weight_matrix();
assert_eq!(matrix.len(), 5);
assert!(matrix.iter().all(|row| row.len() == 5));
for i in 0..5 {
assert_eq!(matrix[i][i], 0.0);
}
for i in 0..5 {
for j in 0..5 {
assert_eq!(matrix[i][j], matrix[j][i]);
}
}
}
#[test]
fn test_mexican_hat_profile() {
let inhibition = LateralInhibition::new(10, 0.8, 0.9).with_radius(3);
let w1 = inhibition.weight(5, 6); let w2 = inhibition.weight(5, 7); let w3 = inhibition.weight(5, 8);
assert!(w1 > w2, "Inhibition decreases with distance");
assert!(w2 > w3, "Inhibition decreases with distance");
assert_eq!(inhibition.weight(5, 9), 0.0, "Beyond radius");
}
}