use std::collections::HashMap;
use crate::error::Result;
use crate::optimizer::{
AdamW, AdamWConfig, GradAccumulator, LrSchedule, Optimizer, clip_grad_norm,
};
use crate::trainer::config::{TrainingConfig, TrainingMetrics};
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::ops::FusedOptimizerOps;
use numr::runtime::{Graph, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
pub struct SimpleTrainer<R: Runtime<DType = DType>, O: Optimizer<R> = AdamW<R>> {
optimizer: O,
accumulator: GradAccumulator<R>,
lr_schedule: Option<LrSchedule>,
max_grad_norm: Option<f64>,
global_step: u64,
accumulated_loss: f64,
loss_count: usize,
forward_graph: Option<R::Graph>,
backward_graph: Option<R::Graph>,
}
impl<R: Runtime<DType = DType>> SimpleTrainer<R, AdamW<R>> {
pub fn new(config: TrainingConfig) -> Result<Self> {
let optimizer = AdamW::new(AdamWConfig {
lr: config.learning_rate,
weight_decay: config.weight_decay,
..AdamWConfig::default()
});
Self::with_optimizer(config, optimizer)
}
}
impl<R: Runtime<DType = DType>, O: Optimizer<R>> SimpleTrainer<R, O> {
pub fn with_optimizer(config: TrainingConfig, optimizer: O) -> Result<Self> {
let accumulator = GradAccumulator::new(config.grad_accum_steps)?;
Ok(Self {
optimizer,
accumulator,
lr_schedule: None,
max_grad_norm: config.max_grad_norm,
global_step: 0,
accumulated_loss: 0.0,
loss_count: 0,
forward_graph: None,
backward_graph: None,
})
}
pub fn with_lr_schedule(mut self, schedule: LrSchedule) -> Self {
self.lr_schedule = Some(schedule);
self
}
pub fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: GradStore<R>,
loss_value: f64,
) -> Result<Option<TrainingMetrics>>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
self.accumulated_loss += loss_value;
self.loss_count += 1;
let averaged_grads = match self.accumulator.accumulate(client, grads)? {
Some(g) => g,
None => return Ok(None),
};
if let Some(ref schedule) = self.lr_schedule {
let lr = schedule.get_lr(self.global_step);
self.optimizer.set_lr(lr);
}
let grad_norm = if let Some(max_norm) = self.max_grad_norm {
let mut grads_mut = averaged_grads;
let norm = clip_grad_norm(client, &mut grads_mut, max_norm)?;
self.optimizer.step(client, params, &grads_mut)?;
Some(norm)
} else {
self.optimizer.step(client, params, &averaged_grads)?;
None
};
let avg_loss = self.accumulated_loss / self.loss_count as f64;
self.accumulated_loss = 0.0;
self.loss_count = 0;
self.global_step += 1;
Ok(Some(TrainingMetrics {
step: self.global_step,
loss: avg_loss,
grad_norm,
lr: self.optimizer.lr(),
}))
}
pub fn global_step(&self) -> u64 {
self.global_step
}
pub fn optimizer(&self) -> &O {
&self.optimizer
}
pub fn optimizer_mut(&mut self) -> &mut O {
&mut self.optimizer
}
pub fn capture_forward_pass<F, T>(&mut self, client: &R::Client, f: F) -> Result<T>
where
F: FnOnce(&R::Client) -> numr::error::Result<T>,
{
let (graph, result) = R::capture_graph(client, f)?;
self.forward_graph = Some(graph);
Ok(result)
}
pub fn capture_backward_pass<F, T>(&mut self, client: &R::Client, f: F) -> Result<T>
where
F: FnOnce(&R::Client) -> numr::error::Result<T>,
{
let (graph, result) = R::capture_graph(client, f)?;
self.backward_graph = Some(graph);
Ok(result)
}
pub fn launch_forward_graph(&self) -> Result<()> {
match &self.forward_graph {
Some(g) => Ok(g.launch()?),
None => Err(crate::error::Error::TrainingError {
reason: "no forward graph captured — call capture_forward_pass first".into(),
}),
}
}
pub fn launch_backward_graph(&self) -> Result<()> {
match &self.backward_graph {
Some(g) => Ok(g.launch()?),
None => Err(crate::error::Error::TrainingError {
reason: "no backward graph captured — call capture_backward_pass first".into(),
}),
}
}
pub fn graphs_captured(&self) -> (bool, bool) {
(self.forward_graph.is_some(), self.backward_graph.is_some())
}
pub fn clear_graphs(&mut self) {
self.forward_graph = None;
self.backward_graph = None;
}
}