Skip to main content

Crate burn_train

Crate burn_train 

Source
Expand description

A library for training neural networks using the burn crate.

Modules§

checkpoint
The checkpoint module.
logger
The logger module.
metric
The metric module.
renderer
Renderer modules to display metrics and training information.
train
The trainer module.

Structs§

AgentEnvAsyncLooprl
An asynchronous agent/environement interface.
AgentEnvBaseLooprl
A simple, synchronized agent/environement interface.
AsyncAgentEnvLoopConfigrl
Configuration for an async agent/environment loop.
AsyncProcessorEvaluation
Event processor for the model evaluation.
AsyncProcessorTraining
Event processor for the training process.
ClassificationOutput
Simple classification output adapted for multiple metrics.
EpisodeSummaryrl
Summary of an episode.
EvaluationItem
An evaluation item.
Evaluator
Evaluates a model on a specific dataset.
EvaluatorBuilder
Struct to configure and create an evaluator.
FileApplicationLoggerInstaller
This struct is used to install a local file application logger to output logs to a given file path.
Interrupter
A handle that allows aborting the training/evaluation process early.
Learner
Learner struct encapsulating all components necessary to train a Neural Network model.
LearnerSummary
Detailed training summary.
LearnerSummaryConfig
Learning summary config.
LearningCheckpointer
Used to create, delete, or load checkpoints of the training process.
LearningComponentsMarker
Concrete type that implements the LearningComponentsTypes trait.
LearningResult
The result of a training, containing the model along with the renderer.
MetricEarlyStoppingStrategy
An early stopping strategy based on a metrics collected during training or validation.
MetricEntry
Contains the metric value at a given time.
MetricSummary
Contains the summary of recorded values for a given metric.
MultiAgentEnvLooprl
An asynchronous runner for multiple agent/environement interfaces.
MultiLabelClassificationOutput
Multi-label classification output adapted for multiple metrics.
OffPolicyConfigrl
Parameters of an on policy training with multi environments and double-batching.
OffPolicyStrategyrl
Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.
RLCheckpointerrl
Used to create, delete, or load checkpoints of the training process.
RLComponentsrl
Struct to minimise parameters passed to RLStrategy::train.
RLComponentsMarkerrl
Concrete type that implements the RLComponentsTypes trait.
RLResultrl
The result of reinforcement learning, containing the final policy along with the renderer.
RLTrainingrl
Structure to configure and launch reinforcement learning trainings.
RegressionOutput
Regression output adapted for the loss metric.
SequenceOutput
Sequence prediction output adapted for multiple metrics.
SummaryMetrics
Contains the summary of recorded metrics for the training and validation steps.
SupervisedTraining
Structure to configure and launch supervised learning trainings.
TimeSteprl
A timestep debscribing an iteration of the state/decision process.
TrainOutput
A training output.
TrainingComponents
Struct to minimise parameters passed to SupervisedLearningStrategy::train. These components are used during training.
TrainingItem
A learner item.
Trajectoryrl
A trajectory, i.e. a list of ordered TimeStep.

Enums§

EvaluatorEvent
Event happening during the evaluation process.
ExecutionStrategy
Describes where training runs.
LearnerEvent
Event happening during the training/validation process.
MultiDeviceOptim
Determine how the optimization is performed when training with multiple devices.
RLStrategiesrl
The strategy for reinforcement learning.
StoppingCondition
The condition that early stopping strategies should follow.
TrainingStrategy
How should the learner run the learning for the model

Traits§

AgentEnvLooprl
Trait for a structure that implements an agent/environement interface.
AgentMetricRegistrationrl
Trait to fake variadic generics for train step metrics.
AgentTextMetricRegistrationrl
Trait to fake variadic generics for train step text metrics.
ApplicationLoggerInstaller
This trait is used to install an application logger.
CloneEarlyStoppingStrategy
A helper trait to provide type-erased cloning.
EarlyStoppingStrategy
A strategy that checks if the training should be stopped.
EpisodeMetricRegistrationrl
Trait to fake variadic generics for episode metrics.
EpisodeTextMetricRegistrationrl
Trait to fake variadic generics for episode text metrics.
EvalMetricRegistration
Trait to fake variadic generics.
EvalTextMetricRegistration
Trait to fake variadic generics.
EventProcessorEvaluation
Process events happening during evaluation.
EventProcessorTraining
Process events happening during training and validation.
InferenceStep
Trait to be implemented for validating models.
ItemLazy
Items that are lazy are not ready to be processed by metrics.
LearningComponentsTypes
Components used for a model to learn, grouped in one trait.
MetricRegistration
Trait to fake variadic generics.
RLComponentsTypesrl
All components used by the reinforcement learning paradigm, grouped in one trait.
RLStrategyrl
Provides the fit function for any learning strategy
SupervisedLearningStrategy
Provides the fit function for any learning strategy
TextMetricRegistration
Trait to fake variadic generics.
TrainMetricRegistrationrl
Trait to fake variadic generics for env step metrics.
TrainStep
Trait to be implemented for models to be able to be trained.
TrainTextMetricRegistrationrl
Trait to fake variadic generics for env step text metrics.

Type Aliases§

CustomLearningStrategy
A reference to an implementation of SupervisedLearningStrategy.
CustomRLStrategyrl
A reference to an implementation of RLStrategy.
LearnerModelRecord
The record of the learner’s model.
LearnerOptimizerRecord
The record of the optimizer.
LearnerSchedulerRecord
The record of the LR scheduler.
RLAgentRecordrl
The record of the learning agent.
RLEventProcessorTyperl
The event processor type for reinforcement learning.
RLPolicyRecordrl
The record of the policy.
SupervisedTrainingEventProcessor
The event processor type for supervised learning.
TrainLoader
A reference to the training split DataLoader.
TrainingBackend
The training backend.
TrainingModel
The model used for training.
ValidLoader
A reference to the validation split DataLoader.