candle_optimisers/lib.rs
1/*!
2Optimisers for use with the [candle](https://github.com/huggingface/candle) framework for lightweight machine learning.
3Apart from LBFGS, these all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn
4
5# Example
6
7Training an MNIST model using the Adam optimiser
8
9```
10# use candle_core::{Result, Tensor};
11# use candle_core::{DType, D};
12# use candle_nn::{loss, ops, VarBuilder, VarMap, optim::Optimizer};
13# use candle_optimisers::{
14# adam::{Adam, ParamsAdam}
15# };
16#
17# pub trait Model: Sized {
18# fn new(vs: VarBuilder) -> Result<Self>;
19# fn forward(&self, xs: &Tensor) -> Result<Tensor>;
20# }
21#
22# pub fn training_loop<M: Model>(
23# m: candle_datasets::vision::Dataset,
24# varmap: &VarMap,
25# model: M,
26# ) -> anyhow::Result<()> {
27# // check to see if cuda device availabke
28# let dev = candle_core::Device::cuda_if_available(0)?;
29# // get the input from the dataset and put on device
30# let train_images = m.train_images.to_device(&dev)?;
31# // get the training labels on the device
32# let train_labels = m.train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
33#
34#
35# // load the test images
36# let test_images = m.test_images.to_device(&dev)?;
37# // load the test labels
38# let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
39#
40 // create the Adam optimiser
41
42 // set the learning rate to 0.004 and use the default parameters for everything else
43 let params = ParamsAdam {
44 lr: 0.004,
45 ..Default::default()
46 };
47 // create the optimiser by passing in the variable to be optimised and the parameters
48 let mut optimiser = Adam::new(varmap.all_vars(), params)?;
49
50 // loop for model optimisation
51 for epoch in 0..100 {
52 // run the model forwards
53 // get log probabilities of results
54 let logits = model.forward(&train_images)?;
55 // softmax the log probabilities
56 let log_sm = ops::log_softmax(&logits, D::Minus1)?;
57 // get the loss
58 let loss = loss::nll(&log_sm, &train_labels)?;
59 // step the tensors by backpropagating the loss
60 optimiser.backward_step(&loss)?;
61
62 # // get the log probabilities of the test images
63 # let test_logits = model.forward(&test_images)?;
64 # // get the sum of the correct predictions
65 # let sum_ok = test_logits
66 # .argmax(D::Minus1)?
67 # .eq(&test_labels)?
68 # .to_dtype(DType::F32)?
69 # .sum_all()?
70 # .to_scalar::<f32>()?;
71 # // get the accuracy on the test set
72 # #[allow(clippy::cast_precision_loss)]
73 # let test_accuracy = sum_ok / test_labels.dims1()? as f32;
74 # println!(
75 # "{:4} train loss: {:8.5} test acc: {:5.2}%",
76 # epoch + 1,
77 # loss.to_scalar::<f32>()?,
78 # 100. * test_accuracy
79 # );
80 }
81 Ok(())
82# }
83```
84*/
85
86use std::fmt::Debug;
87
88use candle_core::Result as CResult;
89use candle_core::Tensor;
90use candle_core::Var;
91pub mod adadelta;
92pub mod adagrad;
93pub mod adam;
94pub mod adamax;
95pub mod esgd;
96pub mod lbfgs;
97pub mod nadam;
98pub mod radam;
99pub mod rmsprop;
100
101/// Trait for optimisers to expose their parameters
102pub trait OptimParams: candle_nn::optim::Optimizer {
103 /// get the current parameters of the Optimiser
104 fn params(&self) -> &Self::Config;
105 /// set the current parameters of the Optimiser
106 fn set_params(&mut self, config: Self::Config);
107}
108
109/// Trait for Models: this is needed for optimisers that require the ability to calculate the loss
110/// such as LBFGS
111pub trait Model: Sized {
112 /// get the loss of the model
113 fn loss(&self) -> CResult<Tensor>; //, xs: &Tensor, ys: &Tensor
114}
115
116/// trait for optimisers like LBFGS that need the ability to calculate the loss
117/// and its gradient
118pub trait LossOptimizer<M: Model>: Sized {
119 /// type of the optimiser configuration
120 type Config: Sized;
121 /// create a new optimiser from a Vec of variables, setup parameters and a model
122 fn new(vs: Vec<Var>, params: Self::Config, model: M) -> CResult<Self>;
123 /// take a step of the optimiser
124 fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome>; //, xs: &Tensor, ys: &Tensor
125 /// get the current learning rate
126 fn learning_rate(&self) -> f64;
127 /// set the learning rate
128 fn set_learning_rate(&mut self, lr: f64);
129 /// get the a vec of the variables being optimised
130 fn into_inner(self) -> Vec<Var>;
131 /// create a new optimiser from a slice of variables, setup parameters and a model
132 fn from_slice(vars: &[&Var], config: Self::Config, model: M) -> CResult<Self> {
133 let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
134 Self::new(vars, config, model)
135 }
136}
137
138/// Outcomes of an optimiser step for methods such as LBFGS
139#[derive(Debug)]
140pub enum ModelOutcome {
141 /// The model took a step and the loss decreased
142 /// contains next loss and the number of func evals
143 Stepped(Tensor, usize),
144 /// The model has converged and the loss has not changed
145 /// contains loss and the number of func evals
146 Converged(Tensor, usize),
147}
148
149/// Method of weight decay to use
150#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
151pub enum Decay {
152 /// Weight decay regularisation to penalise large weights
153 ///
154 /// The gradient is transformed as
155 /// $$ g_{t} \\gets g_{t} + \\lambda \\theta_{t-1}$$
156 ///
157 /// This is equivalent to an L2 regularisation term in the loss adding $\\frac{\\lambda}{2}||\theta||_{2}^{2}$ but avoids autodifferentiation
158 /// of the L2 term
159 WeightDecay(f64),
160 /// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
161 ///
162 /// This directly decays the weights as
163 ///
164 /// $$ \\theta_{t} \\gets (1 - \\eta \\lambda) \\theta_{t-1}$$
165 ///
166 /// This is equivalent to regularisation, only for SGD without momentum, but is different for adaptive gradient methods
167 DecoupledWeightDecay(f64),
168}
169
170/// Type of momentum to use
171#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
172pub enum Momentum {
173 /// classical momentum
174 Classical(f64),
175 /// nesterov momentum
176 Nesterov(f64),
177}