use super::models::NeuralModel;
use super::types::*;
use crate::{Error, Result};
pub use super::types::NeuralTrainingResults;
pub struct NeuralTrainer {
config: TrainingConfig,
}
impl NeuralTrainer {
pub fn new(config: TrainingConfig) -> Self {
Self { config }
}
pub fn train(
&mut self,
model: &mut dyn NeuralModel,
training_data: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
) -> Result<NeuralTrainingResults> {
let start_time = std::time::Instant::now();
let mut training_loss_history = Vec::new();
let mut validation_loss_history = Vec::new();
let mut best_validation_loss = f32::INFINITY;
let mut patience_counter = 0;
let mut early_stopped = false;
let split_index =
(training_data.len() as f32 * (1.0 - self.config.validation_split)) as usize;
let (train_data, val_data) = training_data.split_at(split_index);
println!(
"Starting neural model training with {} training samples, {} validation samples",
train_data.len(),
val_data.len()
);
for epoch in 0..self.config.epochs {
let epoch_start = std::time::Instant::now();
let train_loss = self.train_epoch(model, train_data)?;
training_loss_history.push(train_loss);
let val_loss = self.validate_epoch(model, val_data)?;
validation_loss_history.push(val_loss);
println!(
"Epoch {}/{}: train_loss={:.6}, val_loss={:.6}, time={:.2}s",
epoch + 1,
self.config.epochs,
train_loss,
val_loss,
epoch_start.elapsed().as_secs_f32()
);
if val_loss < best_validation_loss {
best_validation_loss = val_loss;
patience_counter = 0;
println!("New best validation loss: {val_loss:.6}");
} else {
patience_counter += 1;
if patience_counter >= self.config.early_stopping_patience {
println!(
"Early stopping triggered after {patience_counter} epochs of no improvement"
);
early_stopped = true;
break;
}
}
if (epoch + 1) % 10 == 0 {
println!("Learning rate decay would be applied here");
}
}
let training_duration = start_time.elapsed().as_secs_f32();
let epochs_completed = training_loss_history.len();
let final_accuracy = if !validation_loss_history.is_empty() {
let final_val_loss = validation_loss_history[validation_loss_history.len() - 1];
(1.0 - final_val_loss.min(1.0)).max(0.0)
} else {
0.0
};
println!(
"Training completed: {epochs_completed} epochs, {training_duration:.2}s total, final accuracy: {final_accuracy:.3}"
);
Ok(NeuralTrainingResults {
training_loss: training_loss_history,
validation_loss: validation_loss_history,
final_accuracy,
training_duration_secs: training_duration,
epochs_completed,
early_stopped,
})
}
fn train_epoch(
&self,
model: &mut dyn NeuralModel,
train_data: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
) -> Result<f32> {
let mut total_loss = 0.0;
let mut batch_count = 0;
for batch_start in (0..train_data.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(train_data.len());
let batch = &train_data[batch_start..batch_end];
let batch_loss = self.train_batch(model, batch)?;
total_loss += batch_loss;
batch_count += 1;
}
Ok(total_loss / batch_count as f32)
}
fn validate_epoch(
&self,
model: &mut dyn NeuralModel,
val_data: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
) -> Result<f32> {
let mut total_loss = 0.0;
let mut batch_count = 0;
for batch_start in (0..val_data.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(val_data.len());
let batch = &val_data[batch_start..batch_end];
let batch_loss = self.validate_batch(model, batch)?;
total_loss += batch_loss;
batch_count += 1;
}
Ok(total_loss / batch_count as f32)
}
fn train_batch(
&self,
model: &mut dyn NeuralModel,
batch: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
) -> Result<f32> {
let mut batch_loss = 0.0;
for (input, target) in batch {
let output = model.forward(input)?;
let loss = self.compute_loss(&output.binaural_audio, target)?;
batch_loss += loss;
if self.config.augmentation.noise_injection {
}
}
Ok(batch_loss / batch.len() as f32)
}
fn validate_batch(
&self,
model: &mut dyn NeuralModel,
batch: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
) -> Result<f32> {
let mut batch_loss = 0.0;
for (input, target) in batch {
let output = model.forward(input)?;
let loss = self.compute_loss(&output.binaural_audio, target)?;
batch_loss += loss;
}
Ok(batch_loss / batch.len() as f32)
}
fn compute_loss(&self, predicted: &[Vec<f32>], target: &[Vec<f32>]) -> Result<f32> {
if predicted.len() != target.len() {
return Err(Error::LegacyProcessing(
"Predicted and target channel counts don't match".to_string(),
));
}
let mut total_loss = 0.0;
let mut sample_count = 0;
match self.config.loss_function {
LossFunction::MSE => {
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for i in 0..min_len {
let diff = pred_channel[i] - target_channel[i];
total_loss += diff * diff;
sample_count += 1;
}
}
}
LossFunction::MAE => {
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for i in 0..min_len {
let diff = (pred_channel[i] - target_channel[i]).abs();
total_loss += diff;
sample_count += 1;
}
}
}
LossFunction::SpectralLoss => {
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for i in 0..min_len {
let diff = pred_channel[i] - target_channel[i];
total_loss += diff * diff; sample_count += 1;
}
}
total_loss *= 1.2; }
LossFunction::PerceptualLoss => {
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for i in 0..min_len {
let diff = pred_channel[i] - target_channel[i];
let perceptual_weight = 1.0 + 0.5 * (i as f32 / min_len as f32);
total_loss += diff * diff * perceptual_weight;
sample_count += 1;
}
}
}
LossFunction::MultiScaleSpectralLoss => {
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for scale in [1, 2, 4] {
for i in (0..min_len).step_by(scale) {
let diff = pred_channel[i] - target_channel[i];
total_loss += diff * diff / (scale as f32);
sample_count += 1;
}
}
}
}
LossFunction::Combined => {
let mse_loss = self.compute_mse_loss(predicted, target)?;
let spectral_loss = self.compute_spectral_loss(predicted, target)?;
total_loss = 0.7 * mse_loss + 0.3 * spectral_loss;
sample_count = 1; }
}
if sample_count > 0 {
Ok(total_loss / sample_count as f32)
} else {
Ok(0.0)
}
}
fn compute_mse_loss(&self, predicted: &[Vec<f32>], target: &[Vec<f32>]) -> Result<f32> {
let mut total_loss = 0.0;
let mut sample_count = 0;
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for i in 0..min_len {
let diff = pred_channel[i] - target_channel[i];
total_loss += diff * diff;
sample_count += 1;
}
}
Ok(if sample_count > 0 {
total_loss / sample_count as f32
} else {
0.0
})
}
fn compute_spectral_loss(&self, predicted: &[Vec<f32>], target: &[Vec<f32>]) -> Result<f32> {
let mut total_loss = 0.0;
let mut sample_count = 0;
for (pred_channel, target_channel) in predicted.iter().zip(target.iter()) {
let min_len = pred_channel.len().min(target_channel.len());
for window_size in [16, 32, 64, 128] {
for start in (0..min_len).step_by(window_size / 2) {
let end = (start + window_size).min(min_len);
if end > start {
let pred_energy: f32 = pred_channel[start..end].iter().map(|x| x * x).sum();
let target_energy: f32 =
target_channel[start..end].iter().map(|x| x * x).sum();
let diff = pred_energy - target_energy;
total_loss += diff * diff;
sample_count += 1;
}
}
}
}
Ok(if sample_count > 0 {
total_loss / sample_count as f32
} else {
0.0
})
}
}