use std::{collections::HashMap, marker::PhantomData, time::Instant};
use smol_str::format_smolstr;
use crate::{
Array, Dtype, Result,
error::{
EmptyInputPayload, Error, InvariantViolationPayload, LengthMismatchPayload, MissingKeyPayload,
OutOfRangePayload, RankMismatchPayload, ShapePairMismatchPayload,
},
lm::{
cache::KvCache,
load::Weights,
model::Model,
perplexity,
tuner::{
datasets::{Dataset, Example},
optimizers::Optimizer,
},
},
ops::{arithmetic, comparison, logical, reduction},
transforms,
};
#[derive(Debug, Clone)]
pub struct TrainingArgs {
batch_size: usize,
iters: usize,
val_batches: Option<usize>,
steps_per_report: usize,
steps_per_eval: usize,
steps_per_save: usize,
max_seq_length: usize,
adapter_file: String,
grad_checkpoint: bool,
grad_accumulation_steps: usize,
clear_cache_threshold: usize,
acknowledge_no_real_gradients: bool,
}
impl TrainingArgs {
pub fn new() -> Self {
Self {
batch_size: 4,
iters: 100,
val_batches: Some(25),
steps_per_report: 10,
steps_per_eval: 200,
steps_per_save: 100,
max_seq_length: 2048,
adapter_file: "adapters.safetensors".into(),
grad_checkpoint: false,
grad_accumulation_steps: 1,
clear_cache_threshold: 0,
acknowledge_no_real_gradients: false,
}
}
#[inline(always)]
pub fn batch_size(&self) -> usize {
self.batch_size
}
#[inline(always)]
pub fn iters(&self) -> usize {
self.iters
}
#[inline(always)]
pub fn val_batches(&self) -> Option<usize> {
self.val_batches
}
#[inline(always)]
pub fn steps_per_report(&self) -> usize {
self.steps_per_report
}
#[inline(always)]
pub fn steps_per_eval(&self) -> usize {
self.steps_per_eval
}
#[inline(always)]
pub fn steps_per_save(&self) -> usize {
self.steps_per_save
}
#[inline(always)]
pub fn max_seq_length(&self) -> usize {
self.max_seq_length
}
#[inline(always)]
pub fn adapter_file(&self) -> &str {
&self.adapter_file
}
#[inline(always)]
pub fn grad_checkpoint(&self) -> bool {
self.grad_checkpoint
}
#[inline(always)]
pub fn grad_accumulation_steps(&self) -> usize {
self.grad_accumulation_steps
}
#[inline(always)]
pub fn clear_cache_threshold(&self) -> usize {
self.clear_cache_threshold
}
#[inline(always)]
pub fn acknowledge_no_real_gradients(&self) -> bool {
self.acknowledge_no_real_gradients
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn with_iters(mut self, iters: usize) -> Self {
self.iters = iters;
self
}
#[must_use]
pub fn with_val_batches(mut self, val_batches: Option<usize>) -> Self {
self.val_batches = val_batches;
self
}
#[must_use]
pub fn with_steps_per_report(mut self, steps_per_report: usize) -> Self {
self.steps_per_report = steps_per_report;
self
}
#[must_use]
pub fn with_steps_per_eval(mut self, steps_per_eval: usize) -> Self {
self.steps_per_eval = steps_per_eval;
self
}
#[must_use]
pub fn with_steps_per_save(mut self, steps_per_save: usize) -> Self {
self.steps_per_save = steps_per_save;
self
}
#[must_use]
pub fn with_max_seq_length(mut self, max_seq_length: usize) -> Self {
self.max_seq_length = max_seq_length;
self
}
#[must_use]
pub fn with_adapter_file(mut self, adapter_file: impl Into<String>) -> Self {
self.adapter_file = adapter_file.into();
self
}
#[must_use]
pub fn with_grad_checkpoint(mut self, grad_checkpoint: bool) -> Self {
self.grad_checkpoint = grad_checkpoint;
self
}
#[must_use]
pub fn with_grad_accumulation_steps(mut self, grad_accumulation_steps: usize) -> Self {
self.grad_accumulation_steps = grad_accumulation_steps;
self
}
#[must_use]
pub fn with_clear_cache_threshold(mut self, clear_cache_threshold: usize) -> Self {
self.clear_cache_threshold = clear_cache_threshold;
self
}
#[must_use]
pub fn with_acknowledge_no_real_gradients(mut self, acknowledge_no_real_gradients: bool) -> Self {
self.acknowledge_no_real_gradients = acknowledge_no_real_gradients;
self
}
}
impl Default for TrainingArgs {
fn default() -> Self {
Self::new()
}
}
pub fn default_loss<M>(model: &M, batch: &Array, lengths: &Array) -> Result<(Array, Array)>
where
M: Model,
{
let shape = batch.shape();
let (_b, s) = match shape.as_slice() {
[b, s] => (*b, *s),
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"default_loss: batch must be rank-2 [B, S]",
other.len() as u32,
other.to_vec(),
)));
}
};
if s < 2 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"default_loss: batch S",
"must be >= 2 for next-token prediction",
format_smolstr!("{s}"),
)));
}
let lengths_shape = lengths.shape();
let expected_lengths_shape = [shape[0], 2_usize];
if lengths_shape.as_slice() != expected_lengths_shape {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"default_loss: lengths must be [B, 2] = (offset, length)",
expected_lengths_shape.to_vec(),
lengths_shape.to_vec(),
)));
}
let b_dim = shape[0] as i32;
let s_dim = s as i32;
let inputs = crate::ops::indexing::slice(batch, &[0, 0], &[b_dim, s_dim - 1], &[1, 1])?;
let targets = crate::ops::indexing::slice(batch, &[0, 1], &[b_dim, s_dim], &[1, 1])?;
let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
let logits = model.forward(&inputs, &mut cache)?;
let t_dim = targets.shape()[1] as f32;
let steps = Array::arange::<f32>(1.0, t_dim + 1.0, 1.0)?;
let offset = crate::ops::indexing::slice(lengths, &[0, 0], &[b_dim, 1], &[1, 1])?;
let length = crate::ops::indexing::slice(lengths, &[0, 1], &[b_dim, 2], &[1, 1])?;
let offset_f = offset.astype(Dtype::F32)?;
let length_f = length.astype(Dtype::F32)?;
let ge = comparison::greater_equal(&steps, &offset_f)?;
let lt = comparison::less(&steps, &length_f)?;
let mask = logical::logical_and(&ge, <)?;
let ce = perplexity::cross_entropy_none(&logits, &targets)?;
let mask_f = mask.astype(Dtype::F32)?;
let ce_masked = arithmetic::multiply(&ce, &mask_f)?;
let mut ntoks = reduction::sum(&mask_f, false)?;
let ntoks_count = ntoks.item::<f32>()?;
if ntoks_count == 0.0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"default_loss: supervised tokens after the length mask (batch produced 0 supervised tokens)",
)));
}
let ce_sum = reduction::sum(&ce_masked.astype(Dtype::F32)?, false)?;
let loss = arithmetic::divide(&ce_sum, &ntoks)?;
Ok((loss, ntoks))
}
pub fn grad_checkpoint<F>(f: F) -> Result<impl Fn(&[Array]) -> Result<Vec<Array>>>
where
F: Fn(&[Array]) -> Result<Vec<Array>> + 'static,
{
transforms::checkpoint::checkpoint(f)
}
pub trait TrainingCallback {
fn on_train_loss_report(&mut self, _info: &TrainInfo) {}
fn on_val_loss_report(&mut self, _info: &ValInfo) {}
fn on_save(&mut self, _it: usize, _adapter_file: &str) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TrainInfo {
iteration: usize,
train_loss: f32,
learning_rate: f32,
iterations_per_second: f32,
tokens_per_second: f32,
trained_tokens: usize,
}
impl TrainInfo {
pub fn new(
iteration: usize,
train_loss: f32,
learning_rate: f32,
iterations_per_second: f32,
tokens_per_second: f32,
trained_tokens: usize,
) -> Self {
Self {
iteration,
train_loss,
learning_rate,
iterations_per_second,
tokens_per_second,
trained_tokens,
}
}
#[inline(always)]
pub fn iteration(&self) -> usize {
self.iteration
}
#[inline(always)]
pub fn train_loss(&self) -> f32 {
self.train_loss
}
#[inline(always)]
pub fn learning_rate(&self) -> f32 {
self.learning_rate
}
#[inline(always)]
pub fn iterations_per_second(&self) -> f32 {
self.iterations_per_second
}
#[inline(always)]
pub fn tokens_per_second(&self) -> f32 {
self.tokens_per_second
}
#[inline(always)]
pub fn trained_tokens(&self) -> usize {
self.trained_tokens
}
}
#[derive(Debug, Clone)]
pub struct ValInfo {
iteration: usize,
val_loss: f32,
val_time: f32,
}
impl ValInfo {
pub fn new(iteration: usize, val_loss: f32, val_time: f32) -> Self {
Self {
iteration,
val_loss,
val_time,
}
}
#[inline(always)]
pub fn iteration(&self) -> usize {
self.iteration
}
#[inline(always)]
pub fn val_loss(&self) -> f32 {
self.val_loss
}
#[inline(always)]
pub fn val_time(&self) -> f32 {
self.val_time
}
}
pub struct NoopCallback;
impl TrainingCallback for NoopCallback {}
pub struct Batch {
tokens: Array,
lengths: Array,
_marker: PhantomData<()>,
}
impl Batch {
pub fn new(tokens: Array, lengths: Array) -> Self {
Self {
tokens,
lengths,
_marker: PhantomData,
}
}
#[inline(always)]
pub fn tokens_ref(&self) -> &Array {
&self.tokens
}
#[inline(always)]
pub fn lengths_ref(&self) -> &Array {
&self.lengths
}
}
pub fn iterate_batches<'a, D: Dataset + 'a>(
dataset: &'a D,
batch_size: usize,
max_seq_length: usize,
loop_forever: bool,
shuffle_seed: Option<u64>,
) -> Result<impl Iterator<Item = Result<Batch>> + 'a> {
if dataset.len() < batch_size {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"iterate_batches: dataset size",
"must be >= batch_size",
format_smolstr!("{} (batch_size={batch_size})", dataset.len()),
)));
}
let mut idx: Vec<usize> = (0..dataset.len()).collect();
let lens: Vec<usize> = (0..dataset.len())
.map(|i| dataset.process(i).map(|(toks, _)| toks.len()))
.collect::<Result<_>>()?;
idx.sort_by_key(|&i| lens[i]);
let num_batches = dataset.len() / batch_size;
let mut batch_idx: Vec<Vec<usize>> = Vec::with_capacity(num_batches);
for i in 0..num_batches {
batch_idx.push(idx[i * batch_size..(i + 1) * batch_size].to_vec());
}
Ok(BatchIter {
dataset,
batch_idx,
max_seq_length,
cursor: 0,
order: Vec::new(),
loop_forever,
shuffle_seed,
rng_state: shuffle_seed,
first_pass: true,
})
}
struct BatchIter<'a, D: Dataset> {
dataset: &'a D,
batch_idx: Vec<Vec<usize>>,
max_seq_length: usize,
cursor: usize,
order: Vec<usize>,
loop_forever: bool,
shuffle_seed: Option<u64>,
rng_state: Option<u64>,
first_pass: bool,
}
impl<D: Dataset> Iterator for BatchIter<'_, D> {
type Item = Result<Batch>;
fn next(&mut self) -> Option<Self::Item> {
if self.cursor >= self.order.len() {
if !self.first_pass && !self.loop_forever {
return None;
}
self.first_pass = false;
self.order = (0..self.batch_idx.len()).collect();
if self.shuffle_seed.is_some()
&& let Some(seed) = self.rng_state
{
fisher_yates_shuffle(&mut self.order, seed);
self.rng_state = Some(seed.wrapping_add(1));
}
self.cursor = 0;
if self.order.is_empty() {
return None;
}
}
let batch_slot = self.order[self.cursor];
self.cursor += 1;
Some(build_batch(
self.dataset,
&self.batch_idx[batch_slot],
self.max_seq_length,
))
}
}
fn build_batch<D: Dataset>(dataset: &D, indices: &[usize], max_seq_length: usize) -> Result<Batch> {
let mut examples: Vec<Example> = Vec::with_capacity(indices.len());
for &i in indices {
examples.push(dataset.process(i)?);
}
let lengths: Vec<usize> = examples.iter().map(|(toks, _)| toks.len()).collect();
let pad_to = 32usize;
let max_in_batch = *lengths.iter().max().unwrap_or(&0);
let mut max_len_in_batch = 1 + pad_to * max_in_batch.div_ceil(pad_to);
if max_len_in_batch > max_seq_length {
max_len_in_batch = max_seq_length;
}
let batch_size = examples.len();
let mut buf = vec![0i32; batch_size * max_len_in_batch];
let mut len_buf = vec![0i32; batch_size * 2];
for (j, (toks, offset)) in examples.iter().enumerate() {
let truncated = toks.len().min(max_seq_length).min(max_len_in_batch);
for (k, &t) in toks[..truncated].iter().enumerate() {
buf[j * max_len_in_batch + k] = t as i32;
}
len_buf[j * 2] = (*offset).min(truncated) as i32;
len_buf[j * 2 + 1] = truncated as i32;
}
let tokens = Array::from_slice::<i32>(&buf, &(batch_size, max_len_in_batch))?;
let lengths_arr = Array::from_slice::<i32>(&len_buf, &(batch_size, 2usize))?;
Ok(Batch::new(tokens, lengths_arr))
}
fn fisher_yates_shuffle<T>(slice: &mut [T], seed: u64) {
let mut state = seed;
for i in (1..slice.len()).rev() {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^= z >> 31;
let j = (z as usize) % (i + 1);
slice.swap(i, j);
}
}
pub fn evaluate<M: Model, D: Dataset, F>(
model: &M,
dataset: &D,
batch_size: usize,
num_batches: Option<usize>,
max_seq_length: usize,
mut loss_fn: F,
) -> Result<f32>
where
F: FnMut(&M, &Array, &Array) -> Result<(Array, Array)>,
{
let mut total_loss = 0.0_f32;
let mut total_tokens = 0.0_f32;
let iter = iterate_batches(dataset, batch_size, max_seq_length, false, None)?;
let cap = num_batches.unwrap_or(usize::MAX);
for (i, batch) in iter.enumerate() {
if i >= cap {
break;
}
let batch = batch?;
let (mut loss, mut ntoks) = loss_fn(model, batch.tokens_ref(), batch.lengths_ref())?;
let loss_f = loss.item::<f32>()?;
let ntoks_f = ntoks.item::<f32>()?;
total_loss += loss_f * ntoks_f;
total_tokens += ntoks_f;
}
if total_tokens == 0.0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"evaluate: eval set (produced no batches with tokens)",
)));
}
Ok(total_loss / total_tokens)
}
#[allow(clippy::too_many_arguments)]
pub fn train<M, D, O, L, C>(
model: &M,
optimizer: &mut O,
params: &mut Weights,
train_dataset: &D,
val_dataset: Option<&D>,
args: &TrainingArgs,
loss_fn: L,
callback: &mut C,
) -> Result<()>
where
M: Model,
D: Dataset,
O: Optimizer + ?Sized,
L: Fn(&M, &Array, &Array) -> Result<(Array, Array)>,
C: TrainingCallback,
{
if !args.acknowledge_no_real_gradients() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"train: TrainingArgs::acknowledge_no_real_gradients",
"must be set to `true` to run the v1 mechanics-only training path",
)));
}
if args.iters() == 0 {
return Ok(());
}
if args.grad_accumulation_steps() == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"train: grad_accumulation_steps",
"must be >= 1",
)));
}
if args.steps_per_report() == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"train: steps_per_report",
"must be >= 1",
)));
}
if args.steps_per_eval() == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"train: steps_per_eval",
"must be >= 1",
)));
}
if args.steps_per_save() == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"train: steps_per_save",
"must be >= 1",
)));
}
let total_optim_steps = args.iters() / args.grad_accumulation_steps();
let mut window_loss = 0.0_f32;
let mut window_tokens = 0.0_f32;
let mut window_steps = 0usize;
let mut window_microbatches = 0usize;
let mut window_secs = 0.0_f32;
let mut trained_tokens = 0usize;
let mut accumulated_grads: Option<Weights> = None;
let mut accum_count: usize = 0;
let mut optim_step: usize = 0;
let mut window_micro_secs = 0.0_f32;
let mut iter = iterate_batches(
train_dataset,
args.batch_size(),
args.max_seq_length(),
true,
None,
)?;
if let Some(val) = val_dataset
&& total_optim_steps >= 1
{
run_val(model, val, args, 0, callback, &loss_fn)?;
}
for _microbatch_it in 1..=args.iters() {
let micro_start = Instant::now();
let batch = iter.next().ok_or_else(|| {
Error::InvariantViolation(InvariantViolationPayload::new(
"train: batch iterator",
"must never be exhausted (loop=true should never end)",
))
})??;
let (loss_scalar, ntoks_scalar) = (loss_fn)(model, batch.tokens_ref(), batch.lengths_ref())?;
let mut loss_val = loss_scalar.try_clone()?;
let mut ntoks_val = ntoks_scalar.try_clone()?;
let loss_f = loss_val.item::<f32>()?;
let ntoks_f = ntoks_val.item::<f32>()?;
let grads: Weights = build_zero_grads(params)?;
accumulated_grads = Some(match accumulated_grads {
None => grads,
Some(acc) => add_weights(&acc, &grads)?,
});
accum_count += 1;
window_loss += loss_f;
window_tokens += ntoks_f;
window_microbatches += 1;
trained_tokens += ntoks_f as usize;
window_micro_secs += micro_start.elapsed().as_secs_f32();
if accum_count < args.grad_accumulation_steps() {
continue;
}
let avg = divide_weights(
accumulated_grads
.as_ref()
.expect("accumulated_grads must be Some after at least one accum"),
args.grad_accumulation_steps() as f32,
)?;
optimizer.apply_gradients(&avg, params)?;
optim_step += 1;
accumulated_grads = None;
accum_count = 0;
window_steps += 1;
window_secs += window_micro_secs;
window_micro_secs = 0.0;
let is_last_optim_step = optim_step == total_optim_steps;
if optim_step.is_multiple_of(args.steps_per_report()) || is_last_optim_step {
let mean_loss = if window_microbatches > 0 {
window_loss / (window_microbatches as f32)
} else {
0.0
};
let it_sec = if window_secs > 0.0 {
(window_steps as f32) / window_secs
} else {
0.0
};
let tok_sec = if window_secs > 0.0 {
window_tokens / window_secs
} else {
0.0
};
callback.on_train_loss_report(&TrainInfo::new(
optim_step,
mean_loss,
optimizer.learning_rate(),
it_sec,
tok_sec,
trained_tokens,
));
window_loss = 0.0;
window_tokens = 0.0;
window_steps = 0;
window_microbatches = 0;
window_secs = 0.0;
}
if let Some(val) = val_dataset
&& (optim_step.is_multiple_of(args.steps_per_eval()) || is_last_optim_step)
{
run_val(model, val, args, optim_step, callback, &loss_fn)?;
}
if optim_step.is_multiple_of(args.steps_per_save()) {
callback.on_save(optim_step, args.adapter_file())?;
}
}
callback.on_save(optim_step, args.adapter_file())?;
Ok(())
}
fn run_val<M, D, L, C>(
model: &M,
val: &D,
args: &TrainingArgs,
iteration: usize,
callback: &mut C,
loss_fn: &L,
) -> Result<()>
where
M: Model,
D: Dataset,
L: Fn(&M, &Array, &Array) -> Result<(Array, Array)>,
C: TrainingCallback,
{
let val_start = Instant::now();
let val_loss = evaluate(
model,
val,
args.batch_size(),
args.val_batches(),
args.max_seq_length(),
|m, b, l| (loss_fn)(m, b, l),
)?;
let val_time = val_start.elapsed().as_secs_f32();
callback.on_val_loss_report(&ValInfo::new(iteration, val_loss, val_time));
Ok(())
}
fn build_zero_grads(params: &Weights) -> Result<Weights> {
let mut grads: Weights = HashMap::with_capacity(params.len());
for (key, value) in params {
grads.insert(key.clone(), crate::ops::misc::zeros_like(value)?);
}
Ok(grads)
}
fn add_weights(a: &Weights, b: &Weights) -> Result<Weights> {
if a.len() != b.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"trainer::add_weights: lhs vs rhs key counts",
a.len(),
b.len(),
)));
}
let mut out: Weights = HashMap::with_capacity(a.len());
for (key, lhs) in a {
let Some(rhs) = b.get(key) else {
return Err(Error::MissingKey(MissingKeyPayload::new(
"trainer::add_weights: key missing from rhs",
key.as_str(),
)));
};
out.insert(key.clone(), arithmetic::add(lhs, rhs)?);
}
Ok(out)
}
fn divide_weights(w: &Weights, divisor: f32) -> Result<Weights> {
let divisor_scalar = Array::full::<f32>(&[0i32; 0], divisor)?;
let mut out: Weights = HashMap::with_capacity(w.len());
for (key, value) in w {
out.insert(key.clone(), arithmetic::divide(value, &divisor_scalar)?);
}
Ok(out)
}
#[cfg(test)]
mod tests;