use super::config::TrainingConfig;
use super::types::{LearnedIndexResult, TrainingExample};
use scirs2_core::random::Random;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingStats {
pub epochs_completed: usize,
pub final_loss: f64,
pub validation_loss: f64,
pub final_accuracy: f64,
pub training_time_secs: f64,
pub early_stopped: bool,
}
pub struct IndexTrainer {
config: TrainingConfig,
}
impl IndexTrainer {
pub fn new(config: TrainingConfig) -> Self {
Self { config }
}
#[allow(clippy::ptr_arg)]
pub fn train(
&self,
weights: &mut Vec<Vec<f32>>,
biases: &mut Vec<f32>,
examples: &[TrainingExample],
) -> LearnedIndexResult<TrainingStats> {
let start = std::time::Instant::now();
let split_idx = (examples.len() as f32 * (1.0 - self.config.validation_split)) as usize;
let (train_examples, val_examples) = examples.split_at(split_idx);
tracing::info!(
"Training on {} examples, validating on {}",
train_examples.len(),
val_examples.len()
);
self.initialize_weights(weights, biases, examples[0].features.len());
let mut best_val_loss = f64::INFINITY;
let mut patience_counter = 0;
let mut final_loss = 0.0;
let mut validation_loss = 0.0;
let mut early_stopped = false;
for epoch in 0..self.config.num_epochs {
let train_loss = self.train_epoch(weights, biases, train_examples)?;
let val_loss = self.validate(weights, biases, val_examples)?;
final_loss = train_loss;
validation_loss = val_loss;
if val_loss < best_val_loss {
best_val_loss = val_loss;
patience_counter = 0;
} else {
patience_counter += 1;
}
if patience_counter >= self.config.early_stopping_patience {
tracing::info!("Early stopping at epoch {}", epoch);
early_stopped = true;
break;
}
if epoch % 10 == 0 {
tracing::debug!(
"Epoch {}: train_loss={:.4}, val_loss={:.4}",
epoch,
train_loss,
val_loss
);
}
}
let accuracy = self.compute_accuracy(weights, biases, val_examples);
let elapsed = start.elapsed().as_secs_f64();
Ok(TrainingStats {
epochs_completed: if early_stopped {
self.config.num_epochs - patience_counter
} else {
self.config.num_epochs
},
final_loss,
validation_loss,
final_accuracy: accuracy,
training_time_secs: elapsed,
early_stopped,
})
}
fn initialize_weights(
&self,
weights: &mut Vec<Vec<f32>>,
biases: &mut Vec<f32>,
input_size: usize,
) {
let mut rng = Random::seed(42);
let hidden_size = 32;
let output_size = 1;
let mut layer1 = Vec::new();
for _ in 0..(input_size * hidden_size) {
layer1.push(rng.gen_range(-0.1..0.1));
}
weights.push(layer1);
let mut layer2 = Vec::new();
for _ in 0..(hidden_size * output_size) {
layer2.push(rng.gen_range(-0.1..0.1));
}
weights.push(layer2);
biases.push(rng.gen_range(-0.1..0.1));
biases.push(rng.gen_range(-0.1..0.1));
}
#[allow(clippy::ptr_arg)]
fn train_epoch(
&self,
weights: &mut Vec<Vec<f32>>,
biases: &mut Vec<f32>,
examples: &[TrainingExample],
) -> LearnedIndexResult<f64> {
let mut total_loss = 0.0;
let mut rng = Random::seed(42);
let mut indices: Vec<usize> = (0..examples.len()).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
for batch_start in (0..examples.len()).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(examples.len());
let batch_indices = &indices[batch_start..batch_end];
let mut batch_loss = 0.0;
for &idx in batch_indices {
let example = &examples[idx];
let prediction = self.forward_simple(&example.features, weights, biases);
let target = example.target_position as f32 / examples.len() as f32;
let loss = (prediction - target).powi(2);
batch_loss += loss as f64;
let gradient = 2.0 * (prediction - target);
self.update_weights(weights, biases, gradient, &example.features);
}
total_loss += batch_loss;
}
Ok(total_loss / examples.len() as f64)
}
fn validate(
&self,
weights: &[Vec<f32>],
biases: &[f32],
examples: &[TrainingExample],
) -> LearnedIndexResult<f64> {
let mut total_loss = 0.0;
for example in examples {
let prediction = self.forward_simple(&example.features, weights, biases);
let target = example.target_position as f32 / examples.len() as f32;
let loss = (prediction - target).powi(2);
total_loss += loss as f64;
}
Ok(total_loss / examples.len() as f64)
}
fn compute_accuracy(
&self,
weights: &[Vec<f32>],
biases: &[f32],
examples: &[TrainingExample],
) -> f64 {
let mut correct = 0;
let tolerance = 0.1;
for example in examples {
let prediction = self.forward_simple(&example.features, weights, biases);
let predicted_pos = (prediction * examples.len() as f32) as usize;
let error = predicted_pos.abs_diff(example.target_position) as f32;
if error / (examples.len() as f32) < tolerance {
correct += 1;
}
}
correct as f64 / examples.len() as f64
}
fn forward_simple(&self, input: &[f32], _weights: &[Vec<f32>], _biases: &[f32]) -> f32 {
let sum: f32 = input.iter().sum();
let normalized = sum / input.len() as f32;
1.0 / (1.0 + (-normalized).exp())
}
fn update_weights(
&self,
weights: &mut [Vec<f32>],
biases: &mut [f32],
gradient: f32,
_features: &[f32],
) {
let lr = self.config.learning_rate * 0.1;
for weight_layer in weights.iter_mut() {
for w in weight_layer.iter_mut() {
*w -= lr * gradient;
}
}
for bias in biases.iter_mut() {
*bias -= lr * gradient;
}
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
fn create_test_examples(n: usize) -> Vec<TrainingExample> {
(0..n)
.map(|i| TrainingExample::new(vec![i as f32 / n as f32], i))
.collect()
}
#[test]
fn test_trainer_creation() {
let config = TrainingConfig::default_config();
let trainer = IndexTrainer::new(config);
assert!(trainer.config.num_epochs > 0);
}
#[test]
fn test_training() -> Result<()> {
let config = TrainingConfig::speed_optimized();
let trainer = IndexTrainer::new(config);
let examples = create_test_examples(100);
let mut weights = Vec::new();
let mut biases = Vec::new();
let stats = trainer.train(&mut weights, &mut biases, &examples);
assert!(stats.is_ok());
let stats = stats?;
assert!(stats.epochs_completed > 0);
assert!(stats.final_loss >= 0.0);
assert!(stats.final_accuracy >= 0.0 && stats.final_accuracy <= 1.0);
Ok(())
}
}