use crate::{NervousSystemError, Result};
use parking_lot::RwLock;
use std::sync::Arc;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct Experience {
pub input: Vec<f32>,
pub target: Vec<f32>,
pub importance: f32,
}
impl Experience {
pub fn new(input: Vec<f32>, target: Vec<f32>, importance: f32) -> Self {
Self {
input,
target,
importance,
}
}
}
#[derive(Debug, Clone)]
pub struct EWC {
pub(crate) fisher_diag: Vec<f32>,
optimal_params: Vec<f32>,
lambda: f32,
num_samples: usize,
}
impl EWC {
pub fn new(lambda: f32) -> Self {
Self {
fisher_diag: Vec::new(),
optimal_params: Vec::new(),
lambda,
num_samples: 0,
}
}
pub fn compute_fisher(&mut self, params: &[f32], gradients: &[Vec<f32>]) -> Result<()> {
if gradients.is_empty() {
return Err(NervousSystemError::InvalidGradients(
"No gradient samples provided".to_string(),
));
}
let num_params = params.len();
let num_samples = gradients.len();
for (_i, grad) in gradients.iter().enumerate() {
if grad.len() != num_params {
return Err(NervousSystemError::DimensionMismatch {
expected: num_params,
actual: grad.len(),
});
}
}
self.fisher_diag = vec![0.0; num_params];
self.num_samples = num_samples;
#[cfg(feature = "parallel")]
{
self.fisher_diag = (0..num_params)
.into_par_iter()
.map(|i| {
let sum_sq: f32 = gradients.iter().map(|g| g[i] * g[i]).sum();
sum_sq / num_samples as f32
})
.collect();
}
#[cfg(not(feature = "parallel"))]
{
for i in 0..num_params {
let sum_sq: f32 = gradients.iter().map(|g| g[i] * g[i]).sum();
self.fisher_diag[i] = sum_sq / num_samples as f32;
}
}
self.optimal_params = params.to_vec();
Ok(())
}
pub fn ewc_loss(&self, current_params: &[f32]) -> f32 {
if self.fisher_diag.is_empty() {
return 0.0; }
#[cfg(feature = "parallel")]
{
let sum: f32 = current_params
.par_iter()
.zip(self.optimal_params.par_iter())
.zip(self.fisher_diag.par_iter())
.map(|((curr, opt), fisher)| {
let diff = curr - opt;
fisher * diff * diff
})
.sum();
(self.lambda / 2.0) * sum
}
#[cfg(not(feature = "parallel"))]
{
let sum: f32 = current_params
.iter()
.zip(self.optimal_params.iter())
.zip(self.fisher_diag.iter())
.map(|((curr, opt), fisher)| {
let diff = curr - opt;
fisher * diff * diff
})
.sum();
(self.lambda / 2.0) * sum
}
}
pub fn ewc_gradient(&self, current_params: &[f32]) -> Vec<f32> {
if self.fisher_diag.is_empty() {
return vec![0.0; current_params.len()];
}
#[cfg(feature = "parallel")]
{
current_params
.par_iter()
.zip(self.optimal_params.par_iter())
.zip(self.fisher_diag.par_iter())
.map(|((curr, opt), fisher)| self.lambda * fisher * (curr - opt))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
current_params
.iter()
.zip(self.optimal_params.iter())
.zip(self.fisher_diag.iter())
.map(|((curr, opt), fisher)| self.lambda * fisher * (curr - opt))
.collect()
}
}
pub fn num_params(&self) -> usize {
self.fisher_diag.len()
}
pub fn lambda(&self) -> f32 {
self.lambda
}
pub fn num_samples(&self) -> usize {
self.num_samples
}
pub fn is_initialized(&self) -> bool {
!self.fisher_diag.is_empty()
}
}
#[derive(Debug)]
struct RingBuffer<T> {
buffer: Vec<Option<T>>,
capacity: usize,
head: usize,
size: usize,
}
impl<T> RingBuffer<T> {
fn new(capacity: usize) -> Self {
Self {
buffer: (0..capacity).map(|_| None).collect(),
capacity,
head: 0,
size: 0,
}
}
fn push(&mut self, item: T) {
self.buffer[self.head] = Some(item);
self.head = (self.head + 1) % self.capacity;
if self.size < self.capacity {
self.size += 1;
}
}
fn sample(&self, n: usize) -> Vec<&T> {
use rand::seq::SliceRandom;
let mut rng = rand::thread_rng();
let valid_items: Vec<&T> = self.buffer.iter().filter_map(|opt| opt.as_ref()).collect();
valid_items
.choose_multiple(&mut rng, n.min(valid_items.len()))
.copied()
.collect()
}
fn len(&self) -> usize {
self.size
}
fn is_empty(&self) -> bool {
self.size == 0
}
fn clear(&mut self) {
self.buffer = (0..self.capacity).map(|_| None).collect();
self.head = 0;
self.size = 0;
}
}
#[derive(Debug)]
pub struct ComplementaryLearning {
hippocampus: Arc<RwLock<RingBuffer<Experience>>>,
neocortex_params: Vec<f32>,
ewc: EWC,
replay_batch_size: usize,
}
impl ComplementaryLearning {
pub fn new(param_size: usize, buffer_size: usize, lambda: f32) -> Self {
Self {
hippocampus: Arc::new(RwLock::new(RingBuffer::new(buffer_size))),
neocortex_params: vec![0.0; param_size],
ewc: EWC::new(lambda),
replay_batch_size: 32,
}
}
pub fn store_experience(&self, exp: Experience) {
self.hippocampus.write().push(exp);
}
pub fn consolidate(&mut self, iterations: usize, lr: f32) -> Result<f32> {
let mut total_loss = 0.0;
for _ in 0..iterations {
let num_experiences = {
let hippo = self.hippocampus.read();
hippo.len().min(self.replay_batch_size)
};
if num_experiences == 0 {
continue;
}
let sampled_experiences: Vec<Experience> = {
let hippo = self.hippocampus.read();
hippo
.sample(self.replay_batch_size)
.into_iter()
.map(|e| e.clone())
.collect()
};
let mut batch_loss = 0.0;
for exp in &sampled_experiences {
let prediction = &self.neocortex_params[0..exp.target.len()];
let loss: f32 = prediction
.iter()
.zip(exp.target.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f32>()
/ exp.target.len() as f32;
batch_loss += loss * exp.importance;
for i in 0..exp.target.len().min(self.neocortex_params.len()) {
let grad =
2.0 * (self.neocortex_params[i] - exp.target[i]) / exp.target.len() as f32;
let ewc_grad = if self.ewc.is_initialized() {
self.ewc.ewc_gradient(&self.neocortex_params)[i]
} else {
0.0
};
self.neocortex_params[i] -= lr * (grad + ewc_grad);
}
}
total_loss += batch_loss / sampled_experiences.len() as f32;
}
Ok(total_loss / iterations as f32)
}
pub fn interleaved_training(&mut self, new_data: &[Experience], lr: f32) -> Result<()> {
for exp in new_data {
self.store_experience(exp.clone());
}
let replay_ratio = 0.5; let num_replay = (new_data.len() as f32 * replay_ratio) as usize;
if num_replay > 0 {
self.consolidate(num_replay, lr)?;
}
Ok(())
}
pub fn clear_hippocampus(&self) {
self.hippocampus.write().clear();
}
pub fn hippocampus_size(&self) -> usize {
self.hippocampus.read().len()
}
pub fn neocortex_params(&self) -> &[f32] {
&self.neocortex_params
}
pub fn update_ewc(&mut self, gradients: &[Vec<f32>]) -> Result<()> {
self.ewc.compute_fisher(&self.neocortex_params, gradients)
}
}
#[derive(Debug)]
pub struct RewardConsolidation {
ewc: EWC,
reward_trace: f32,
tau_reward: f32,
threshold: f32,
base_lambda: f32,
}
impl RewardConsolidation {
pub fn new(base_lambda: f32, tau_reward: f32, threshold: f32) -> Self {
Self {
ewc: EWC::new(base_lambda),
reward_trace: 0.0,
tau_reward,
threshold,
base_lambda,
}
}
pub fn modulate(&mut self, reward: f32, dt: f32) {
let alpha = 1.0 - (-dt / self.tau_reward).exp();
self.reward_trace = (1.0 - alpha) * self.reward_trace + alpha * reward;
let lambda_scale = 1.0 + (self.reward_trace / self.threshold).max(0.0);
self.ewc.lambda = self.base_lambda * lambda_scale;
}
pub fn should_consolidate(&self) -> bool {
self.reward_trace >= self.threshold
}
pub fn reward_trace(&self) -> f32 {
self.reward_trace
}
pub fn ewc(&self) -> &EWC {
&self.ewc
}
pub fn ewc_mut(&mut self) -> &mut EWC {
&mut self.ewc
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ewc_creation() {
let ewc = EWC::new(1000.0);
assert_eq!(ewc.lambda(), 1000.0);
assert!(!ewc.is_initialized());
}
#[test]
fn test_ewc_fisher_computation() {
let mut ewc = EWC::new(1000.0);
let params = vec![0.5; 10];
let gradients: Vec<Vec<f32>> = vec![vec![0.1; 10]; 5];
ewc.compute_fisher(¶ms, &gradients).unwrap();
assert!(ewc.is_initialized());
assert_eq!(ewc.num_params(), 10);
assert_eq!(ewc.num_samples(), 5);
}
#[test]
fn test_ewc_loss_gradient() {
let mut ewc = EWC::new(1000.0);
let params = vec![0.5; 10];
let gradients: Vec<Vec<f32>> = vec![vec![0.1; 10]; 5];
ewc.compute_fisher(¶ms, &gradients).unwrap();
let new_params = vec![0.6; 10];
let loss = ewc.ewc_loss(&new_params);
let grad = ewc.ewc_gradient(&new_params);
assert!(loss > 0.0);
assert_eq!(grad.len(), 10);
assert!(grad.iter().all(|&g| g > 0.0)); }
#[test]
fn test_complementary_learning() {
let mut cls = ComplementaryLearning::new(10, 100, 1000.0);
let exp = Experience::new(vec![1.0; 5], vec![0.5; 5], 1.0);
cls.store_experience(exp);
assert_eq!(cls.hippocampus_size(), 1);
let result = cls.consolidate(10, 0.01);
assert!(result.is_ok());
}
#[test]
fn test_reward_consolidation() {
let mut rc = RewardConsolidation::new(1000.0, 1.0, 0.5);
assert!(!rc.should_consolidate());
rc.modulate(1.0, 0.1);
assert!(rc.reward_trace() > 0.0);
for _ in 0..10 {
rc.modulate(1.0, 0.1);
}
assert!(rc.should_consolidate());
}
#[test]
fn test_ring_buffer() {
let mut buffer: RingBuffer<i32> = RingBuffer::new(3);
buffer.push(1);
buffer.push(2);
buffer.push(3);
assert_eq!(buffer.len(), 3);
buffer.push(4); assert_eq!(buffer.len(), 3);
let samples = buffer.sample(2);
assert_eq!(samples.len(), 2);
}
#[test]
fn test_interleaved_training() {
let mut cls = ComplementaryLearning::new(10, 100, 1000.0);
let new_data = vec![
Experience::new(vec![1.0; 5], vec![0.5; 5], 1.0),
Experience::new(vec![0.8; 5], vec![0.4; 5], 1.0),
];
let result = cls.interleaved_training(&new_data, 0.01);
assert!(result.is_ok());
assert!(cls.hippocampus_size() > 0);
}
}