Expand description
Optimisers for use with the candle framework for lightweight machine learning.
Apart from LBFGS, these all implement the candle_nn::optim::Optimizer
trait from candle-nn
§Example
Training an MNIST model using the Adam optimiser
// create the Adam optimiser
// set the learning rate to 0.004 and use the default parameters for everything else
let params = ParamsAdam {
lr: 0.004,
..Default::default()
};
// create the optimiser by passing in the variable to be optimised and the parameters
let mut optimiser = Adam::new(varmap.all_vars(), params)?;
// loop for model optimisation
for epoch in 0..100 {
// run the model forwards
// get log probabilities of results
let logits = model.forward(&train_images)?;
// softmax the log probabilities
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// get the loss
let loss = loss::nll(&log_sm, &train_labels)?;
// step the tensors by backpropagating the loss
optimiser.backward_step(&loss)?;
}
Ok(())
Modules§
- Adadelta optimiser
- Adagrad optimiser
- Adam optimiser (inlcuding AdamW)
- Adamax optimiser
- Stochastic Gradient Descent
- Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm
- NAdam optimiser: Adam with Nesterov momentum
- RAdam optimiser
- RMS prop algorithm
Enums§
- Method of weight decay to use
- Outcomes of an optimiser step for methods such as LBFGS
- Type of momentum to use
Traits§
- trait for optimisers like LBFGS that need the ability to calculate the loss and its gradient
- Trait for Models: this is needed for optimisers that require the ability to calculate the loss such as LBFGS
- Trait for optimisers to expose their parameters