candle_optimisers/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/*!
Optimisers for use with the [candle](https://github.com/huggingface/candle) framework for lightweight machine learning.
Apart from LBFGS, these all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn

# Example

Training an MNIST model using the Adam optimiser

```
# use candle_core::{Result, Tensor};
# use candle_core::{DType, D};
# use candle_nn::{loss, ops, VarBuilder, VarMap, optim::Optimizer};
# use candle_optimisers::{
#     adam::{Adam, ParamsAdam}
# };
#
# pub trait Model: Sized {
#     fn new(vs: VarBuilder) -> Result<Self>;
#     fn forward(&self, xs: &Tensor) -> Result<Tensor>;
# }
#
# pub fn training_loop<M: Model>(
#     m: candle_datasets::vision::Dataset,
#     varmap: &VarMap,
#     model: M,
# ) -> anyhow::Result<()> {
#     // check to see if cuda device availabke
#     let dev = candle_core::Device::cuda_if_available(0)?;
#     // get the input from the dataset and put on device
#     let train_images = m.train_images.to_device(&dev)?;
#     // get the training labels on the device
#     let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
#
#
#     // load the test images
#     let test_images = m.test_images.to_device(&dev)?;
#     // load the test labels
#     let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
#
    // create the Adam optimiser

    // set the learning rate to 0.004 and use the default parameters for everything else
    let params = ParamsAdam {
            lr: 0.004,
            ..Default::default()
        };
    // create the optimiser by passing in the variable to be optimised and the parameters
    let mut optimiser = Adam::new(varmap.all_vars(),  params)?;

    // loop for model optimisation
    for epoch in 0..100 {
        // run the model forwards
        // get log probabilities of results
        let logits = model.forward(&train_images)?;
        // softmax the log probabilities
        let log_sm = ops::log_softmax(&logits, D::Minus1)?;
        // get the loss
        let loss = loss::nll(&log_sm, &train_labels)?;
        // step the tensors by backpropagating the loss
        optimiser.backward_step(&loss)?;

        # // get the log probabilities of the test images
        # let test_logits = model.forward(&test_images)?;
        # // get the sum of the correct predictions
        # let sum_ok = test_logits
        #     .argmax(D::Minus1)?
        #     .eq(&test_labels)?
        #     .to_dtype(DType::F32)?
        #     .sum_all()?
        #     .to_scalar::<f32>()?;
        # // get the accuracy on the test set
        # #[allow(clippy::cast_precision_loss)]
        # let test_accuracy = sum_ok / test_labels.dims1()? as f32;
        # println!(
        #     "{:4} train loss: {:8.5} test acc: {:5.2}%",
        #     epoch + 1,
        #     loss.to_scalar::<f32>()?,
        #     100. * test_accuracy
        # );
    }
    Ok(())
# }
```
*/

use std::fmt::Debug;

use candle_core::Result as CResult;
use candle_core::Tensor;
use candle_core::Var;
pub mod adadelta;
pub mod adagrad;
pub mod adam;
pub mod adamax;
pub mod esgd;
pub mod lbfgs;
pub mod nadam;
pub mod radam;
pub mod rmsprop;

/// Trait for optimisers to expose their parameters
pub trait OptimParams: candle_nn::optim::Optimizer {
    /// get the current parameters of the Optimiser
    fn params(&self) -> &Self::Config;
    /// set the current parameters of the Optimiser
    fn set_params(&mut self, config: Self::Config);
}

/// Trait for Models: this is needed for optimisers that require the ability to calculate the loss
/// such as LBFGS
pub trait Model: Sized {
    /// get the loss of the model
    fn loss(&self) -> CResult<Tensor>; //, xs: &Tensor, ys: &Tensor
}

/// trait for optimisers like LBFGS that need the ability to calculate the loss
/// and its gradient
pub trait LossOptimizer<M: Model>: Sized {
    /// type of the optimiser configuration
    type Config: Sized;
    /// create a new optimiser from a Vec of variables, setup parameters and a model
    fn new(vs: Vec<Var>, params: Self::Config, model: M) -> CResult<Self>;
    /// take a step of the optimiser
    fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome>; //, xs: &Tensor, ys: &Tensor
    /// get the current learning rate
    fn learning_rate(&self) -> f64;
    /// set the learning rate
    fn set_learning_rate(&mut self, lr: f64);
    /// get the a vec of the variables being optimised
    fn into_inner(self) -> Vec<Var>;
    /// create a new optimiser from a slice of variables, setup parameters and a model
    fn from_slice(vars: &[&Var], config: Self::Config, model: M) -> CResult<Self> {
        let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
        Self::new(vars, config, model)
    }
}

/// Outcomes of an optimiser step for methods such as LBFGS
#[derive(Debug)]
pub enum ModelOutcome {
    /// The model took a step and the loss decreased
    /// contains next loss and the number of func evals
    Stepped(Tensor, usize),
    /// The model has converged and the loss has not changed
    /// contains loss and the number of func evals
    Converged(Tensor, usize),
}

/// Method of weight decay to use
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
pub enum Decay {
    /// Weight decay regularisation to penalise large weights
    ///
    /// The gradient is transformed as
    /// $$ g_{t} \\gets g_{t} + \\lambda  \\theta_{t-1}$$
    ///
    /// This is equivalent to an L2 regularisation term in the loss adding $\\frac{\\lambda}{2}||\theta||_{2}^{2}$ but avoids autodifferentiation
    /// of the L2 term
    WeightDecay(f64),
    /// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
    ///
    /// This directly decays the weights as
    ///
    /// $$ \\theta_{t} \\gets (1 - \\eta \\lambda) \\theta_{t-1}$$
    ///
    /// This is equivalent to regularisation, only for SGD without momentum, but is different for adaptive gradient methods
    DecoupledWeightDecay(f64),
}

/// Type of momentum to use
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
pub enum Momentum {
    /// classical momentum
    Classical(f64),
    /// nesterov momentum
    Nesterov(f64),
}