candle_optimisers/
lib.rs

1/*!
2Optimisers for use with the [candle](https://github.com/huggingface/candle) framework for lightweight machine learning.
3Apart from LBFGS, these all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn
4
5# Example
6
7Training an MNIST model using the Adam optimiser
8
9```
10# use candle_core::{Result, Tensor};
11# use candle_core::{DType, D};
12# use candle_nn::{loss, ops, VarBuilder, VarMap, optim::Optimizer};
13# use candle_optimisers::{
14#     adam::{Adam, ParamsAdam}
15# };
16#
17# pub trait Model: Sized {
18#     fn new(vs: VarBuilder) -> Result<Self>;
19#     fn forward(&self, xs: &Tensor) -> Result<Tensor>;
20# }
21#
22# pub fn training_loop<M: Model>(
23#     m: candle_datasets::vision::Dataset,
24#     varmap: &VarMap,
25#     model: M,
26# ) -> anyhow::Result<()> {
27#     // check to see if cuda device availabke
28#     let dev = candle_core::Device::cuda_if_available(0)?;
29#     // get the input from the dataset and put on device
30#     let train_images = m.train_images.to_device(&dev)?;
31#     // get the training labels on the device
32#     let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
33#
34#
35#     // load the test images
36#     let test_images = m.test_images.to_device(&dev)?;
37#     // load the test labels
38#     let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
39#
40    // create the Adam optimiser
41
42    // set the learning rate to 0.004 and use the default parameters for everything else
43    let params = ParamsAdam {
44            lr: 0.004,
45            ..Default::default()
46        };
47    // create the optimiser by passing in the variable to be optimised and the parameters
48    let mut optimiser = Adam::new(varmap.all_vars(),  params)?;
49
50    // loop for model optimisation
51    for epoch in 0..100 {
52        // run the model forwards
53        // get log probabilities of results
54        let logits = model.forward(&train_images)?;
55        // softmax the log probabilities
56        let log_sm = ops::log_softmax(&logits, D::Minus1)?;
57        // get the loss
58        let loss = loss::nll(&log_sm, &train_labels)?;
59        // step the tensors by backpropagating the loss
60        optimiser.backward_step(&loss)?;
61
62        # // get the log probabilities of the test images
63        # let test_logits = model.forward(&test_images)?;
64        # // get the sum of the correct predictions
65        # let sum_ok = test_logits
66        #     .argmax(D::Minus1)?
67        #     .eq(&test_labels)?
68        #     .to_dtype(DType::F32)?
69        #     .sum_all()?
70        #     .to_scalar::<f32>()?;
71        # // get the accuracy on the test set
72        # #[allow(clippy::cast_precision_loss)]
73        # let test_accuracy = sum_ok / test_labels.dims1()? as f32;
74        # println!(
75        #     "{:4} train loss: {:8.5} test acc: {:5.2}%",
76        #     epoch + 1,
77        #     loss.to_scalar::<f32>()?,
78        #     100. * test_accuracy
79        # );
80    }
81    Ok(())
82# }
83```
84*/
85
86use std::fmt::Debug;
87
88use candle_core::Result as CResult;
89use candle_core::Tensor;
90use candle_core::Var;
91pub mod adadelta;
92pub mod adagrad;
93pub mod adam;
94pub mod adamax;
95pub mod esgd;
96pub mod lbfgs;
97pub mod nadam;
98pub mod radam;
99pub mod rmsprop;
100
101/// Trait for optimisers to expose their parameters
102pub trait OptimParams: candle_nn::optim::Optimizer {
103    /// get the current parameters of the Optimiser
104    fn params(&self) -> &Self::Config;
105    /// set the current parameters of the Optimiser
106    fn set_params(&mut self, config: Self::Config);
107}
108
109/// Trait for Models: this is needed for optimisers that require the ability to calculate the loss
110/// such as LBFGS
111pub trait Model: Sized {
112    /// get the loss of the model
113    fn loss(&self) -> CResult<Tensor>; //, xs: &Tensor, ys: &Tensor
114}
115
116/// trait for optimisers like LBFGS that need the ability to calculate the loss
117/// and its gradient
118pub trait LossOptimizer<M: Model>: Sized {
119    /// type of the optimiser configuration
120    type Config: Sized;
121    /// create a new optimiser from a Vec of variables, setup parameters and a model
122    fn new(vs: Vec<Var>, params: Self::Config, model: M) -> CResult<Self>;
123    /// take a step of the optimiser
124    fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome>; //, xs: &Tensor, ys: &Tensor
125    /// get the current learning rate
126    fn learning_rate(&self) -> f64;
127    /// set the learning rate
128    fn set_learning_rate(&mut self, lr: f64);
129    /// get the a vec of the variables being optimised
130    fn into_inner(self) -> Vec<Var>;
131    /// create a new optimiser from a slice of variables, setup parameters and a model
132    fn from_slice(vars: &[&Var], config: Self::Config, model: M) -> CResult<Self> {
133        let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
134        Self::new(vars, config, model)
135    }
136}
137
138/// Outcomes of an optimiser step for methods such as LBFGS
139#[derive(Debug)]
140pub enum ModelOutcome {
141    /// The model took a step and the loss decreased
142    /// contains next loss and the number of func evals
143    Stepped(Tensor, usize),
144    /// The model has converged and the loss has not changed
145    /// contains loss and the number of func evals
146    Converged(Tensor, usize),
147}
148
149/// Method of weight decay to use
150#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
151pub enum Decay {
152    /// Weight decay regularisation to penalise large weights
153    ///
154    /// The gradient is transformed as
155    /// $$ g_{t} \\gets g_{t} + \\lambda  \\theta_{t-1}$$
156    ///
157    /// This is equivalent to an L2 regularisation term in the loss adding $\\frac{\\lambda}{2}||\theta||_{2}^{2}$ but avoids autodifferentiation
158    /// of the L2 term
159    WeightDecay(f64),
160    /// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
161    ///
162    /// This directly decays the weights as
163    ///
164    /// $$ \\theta_{t} \\gets (1 - \\eta \\lambda) \\theta_{t-1}$$
165    ///
166    /// This is equivalent to regularisation, only for SGD without momentum, but is different for adaptive gradient methods
167    DecoupledWeightDecay(f64),
168}
169
170/// Type of momentum to use
171#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
172pub enum Momentum {
173    /// classical momentum
174    Classical(f64),
175    /// nesterov momentum
176    Nesterov(f64),
177}