candle_optimisers/
lbfgs.rs

1/*!
2
3Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm
4
5A pseudo second order optimiser based on the BFGS method.
6
7Described in [On the limited memory BFGS method for large scale optimization](https://link.springer.com/article/10.1007/BF01589116)
8
9For a history of size $n$, assume we have stored the last $n$ updates in form $s_{k} = x_{k+1} - x_{k}$ and $y_{k} = g_{k+1}-g_{k}$ where $g_{k} = \\nabla f(x_{k})$.
10We use a two loop recursion method to compute the direction of descent:
11
12$$
13\\begin{aligned}
14    &q = g_k\\\\
15    &// \\texttt{ Iterate over history from newest to oldest}\\\\
16    &\\mathbf{For}\\ i=k-1 \\: \\mathbf{to}\\: k-n \\: \\mathbf{do}\\\\
17    &\\hspace{5mm}\\rho_{i} = \\frac{1}{y_{i}^{\\top} s_{i}} \\\\
18    &\\hspace{5mm} \\alpha_i = \\rho_i s^\\top_i q\\\\
19    &\\hspace{5mm} q = q - \\alpha_i y_i\\\\
20    &\\gamma_k = \\frac{s_{k - 1}^{\\top} y_{k - 1}}{y_{k - 1}^{\\top} y_{k - 1}} \\\\
21    &q = \\gamma_{k} q\\\\
22    &// \\texttt{ Iterate over history from oldest to newest}\\\\
23    &\\mathbf{For}\\ i=k-n \\: \\mathbf{to}\\: k-1 \\: \\mathbf{do}\\\\
24    &\\hspace{5mm} \\beta_i = \\rho_i y^\\top_i q\\\\
25    &\\hspace{5mm} q = q + s_i (\\alpha_i - \\beta_i)\\\\
26    &q = -q
27\\end{aligned}
28$$
29*/
30
31//<https://sagecal.sourceforge.net/pytorch/index.html> possible extensions
32
33use crate::{LossOptimizer, Model, ModelOutcome};
34use candle_core::Result as CResult;
35use candle_core::{Tensor, Var};
36use log::info;
37use std::collections::VecDeque;
38// use candle_nn::optim::Optimizer;
39
40mod strong_wolfe;
41
42/// Line search method
43/// Only Strong Wolfe is currently implemented
44#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
45#[non_exhaustive]
46pub enum LineSearch {
47    /// strong wolfe line search: c1, c2, tolerance
48    /// suggested vals for c1 and c2: 1e-4, 0.9, for tolerance 1e-9
49    ///
50    /// Ensures the Strong Wolfe conditions are met for step size $t$ in direction $\\bm{d}$:
51    ///
52    /// Armijo rule:
53    /// $$ f(x + t \\bm{d}) \\leq f(x) + c_1 t \\bm{d}^T \\nabla f(x)  $$
54    ///
55    /// and
56    ///
57    ///  Strong Curvature Condition:
58    /// $$ |\\bm{d}^{T} \\nabla f(x + t \\bm{d})| \\leq c_{2} |\\bm{d}^{T} \\nabla f(x)| $$
59    StrongWolfe(f64, f64, f64),
60}
61
62/// Conditions for terminsation based on gradient
63#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
64#[non_exhaustive]
65pub enum GradConv {
66    /// convergence based on max abs component of gradient
67    MinForce(f64),
68    /// convergence based on mean force
69    RMSForce(f64),
70}
71
72/// Conditions for termination based on step size
73#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
74#[non_exhaustive]
75pub enum StepConv {
76    /// convergence based on max abs component of step
77    MinStep(f64),
78    /// convergence based on root mean size of step
79    RMSStep(f64),
80}
81
82/// Parameters for LBFGS optimiser
83#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
84pub struct ParamsLBFGS {
85    /// 'Learning rate': used for initial step size guess
86    /// and when no line search is used
87    pub lr: f64,
88    /// size of history to retain
89    pub history_size: usize,
90    /// linesearch method to use
91    pub line_search: Option<LineSearch>,
92    /// convergence criteria for gradient
93    pub grad_conv: GradConv,
94    /// convergence criteria for step size
95    pub step_conv: StepConv,
96    /// weight decay
97    pub weight_decay: Option<f64>,
98}
99
100impl Default for ParamsLBFGS {
101    fn default() -> Self {
102        Self {
103            lr: 1.,
104            // max_iter: 20,
105            // max_eval: None,
106            history_size: 100,
107            line_search: None,
108            grad_conv: GradConv::MinForce(1e-7),
109            step_conv: StepConv::MinStep(1e-9),
110            weight_decay: None,
111        }
112    }
113}
114
115/// LBFGS optimiser
116///
117/// A pseudo second order optimiser based on the BFGS method.
118///
119/// Described in [On the limited memory BFGS method for large scale optimization](https://link.springer.com/article/10.1007/BF01589116)
120///
121/// <https://sagecal.sourceforge.net/pytorch/index.html>
122#[derive(Debug)]
123pub struct Lbfgs<M: Model> {
124    vars: Vec<Var>,
125    model: M,
126    s_hist: VecDeque<(Tensor, Tensor)>,
127    last_grad: Option<Var>,
128    next_grad: Option<Var>,
129    last_step: Option<Var>,
130    params: ParamsLBFGS,
131    first: bool,
132}
133
134impl<M: Model> LossOptimizer<M> for Lbfgs<M> {
135    type Config = ParamsLBFGS;
136
137    fn new(vs: Vec<Var>, params: Self::Config, model: M) -> CResult<Self> {
138        let hist_size = params.history_size;
139        Ok(Lbfgs {
140            vars: vs,
141            model,
142            s_hist: VecDeque::with_capacity(hist_size),
143            last_step: None,
144            last_grad: None,
145            next_grad: None,
146            params,
147            first: true,
148        })
149    }
150
151    #[allow(clippy::too_many_lines)]
152    fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome> {
153        let mut evals = 1;
154
155        let grad = if let Some(this_grad) = &self.next_grad {
156            this_grad.as_tensor().copy()?
157        } else {
158            flat_grads(&self.vars, loss, self.params.weight_decay)?
159        };
160
161        match self.params.grad_conv {
162            GradConv::MinForce(tol) => {
163                if grad
164                    .abs()?
165                    .max(0)?
166                    .to_dtype(candle_core::DType::F64)?
167                    .to_scalar::<f64>()?
168                    < tol
169                {
170                    info!("grad converged");
171                    return Ok(ModelOutcome::Converged(loss.clone(), evals));
172                }
173            }
174            GradConv::RMSForce(tol) => {
175                if grad
176                    .sqr()?
177                    .mean_all()?
178                    .to_dtype(candle_core::DType::F64)?
179                    .to_scalar::<f64>()?
180                    .sqrt()
181                    < tol
182                {
183                    info!("grad converged");
184                    return Ok(ModelOutcome::Converged(loss.clone(), evals));
185                }
186            }
187        }
188
189        let mut yk = None;
190
191        if let Some(last) = &self.last_grad {
192            yk = Some((&grad - last.as_tensor())?);
193            last.set(&grad)?;
194        } else {
195            self.last_grad = Some(Var::from_tensor(&grad)?);
196        }
197
198        let q = Var::from_tensor(&grad)?;
199
200        let hist_size = self.s_hist.len();
201
202        if hist_size == self.params.history_size {
203            self.s_hist.pop_front();
204        }
205        if let Some(yk) = yk {
206            if let Some(step) = &self.last_step {
207                self.s_hist.push_back((step.as_tensor().clone(), yk));
208            }
209        }
210
211        let gamma = if let Some((s, y)) = self.s_hist.back() {
212            let numr = y
213                .unsqueeze(0)?
214                .matmul(&(s.unsqueeze(1)?))?
215                .to_dtype(candle_core::DType::F64)?
216                .squeeze(1)?
217                .squeeze(0)?
218                .to_scalar::<f64>()?;
219
220            let denom = y
221                .unsqueeze(0)?
222                .matmul(&(y.unsqueeze(1)?))?
223                .to_dtype(candle_core::DType::F64)?
224                .squeeze(1)?
225                .squeeze(0)?
226                .to_scalar::<f64>()?
227                + 1e-10;
228
229            numr / denom
230        } else {
231            1.
232        };
233
234        let mut rhos = VecDeque::with_capacity(hist_size);
235        let mut alphas = VecDeque::with_capacity(hist_size);
236        for (s, y) in self.s_hist.iter().rev() {
237            let rho = (y
238                .unsqueeze(0)?
239                .matmul(&(s.unsqueeze(1)?))?
240                .to_dtype(candle_core::DType::F64)?
241                .squeeze(1)?
242                .squeeze(0)?
243                .to_scalar::<f64>()?
244                + 1e-10)
245                .powi(-1);
246
247            let alpha = rho
248                * s.unsqueeze(0)?
249                    .matmul(&(q.unsqueeze(1)?))?
250                    .to_dtype(candle_core::DType::F64)?
251                    .squeeze(1)?
252                    .squeeze(0)?
253                    .to_scalar::<f64>()?;
254
255            q.set(&q.sub(&(y * alpha)?)?)?;
256            // we are iterating in reverse and so want to insert at the front of the VecDeque
257            alphas.push_front(alpha);
258            rhos.push_front(rho);
259        }
260
261        // z = q * gamma so use interior mutability of q to set it
262        q.set(&(q.as_tensor() * gamma)?)?;
263        for (((s, y), alpha), rho) in self
264            .s_hist
265            .iter()
266            .zip(alphas.into_iter())
267            .zip(rhos.into_iter())
268        {
269            let beta = rho
270                * y.unsqueeze(0)?
271                    .matmul(&(q.unsqueeze(1)?))?
272                    .to_dtype(candle_core::DType::F64)?
273                    .squeeze(1)?
274                    .squeeze(0)?
275                    .to_scalar::<f64>()?;
276
277            q.set(&q.add(&(s * (alpha - beta))?)?)?;
278        }
279
280        // let dd = (&grad * q.as_tensor())?.sum_all()?;
281        let dd = grad
282            .unsqueeze(0)?
283            .matmul(&(q.unsqueeze(1)?))?
284            .to_dtype(candle_core::DType::F64)?
285            .squeeze(1)?
286            .squeeze(0)?
287            .to_scalar::<f64>()?;
288
289        let mut lr = if self.first {
290            self.first = false;
291            -(1_f64.min(
292                1. / grad
293                    .abs()?
294                    .sum_all()?
295                    .to_dtype(candle_core::DType::F64)?
296                    .to_scalar::<f64>()?,
297            )) * self.params.lr
298        } else {
299            -self.params.lr
300        };
301
302        if let Some(ls) = &self.params.line_search {
303            match ls {
304                LineSearch::StrongWolfe(c1, c2, tol) => {
305                    let (loss, grad, t, steps) =
306                        self.strong_wolfe(lr, &q, loss, &grad, dd, *c1, *c2, *tol, 25)?;
307                    if let Some(next_grad) = &self.next_grad {
308                        next_grad.set(&grad)?;
309                    } else {
310                        self.next_grad = Some(Var::from_tensor(&grad)?);
311                    }
312
313                    evals += steps;
314                    lr = t;
315                    q.set(&(q.as_tensor() * lr)?)?;
316
317                    if let Some(step) = &self.last_step {
318                        step.set(&q)?;
319                    } else {
320                        self.last_step = Some(Var::from_tensor(&q)?);
321                    }
322
323                    match self.params.step_conv {
324                        StepConv::MinStep(tol) => {
325                            if q.abs()?
326                                .max(0)?
327                                .to_dtype(candle_core::DType::F64)?
328                                .to_scalar::<f64>()?
329                                < tol
330                            {
331                                add_grad(&mut self.vars, q.as_tensor())?;
332                                info!("step converged");
333                                Ok(ModelOutcome::Converged(loss, evals))
334                            } else {
335                                add_grad(&mut self.vars, q.as_tensor())?;
336                                Ok(ModelOutcome::Stepped(loss, evals))
337                            }
338                        }
339                        StepConv::RMSStep(tol) => {
340                            if q.sqr()?
341                                .mean_all()?
342                                .to_dtype(candle_core::DType::F64)?
343                                .to_scalar::<f64>()?
344                                .sqrt()
345                                < tol
346                            {
347                                add_grad(&mut self.vars, q.as_tensor())?;
348                                info!("step converged");
349                                Ok(ModelOutcome::Converged(loss, evals))
350                            } else {
351                                add_grad(&mut self.vars, q.as_tensor())?;
352                                Ok(ModelOutcome::Stepped(loss, evals))
353                            }
354                        }
355                    }
356                }
357            }
358        } else {
359            q.set(&(q.as_tensor() * lr)?)?;
360
361            if let Some(step) = &self.last_step {
362                step.set(&q)?;
363            } else {
364                self.last_step = Some(Var::from_tensor(&q)?);
365            }
366
367            match self.params.step_conv {
368                StepConv::MinStep(tol) => {
369                    if q.abs()?
370                        .max(0)?
371                        .to_dtype(candle_core::DType::F64)?
372                        .to_scalar::<f64>()?
373                        < tol
374                    {
375                        add_grad(&mut self.vars, q.as_tensor())?;
376
377                        let next_loss = self.model.loss()?;
378                        evals += 1;
379                        info!("step converged");
380                        Ok(ModelOutcome::Converged(next_loss, evals))
381                    } else {
382                        add_grad(&mut self.vars, q.as_tensor())?;
383
384                        let next_loss = self.model.loss()?;
385                        evals += 1;
386                        Ok(ModelOutcome::Stepped(next_loss, evals))
387                    }
388                }
389                StepConv::RMSStep(tol) => {
390                    if q.sqr()?
391                        .mean_all()?
392                        .to_dtype(candle_core::DType::F64)?
393                        .to_scalar::<f64>()?
394                        .sqrt()
395                        < tol
396                    {
397                        add_grad(&mut self.vars, q.as_tensor())?;
398
399                        let next_loss = self.model.loss()?;
400                        evals += 1;
401                        info!("step converged");
402                        Ok(ModelOutcome::Converged(next_loss, evals))
403                    } else {
404                        add_grad(&mut self.vars, q.as_tensor())?;
405
406                        let next_loss = self.model.loss()?;
407                        evals += 1;
408                        Ok(ModelOutcome::Stepped(next_loss, evals))
409                    }
410                }
411            }
412        }
413    }
414
415    fn learning_rate(&self) -> f64 {
416        self.params.lr
417    }
418
419    fn set_learning_rate(&mut self, lr: f64) {
420        self.params.lr = lr;
421    }
422
423    #[must_use]
424    fn into_inner(self) -> Vec<Var> {
425        self.vars
426    }
427}
428
429#[allow(clippy::inline_always)]
430#[inline(always)]
431fn flat_grads(vs: &Vec<Var>, loss: &Tensor, weight_decay: Option<f64>) -> CResult<Tensor> {
432    let grads = loss.backward()?;
433    let mut flat_grads = Vec::with_capacity(vs.len());
434    if let Some(wd) = weight_decay {
435        for v in vs {
436            if let Some(grad) = grads.get(v) {
437                let grad = &(grad + (wd * v.as_tensor())?)?;
438                flat_grads.push(grad.flatten_all()?);
439            } else {
440                let grad = (wd * v.as_tensor())?; // treat as if grad were 0
441                flat_grads.push(grad.flatten_all()?);
442            }
443        }
444    } else {
445        for v in vs {
446            if let Some(grad) = grads.get(v) {
447                flat_grads.push(grad.flatten_all()?);
448            } else {
449                let n_elems = v.elem_count();
450                flat_grads.push(candle_core::Tensor::zeros(n_elems, v.dtype(), v.device())?);
451            }
452        }
453    }
454    candle_core::Tensor::cat(&flat_grads, 0)
455}
456
457fn add_grad(vs: &mut Vec<Var>, flat_tensor: &Tensor) -> CResult<()> {
458    let mut offset = 0;
459    for var in vs {
460        let n_elems = var.elem_count();
461        let tensor = flat_tensor
462            .narrow(0, offset, n_elems)?
463            .reshape(var.shape())?;
464        var.set(&var.add(&tensor)?)?;
465        offset += n_elems;
466    }
467    Ok(())
468}
469
470fn set_vs(vs: &mut [Var], vals: &Vec<Tensor>) -> CResult<()> {
471    for (var, t) in vs.iter().zip(vals) {
472        var.set(t)?;
473    }
474    Ok(())
475}
476
477#[cfg(test)]
478mod tests {
479    // use candle_core::test_utils::{to_vec0_round, to_vec2_round};
480
481    use crate::Model;
482    use anyhow::Result;
483    use assert_approx_eq::assert_approx_eq;
484    use candle_core::Device;
485    use candle_core::{Module, Result as CResult};
486    pub struct LinearModel {
487        linear: candle_nn::Linear,
488        xs: Tensor,
489        ys: Tensor,
490    }
491
492    impl Model for LinearModel {
493        fn loss(&self) -> CResult<Tensor> {
494            let preds = self.forward(&self.xs)?;
495            let loss = candle_nn::loss::mse(&preds, &self.ys)?;
496            Ok(loss)
497        }
498    }
499
500    impl LinearModel {
501        fn new() -> CResult<(Self, Vec<Var>)> {
502            let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?;
503            let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?;
504
505            let linear =
506                candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone()));
507
508            Ok((
509                Self {
510                    linear,
511                    xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?,
512                    ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?,
513                },
514                vec![weight, bias],
515            ))
516        }
517
518        fn forward(&self, xs: &Tensor) -> CResult<Tensor> {
519            self.linear.forward(xs)
520        }
521    }
522
523    use super::*;
524    #[test]
525    fn lr_test() -> Result<()> {
526        let params = ParamsLBFGS {
527            lr: 0.004,
528            ..Default::default()
529        };
530        let (model, vars) = LinearModel::new()?;
531        let mut lbfgs = Lbfgs::new(vars, params, model)?;
532        assert_approx_eq!(0.004, lbfgs.learning_rate());
533        lbfgs.set_learning_rate(0.002);
534        assert_approx_eq!(0.002, lbfgs.learning_rate());
535        Ok(())
536    }
537
538    #[test]
539    fn into_inner_test() -> Result<()> {
540        let params = ParamsLBFGS {
541            lr: 0.004,
542            ..Default::default()
543        };
544        // Now use backprop to run a linear regression between samples and get the coefficients back.
545
546        let (model, vars) = LinearModel::new()?;
547        let slice: Vec<&Var> = vars.iter().collect();
548        let lbfgs = Lbfgs::from_slice(&slice, params, model)?;
549        let inner = lbfgs.into_inner();
550
551        assert_eq!(inner[0].as_tensor().to_vec1::<f64>()?, &[3f64, 1.]);
552        println!("checked weights");
553        assert_approx_eq!(inner[1].as_tensor().to_vec0::<f64>()?, -2_f64);
554        Ok(())
555    }
556}