Skip to main content

Trainer

Struct Trainer 

Source
pub struct Trainer {
    pub metrics: MetricsTracker,
    /* private fields */
}
Expand description

High-level trainer that orchestrates the training loop

§Example

use entrenar::train::{Trainer, TrainConfig, Batch, MSELoss, EarlyStopping};
use entrenar::optim::Adam;
use entrenar::Tensor;

// Setup
let params = vec![Tensor::zeros(10, true)];
let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
let config = TrainConfig::default();

let mut trainer = Trainer::new(params, Box::new(optimizer), config);
trainer.set_loss(Box::new(MSELoss));
trainer.add_callback(EarlyStopping::new(5, 0.001));

// Training with callbacks
// let result = trainer.train(10, || batches.clone(), |x| x.clone());

Fields§

§metrics: MetricsTracker

Metrics tracker

Implementations§

Source§

impl Trainer

Source

pub fn new( params: Vec<Tensor>, optimizer: Box<dyn Optimizer>, config: TrainConfig, ) -> Self

Create a new trainer

Source

pub fn set_loss(&mut self, loss_fn: Box<dyn LossFn>)

Set the loss function

Source

pub fn add_callback<C: TrainerCallback + 'static>(&mut self, callback: C)

Add a callback to the trainer

Source

pub fn lr(&self) -> f32

Get current learning rate

Source

pub fn set_lr(&mut self, lr: f32)

Set learning rate

Source

pub fn params(&self) -> &[Tensor]

Get reference to model parameters

Source

pub fn params_mut(&mut self) -> &mut [Tensor]

Get mutable reference to model parameters

Source

pub fn callbacks(&self) -> &CallbackManager

Get reference to callback manager

Source

pub fn callbacks_mut(&mut self) -> &mut CallbackManager

Get mutable reference to callback manager

Source

pub fn save( &self, path: impl AsRef<Path>, name: &str, architecture: &str, ) -> Result<()>

Save model parameters to a file

This method persists the trained model weights to disk in SafeTensors format. Call this after training completes to preserve the learned parameters.

§Arguments
  • path - Output file path (should end in .safetensors)
  • name - Model name for metadata
  • architecture - Model architecture description
§Example
// After training...
trainer.save("model.safetensors", "my-model", "linear").expect("save failed");
§Errors

Returns an error if the file cannot be written.

Source

pub fn save_with_names( &self, path: impl AsRef<Path>, name: &str, architecture: &str, param_names: &[&str], ) -> Result<()>

Save model with custom parameter names

Like save() but allows specifying custom names for each parameter tensor.

§Arguments
  • path - Output file path
  • name - Model name
  • architecture - Architecture description
  • param_names - Names for each parameter (must match params length)
§Errors

Returns an error if param_names length doesn’t match params or file cannot be written.

Source§

impl Trainer

Source

pub fn train_epoch<F, I>(&mut self, batches: I, forward_fn: F) -> f32
where F: Fn(&Tensor) -> Tensor, I: IntoIterator<Item = Batch>,

Train for one epoch

§Arguments
  • batches - Iterator over training batches
  • forward_fn - Closure that computes predictions from inputs
§Returns

Average loss over the epoch

Source

pub fn validate<F, I>(&mut self, batches: I, forward_fn: F) -> f32
where F: Fn(&Tensor) -> Tensor, I: IntoIterator<Item = Batch>,

Validate on a dataset without updating parameters

§Arguments
  • batches - Iterator over validation batches
  • forward_fn - Closure that computes predictions from inputs
§Returns

Average validation loss

§Example
let val_loss = trainer.validate(val_batches, |x| x.clone());
println!("Validation loss: {:.4}", val_loss);
Source§

impl Trainer

Source

pub fn train_step<F>(&mut self, batch: &Batch, forward_fn: F) -> f32
where F: FnOnce(&Tensor) -> Tensor,

Perform a single training step

§Arguments
  • batch - Training batch with inputs and targets
  • forward_fn - Closure that computes predictions from inputs
§Returns

Scalar loss value for this batch

§Example
let loss = trainer.train_step(&batch, |inputs| {
    // Forward pass: compute predictions
    inputs.clone() // Simplified example
});
Source§

impl Trainer

Source

pub fn train<F, B, I>( &mut self, max_epochs: usize, batch_fn: B, forward_fn: F, ) -> TrainResult
where F: Fn(&Tensor) -> Tensor, B: Fn() -> I, I: IntoIterator<Item = Batch>,

Train for multiple epochs with full callback support

§Arguments
  • max_epochs - Maximum number of epochs to train
  • batch_fn - Function that returns batches for each epoch
  • forward_fn - Closure that computes predictions from inputs
§Returns

TrainResult with final metrics

§Example
trainer.add_callback(EarlyStopping::new(5, 0.001));

let result = trainer.train(100, || batches.clone(), |x| x.clone());
println!("Trained {} epochs, final loss: {:.4}", result.final_epoch, result.final_loss);
Source§

impl Trainer

Source

pub fn train_with_val<F, BT, BV, IT, IV>( &mut self, max_epochs: usize, train_fn: BT, val_fn: BV, forward_fn: F, ) -> TrainResult
where F: Fn(&Tensor) -> Tensor, BT: Fn() -> IT, BV: Fn() -> IV, IT: IntoIterator<Item = Batch>, IV: IntoIterator<Item = Batch>,

Train for multiple epochs with validation after each epoch

This method runs training and validation each epoch, passing validation loss to callbacks for proper early stopping and checkpointing.

§Arguments
  • max_epochs - Maximum number of epochs to train
  • train_fn - Function that returns training batches for each epoch
  • val_fn - Function that returns validation batches for each epoch
  • forward_fn - Closure that computes predictions from inputs
§Returns

TrainResult with final metrics including best validation loss

§Example
trainer.add_callback(EarlyStopping::new(5, 0.001).monitor_validation());

let result = trainer.train_with_val(
    100,
    || train_batches.clone(),
    || val_batches.clone(),
    |x| x.clone()
);
println!("Best val loss: {:.4}", result.best_loss);

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> Conv for T

Source§

fn conv<T>(self) -> T
where Self: Into<T>,

Converts self into T using Into<T>. Read more
Source§

impl<T> Downcast<T> for T

Source§

fn downcast(&self) -> &T

Source§

impl<T> FmtForward for T

Source§

fn fmt_binary(self) -> FmtBinary<Self>
where Self: Binary,

Causes self to use its Binary implementation when Debug-formatted.
Source§

fn fmt_display(self) -> FmtDisplay<Self>
where Self: Display,

Causes self to use its Display implementation when Debug-formatted.
Source§

fn fmt_lower_exp(self) -> FmtLowerExp<Self>
where Self: LowerExp,

Causes self to use its LowerExp implementation when Debug-formatted.
Source§

fn fmt_lower_hex(self) -> FmtLowerHex<Self>
where Self: LowerHex,

Causes self to use its LowerHex implementation when Debug-formatted.
Source§

fn fmt_octal(self) -> FmtOctal<Self>
where Self: Octal,

Causes self to use its Octal implementation when Debug-formatted.
Source§

fn fmt_pointer(self) -> FmtPointer<Self>
where Self: Pointer,

Causes self to use its Pointer implementation when Debug-formatted.
Source§

fn fmt_upper_exp(self) -> FmtUpperExp<Self>
where Self: UpperExp,

Causes self to use its UpperExp implementation when Debug-formatted.
Source§

fn fmt_upper_hex(self) -> FmtUpperHex<Self>
where Self: UpperHex,

Causes self to use its UpperHex implementation when Debug-formatted.
Source§

fn fmt_list(self) -> FmtList<Self>
where &'a Self: for<'a> IntoIterator,

Formats each item in a sequence. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pipe for T
where T: ?Sized,

Source§

fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> R
where Self: Sized,

Pipes by value. This is generally the method you want to use. Read more
Source§

fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> R
where R: 'a,

Borrows self and passes that borrow into the pipe function. Read more
Source§

fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> R
where R: 'a,

Mutably borrows self and passes that borrow into the pipe function. Read more
Source§

fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
where Self: Borrow<B>, B: 'a + ?Sized, R: 'a,

Borrows self, then passes self.borrow() into the pipe function. Read more
Source§

fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
where Self: BorrowMut<B>, B: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.borrow_mut() into the pipe function. Read more
Source§

fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
where Self: AsRef<U>, U: 'a + ?Sized, R: 'a,

Borrows self, then passes self.as_ref() into the pipe function.
Source§

fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
where Self: AsMut<U>, U: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.as_mut() into the pipe function.
Source§

fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
where Self: Deref<Target = T>, T: 'a + ?Sized, R: 'a,

Borrows self, then passes self.deref() into the pipe function.
Source§

fn pipe_deref_mut<'a, T, R>( &'a mut self, func: impl FnOnce(&'a mut T) -> R, ) -> R
where Self: DerefMut<Target = T> + Deref, T: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.deref_mut() into the pipe function.
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> PolicyExt for T
where T: ?Sized,

Source§

fn and<P, B, E>(self, other: P) -> And<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow only if self and other return Action::Follow. Read more
Source§

fn or<P, B, E>(self, other: P) -> Or<T, P>
where T: Policy<B, E>, P: Policy<B, E>,

Create a new Policy that returns Action::Follow if either self or other returns Action::Follow. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T> Tap for T

Source§

fn tap(self, func: impl FnOnce(&Self)) -> Self

Immutable access to a value. Read more
Source§

fn tap_mut(self, func: impl FnOnce(&mut Self)) -> Self

Mutable access to a value. Read more
Source§

fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
where Self: Borrow<B>, B: ?Sized,

Immutable access to the Borrow<B> of a value. Read more
Source§

fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
where Self: BorrowMut<B>, B: ?Sized,

Mutable access to the BorrowMut<B> of a value. Read more
Source§

fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
where Self: AsRef<R>, R: ?Sized,

Immutable access to the AsRef<R> view of a value. Read more
Source§

fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
where Self: AsMut<R>, R: ?Sized,

Mutable access to the AsMut<R> view of a value. Read more
Source§

fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
where Self: Deref<Target = T>, T: ?Sized,

Immutable access to the Deref::Target of a value. Read more
Source§

fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
where Self: DerefMut<Target = T> + Deref, T: ?Sized,

Mutable access to the Deref::Target of a value. Read more
Source§

fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self

Calls .tap() only in debug builds, and is erased in release builds.
Source§

fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self

Calls .tap_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
where Self: Borrow<B>, B: ?Sized,

Calls .tap_borrow() only in debug builds, and is erased in release builds.
Source§

fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
where Self: BorrowMut<B>, B: ?Sized,

Calls .tap_borrow_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
where Self: AsRef<R>, R: ?Sized,

Calls .tap_ref() only in debug builds, and is erased in release builds.
Source§

fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
where Self: AsMut<R>, R: ?Sized,

Calls .tap_ref_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
where Self: Deref<Target = T>, T: ?Sized,

Calls .tap_deref() only in debug builds, and is erased in release builds.
Source§

fn tap_deref_mut_dbg<T>(self, func: impl FnOnce(&mut T)) -> Self
where Self: DerefMut<Target = T> + Deref, T: ?Sized,

Calls .tap_deref_mut() only in debug builds, and is erased in release builds.
Source§

impl<T> TryConv for T

Source§

fn try_conv<T>(self) -> Result<T, Self::Error>
where Self: TryInto<T>,

Attempts to convert self into T using TryInto<T>. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<S, T> Upcast<T> for S
where T: UpcastFrom<S> + ?Sized, S: ?Sized,

Source§

fn upcast(&self) -> &T
where Self: ErasableGeneric, T: ErasableGeneric<Repr = Self::Repr>,

Perform a zero-cost type-safe upcast to a wider ref type within the Wasm bindgen generics type system. Read more
Source§

fn upcast_into(self) -> T
where Self: Sized + ErasableGeneric, T: ErasableGeneric<Repr = Self::Repr>,

Perform a zero-cost type-safe upcast to a wider type within the Wasm bindgen generics type system. Read more
Source§

impl<T> Upcast<T> for T

Source§

fn upcast(&self) -> Option<&T>

Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more