use alloc::vec;
use alloc::vec::Vec;
use crate::math::sigmoid;
const DEFAULT_TARGET_RATE: f64 = 0.1;
const Q14_ONE: i32 = 16384;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AstrocyteMode {
WeightMod,
LearningRateGate,
}
pub struct AstrocyteGate {
spike_rates: Vec<f64>,
modulation: Vec<f64>,
tau: f64,
target_rate: f64,
n_hidden: usize,
mode: AstrocyteMode,
}
impl AstrocyteGate {
pub fn new(n_hidden: usize, tau: f64) -> Self {
Self::with_mode(n_hidden, tau, AstrocyteMode::WeightMod)
}
pub fn with_mode(n_hidden: usize, tau: f64, mode: AstrocyteMode) -> Self {
let mut gate = Self {
spike_rates: vec![DEFAULT_TARGET_RATE; n_hidden],
modulation: vec![0.0; n_hidden],
tau,
target_rate: DEFAULT_TARGET_RATE,
n_hidden,
mode,
};
gate.recompute_modulation();
gate
}
#[inline]
pub fn mode(&self) -> AstrocyteMode {
self.mode
}
pub fn update(&mut self, spikes: &[u8]) {
debug_assert_eq!(spikes.len(), self.n_hidden);
let alpha = 1.0 / self.tau;
let decay = 1.0 - alpha;
for (j, &spike) in spikes.iter().enumerate().take(self.n_hidden) {
let spike_val = if spike != 0 { 1.0 } else { 0.0 };
self.spike_rates[j] = decay * self.spike_rates[j] + alpha * spike_val;
}
self.recompute_modulation();
}
fn recompute_modulation(&mut self) {
match self.mode {
AstrocyteMode::WeightMod => {
for j in 0..self.n_hidden {
self.modulation[j] =
2.0 * sigmoid(self.spike_rates[j] - self.target_rate) - 1.0;
}
}
AstrocyteMode::LearningRateGate => {
for j in 0..self.n_hidden {
self.modulation[j] = sigmoid(self.spike_rates[j] - self.target_rate);
}
}
}
}
#[inline]
pub fn modulate_weight(&self, neuron_j: usize, base_weight: i16) -> i16 {
let mod_q14 = (self.modulation[neuron_j] * 8192.0) as i32;
let scale = Q14_ONE + mod_q14; let result = (base_weight as i32 * scale) >> 14;
result.clamp(i16::MIN as i32, i16::MAX as i32) as i16
}
#[inline]
pub fn effective_eta_q14(&self, neuron_j: usize, eta_q14: i16) -> i16 {
let gate = self.modulation[neuron_j]; let gate_q14 = (gate * Q14_ONE as f64) as i32;
let result = (eta_q14 as i32 * gate_q14) >> 14;
result.clamp(i16::MIN as i32, i16::MAX as i32) as i16
}
pub fn modulation(&self) -> &[f64] {
&self.modulation
}
pub fn spike_rates(&self) -> &[f64] {
&self.spike_rates
}
pub fn reset(&mut self) {
for r in self.spike_rates.iter_mut() {
*r = self.target_rate;
}
self.recompute_modulation();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn astrocyte_new_initializes_correctly() {
let gate = AstrocyteGate::new(8, 1000.0);
assert_eq!(gate.spike_rates().len(), 8);
assert_eq!(gate.modulation().len(), 8);
for &r in gate.spike_rates() {
assert!((r - 0.1).abs() < 1e-10);
}
for &m in gate.modulation() {
assert!(m.abs() < 1e-3);
}
}
#[test]
fn astrocyte_high_spike_rate_strengthens() {
let mut gate = AstrocyteGate::new(4, 10.0); for _ in 0..50 {
gate.update(&[1, 0, 0, 0]);
}
assert!(
gate.modulation()[0] > 0.1,
"high-rate neuron should have positive modulation, got {}",
gate.modulation()[0]
);
assert!(
gate.modulation()[1] < 0.0,
"low-rate neuron should have negative modulation, got {}",
gate.modulation()[1]
);
}
#[test]
fn astrocyte_modulate_weight_bounded() {
let mut gate = AstrocyteGate::new(2, 10.0);
for _ in 0..100 {
gate.update(&[1, 0]);
}
let original: i16 = 1000;
let modulated = gate.modulate_weight(0, original);
assert!(
modulated > original,
"high-rate modulation should increase weight"
);
assert!(modulated < i16::MAX, "modulated weight should not overflow");
let modulated_low = gate.modulate_weight(1, original);
assert!(
modulated_low < original,
"low-rate modulation should decrease weight"
);
}
#[test]
fn astrocyte_reset() {
let mut gate = AstrocyteGate::new(4, 10.0);
for _ in 0..50 {
gate.update(&[1, 1, 1, 1]);
}
gate.reset();
for &r in gate.spike_rates() {
assert!((r - 0.1).abs() < 1e-10);
}
}
#[test]
fn astrocyte_modulate_zero_weight() {
let gate = AstrocyteGate::new(2, 1000.0);
assert_eq!(gate.modulate_weight(0, 0), 0);
}
#[test]
fn agmp_gates_learning_rate() {
use crate::snn::lif::f64_to_q14;
let mut gate = AstrocyteGate::with_mode(4, 10.0, AstrocyteMode::LearningRateGate);
for _ in 0..80 {
gate.update(&[1, 0, 0, 0]);
}
let eta = f64_to_q14(0.01);
let eta_0 = gate.effective_eta_q14(0, eta);
let eta_1 = gate.effective_eta_q14(1, eta);
assert!(
eta_0 > eta_1,
"high-rate neuron should have larger effective eta than silent neuron: \
eta_0={eta_0}, eta_1={eta_1}"
);
assert!(
eta_0 >= 0,
"effective eta must be non-negative, got {eta_0}"
);
assert!(
eta_1 >= 0,
"effective eta must be non-negative, got {eta_1}"
);
assert!(
eta_0 <= eta,
"effective eta must not exceed base eta: eta_0={eta_0} eta={eta}"
);
for &m in gate.modulation() {
assert!(
m > 0.0 && m < 1.0,
"LearningRateGate modulation must be in (0,1), got {m}"
);
}
}
#[test]
fn learning_rate_gate_mode_roundtrips() {
let gate = AstrocyteGate::with_mode(2, 500.0, AstrocyteMode::LearningRateGate);
assert_eq!(gate.mode(), AstrocyteMode::LearningRateGate);
let gate2 = AstrocyteGate::new(2, 500.0);
assert_eq!(gate2.mode(), AstrocyteMode::WeightMod);
}
}