SupervisedTraining

Struct SupervisedTraining 

Source
pub struct SupervisedTraining<LC>{ /* private fields */ }
Expand description

Structure to configure and launch supervised learning trainings.

Implementations§

Source§

impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
where B: AutodiffBackend, LR: LrScheduler + 'static, M: TrainStep + AutodiffModule<B> + Display + 'static, M::InnerModule: InferenceStep, O: Optimizer<M, B> + 'static,

Source

pub fn new( directory: impl AsRef<Path>, dataloader_train: Arc<dyn DataLoader<B, M::Input>>, dataloader_valid: Arc<dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>>, ) -> Self

Creates a new runner for a supervised training.

§Arguments
  • directory - The directory to save the checkpoints.
  • dataloader_train - The dataloader for the training split.
  • dataloader_valid - The dataloader for the validation split.
Source§

impl<LC: LearningComponentsTypes> SupervisedTraining<LC>

Source

pub fn with_training_strategy( self, training_strategy: TrainingStrategy<LC>, ) -> Self

Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided ones.

§Arguments
  • training_strategy - The training strategy.
Source

pub fn with_metric_logger<ML>(self, logger: ML) -> Self
where ML: MetricLogger + 'static,

Replace the default metric loggers with the provided ones.

§Arguments
  • logger - The training logger.
Source

pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>( self, strategy: CS, ) -> Self

Update the checkpointing_strategy.

Source

pub fn renderer<MR>(self, renderer: MR) -> Self
where MR: MetricsRenderer + 'static,

Replace the default CLI renderer with a custom one.

§Arguments
  • renderer - The custom renderer.
Source

pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self

Register all metrics as numeric for the training and validation set.

Source

pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self

Register all metrics as numeric for the training and validation set.

Source

pub fn metric_train<Me: Metric + 'static>(self, metric: Me) -> Self

Register a training metric.

Source

pub fn metric_valid<Me: Metric + 'static>(self, metric: Me) -> Self

Register a validation metric.

Source

pub fn grads_accumulation(self, accumulation: usize) -> Self

Enable gradients accumulation.

§Notes

When you enable gradients accumulation, the gradients object used by the optimizer will be the sum of all gradients generated by each backward pass. It might be a good idea to reduce the learning to compensate.

The effect is similar to increasing the batch size and the learning rate by the accumulation amount.

Source

pub fn metric_train_numeric<Me>(self, metric: Me) -> Self

Register a numeric training metric.

Source

pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>( self, metric: Me, ) -> Self

Register a numeric validation metric.

Source

pub fn num_epochs(self, num_epochs: usize) -> Self

The number of epochs the training should last.

Source

pub fn checkpoint(self, checkpoint: usize) -> Self

The epoch from which the training must resume.

Source

pub fn interrupter(&self) -> Interrupter

Provides a handle that can be used to interrupt training.

Source

pub fn with_interrupter(self, interrupter: Interrupter) -> Self

Override the handle for stopping training with an externally provided handle

Source

pub fn early_stopping<Strategy>(self, strategy: Strategy) -> Self
where Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,

Register an early stopping strategy to stop the training when the conditions are meet.

Source

pub fn with_application_logger( self, logger: Option<Box<dyn ApplicationLoggerInstaller>>, ) -> Self

By default, Rust logs are captured and written into experiment.log. If disabled, standard Rust log handling will apply.

Source

pub fn with_file_checkpointer<FR>(self, recorder: FR) -> Self

Register a checkpointer that will save the optimizer, the model and the scheduler to different files.

Source

pub fn summary(self) -> Self

Enable the training summary report.

The summary will be displayed after .fit(), when the renderer is dropped.

Source§

impl<LC: LearningComponentsTypes + Send + 'static> SupervisedTraining<LC>

Source

pub fn launch( self, learner: Learner<LC>, ) -> LearningResult<<LC as LearningComponentsTypes>::InferenceModel>

Launch this training with the given Learner.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Adaptor<()> for T

Source§

fn adapt(&self)

Adapt the type to be passed to a metric.
Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V