boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
//! Simple trainer for neural networks
//!
//! Integrates an optimizer, gradient accumulation, gradient clipping,
//! and LR scheduling into a single training step abstraction.

use std::collections::HashMap;

use crate::error::Result;
use crate::optimizer::{
    AdamW, AdamWConfig, GradAccumulator, LrSchedule, Optimizer, clip_grad_norm,
};
use crate::trainer::config::{TrainingConfig, TrainingMetrics};
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};

use crate::ops::FusedOptimizerOps;
use numr::runtime::{Graph, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};

/// Simple single-device trainer
///
/// Manages the training loop plumbing:
/// - Gradient accumulation across micro-batches
/// - Gradient clipping by global norm
/// - Learning rate scheduling
/// - Optimizer step (any optimizer implementing the `Optimizer` trait)
///
/// Generic over the optimizer type `O`. Use `SimpleTrainer::new()` for the default
/// AdamW optimizer, or `SimpleTrainer::with_optimizer()` for a custom one.
///
/// # Usage
///
/// ```ignore
/// let mut trainer = SimpleTrainer::new(config)?;
///
/// for micro_batch in data {
///     let loss = forward(micro_batch, &params);
///     let grads = backward(&loss, &client)?;
///
///     if let Some(metrics) = trainer.step(&client, &mut params, grads, loss_val)? {
///         println!("step {} loss={:.4} lr={:.6}", metrics.step, metrics.loss, metrics.lr);
///     }
/// }
/// ```
pub struct SimpleTrainer<R: Runtime<DType = DType>, O: Optimizer<R> = AdamW<R>> {
    optimizer: O,
    accumulator: GradAccumulator<R>,
    lr_schedule: Option<LrSchedule>,
    max_grad_norm: Option<f64>,
    global_step: u64,
    accumulated_loss: f64,
    loss_count: usize,
    forward_graph: Option<R::Graph>,
    backward_graph: Option<R::Graph>,
}

impl<R: Runtime<DType = DType>> SimpleTrainer<R, AdamW<R>> {
    /// Create a new trainer with the default AdamW optimizer.
    pub fn new(config: TrainingConfig) -> Result<Self> {
        let optimizer = AdamW::new(AdamWConfig {
            lr: config.learning_rate,
            weight_decay: config.weight_decay,
            ..AdamWConfig::default()
        });
        Self::with_optimizer(config, optimizer)
    }
}

impl<R: Runtime<DType = DType>, O: Optimizer<R>> SimpleTrainer<R, O> {
    /// Create a new trainer with a custom optimizer.
    pub fn with_optimizer(config: TrainingConfig, optimizer: O) -> Result<Self> {
        let accumulator = GradAccumulator::new(config.grad_accum_steps)?;

        Ok(Self {
            optimizer,
            accumulator,
            lr_schedule: None,
            max_grad_norm: config.max_grad_norm,
            global_step: 0,
            accumulated_loss: 0.0,
            loss_count: 0,
            forward_graph: None,
            backward_graph: None,
        })
    }

    pub fn with_lr_schedule(mut self, schedule: LrSchedule) -> Self {
        self.lr_schedule = Some(schedule);
        self
    }

    /// Process one micro-batch of gradients.
    ///
    /// Accumulates gradients. When enough micro-batches are accumulated,
    /// clips gradients, applies the optimizer step, and returns metrics.
    ///
    /// Returns `None` if still accumulating, `Some(metrics)` after a full step.
    pub fn step<C>(
        &mut self,
        client: &C,
        params: &mut HashMap<TensorId, Tensor<R>>,
        grads: GradStore<R>,
        loss_value: f64,
    ) -> Result<Option<TrainingMetrics>>
    where
        C: RuntimeClient<R>
            + BinaryOps<R>
            + UnaryOps<R>
            + ScalarOps<R>
            + ReduceOps<R>
            + FusedOptimizerOps<R>,
    {
        self.accumulated_loss += loss_value;
        self.loss_count += 1;

        let averaged_grads = match self.accumulator.accumulate(client, grads)? {
            Some(g) => g,
            None => return Ok(None),
        };

        // Apply LR schedule
        if let Some(ref schedule) = self.lr_schedule {
            let lr = schedule.get_lr(self.global_step);
            self.optimizer.set_lr(lr);
        }

        // Gradient clipping
        let grad_norm = if let Some(max_norm) = self.max_grad_norm {
            let mut grads_mut = averaged_grads;
            let norm = clip_grad_norm(client, &mut grads_mut, max_norm)?;
            self.optimizer.step(client, params, &grads_mut)?;
            Some(norm)
        } else {
            self.optimizer.step(client, params, &averaged_grads)?;
            None
        };

        let avg_loss = self.accumulated_loss / self.loss_count as f64;
        self.accumulated_loss = 0.0;
        self.loss_count = 0;

        self.global_step += 1;

        Ok(Some(TrainingMetrics {
            step: self.global_step,
            loss: avg_loss,
            grad_norm,
            lr: self.optimizer.lr(),
        }))
    }

    pub fn global_step(&self) -> u64 {
        self.global_step
    }

    pub fn optimizer(&self) -> &O {
        &self.optimizer
    }

    pub fn optimizer_mut(&mut self) -> &mut O {
        &mut self.optimizer
    }

    /// Capture a forward pass into a CUDA graph for replay.
    ///
    /// On CPU/WebGPU the closure executes eagerly and `launch_forward_graph`
    /// becomes a no-op. On CUDA the recorded kernel sequence is replayed
    /// with a single `cuGraphLaunch`, eliminating per-kernel CPU overhead.
    pub fn capture_forward_pass<F, T>(&mut self, client: &R::Client, f: F) -> Result<T>
    where
        F: FnOnce(&R::Client) -> numr::error::Result<T>,
    {
        let (graph, result) = R::capture_graph(client, f)?;
        self.forward_graph = Some(graph);
        Ok(result)
    }

    /// Capture a backward pass into a CUDA graph for replay.
    pub fn capture_backward_pass<F, T>(&mut self, client: &R::Client, f: F) -> Result<T>
    where
        F: FnOnce(&R::Client) -> numr::error::Result<T>,
    {
        let (graph, result) = R::capture_graph(client, f)?;
        self.backward_graph = Some(graph);
        Ok(result)
    }

    /// Replay the captured forward graph.
    ///
    /// Returns an error if no forward graph has been captured.
    pub fn launch_forward_graph(&self) -> Result<()> {
        match &self.forward_graph {
            Some(g) => Ok(g.launch()?),
            None => Err(crate::error::Error::TrainingError {
                reason: "no forward graph captured — call capture_forward_pass first".into(),
            }),
        }
    }

    /// Replay the captured backward graph.
    ///
    /// Returns an error if no backward graph has been captured.
    pub fn launch_backward_graph(&self) -> Result<()> {
        match &self.backward_graph {
            Some(g) => Ok(g.launch()?),
            None => Err(crate::error::Error::TrainingError {
                reason: "no backward graph captured — call capture_backward_pass first".into(),
            }),
        }
    }

    /// Check whether forward and backward graphs have been captured.
    pub fn graphs_captured(&self) -> (bool, bool) {
        (self.forward_graph.is_some(), self.backward_graph.is_some())
    }

    /// Drop both captured graphs, reverting to eager execution.
    pub fn clear_graphs(&mut self) {
        self.forward_graph = None;
        self.backward_graph = None;
    }
}