use crate::plasticity::consolidate::EWC;
use crate::{NervousSystemError, Result};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EligibilityState {
pub trace: f32,
pub last_update: u64,
pub tau: f32,
}
impl EligibilityState {
pub fn new(tau: f32) -> Self {
Self {
trace: 0.0,
last_update: 0,
tau,
}
}
pub fn update(&mut self, value: f32, timestamp: u64) {
if self.last_update > 0 {
let dt = (timestamp - self.last_update) as f32;
self.trace *= (-dt / self.tau).exp();
}
self.trace += value;
self.last_update = timestamp;
}
pub fn trace(&self) -> f32 {
self.trace
}
}
#[derive(Debug, Clone)]
pub struct ConsolidationSchedule {
pub replay_interval_secs: u64,
pub batch_size: usize,
pub learning_rate: f32,
pub last_consolidation: u64,
}
impl Default for ConsolidationSchedule {
fn default() -> Self {
Self {
replay_interval_secs: 3600, batch_size: 32,
learning_rate: 0.01,
last_consolidation: 0,
}
}
}
impl ConsolidationSchedule {
pub fn new(interval_secs: u64, batch_size: usize, learning_rate: f32) -> Self {
Self {
replay_interval_secs: interval_secs,
batch_size,
learning_rate,
last_consolidation: 0,
}
}
pub fn should_consolidate(&self, current_time: u64) -> bool {
if self.last_consolidation == 0 {
return false; }
current_time - self.last_consolidation >= self.replay_interval_secs
}
}
#[derive(Debug, Clone)]
pub struct ParameterVersion {
pub collection_id: u64,
pub version: u32,
pub eligibility_windows: HashMap<u64, EligibilityState>,
pub fisher_diagonal: Option<Vec<f32>>,
pub created_at: u64,
tau: f32,
}
impl ParameterVersion {
pub fn new(collection_id: u64, version: u32, created_at: u64) -> Self {
Self {
collection_id,
version,
eligibility_windows: HashMap::new(),
fisher_diagonal: None,
created_at,
tau: 2000.0, }
}
pub fn with_tau(mut self, tau: f32) -> Self {
self.tau = tau;
self
}
pub fn update_eligibility(&mut self, param_id: u64, value: f32, timestamp: u64) {
self.eligibility_windows
.entry(param_id)
.or_insert_with(|| EligibilityState::new(self.tau))
.update(value, timestamp);
}
pub fn get_eligibility(&self, param_id: u64) -> f32 {
self.eligibility_windows
.get(¶m_id)
.map(|state| state.trace())
.unwrap_or(0.0)
}
pub fn set_fisher(&mut self, fisher: Vec<f32>) {
self.fisher_diagonal = Some(fisher);
}
pub fn has_fisher(&self) -> bool {
self.fisher_diagonal.is_some()
}
}
pub struct CollectionVersioning {
collection_id: u64,
version: u32,
current_params: Vec<f32>,
versions: HashMap<u32, ParameterVersion>,
ewc: EWC,
consolidation_policy: ConsolidationSchedule,
}
impl CollectionVersioning {
pub fn new(collection_id: u64, consolidation_policy: ConsolidationSchedule) -> Self {
Self {
collection_id,
version: 0,
current_params: Vec::new(),
versions: HashMap::new(),
ewc: EWC::new(1000.0), consolidation_policy,
}
}
pub fn with_lambda(mut self, lambda: f32) -> Self {
self.ewc = EWC::new(lambda);
self
}
pub fn bump_version(&mut self) {
self.version += 1;
let timestamp = current_timestamp_ms();
let param_version = ParameterVersion::new(self.collection_id, self.version, timestamp);
self.versions.insert(self.version, param_version);
}
pub fn update_parameters(&mut self, params: &[f32]) {
self.current_params = params.to_vec();
}
pub fn current_parameters(&self) -> &[f32] {
&self.current_params
}
pub fn apply_ewc(&self, base_gradient: &[f32]) -> Vec<f32> {
if !self.ewc.is_initialized() {
return base_gradient.to_vec();
}
let ewc_grad = self.ewc.ewc_gradient(&self.current_params);
base_gradient
.iter()
.zip(ewc_grad.iter())
.map(|(base, ewc)| base + ewc)
.collect()
}
pub fn should_consolidate(&self, current_time: u64) -> bool {
self.consolidation_policy.should_consolidate(current_time)
}
pub fn consolidate(&mut self, gradients: &[Vec<f32>], current_time: u64) -> Result<()> {
self.ewc.compute_fisher(&self.current_params, gradients)?;
self.consolidation_policy.last_consolidation = current_time;
if let Some(version) = self.versions.get_mut(&self.version) {
if !self.ewc.fisher_diag.is_empty() {
version.set_fisher(self.ewc.fisher_diag.clone());
}
}
Ok(())
}
pub fn version(&self) -> u32 {
self.version
}
pub fn collection_id(&self) -> u64 {
self.collection_id
}
pub fn get_version(&self, version: u32) -> Option<&ParameterVersion> {
self.versions.get(&version)
}
pub fn ewc_loss(&self) -> f32 {
self.ewc.ewc_loss(&self.current_params)
}
pub fn update_eligibility(&mut self, param_id: u64, value: f32) {
let timestamp = current_timestamp_ms();
if let Some(version) = self.versions.get_mut(&self.version) {
version.update_eligibility(param_id, value, timestamp);
}
}
pub fn consolidation_schedule(&self) -> &ConsolidationSchedule {
&self.consolidation_policy
}
}
fn current_timestamp_ms() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eligibility_state() {
let mut state = EligibilityState::new(1000.0);
state.update(1.0, 100); assert_eq!(state.trace(), 1.0);
state.update(0.0, 1100); assert!(
state.trace() > 0.3 && state.trace() < 0.4,
"trace: {}",
state.trace()
);
}
#[test]
fn test_consolidation_schedule() {
let mut schedule = ConsolidationSchedule::new(3600, 32, 0.01);
assert!(!schedule.should_consolidate(0));
schedule.last_consolidation = 1; assert!(schedule.should_consolidate(7201));
schedule.last_consolidation = 7200;
assert!(!schedule.should_consolidate(7200));
}
#[test]
fn test_parameter_version() {
let mut version = ParameterVersion::new(1, 0, 0);
version.update_eligibility(0, 1.0, 100);
version.update_eligibility(1, 0.5, 100);
assert_eq!(version.get_eligibility(0), 1.0);
assert_eq!(version.get_eligibility(1), 0.5);
assert_eq!(version.get_eligibility(999), 0.0);
assert!(!version.has_fisher());
version.set_fisher(vec![0.1; 10]);
assert!(version.has_fisher());
}
#[test]
fn test_collection_versioning() {
let schedule = ConsolidationSchedule::default();
let mut versioning = CollectionVersioning::new(1, schedule);
assert_eq!(versioning.version(), 0);
versioning.bump_version();
assert_eq!(versioning.version(), 1);
versioning.bump_version();
assert_eq!(versioning.version(), 2);
}
#[test]
fn test_update_parameters() {
let schedule = ConsolidationSchedule::default();
let mut versioning = CollectionVersioning::new(1, schedule);
let params = vec![0.5; 100];
versioning.update_parameters(¶ms);
assert_eq!(versioning.current_parameters(), ¶ms);
}
#[test]
fn test_consolidation() {
let schedule = ConsolidationSchedule::new(10, 32, 0.01);
let mut versioning = CollectionVersioning::new(1, schedule);
versioning.bump_version();
let params = vec![0.5; 50];
versioning.update_parameters(¶ms);
let gradients: Vec<Vec<f32>> = vec![vec![0.1; 50]; 10];
let result = versioning.consolidate(&gradients, 5);
assert!(result.is_ok());
assert!(!versioning.should_consolidate(5));
assert!(versioning.should_consolidate(20));
}
#[test]
fn test_ewc_integration() {
let schedule = ConsolidationSchedule::default();
let mut versioning =
CollectionVersioning::with_lambda(CollectionVersioning::new(1, schedule), 1000.0);
versioning.bump_version();
let params = vec![0.5; 20];
versioning.update_parameters(¶ms);
let gradients: Vec<Vec<f32>> = vec![vec![0.1; 20]; 5];
versioning.consolidate(&gradients, 0).unwrap();
let new_params = vec![0.6; 20];
versioning.update_parameters(&new_params);
let loss = versioning.ewc_loss();
assert!(loss > 0.0, "EWC loss should be positive");
let base_grad = vec![0.1; 20];
let modified_grad = versioning.apply_ewc(&base_grad);
assert_eq!(modified_grad.len(), 20);
assert!(modified_grad.iter().any(|&g| g != 0.1));
}
#[test]
fn test_eligibility_tracking() {
let schedule = ConsolidationSchedule::default();
let mut versioning = CollectionVersioning::new(1, schedule);
versioning.bump_version();
versioning.update_eligibility(0, 1.0);
versioning.update_eligibility(1, 0.5);
let version = versioning.get_version(1).unwrap();
assert!(version.get_eligibility(0) > 0.9);
assert!(version.get_eligibility(1) > 0.4);
}
#[test]
fn test_multiple_versions() {
let schedule = ConsolidationSchedule::default();
let mut versioning = CollectionVersioning::new(1, schedule);
for v in 1..=5 {
versioning.bump_version();
assert_eq!(versioning.version(), v);
let version = versioning.get_version(v);
assert!(version.is_some());
assert_eq!(version.unwrap().version, v);
}
}
}