Skip to main content

candle_optimisers/
lbfgs.rs

1/*!
2
3Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm
4
5A pseudo second order optimiser based on the BFGS method.
6
7Described in [On the limited memory BFGS method for large scale optimization](https://link.springer.com/article/10.1007/BF01589116)
8
9For a history of size $n$, assume we have stored the last $n$ updates in form $s_{k} = x_{k+1} - x_{k}$ and $y_{k} = g_{k+1}-g_{k}$ where $g_{k} = \\nabla f(x_{k})$.
10We use a two loop recursion method to compute the direction of descent:
11
12$$
13\\begin{aligned}
14    &q = g_k\\\\
15    &// \\texttt{ Iterate over history from newest to oldest}\\\\
16    &\\mathbf{For}\\ i=k-1 \\: \\mathbf{to}\\: k-n \\: \\mathbf{do}\\\\
17    &\\hspace{5mm}\\rho_{i} = \\frac{1}{y_{i}^{\\top} s_{i}} \\\\
18    &\\hspace{5mm} \\alpha_i = \\rho_i s^\\top_i q\\\\
19    &\\hspace{5mm} q = q - \\alpha_i y_i\\\\
20    &\\gamma_k = \\frac{s_{k - 1}^{\\top} y_{k - 1}}{y_{k - 1}^{\\top} y_{k - 1}} \\\\
21    &q = \\gamma_{k} q\\\\
22    &// \\texttt{ Iterate over history from oldest to newest}\\\\
23    &\\mathbf{For}\\ i=k-n \\: \\mathbf{to}\\: k-1 \\: \\mathbf{do}\\\\
24    &\\hspace{5mm} \\beta_i = \\rho_i y^\\top_i q\\\\
25    &\\hspace{5mm} q = q + s_i (\\alpha_i - \\beta_i)\\\\
26    &q = -q
27\\end{aligned}
28$$
29*/
30
31//<https://sagecal.sourceforge.net/pytorch/index.html> possible extensions
32
33use crate::{LossOptimizer, Model, ModelOutcome};
34use candle_core::Result as CResult;
35use candle_core::{Tensor, Var};
36use log::info;
37use std::collections::VecDeque;
38// use candle_nn::optim::Optimizer;
39
40mod strong_wolfe;
41
42/// Line search method
43/// Only Strong Wolfe is currently implemented
44#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
45#[non_exhaustive]
46pub enum LineSearch {
47    /// strong wolfe line search: c1, c2, tolerance
48    /// suggested vals for c1 and c2: 1e-4, 0.9, for tolerance 1e-9
49    ///
50    /// Ensures the Strong Wolfe conditions are met for step size $t$ in direction $\\bm{d}$:
51    ///
52    /// Armijo rule:
53    /// $$ f(x + t \\bm{d}) \\leq f(x) + c_1 t \\bm{d}^T \\nabla f(x)  $$
54    ///
55    /// and
56    ///
57    ///  Strong Curvature Condition:
58    /// $$ |\\bm{d}^{T} \\nabla f(x + t \\bm{d})| \\leq c_{2} |\\bm{d}^{T} \\nabla f(x)| $$
59    StrongWolfe(f64, f64, f64),
60}
61
62/// Conditions for terminsation based on gradient
63#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
64#[non_exhaustive]
65pub enum GradConv {
66    /// convergence based on max abs component of gradient
67    MinForce(f64),
68    /// convergence based on mean force
69    RMSForce(f64),
70}
71
72/// Conditions for termination based on step size
73#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
74#[non_exhaustive]
75pub enum StepConv {
76    /// convergence based on max abs component of step
77    MinStep(f64),
78    /// convergence based on root mean size of step
79    RMSStep(f64),
80}
81
82/// Parameters for LBFGS optimiser
83#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
84pub struct ParamsLBFGS {
85    /// 'Learning rate': used for initial step size guess
86    /// and when no line search is used
87    pub lr: f64,
88    /// size of history to retain
89    pub history_size: usize,
90    /// linesearch method to use
91    pub line_search: Option<LineSearch>,
92    /// convergence criteria for gradient
93    pub grad_conv: GradConv,
94    /// convergence criteria for step size
95    pub step_conv: StepConv,
96    /// weight decay
97    pub weight_decay: Option<f64>,
98}
99
100impl Default for ParamsLBFGS {
101    fn default() -> Self {
102        Self {
103            lr: 1.,
104            // max_iter: 20,
105            // max_eval: None,
106            history_size: 100,
107            line_search: None,
108            grad_conv: GradConv::MinForce(1e-7),
109            step_conv: StepConv::MinStep(1e-9),
110            weight_decay: None,
111        }
112    }
113}
114
115/// LBFGS optimiser
116///
117/// A pseudo second order optimiser based on the BFGS method.
118///
119/// Described in [On the limited memory BFGS method for large scale optimization](https://link.springer.com/article/10.1007/BF01589116)
120///
121/// <https://sagecal.sourceforge.net/pytorch/index.html>
122#[derive(Debug)]
123pub struct Lbfgs<'a, M: Model> {
124    vars: Vec<Var>,
125    model: &'a M,
126    s_hist: VecDeque<(Tensor, Tensor)>,
127    last_grad: Option<Var>,
128    next_grad: Option<Var>,
129    last_step: Option<Var>,
130    params: ParamsLBFGS,
131    first: bool,
132}
133
134impl<'a, M: Model> LossOptimizer<'a, M> for Lbfgs<'a, M> {
135    type Config = ParamsLBFGS;
136
137    fn new(vs: Vec<Var>, params: Self::Config, model: &'a M) -> CResult<Self> {
138        let hist_size = params.history_size;
139        Ok(Lbfgs {
140            vars: vs,
141            model,
142            s_hist: VecDeque::with_capacity(hist_size),
143            last_step: None,
144            last_grad: None,
145            next_grad: None,
146            params,
147            first: true,
148        })
149    }
150
151    #[allow(clippy::too_many_lines)]
152    fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome> {
153        let mut evals = 1;
154
155        let grad = if let Some(this_grad) = &self.next_grad {
156            this_grad.as_tensor().copy()?
157        } else {
158            flat_grads(&self.vars, loss, self.params.weight_decay)?
159        };
160
161        match self.params.grad_conv {
162            GradConv::MinForce(tol) => {
163                if grad
164                    .abs()?
165                    .max(0)?
166                    .to_dtype(candle_core::DType::F64)?
167                    .to_scalar::<f64>()?
168                    < tol
169                {
170                    info!("grad converged");
171                    return Ok(ModelOutcome::Converged(loss.clone(), evals));
172                }
173            }
174            GradConv::RMSForce(tol) => {
175                if grad
176                    .sqr()?
177                    .mean_all()?
178                    .to_dtype(candle_core::DType::F64)?
179                    .to_scalar::<f64>()?
180                    .sqrt()
181                    < tol
182                {
183                    info!("grad converged");
184                    return Ok(ModelOutcome::Converged(loss.clone(), evals));
185                }
186            }
187        }
188
189        let mut yk = None;
190
191        if let Some(last) = &self.last_grad {
192            yk = Some((&grad - last.as_tensor())?);
193            last.set(&grad)?;
194        } else {
195            self.last_grad = Some(Var::from_tensor(&grad)?);
196        }
197
198        let q = Var::from_tensor(&grad)?;
199
200        let hist_size = self.s_hist.len();
201
202        if hist_size == self.params.history_size {
203            self.s_hist.pop_front();
204        }
205        if let Some(yk) = yk {
206            if let Some(step) = &self.last_step {
207                self.s_hist.push_back((step.as_tensor().clone(), yk));
208            }
209        }
210
211        let gamma = if let Some((s, y)) = self.s_hist.back() {
212            let numr = y
213                .unsqueeze(0)?
214                .matmul(&(s.unsqueeze(1)?))?
215                .to_dtype(candle_core::DType::F64)?
216                .squeeze(1)?
217                .squeeze(0)?
218                .to_scalar::<f64>()?;
219
220            let denom = y
221                .unsqueeze(0)?
222                .matmul(&(y.unsqueeze(1)?))?
223                .to_dtype(candle_core::DType::F64)?
224                .squeeze(1)?
225                .squeeze(0)?
226                .to_scalar::<f64>()?
227                + 1e-10;
228
229            numr / denom
230        } else {
231            1.
232        };
233
234        let mut rhos = VecDeque::with_capacity(hist_size);
235        let mut alphas = VecDeque::with_capacity(hist_size);
236        for (s, y) in self.s_hist.iter().rev() {
237            let rho = (y
238                .unsqueeze(0)?
239                .matmul(&(s.unsqueeze(1)?))?
240                .to_dtype(candle_core::DType::F64)?
241                .squeeze(1)?
242                .squeeze(0)?
243                .to_scalar::<f64>()?
244                + 1e-10)
245                .powi(-1);
246
247            let alpha = rho
248                * s.unsqueeze(0)?
249                    .matmul(&(q.unsqueeze(1)?))?
250                    .to_dtype(candle_core::DType::F64)?
251                    .squeeze(1)?
252                    .squeeze(0)?
253                    .to_scalar::<f64>()?;
254
255            q.set(&q.sub(&(y * alpha)?)?)?;
256            // we are iterating in reverse and so want to insert at the front of the VecDeque
257            alphas.push_front(alpha);
258            rhos.push_front(rho);
259        }
260
261        // z = q * gamma so use interior mutability of q to set it
262        q.set(&(q.as_tensor() * gamma)?)?;
263        for (((s, y), alpha), rho) in self.s_hist.iter().zip(alphas).zip(rhos) {
264            let beta = rho
265                * y.unsqueeze(0)?
266                    .matmul(&(q.unsqueeze(1)?))?
267                    .to_dtype(candle_core::DType::F64)?
268                    .squeeze(1)?
269                    .squeeze(0)?
270                    .to_scalar::<f64>()?;
271
272            q.set(&q.add(&(s * (alpha - beta))?)?)?;
273        }
274
275        // let dd = (&grad * q.as_tensor())?.sum_all()?;
276        let dd = grad
277            .unsqueeze(0)?
278            .matmul(&(q.unsqueeze(1)?))?
279            .to_dtype(candle_core::DType::F64)?
280            .squeeze(1)?
281            .squeeze(0)?
282            .to_scalar::<f64>()?;
283
284        let mut lr = if self.first {
285            self.first = false;
286            -(1_f64.min(
287                1. / grad
288                    .abs()?
289                    .sum_all()?
290                    .to_dtype(candle_core::DType::F64)?
291                    .to_scalar::<f64>()?,
292            )) * self.params.lr
293        } else {
294            -self.params.lr
295        };
296
297        if let Some(ls) = &self.params.line_search {
298            match ls {
299                LineSearch::StrongWolfe(c1, c2, tol) => {
300                    let (loss, grad, t, steps) =
301                        self.strong_wolfe(lr, &q, loss, &grad, dd, *c1, *c2, *tol, 25)?;
302                    if let Some(next_grad) = &self.next_grad {
303                        next_grad.set(&grad)?;
304                    } else {
305                        self.next_grad = Some(Var::from_tensor(&grad)?);
306                    }
307
308                    evals += steps;
309                    lr = t;
310                    q.set(&(q.as_tensor() * lr)?)?;
311
312                    if let Some(step) = &self.last_step {
313                        step.set(&q)?;
314                    } else {
315                        self.last_step = Some(Var::from_tensor(&q)?);
316                    }
317
318                    match self.params.step_conv {
319                        StepConv::MinStep(tol) => {
320                            if q.abs()?
321                                .max(0)?
322                                .to_dtype(candle_core::DType::F64)?
323                                .to_scalar::<f64>()?
324                                < tol
325                            {
326                                add_grad(&mut self.vars, q.as_tensor())?;
327                                info!("step converged");
328                                Ok(ModelOutcome::Converged(loss, evals))
329                            } else {
330                                add_grad(&mut self.vars, q.as_tensor())?;
331                                Ok(ModelOutcome::Stepped(loss, evals))
332                            }
333                        }
334                        StepConv::RMSStep(tol) => {
335                            if q.sqr()?
336                                .mean_all()?
337                                .to_dtype(candle_core::DType::F64)?
338                                .to_scalar::<f64>()?
339                                .sqrt()
340                                < tol
341                            {
342                                add_grad(&mut self.vars, q.as_tensor())?;
343                                info!("step converged");
344                                Ok(ModelOutcome::Converged(loss, evals))
345                            } else {
346                                add_grad(&mut self.vars, q.as_tensor())?;
347                                Ok(ModelOutcome::Stepped(loss, evals))
348                            }
349                        }
350                    }
351                }
352            }
353        } else {
354            q.set(&(q.as_tensor() * lr)?)?;
355
356            if let Some(step) = &self.last_step {
357                step.set(&q)?;
358            } else {
359                self.last_step = Some(Var::from_tensor(&q)?);
360            }
361
362            match self.params.step_conv {
363                StepConv::MinStep(tol) => {
364                    if q.abs()?
365                        .max(0)?
366                        .to_dtype(candle_core::DType::F64)?
367                        .to_scalar::<f64>()?
368                        < tol
369                    {
370                        add_grad(&mut self.vars, q.as_tensor())?;
371
372                        let next_loss = self.model.loss()?;
373                        evals += 1;
374                        info!("step converged");
375                        Ok(ModelOutcome::Converged(next_loss, evals))
376                    } else {
377                        add_grad(&mut self.vars, q.as_tensor())?;
378
379                        let next_loss = self.model.loss()?;
380                        evals += 1;
381                        Ok(ModelOutcome::Stepped(next_loss, evals))
382                    }
383                }
384                StepConv::RMSStep(tol) => {
385                    if q.sqr()?
386                        .mean_all()?
387                        .to_dtype(candle_core::DType::F64)?
388                        .to_scalar::<f64>()?
389                        .sqrt()
390                        < tol
391                    {
392                        add_grad(&mut self.vars, q.as_tensor())?;
393
394                        let next_loss = self.model.loss()?;
395                        evals += 1;
396                        info!("step converged");
397                        Ok(ModelOutcome::Converged(next_loss, evals))
398                    } else {
399                        add_grad(&mut self.vars, q.as_tensor())?;
400
401                        let next_loss = self.model.loss()?;
402                        evals += 1;
403                        Ok(ModelOutcome::Stepped(next_loss, evals))
404                    }
405                }
406            }
407        }
408    }
409
410    fn learning_rate(&self) -> f64 {
411        self.params.lr
412    }
413
414    fn set_learning_rate(&mut self, lr: f64) {
415        self.params.lr = lr;
416    }
417
418    fn into_inner(self) -> Vec<Var> {
419        self.vars
420    }
421}
422
423#[allow(clippy::inline_always)]
424#[inline(always)]
425fn flat_grads(vs: &Vec<Var>, loss: &Tensor, weight_decay: Option<f64>) -> CResult<Tensor> {
426    let grads = loss.backward()?;
427    let mut flat_grads = Vec::with_capacity(vs.len());
428    if let Some(wd) = weight_decay {
429        for v in vs {
430            if let Some(grad) = grads.get(v) {
431                let grad = &(grad + (wd * v.as_tensor())?)?;
432                flat_grads.push(grad.flatten_all()?);
433            } else {
434                let grad = (wd * v.as_tensor())?; // treat as if grad were 0
435                flat_grads.push(grad.flatten_all()?);
436            }
437        }
438    } else {
439        for v in vs {
440            if let Some(grad) = grads.get(v) {
441                flat_grads.push(grad.flatten_all()?);
442            } else {
443                let n_elems = v.elem_count();
444                flat_grads.push(candle_core::Tensor::zeros(n_elems, v.dtype(), v.device())?);
445            }
446        }
447    }
448    candle_core::Tensor::cat(&flat_grads, 0)
449}
450
451fn add_grad(vs: &mut Vec<Var>, flat_tensor: &Tensor) -> CResult<()> {
452    let mut offset = 0;
453    for var in vs {
454        let n_elems = var.elem_count();
455        let tensor = flat_tensor
456            .narrow(0, offset, n_elems)?
457            .reshape(var.shape())?;
458        var.set(&var.add(&tensor)?)?;
459        offset += n_elems;
460    }
461    Ok(())
462}
463
464fn set_vs(vs: &mut [Var], vals: &Vec<Tensor>) -> CResult<()> {
465    for (var, t) in vs.iter().zip(vals) {
466        var.set(t)?;
467    }
468    Ok(())
469}
470
471#[cfg(test)]
472mod tests {
473    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
474
475    use crate::Model;
476    use anyhow::Result;
477    use assert_approx_eq::assert_approx_eq;
478    use candle_core::Device;
479    use candle_core::{Module, Result as CResult};
480    pub struct LinearModel {
481        linear: candle_nn::Linear,
482        xs: Tensor,
483        ys: Tensor,
484    }
485
486    impl Model for LinearModel {
487        fn loss(&self) -> CResult<Tensor> {
488            let preds = self.forward(&self.xs)?;
489            let loss = candle_nn::loss::mse(&preds, &self.ys)?;
490            Ok(loss)
491        }
492    }
493
494    impl LinearModel {
495        fn new() -> CResult<(Self, Vec<Var>)> {
496            let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?;
497            let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?;
498
499            let linear =
500                candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone()));
501
502            Ok((
503                Self {
504                    linear,
505                    xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?,
506                    ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?,
507                },
508                vec![weight, bias],
509            ))
510        }
511
512        fn forward(&self, xs: &Tensor) -> CResult<Tensor> {
513            self.linear.forward(xs)
514        }
515    }
516
517    use super::*;
518    #[test]
519    fn lr_test() -> Result<()> {
520        let params = ParamsLBFGS {
521            lr: 0.004,
522            ..Default::default()
523        };
524        let (model, vars) = LinearModel::new()?;
525        let mut lbfgs = Lbfgs::new(vars, params, &model)?;
526        assert_approx_eq!(0.004, lbfgs.learning_rate());
527        lbfgs.set_learning_rate(0.002);
528        assert_approx_eq!(0.002, lbfgs.learning_rate());
529        Ok(())
530    }
531
532    #[test]
533    fn into_inner_test() -> Result<()> {
534        let params = ParamsLBFGS {
535            lr: 0.004,
536            ..Default::default()
537        };
538        // Now use backprop to run a linear regression between samples and get the coefficients back.
539
540        let (model, vars) = LinearModel::new()?;
541        let slice: Vec<&Var> = vars.iter().collect();
542        let lbfgs = Lbfgs::from_slice(&slice, params, &model)?;
543        let inner = lbfgs.into_inner();
544
545        assert_eq!(inner[0].as_tensor().to_vec1::<f64>()?, &[3f64, 1.]);
546        println!("checked weights");
547        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f64>()?, -2_f64);
548        Ok(())
549    }
550}