use crate::error::{Result, TinyDancerError};
use crate::model::{FastGRNN, FastGRNNConfig};
use ndarray::{Array1, Array2};
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub learning_rate: f32,
pub batch_size: usize,
pub epochs: usize,
pub validation_split: f32,
pub early_stopping_patience: Option<usize>,
pub lr_decay: f32,
pub lr_decay_step: usize,
pub grad_clip: f32,
pub adam_beta1: f32,
pub adam_beta2: f32,
pub adam_epsilon: f32,
pub l2_reg: f32,
pub enable_distillation: bool,
pub distillation_temperature: f32,
pub distillation_alpha: f32,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
batch_size: 32,
epochs: 100,
validation_split: 0.2,
early_stopping_patience: Some(10),
lr_decay: 0.5,
lr_decay_step: 20,
grad_clip: 5.0,
adam_beta1: 0.9,
adam_beta2: 0.999,
adam_epsilon: 1e-8,
l2_reg: 1e-5,
enable_distillation: false,
distillation_temperature: 3.0,
distillation_alpha: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingDataset {
pub features: Vec<Vec<f32>>,
pub labels: Vec<f32>,
pub soft_targets: Option<Vec<f32>>,
}
impl TrainingDataset {
pub fn new(features: Vec<Vec<f32>>, labels: Vec<f32>) -> Result<Self> {
if features.len() != labels.len() {
return Err(TinyDancerError::InvalidInput(
"Features and labels must have the same length".to_string(),
));
}
if features.is_empty() {
return Err(TinyDancerError::InvalidInput(
"Dataset cannot be empty".to_string(),
));
}
Ok(Self {
features,
labels,
soft_targets: None,
})
}
pub fn with_soft_targets(mut self, soft_targets: Vec<f32>) -> Result<Self> {
if soft_targets.len() != self.labels.len() {
return Err(TinyDancerError::InvalidInput(
"Soft targets must match dataset size".to_string(),
));
}
self.soft_targets = Some(soft_targets);
Ok(self)
}
pub fn split(&self, val_ratio: f32) -> Result<(Self, Self)> {
if !(0.0..=1.0).contains(&val_ratio) {
return Err(TinyDancerError::InvalidInput(
"Validation ratio must be between 0.0 and 1.0".to_string(),
));
}
let n_samples = self.features.len();
let n_val = (n_samples as f32 * val_ratio) as usize;
let n_train = n_samples - n_val;
let mut indices: Vec<usize> = (0..n_samples).collect();
let mut rng = rand::thread_rng();
indices.shuffle(&mut rng);
let train_indices = &indices[..n_train];
let val_indices = &indices[n_train..];
let train_features: Vec<Vec<f32>> = train_indices
.iter()
.map(|&i| self.features[i].clone())
.collect();
let train_labels: Vec<f32> = train_indices.iter().map(|&i| self.labels[i]).collect();
let val_features: Vec<Vec<f32>> = val_indices
.iter()
.map(|&i| self.features[i].clone())
.collect();
let val_labels: Vec<f32> = val_indices.iter().map(|&i| self.labels[i]).collect();
let mut train_dataset = Self::new(train_features, train_labels)?;
let mut val_dataset = Self::new(val_features, val_labels)?;
if let Some(soft_targets) = &self.soft_targets {
let train_soft: Vec<f32> = train_indices.iter().map(|&i| soft_targets[i]).collect();
let val_soft: Vec<f32> = val_indices.iter().map(|&i| soft_targets[i]).collect();
train_dataset.soft_targets = Some(train_soft);
val_dataset.soft_targets = Some(val_soft);
}
Ok((train_dataset, val_dataset))
}
pub fn normalize(&mut self) -> Result<(Vec<f32>, Vec<f32>)> {
if self.features.is_empty() {
return Err(TinyDancerError::InvalidInput(
"Cannot normalize empty dataset".to_string(),
));
}
let n_features = self.features[0].len();
let mut means = vec![0.0; n_features];
let mut stds = vec![0.0; n_features];
for feature in &self.features {
for (i, &val) in feature.iter().enumerate() {
means[i] += val;
}
}
for mean in &mut means {
*mean /= self.features.len() as f32;
}
for feature in &self.features {
for (i, &val) in feature.iter().enumerate() {
stds[i] += (val - means[i]).powi(2);
}
}
for std in &mut stds {
*std = (*std / self.features.len() as f32).sqrt();
if *std < 1e-8 {
*std = 1.0; }
}
for feature in &mut self.features {
for (i, val) in feature.iter_mut().enumerate() {
*val = (*val - means[i]) / stds[i];
}
}
Ok((means, stds))
}
pub fn len(&self) -> usize {
self.features.len()
}
pub fn is_empty(&self) -> bool {
self.features.is_empty()
}
}
pub struct BatchIterator<'a> {
dataset: &'a TrainingDataset,
batch_size: usize,
indices: Vec<usize>,
current_idx: usize,
}
impl<'a> BatchIterator<'a> {
pub fn new(dataset: &'a TrainingDataset, batch_size: usize, shuffle: bool) -> Self {
let mut indices: Vec<usize> = (0..dataset.len()).collect();
if shuffle {
let mut rng = rand::thread_rng();
indices.shuffle(&mut rng);
}
Self {
dataset,
batch_size,
indices,
current_idx: 0,
}
}
}
impl<'a> Iterator for BatchIterator<'a> {
type Item = (Vec<Vec<f32>>, Vec<f32>, Option<Vec<f32>>);
fn next(&mut self) -> Option<Self::Item> {
if self.current_idx >= self.indices.len() {
return None;
}
let end_idx = (self.current_idx + self.batch_size).min(self.indices.len());
let batch_indices = &self.indices[self.current_idx..end_idx];
let features: Vec<Vec<f32>> = batch_indices
.iter()
.map(|&i| self.dataset.features[i].clone())
.collect();
let labels: Vec<f32> = batch_indices
.iter()
.map(|&i| self.dataset.labels[i])
.collect();
let soft_targets = self
.dataset
.soft_targets
.as_ref()
.map(|targets| batch_indices.iter().map(|&i| targets[i]).collect());
self.current_idx = end_idx;
Some((features, labels, soft_targets))
}
}
#[derive(Debug)]
struct AdamOptimizer {
m_weights: Vec<Array2<f32>>,
m_biases: Vec<Array1<f32>>,
v_weights: Vec<Array2<f32>>,
v_biases: Vec<Array1<f32>>,
t: usize,
beta1: f32,
beta2: f32,
epsilon: f32,
}
impl AdamOptimizer {
fn new(model_config: &FastGRNNConfig, training_config: &TrainingConfig) -> Self {
let hidden_dim = model_config.hidden_dim;
let input_dim = model_config.input_dim;
let output_dim = model_config.output_dim;
Self {
m_weights: vec![
Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, hidden_dim)), Array2::zeros((output_dim, hidden_dim)), ],
m_biases: vec![
Array1::zeros(hidden_dim), Array1::zeros(hidden_dim), Array1::zeros(hidden_dim), Array1::zeros(output_dim), ],
v_weights: vec![
Array2::zeros((hidden_dim, input_dim)),
Array2::zeros((hidden_dim, input_dim)),
Array2::zeros((hidden_dim, input_dim)),
Array2::zeros((hidden_dim, hidden_dim)),
Array2::zeros((output_dim, hidden_dim)),
],
v_biases: vec![
Array1::zeros(hidden_dim),
Array1::zeros(hidden_dim),
Array1::zeros(hidden_dim),
Array1::zeros(output_dim),
],
t: 0,
beta1: training_config.adam_beta1,
beta2: training_config.adam_beta2,
epsilon: training_config.adam_epsilon,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub epoch: usize,
pub train_loss: f32,
pub val_loss: f32,
pub train_accuracy: f32,
pub val_accuracy: f32,
pub learning_rate: f32,
}
pub struct Trainer {
config: TrainingConfig,
optimizer: AdamOptimizer,
best_val_loss: f32,
patience_counter: usize,
metrics_history: Vec<TrainingMetrics>,
}
impl Trainer {
pub fn new(model_config: &FastGRNNConfig, config: TrainingConfig) -> Self {
let optimizer = AdamOptimizer::new(model_config, &config);
Self {
config,
optimizer,
best_val_loss: f32::INFINITY,
patience_counter: 0,
metrics_history: Vec::new(),
}
}
pub fn train(
&mut self,
model: &mut FastGRNN,
dataset: &TrainingDataset,
) -> Result<Vec<TrainingMetrics>> {
let (train_dataset, val_dataset) = dataset.split(self.config.validation_split)?;
println!("Training FastGRNN model");
println!(
"Train samples: {}, Val samples: {}",
train_dataset.len(),
val_dataset.len()
);
println!("Hyperparameters: {:?}", self.config);
let mut current_lr = self.config.learning_rate;
for epoch in 0..self.config.epochs {
if epoch > 0 && epoch % self.config.lr_decay_step == 0 {
current_lr *= self.config.lr_decay;
println!("Decaying learning rate to {:.6}", current_lr);
}
let train_loss = self.train_epoch(model, &train_dataset, current_lr)?;
let (val_loss, val_accuracy) = self.evaluate(model, &val_dataset)?;
let (_, train_accuracy) = self.evaluate(model, &train_dataset)?;
let metrics = TrainingMetrics {
epoch,
train_loss,
val_loss,
train_accuracy,
val_accuracy,
learning_rate: current_lr,
};
self.metrics_history.push(metrics.clone());
println!(
"Epoch {}/{}: train_loss={:.4}, val_loss={:.4}, train_acc={:.4}, val_acc={:.4}",
epoch + 1,
self.config.epochs,
train_loss,
val_loss,
train_accuracy,
val_accuracy
);
if let Some(patience) = self.config.early_stopping_patience {
if val_loss < self.best_val_loss {
self.best_val_loss = val_loss;
self.patience_counter = 0;
println!("New best validation loss: {:.4}", val_loss);
} else {
self.patience_counter += 1;
if self.patience_counter >= patience {
println!("Early stopping triggered at epoch {}", epoch + 1);
break;
}
}
}
}
Ok(self.metrics_history.clone())
}
fn train_epoch(
&mut self,
model: &mut FastGRNN,
dataset: &TrainingDataset,
learning_rate: f32,
) -> Result<f32> {
let mut total_loss = 0.0;
let mut n_batches = 0;
let batch_iter = BatchIterator::new(dataset, self.config.batch_size, true);
for (features, labels, soft_targets) in batch_iter {
let batch_loss = self.train_batch(
model,
&features,
&labels,
soft_targets.as_ref(),
learning_rate,
)?;
total_loss += batch_loss;
n_batches += 1;
}
Ok(total_loss / n_batches as f32)
}
fn train_batch(
&mut self,
model: &mut FastGRNN,
features: &[Vec<f32>],
labels: &[f32],
soft_targets: Option<&Vec<f32>>,
learning_rate: f32,
) -> Result<f32> {
let batch_size = features.len();
let mut total_loss = 0.0;
for (i, feature) in features.iter().enumerate() {
let prediction = model.forward(feature, None)?;
let target = labels[i];
let loss = if self.config.enable_distillation {
if let Some(soft_targets) = soft_targets {
let hard_loss = binary_cross_entropy(prediction, target);
let soft_loss = binary_cross_entropy(prediction, soft_targets[i]);
self.config.distillation_alpha * soft_loss
+ (1.0 - self.config.distillation_alpha) * hard_loss
} else {
binary_cross_entropy(prediction, target)
}
} else {
binary_cross_entropy(prediction, target)
};
total_loss += loss;
}
self.apply_gradients(model, learning_rate)?;
Ok(total_loss / batch_size as f32)
}
fn apply_gradients(&mut self, _model: &mut FastGRNN, _learning_rate: f32) -> Result<()> {
self.optimizer.t += 1;
Ok(())
}
fn evaluate(&self, model: &FastGRNN, dataset: &TrainingDataset) -> Result<(f32, f32)> {
let mut total_loss = 0.0;
let mut correct = 0;
for (i, feature) in dataset.features.iter().enumerate() {
let prediction = model.forward(feature, None)?;
let target = dataset.labels[i];
let loss = binary_cross_entropy(prediction, target);
total_loss += loss;
let predicted_class = if prediction >= 0.5 { 1.0_f32 } else { 0.0_f32 };
let target_class = if target >= 0.5 { 1.0_f32 } else { 0.0_f32 };
if (predicted_class - target_class).abs() < 0.01_f32 {
correct += 1;
}
}
let avg_loss = total_loss / dataset.len() as f32;
let accuracy = correct as f32 / dataset.len() as f32;
Ok((avg_loss, accuracy))
}
pub fn metrics_history(&self) -> &[TrainingMetrics] {
&self.metrics_history
}
pub fn save_metrics<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = serde_json::to_string_pretty(&self.metrics_history)
.map_err(|e| TinyDancerError::SerializationError(e.to_string()))?;
std::fs::write(path, json)?;
Ok(())
}
}
fn binary_cross_entropy(prediction: f32, target: f32) -> f32 {
let eps = 1e-7;
let pred = prediction.clamp(eps, 1.0 - eps);
-target * pred.ln() - (1.0 - target) * (1.0 - pred).ln()
}
pub fn temperature_softmax(logit: f32, temperature: f32) -> f32 {
let scaled = logit / temperature;
if scaled > 0.0 {
1.0 / (1.0 + (-scaled).exp())
} else {
let ex = scaled.exp();
ex / (1.0 + ex)
}
}
pub fn generate_teacher_predictions(
teacher: &FastGRNN,
features: &[Vec<f32>],
temperature: f32,
) -> Result<Vec<f32>> {
features
.iter()
.map(|feature| {
let logit = teacher.forward(feature, None)?;
Ok(temperature_softmax(logit, temperature))
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataset_creation() {
let features = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let labels = vec![0.0, 1.0, 0.0];
let dataset = TrainingDataset::new(features, labels).unwrap();
assert_eq!(dataset.len(), 3);
}
#[test]
fn test_dataset_split() {
let features = vec![vec![1.0; 5]; 100];
let labels = vec![0.0; 100];
let dataset = TrainingDataset::new(features, labels).unwrap();
let (train, val) = dataset.split(0.2).unwrap();
assert_eq!(train.len(), 80);
assert_eq!(val.len(), 20);
}
#[test]
fn test_batch_iterator() {
let features = vec![vec![1.0; 5]; 10];
let labels = vec![0.0; 10];
let dataset = TrainingDataset::new(features, labels).unwrap();
let mut iter = BatchIterator::new(&dataset, 3, false);
let batch1 = iter.next().unwrap();
assert_eq!(batch1.0.len(), 3);
let batch2 = iter.next().unwrap();
assert_eq!(batch2.0.len(), 3);
let batch3 = iter.next().unwrap();
assert_eq!(batch3.0.len(), 3);
let batch4 = iter.next().unwrap();
assert_eq!(batch4.0.len(), 1);
assert!(iter.next().is_none());
}
#[test]
fn test_normalization() {
let features = vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
];
let labels = vec![0.0, 1.0, 0.0];
let mut dataset = TrainingDataset::new(features, labels).unwrap();
let (means, stds) = dataset.normalize().unwrap();
assert_eq!(means.len(), 3);
assert_eq!(stds.len(), 3);
let sum: f32 = dataset.features.iter().map(|f| f[0]).sum();
let mean = sum / dataset.len() as f32;
assert!((mean.abs()) < 1e-5);
}
#[test]
fn test_bce_loss() {
let loss1 = binary_cross_entropy(0.9, 1.0);
let loss2 = binary_cross_entropy(0.1, 1.0);
assert!(loss1 < loss2); }
#[test]
fn test_temperature_softmax() {
let logit = 2.0;
let soft1 = temperature_softmax(logit, 1.0);
let soft2 = temperature_softmax(logit, 2.0);
assert!((soft1 - 0.5).abs() > (soft2 - 0.5).abs());
}
}