Crate candle_optimisers

Source
Expand description

Optimisers for use with the candle framework for lightweight machine learning. Apart from LBFGS, these all implement the candle_nn::optim::Optimizer trait from candle-nn

§Example

Training an MNIST model using the Adam optimiser

    // create the Adam optimiser

    // set the learning rate to 0.004 and use the default parameters for everything else
    let params = ParamsAdam {
            lr: 0.004,
            ..Default::default()
        };
    // create the optimiser by passing in the variable to be optimised and the parameters
    let mut optimiser = Adam::new(varmap.all_vars(),  params)?;

    // loop for model optimisation
    for epoch in 0..100 {
        // run the model forwards
        // get log probabilities of results
        let logits = model.forward(&train_images)?;
        // softmax the log probabilities
        let log_sm = ops::log_softmax(&logits, D::Minus1)?;
        // get the loss
        let loss = loss::nll(&log_sm, &train_labels)?;
        // step the tensors by backpropagating the loss
        optimiser.backward_step(&loss)?;

    }
    Ok(())

Modules§

  • Adadelta optimiser
  • Adagrad optimiser
  • Adam optimiser (inlcuding AdamW)
  • Adamax optimiser
  • Stochastic Gradient Descent
  • Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm
  • NAdam optimiser: Adam with Nesterov momentum
  • RAdam optimiser
  • RMS prop algorithm

Enums§

  • Method of weight decay to use
  • Outcomes of an optimiser step for methods such as LBFGS
  • Type of momentum to use

Traits§

  • trait for optimisers like LBFGS that need the ability to calculate the loss and its gradient
  • Trait for Models: this is needed for optimisers that require the ability to calculate the loss such as LBFGS
  • Trait for optimisers to expose their parameters