ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
use alloc::collections::VecDeque;
use std::fs;

use crate::{error::Result, session::SessionInputs, training::Trainer};

mod dataloader;
pub use self::dataloader::{DataLoader, IterableDataLoader, iterable_data_loader};
mod args;
pub use self::args::{CheckpointStrategy, EvaluationStrategy, TrainingArguments};
mod callbacks;
pub use self::callbacks::{TrainerCallbacks, TrainerControl, TrainerState};

macro_rules! callback {
	($which:ident($self:expr, $optimizer:expr, $args:expr, $state:expr)) => {
		let mut halt = false;
		for cb in &mut $args.callbacks {
			let mut control = TrainerControl::new($self);
			cb.$which(&$state, &mut control)?;
			halt = halt || control.halt;
			if let Some(lr) = control.lr {
				$optimizer.set_lr(lr)?;
			}
		}
		if halt {
			return $self.handle_halt(&mut $args.callbacks, &$state);
		}
	};
	($which:ident($self:expr, $optimizer:expr, $args:expr, $state:expr), $($addt:expr),*) => {
		let mut halt = false;
		for cb in &mut $args.callbacks {
			let mut control = TrainerControl::new($self);
			cb.$which($($addt,)* &$state, &mut control)?;
			halt = halt || control.halt;
			if let Some(lr) = control.lr {
				$optimizer.set_lr(lr)?;
			}
		}
		if halt {
			return $self.handle_halt(&mut $args.callbacks, &$state);
		}
	};
}

impl Trainer {
	pub fn train<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
		&self,
		mut args: TrainingArguments<I, L, NI, NL>
	) -> Result<()> {
		let mut optimizer = self.optimizer();
		optimizer.set_lr(args.lr)?;

		let mut saved_ckpts = VecDeque::new();
		let mut state = TrainerState::new(&args);
		let mut last_epoch = -1.0;
		for (iter_step, _) in (0..args.max_steps).enumerate() {
			state.iter_step = iter_step;
			state.epoch = args.loader.len().map(|dl_len| iter_step as f32 / dl_len as f32);

			if let Some(epoch) = state.epoch {
				if epoch.trunc() != last_epoch {
					callback!(epoch(self, optimizer, args, state));
				}

				last_epoch = epoch.trunc();
			}

			let (inputs, labels) = args.loader.load(iter_step)?;
			let (inputs, labels) = (inputs.into(), labels.into());

			let outputs = self.step(inputs, labels)?;
			let loss = outputs[0].try_extract_scalar::<f32>()?;
			callback!(train_step(self, optimizer, args, state), loss);

			if iter_step % args.gradient_accumulation_steps == 0 {
				optimizer.step()?;
				optimizer.reset_grad()?;
				state.global_step += 1;
				callback!(optimizer_step(self, optimizer, args, state), loss);
			}

			if args.ckpt_strategy.should_fire(state.global_step, iter_step, args.loader.len()) {
				if !args.ckpt_path.exists() {
					let _ = fs::create_dir_all(&args.ckpt_path);
				}

				let ckpt_path =
					args.ckpt_path
						.join(format!("epoch={},step={}.ortckpt", state.epoch.map(f32::trunc).unwrap_or(0.0) as usize, state.global_step));
				self.checkpoint().save(&ckpt_path, true)?;

				saved_ckpts.push_front(ckpt_path.clone());
				while saved_ckpts.len() > args.max_saved_ckpts {
					let Some(old_ckpt) = saved_ckpts.pop_back() else {
						break;
					};
					let _ = fs::remove_file(old_ckpt);
				}
			}

			if args
				.eval_strategy
				.should_fire(state.global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len()))
			{
				callback!(eval_begin(self, optimizer, args, state));
				let eval_loss = self.eval_inner(&mut args)?;
				callback!(eval_end(self, optimizer, args, state), eval_loss);
			}
		}
		Ok(())
	}

	fn handle_halt(&self, cbs: &mut Vec<Box<dyn TrainerCallbacks>>, state: &TrainerState) -> Result<()> {
		for cb in cbs {
			let mut control = TrainerControl::new(self);
			cb.end(state, &mut control)?;
		}
		Ok(())
	}

	pub(crate) fn eval_inner<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
		&self,
		args: &mut TrainingArguments<I, L, NI, NL>
	) -> crate::Result<f32> {
		let Some(eval_loader) = &mut args.eval_loader else {
			return Ok(0.0);
		};

		let mut total_loss = 0.0;
		for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) {
			let (inputs, labels) = eval_loader.load(step)?;
			let (inputs, labels) = (inputs.into(), labels.into());

			let outputs = self.eval_step(inputs, labels)?;
			let loss = outputs[0].try_extract_scalar::<f32>()?;
			total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.);
		}

		Ok(total_loss)
	}
}