LearnerBuilder

Struct LearnerBuilder 

Source
pub struct LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
where B: AutodiffBackend, M: AutodiffModule<B> + TrainStep<TI, TO> + Display + 'static, M::InnerModule: ValidStep<VI, VO>, O: Optimizer<M, B>, S: LrScheduler, TI: Send + 'static, VI: Send + 'static, TO: ItemLazy + 'static, VO: ItemLazy + 'static,
{ /* private fields */ }
Expand description

Struct to configure and create a learner.

The generics components of the builder should probably not be set manually, as they are optimized for Rust type inference.

Implementations§

Source§

impl<B, M, O, S, TI, VI, TO, VO> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
where B: AutodiffBackend, M: AutodiffModule<B> + TrainStep<TI, TO> + Display + 'static, M::InnerModule: ValidStep<VI, VO>, O: Optimizer<M, B>, S: LrScheduler, TI: Send + 'static, VI: Send + 'static, TO: ItemLazy + 'static, VO: ItemLazy + 'static,

Source

pub fn new(directory: impl AsRef<Path>) -> Self

Creates a new learner builder.

§Arguments
  • directory - The directory to save the checkpoints.
Source

pub fn metric_loggers<MT, MV>(self, logger_train: MT, logger_valid: MV) -> Self
where MT: MetricLogger + 'static, MV: MetricLogger + 'static,

Replace the default metric loggers with the provided ones.

§Arguments
  • logger_train - The training logger.
  • logger_valid - The validation logger.
Source

pub fn with_checkpointing_strategy<CS>(self, strategy: CS) -> Self
where CS: CheckpointingStrategy + 'static,

Update the checkpointing_strategy.

Source

pub fn renderer<MR>(self, renderer: MR) -> Self
where MR: MetricsRenderer + 'static,

Replace the default CLI renderer with a custom one.

§Arguments
  • renderer - The custom renderer.
Source

pub fn metrics<Me: MetricRegistration<B, M, O, S, TI, VI, TO, VO>>( self, metrics: Me, ) -> Self

Register all metrics as numeric for the training and validation set.

Source

pub fn metrics_text<Me: TextMetricRegistration<B, M, O, S, TI, VI, TO, VO>>( self, metrics: Me, ) -> Self

Register all metrics as numeric for the training and validation set.

Source

pub fn metric_train<Me: Metric + 'static>(self, metric: Me) -> Self
where <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,

Register a training metric.

Source

pub fn metric_valid<Me: Metric + 'static>(self, metric: Me) -> Self
where <VO as ItemLazy>::ItemSync: Adaptor<Me::Input>,

Register a validation metric.

Source

pub fn grads_accumulation(self, accumulation: usize) -> Self

Enable gradients accumulation.

§Notes

When you enable gradients accumulation, the gradients object used by the optimizer will be the sum of all gradients generated by each backward pass. It might be a good idea to reduce the learning to compensate.

The effect is similar to increasing the batch size and the learning rate by the accumulation amount.

Source

pub fn metric_train_numeric<Me>(self, metric: Me) -> Self
where Me: Metric + Numeric + 'static, <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,

Register a numeric training metric.

Source

pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>( self, metric: Me, ) -> Self
where <VO as ItemLazy>::ItemSync: Adaptor<Me::Input>,

Register a numeric validation metric.

Source

pub fn num_epochs(self, num_epochs: usize) -> Self

The number of epochs the training should last.

Source

pub fn learning_strategy(self, learning_strategy: LearningStrategy<B>) -> Self

Run the training loop with different strategies

Source

pub fn checkpoint(self, checkpoint: usize) -> Self

The epoch from which the training must resume.

Source

pub fn interrupter(&self) -> Interrupter

Provides a handle that can be used to interrupt training.

Source

pub fn with_interrupter(self, interrupter: Interrupter) -> Self

Override the handle for stopping training with an externally provided handle

Source

pub fn early_stopping<Strategy>(self, strategy: Strategy) -> Self
where Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,

Register an early stopping strategy to stop the training when the conditions are meet.

Source

pub fn with_application_logger( self, logger: Option<Box<dyn ApplicationLoggerInstaller>>, ) -> Self

By default, Rust logs are captured and written into experiment.log. If disabled, standard Rust log handling will apply.

Source

pub fn with_file_checkpointer<FR>(self, recorder: FR) -> Self
where FR: FileRecorder<B> + 'static + FileRecorder<B::InnerBackend>, O::Record: 'static, M::Record: 'static, S::Record<B>: 'static,

Register a checkpointer that will save the optimizer, the model and the scheduler to different files.

Source

pub fn summary(self) -> Self

Enable the training summary report.

The summary will be displayed after .fit(), when the renderer is dropped.

Source

pub fn build( self, model: M, optim: O, lr_scheduler: S, ) -> Learner<LearnerComponentsMarker<B, S, M, O, AsyncCheckpointer<M::Record, B>, AsyncCheckpointer<O::Record, B>, AsyncCheckpointer<S::Record<B>, B>, AsyncProcessorTraining<FullEventProcessorTraining<TO, VO>>, Box<dyn CheckpointingStrategy>, LearningDataMarker<TI, VI, TO, VO>>>
where M::Record: 'static, O::Record: 'static, S::Record<B>: 'static,

Create the learner from a model and an optimizer. The learning rate scheduler can also be a simple learning rate.

Auto Trait Implementations§

§

impl<B, M, O, S, TI, VI, TO, VO> Freeze for LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
where <B as Backend>::Device: Freeze,

§

impl<B, M, O, S, TI, VI, TO, VO> !RefUnwindSafe for LearnerBuilder<B, M, O, S, TI, VI, TO, VO>

§

impl<B, M, O, S, TI, VI, TO, VO> !Send for LearnerBuilder<B, M, O, S, TI, VI, TO, VO>

§

impl<B, M, O, S, TI, VI, TO, VO> !Sync for LearnerBuilder<B, M, O, S, TI, VI, TO, VO>

§

impl<B, M, O, S, TI, VI, TO, VO> Unpin for LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
where <B as Backend>::Device: Unpin, TI: Unpin, VI: Unpin, TO: Unpin, VO: Unpin,

§

impl<B, M, O, S, TI, VI, TO, VO> !UnwindSafe for LearnerBuilder<B, M, O, S, TI, VI, TO, VO>

Blanket Implementations§

Source§

impl<T> Adaptor<()> for T

Source§

fn adapt(&self)

Adapt the type to be passed to a metric.
Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V