pub struct Trainer {
pub metrics: MetricsTracker,
/* private fields */
}Expand description
High-level trainer that orchestrates the training loop
§Example
use entrenar::train::{Trainer, TrainConfig, Batch, MSELoss, EarlyStopping};
use entrenar::optim::Adam;
use entrenar::Tensor;
// Setup
let params = vec![Tensor::zeros(10, true)];
let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();
let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
trainer.add_callback(EarlyStopping::new(5, 0.001));
// Training with callbacks
// let result = trainer.train(10, || batches.clone(), |x| x.clone());Fields§
§metrics: MetricsTrackerMetrics tracker
Implementations§
Source§impl Trainer
impl Trainer
Sourcepub fn new(
params: Vec<Tensor>,
optimizer: Box<dyn Optimizer>,
config: TrainConfig,
) -> Self
pub fn new( params: Vec<Tensor>, optimizer: Box<dyn Optimizer>, config: TrainConfig, ) -> Self
Create a new trainer
Sourcepub fn add_callback<C: TrainerCallback + 'static>(&mut self, callback: C)
pub fn add_callback<C: TrainerCallback + 'static>(&mut self, callback: C)
Add a callback to the trainer
Sourcepub fn params_mut(&mut self) -> &mut [Tensor]
pub fn params_mut(&mut self) -> &mut [Tensor]
Get mutable reference to model parameters
Sourcepub fn callbacks(&self) -> &CallbackManager
pub fn callbacks(&self) -> &CallbackManager
Get reference to callback manager
Sourcepub fn callbacks_mut(&mut self) -> &mut CallbackManager
pub fn callbacks_mut(&mut self) -> &mut CallbackManager
Get mutable reference to callback manager
Sourcepub fn save(
&self,
path: impl AsRef<Path>,
name: &str,
architecture: &str,
) -> Result<()>
pub fn save( &self, path: impl AsRef<Path>, name: &str, architecture: &str, ) -> Result<()>
Save model parameters to a file
This method persists the trained model weights to disk in SafeTensors format. Call this after training completes to preserve the learned parameters.
§Arguments
path- Output file path (should end in .safetensors)name- Model name for metadataarchitecture- Model architecture description
§Example
// After training...
trainer.save("model.safetensors", "my-model", "linear").expect("save failed");§Errors
Returns an error if the file cannot be written.
Sourcepub fn save_with_names(
&self,
path: impl AsRef<Path>,
name: &str,
architecture: &str,
param_names: &[&str],
) -> Result<()>
pub fn save_with_names( &self, path: impl AsRef<Path>, name: &str, architecture: &str, param_names: &[&str], ) -> Result<()>
Save model with custom parameter names
Like save() but allows specifying custom names for each parameter tensor.
§Arguments
path- Output file pathname- Model namearchitecture- Architecture descriptionparam_names- Names for each parameter (must match params length)
§Errors
Returns an error if param_names length doesn’t match params or file cannot be written.
Source§impl Trainer
impl Trainer
Sourcepub fn train_epoch<F, I>(&mut self, batches: I, forward_fn: F) -> f32
pub fn train_epoch<F, I>(&mut self, batches: I, forward_fn: F) -> f32
Sourcepub fn validate<F, I>(&mut self, batches: I, forward_fn: F) -> f32
pub fn validate<F, I>(&mut self, batches: I, forward_fn: F) -> f32
Validate on a dataset without updating parameters
§Arguments
batches- Iterator over validation batchesforward_fn- Closure that computes predictions from inputs
§Returns
Average validation loss
§Example
let val_loss = trainer.validate(val_batches, |x| x.clone());
println!("Validation loss: {:.4}", val_loss);Source§impl Trainer
impl Trainer
Sourcepub fn train_step<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
pub fn train_step<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
Perform a single training step
§Arguments
batch- Training batch with inputs and targetsforward_fn- Closure that computes predictions from inputs
§Returns
Scalar loss value for this batch
§Example
let loss = trainer.train_step(&batch, |inputs| {
// Forward pass: compute predictions
inputs.clone() // Simplified example
});Source§impl Trainer
impl Trainer
Sourcepub fn train<F, B, I>(
&mut self,
max_epochs: usize,
batch_fn: B,
forward_fn: F,
) -> TrainResult
pub fn train<F, B, I>( &mut self, max_epochs: usize, batch_fn: B, forward_fn: F, ) -> TrainResult
Train for multiple epochs with full callback support
§Arguments
max_epochs- Maximum number of epochs to trainbatch_fn- Function that returns batches for each epochforward_fn- Closure that computes predictions from inputs
§Returns
TrainResult with final metrics
§Example
trainer.add_callback(EarlyStopping::new(5, 0.001));
let result = trainer.train(100, || batches.clone(), |x| x.clone());
println!("Trained {} epochs, final loss: {:.4}", result.final_epoch, result.final_loss);Source§impl Trainer
impl Trainer
Sourcepub fn train_with_val<F, BT, BV, IT, IV>(
&mut self,
max_epochs: usize,
train_fn: BT,
val_fn: BV,
forward_fn: F,
) -> TrainResultwhere
F: Fn(&Tensor) -> Tensor,
BT: Fn() -> IT,
BV: Fn() -> IV,
IT: IntoIterator<Item = Batch>,
IV: IntoIterator<Item = Batch>,
pub fn train_with_val<F, BT, BV, IT, IV>(
&mut self,
max_epochs: usize,
train_fn: BT,
val_fn: BV,
forward_fn: F,
) -> TrainResultwhere
F: Fn(&Tensor) -> Tensor,
BT: Fn() -> IT,
BV: Fn() -> IV,
IT: IntoIterator<Item = Batch>,
IV: IntoIterator<Item = Batch>,
Train for multiple epochs with validation after each epoch
This method runs training and validation each epoch, passing validation loss to callbacks for proper early stopping and checkpointing.
§Arguments
max_epochs- Maximum number of epochs to traintrain_fn- Function that returns training batches for each epochval_fn- Function that returns validation batches for each epochforward_fn- Closure that computes predictions from inputs
§Returns
TrainResult with final metrics including best validation loss
§Example
trainer.add_callback(EarlyStopping::new(5, 0.001).monitor_validation());
let result = trainer.train_with_val(
100,
|| train_batches.clone(),
|| val_batches.clone(),
|x| x.clone()
);
println!("Best val loss: {:.4}", result.best_loss);Auto Trait Implementations§
impl Freeze for Trainer
impl !RefUnwindSafe for Trainer
impl !Send for Trainer
impl !Sync for Trainer
impl Unpin for Trainer
impl UnsafeUnpin for Trainer
impl !UnwindSafe for Trainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.