use alloc::vec::Vec;
use super::ContinualStrategy;
use crate::drift::DriftSignal;
use crate::math::{abs, ceil};
pub struct NeuronRegeneration {
utility: Vec<f64>,
utility_alpha: f64,
regen_fraction: f64,
regen_interval: u64,
group_size: usize,
n_params: usize,
n_updates: u64,
rng_state: u64,
regenerated_mask: Vec<bool>,
}
#[inline]
fn small_random(state: &mut u64) -> f64 {
let bits = crate::rng::xorshift64(state);
(bits as f64 / u64::MAX as f64 - 0.5) * 0.02
}
impl NeuronRegeneration {
pub fn new(
n_params: usize,
group_size: usize,
regen_fraction: f64,
regen_interval: u64,
utility_alpha: f64,
seed: u64,
) -> Self {
assert!(group_size > 0, "group_size must be >= 1");
assert!(n_params > 0, "n_params must be >= 1");
let n_groups = n_params.div_ceil(group_size);
Self {
utility: alloc::vec![0.0; n_groups],
utility_alpha,
regen_fraction: regen_fraction.clamp(0.0, 1.0),
regen_interval: regen_interval.max(1),
group_size,
n_params,
n_updates: 0,
rng_state: if seed == 0 { 1 } else { seed },
regenerated_mask: alloc::vec![false; n_groups],
}
}
pub fn with_defaults(n_params: usize) -> Self {
Self::new(n_params, 1, 0.01, 1000, 0.99, 0xDEAD_BEEF)
}
pub fn utility(&self) -> &[f64] {
&self.utility
}
pub fn n_groups(&self) -> usize {
self.utility.len()
}
pub fn was_regenerated(&self, group_idx: usize) -> bool {
self.regenerated_mask
.get(group_idx)
.copied()
.unwrap_or(false)
}
pub fn force_regenerate(&mut self, params: &mut [f64]) {
self.perform_regeneration(params);
}
pub fn n_updates(&self) -> u64 {
self.n_updates
}
fn utility_threshold(&self) -> f64 {
let n = self.utility.len();
if n == 0 {
return 0.0;
}
let k = (ceil(n as f64 * self.regen_fraction) as usize)
.max(1)
.min(n);
let mut sorted: Vec<f64> = self.utility.clone();
for i in 1..sorted.len() {
let key = sorted[i];
let mut j = i;
while j > 0 && sorted[j - 1] > key {
sorted[j] = sorted[j - 1];
j -= 1;
}
sorted[j] = key;
}
sorted[k - 1]
}
fn mean_utility(&self) -> f64 {
if self.utility.is_empty() {
return 0.0;
}
let sum: f64 = self.utility.iter().sum();
sum / self.utility.len() as f64
}
fn perform_regeneration(&mut self, params: &mut [f64]) {
let n_groups = self.n_groups();
if n_groups == 0 {
return;
}
let threshold = self.utility_threshold();
let mean_util = self.mean_utility();
let target_count = (ceil(n_groups as f64 * self.regen_fraction) as usize)
.max(1)
.min(n_groups);
let mut regen_count = 0;
for m in self.regenerated_mask.iter_mut() {
*m = false;
}
for g in 0..n_groups {
if regen_count >= target_count {
break;
}
if self.utility[g] <= threshold {
self.regenerated_mask[g] = true;
regen_count += 1;
let start = g * self.group_size;
let end = (start + self.group_size).min(params.len());
for p in &mut params[start..end] {
*p = small_random(&mut self.rng_state);
}
self.utility[g] = mean_util;
}
}
}
fn mark_for_regeneration(&mut self, gradients: &mut [f64]) {
let n_groups = self.n_groups();
if n_groups == 0 {
return;
}
let threshold = self.utility_threshold();
let mean_util = self.mean_utility();
let target_count = (ceil(n_groups as f64 * self.regen_fraction) as usize)
.max(1)
.min(n_groups);
let mut regen_count = 0;
for m in self.regenerated_mask.iter_mut() {
*m = false;
}
for g in 0..n_groups {
if regen_count >= target_count {
break;
}
if self.utility[g] <= threshold {
self.regenerated_mask[g] = true;
regen_count += 1;
let start = g * self.group_size;
let end = (start + self.group_size).min(gradients.len());
for grad in &mut gradients[start..end] {
*grad = 0.0;
}
self.utility[g] = mean_util;
}
}
}
}
impl ContinualStrategy for NeuronRegeneration {
fn pre_update(&mut self, params: &[f64], gradients: &mut [f64]) {
let _ = params; let n_groups = self.n_groups();
for g in 0..n_groups {
let start = g * self.group_size;
let end = (start + self.group_size).min(gradients.len());
let count = end - start;
if count == 0 {
continue;
}
let mut sum = 0.0;
for &grad in &gradients[start..end] {
sum += abs(grad);
}
let group_mag = sum / count as f64;
self.utility[g] =
self.utility_alpha * self.utility[g] + (1.0 - self.utility_alpha) * group_mag;
}
self.n_updates += 1;
if self.n_updates % self.regen_interval == 0 {
self.mark_for_regeneration(gradients);
}
}
fn post_update(&mut self, _params: &[f64]) {
let any_regenerated = self.regenerated_mask.iter().any(|&m| m);
if !any_regenerated {
return;
}
let n_groups = self.n_groups();
for g in 0..n_groups {
if self.regenerated_mask[g] {
let start = g * self.group_size;
let end = (start + self.group_size).min(self.n_params);
for _ in start..end {
let _ = crate::rng::xorshift64(&mut self.rng_state);
}
}
}
for m in self.regenerated_mask.iter_mut() {
*m = false;
}
}
fn on_drift(&mut self, _params: &[f64], signal: DriftSignal) {
match signal {
DriftSignal::Drift => {
let n_groups = self.n_groups();
if n_groups == 0 {
return;
}
let threshold = self.utility_threshold();
let mean_util = self.mean_utility();
let target_count = (ceil(n_groups as f64 * self.regen_fraction) as usize)
.max(1)
.min(n_groups);
let mut regen_count = 0;
for m in self.regenerated_mask.iter_mut() {
*m = false;
}
for g in 0..n_groups {
if regen_count >= target_count {
break;
}
if self.utility[g] <= threshold {
self.regenerated_mask[g] = true;
regen_count += 1;
self.utility[g] = mean_util;
}
}
}
DriftSignal::Warning | DriftSignal::Stable => {
}
}
}
fn n_params(&self) -> usize {
self.n_params
}
fn reset(&mut self) {
for u in self.utility.iter_mut() {
*u = 0.0;
}
for m in self.regenerated_mask.iter_mut() {
*m = false;
}
self.n_updates = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn initially_zero_utility() {
let regen = NeuronRegeneration::with_defaults(10);
assert_eq!(
regen.n_groups(),
10,
"should have 10 groups with group_size=1"
);
for &u in regen.utility() {
assert!(
u == 0.0,
"all utility values should start at 0.0, got {}",
u
);
}
}
#[test]
fn utility_tracks_gradient_magnitude() {
let mut regen = NeuronRegeneration::new(4, 1, 0.01, 10_000, 0.5, 42);
let params = vec![1.0; 4];
let mut grads = vec![0.0; 4];
grads[0] = 10.0;
grads[1] = 0.001;
grads[2] = 0.001;
grads[3] = 0.001;
for _ in 0..20 {
regen.pre_update(¶ms, &mut grads);
}
let u = regen.utility();
assert!(
u[0] > u[1],
"group 0 utility ({}) should exceed group 1 ({}), larger gradients mean higher utility",
u[0],
u[1]
);
assert!(
u[0] > u[2],
"group 0 utility ({}) should exceed group 2 ({})",
u[0],
u[2]
);
}
#[test]
fn regeneration_triggers_at_interval() {
let interval = 5;
let mut regen = NeuronRegeneration::new(10, 1, 0.5, interval, 0.99, 42);
let params = vec![1.0; 10];
let mut grads = vec![0.01; 10];
for _ in 0..(interval - 1) {
regen.pre_update(¶ms, &mut grads);
assert!(
!regen.regenerated_mask.iter().any(|&m| m),
"no regeneration should happen before regen_interval is reached"
);
}
regen.pre_update(¶ms, &mut grads);
assert!(
regen.regenerated_mask.iter().any(|&m| m),
"regeneration should trigger at step {}",
interval
);
}
#[test]
fn bottom_fraction_gets_regenerated() {
let mut regen = NeuronRegeneration::new(10, 1, 0.20, 5, 0.5, 42);
let params = vec![1.0; 10];
let mut grads = vec![5.0; 10];
grads[8] = 0.0001;
grads[9] = 0.0001;
for _ in 0..4 {
regen.pre_update(¶ms, &mut grads);
}
regen.pre_update(¶ms, &mut grads);
assert!(
regen.was_regenerated(8),
"group 8 (low utility) should be regenerated"
);
assert!(
regen.was_regenerated(9),
"group 9 (low utility) should be regenerated"
);
assert!(
!regen.was_regenerated(0),
"group 0 (high utility) should not be regenerated"
);
}
#[test]
fn regenerated_params_are_reinitialized() {
let mut regen = NeuronRegeneration::new(4, 2, 0.5, 1, 0.5, 42);
let mut params = vec![100.0; 4];
let mut grads = vec![0.0; 4];
grads[0] = 0.0;
grads[1] = 0.0;
grads[2] = 10.0;
grads[3] = 10.0;
regen.pre_update(¶ms, &mut grads);
regen.force_regenerate(&mut params);
assert!(
abs(params[0]) < 1.0,
"regenerated param[0] should be small random, got {}",
params[0]
);
assert!(
abs(params[1]) < 1.0,
"regenerated param[1] should be small random, got {}",
params[1]
);
}
#[test]
fn non_regenerated_params_unchanged() {
let mut regen = NeuronRegeneration::new(4, 2, 0.5, 1, 0.5, 42);
let mut params = vec![100.0; 4];
let mut grads = vec![0.0; 4];
grads[0] = 0.0;
grads[1] = 0.0;
grads[2] = 10.0;
grads[3] = 10.0;
regen.pre_update(¶ms, &mut grads);
let before_2 = params[2];
let before_3 = params[3];
regen.force_regenerate(&mut params);
assert_eq!(
params[2], before_2,
"non-regenerated param[2] should be unchanged"
);
assert_eq!(
params[3], before_3,
"non-regenerated param[3] should be unchanged"
);
}
#[test]
fn drift_forces_immediate_regeneration() {
let mut regen = NeuronRegeneration::new(10, 1, 0.20, 1_000_000, 0.5, 42);
let params = vec![1.0; 10];
for _ in 0..10 {
let mut grads = vec![5.0; 10];
grads[8] = 0.0001;
grads[9] = 0.0001;
regen.pre_update(¶ms, &mut grads);
}
assert!(
!regen.regenerated_mask.iter().any(|&m| m),
"should not have regenerated yet (interval is 1_000_000)"
);
regen.on_drift(¶ms, DriftSignal::Drift);
assert!(
regen.regenerated_mask.iter().any(|&m| m),
"drift signal should force immediate regeneration marking"
);
assert!(
regen.was_regenerated(8) || regen.was_regenerated(9),
"low-utility groups should be marked for regeneration on drift"
);
}
#[test]
fn reset_clears_all_state() {
let mut regen = NeuronRegeneration::new(10, 1, 0.5, 2, 0.5, 42);
let params = vec![1.0; 10];
let mut grads = vec![1.0; 10];
regen.pre_update(¶ms, &mut grads);
regen.pre_update(¶ms, &mut grads);
assert!(
regen.n_updates() > 0,
"should have incremented update counter"
);
regen.reset();
assert_eq!(regen.n_updates(), 0, "counter should be zeroed after reset");
for &u in regen.utility() {
assert!(u == 0.0, "utility should be 0.0 after reset, got {}", u);
}
for g in 0..regen.n_groups() {
assert!(
!regen.was_regenerated(g),
"regenerated mask should be cleared after reset"
);
}
}
#[test]
fn group_size_groups_params_correctly() {
let mut regen = NeuronRegeneration::new(12, 4, 0.33, 5, 0.5, 42);
assert_eq!(regen.n_groups(), 3, "12 params / 4 group_size = 3 groups");
let params = vec![1.0; 12];
let mut grads = vec![0.0; 12];
for g in grads.iter_mut().skip(4) {
*g = 5.0;
}
for _ in 0..5 {
regen.pre_update(¶ms, &mut grads);
}
assert!(
regen.was_regenerated(0),
"group 0 (zero gradients) should be regenerated"
);
assert!(
!regen.was_regenerated(1),
"group 1 (high gradients) should not be regenerated"
);
assert!(
!regen.was_regenerated(2),
"group 2 (high gradients) should not be regenerated"
);
let mut params_mut = vec![100.0; 12];
regen.force_regenerate(&mut params_mut);
for (i, &p) in params_mut.iter().enumerate().take(4) {
assert!(
abs(p) < 1.0,
"all params in regenerated group should be reinitialized, param[{}] = {}",
i,
p
);
}
for (i, &p) in params_mut.iter().enumerate().skip(4) {
assert_eq!(
p, 100.0,
"params outside regenerated group should be untouched, param[{}] = {}",
i, p
);
}
}
#[test]
fn high_utility_neurons_preserved() {
let mut regen = NeuronRegeneration::new(5, 1, 0.20, 3, 0.5, 42);
let params = vec![1.0; 5];
let mut grads = vec![0.001; 5];
grads[0] = 100.0;
for _ in 0..3 {
regen.pre_update(¶ms, &mut grads);
}
assert!(
!regen.was_regenerated(0),
"high-utility group 0 should be preserved, not regenerated"
);
let any_low_regenerated = (1..5).any(|g| regen.was_regenerated(g));
assert!(
any_low_regenerated,
"at least one low-utility group should be regenerated"
);
}
#[test]
fn warning_and_stable_signals_are_noop() {
let mut regen = NeuronRegeneration::new(10, 1, 0.5, 1_000_000, 0.5, 42);
let params = vec![1.0; 10];
let mut grads = vec![1.0; 10];
grads[9] = 0.0;
for _ in 0..5 {
regen.pre_update(¶ms, &mut grads);
}
regen.on_drift(¶ms, DriftSignal::Warning);
assert!(
!regen.regenerated_mask.iter().any(|&m| m),
"Warning signal should not trigger regeneration"
);
regen.on_drift(¶ms, DriftSignal::Stable);
assert!(
!regen.regenerated_mask.iter().any(|&m| m),
"Stable signal should not trigger regeneration"
);
}
#[test]
fn xorshift_produces_distinct_values() {
let mut state: u64 = 42;
let mut vals: Vec<f64> = Vec::new();
for _ in 0..100 {
vals.push(small_random(&mut state));
}
for &v in &vals {
assert!(
(-0.011..=0.011).contains(&v),
"small_random should produce values in ~[-0.01, 0.01], got {}",
v
);
}
let first = vals[0];
let all_same = vals.iter().all(|&v| (v - first).abs() < 1e-15);
assert!(!all_same, "PRNG should produce distinct values");
}
}