use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct TrainingBatch {
pub sources: Vec<Vec<f32>>,
pub targets: Vec<Vec<f32>>,
pub expected_residuals: Vec<Vec<f32>>,
}
impl TrainingBatch {
pub fn new() -> Self {
Self {
sources: Vec::new(),
targets: Vec::new(),
expected_residuals: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
sources: Vec::with_capacity(capacity),
targets: Vec::with_capacity(capacity),
expected_residuals: Vec::with_capacity(capacity),
}
}
pub fn add(&mut self, source: Vec<f32>, target: Vec<f32>, expected: Vec<f32>) {
self.sources.push(source);
self.targets.push(target);
self.expected_residuals.push(expected);
}
pub fn len(&self) -> usize {
self.sources.len()
}
pub fn is_empty(&self) -> bool {
self.sources.is_empty()
}
pub fn clear(&mut self) {
self.sources.clear();
self.targets.clear();
self.expected_residuals.clear();
}
}
impl Default for TrainingBatch {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct ReplayBuffer {
experiences: VecDeque<Experience>,
capacity: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
experiences: VecDeque::with_capacity(capacity),
capacity,
}
}
pub fn add(&mut self, source: Vec<f32>, target: Vec<f32>, expected: Vec<f32>) {
if self.experiences.len() >= self.capacity {
self.experiences.pop_front();
}
self.experiences.push_back(Experience {
source,
target,
expected_residual: expected,
timestamp_ms: current_time_ms(),
});
}
pub fn sample(&self, batch_size: usize) -> TrainingBatch {
let mut batch = TrainingBatch::with_capacity(batch_size);
if self.experiences.is_empty() {
return batch;
}
let seed = current_time_ms();
let n = self.experiences.len();
for i in 0..batch_size.min(n) {
let idx = ((seed.wrapping_mul(6364136223846793005).wrapping_add(i as u64)) % n as u64) as usize;
let exp = &self.experiences[idx];
batch.add(
exp.source.clone(),
exp.target.clone(),
exp.expected_residual.clone(),
);
}
batch
}
pub fn len(&self) -> usize {
self.experiences.len()
}
pub fn is_empty(&self) -> bool {
self.experiences.is_empty()
}
pub fn clear(&mut self) {
self.experiences.clear();
}
}
#[derive(Debug, Clone)]
struct Experience {
source: Vec<f32>,
target: Vec<f32>,
expected_residual: Vec<f32>,
timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub loss: f32,
pub ewc_loss: f32,
pub total_loss: f32,
pub gradient_norm: f32,
pub learning_rate: f32,
pub batch_size: usize,
pub step: usize,
}
impl TrainingMetrics {
pub fn new(
loss: f32,
ewc_loss: f32,
gradient_norm: f32,
learning_rate: f32,
batch_size: usize,
step: usize,
) -> Self {
Self {
loss,
ewc_loss,
total_loss: loss + ewc_loss,
gradient_norm,
learning_rate,
batch_size,
step,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub avg_loss: f32,
pub avg_ewc_loss: f32,
pub batches: usize,
pub samples: usize,
pub epoch: usize,
pub duration_ms: u64,
}
impl TrainingResult {
pub fn from_metrics(metrics: &[TrainingMetrics], epoch: usize, duration_ms: u64) -> Self {
let n = metrics.len() as f32;
Self {
avg_loss: metrics.iter().map(|m| m.loss).sum::<f32>() / n.max(1.0),
avg_ewc_loss: metrics.iter().map(|m| m.ewc_loss).sum::<f32>() / n.max(1.0),
batches: metrics.len(),
samples: metrics.iter().map(|m| m.batch_size).sum(),
epoch,
duration_ms,
}
}
}
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_training_batch() {
let mut batch = TrainingBatch::new();
batch.add(vec![1.0, 2.0], vec![3.0, 4.0], vec![0.1, 0.2]);
assert_eq!(batch.len(), 1);
assert!(!batch.is_empty());
batch.clear();
assert!(batch.is_empty());
}
#[test]
fn test_replay_buffer() {
let mut buffer = ReplayBuffer::new(100);
for i in 0..50 {
buffer.add(
vec![i as f32],
vec![i as f32 + 1.0],
vec![0.1],
);
}
assert_eq!(buffer.len(), 50);
let batch = buffer.sample(10);
assert_eq!(batch.len(), 10);
}
#[test]
fn test_replay_buffer_overflow() {
let mut buffer = ReplayBuffer::new(10);
for i in 0..20 {
buffer.add(vec![i as f32], vec![i as f32], vec![0.0]);
}
assert_eq!(buffer.len(), 10);
}
}