Skip to main content

burn_optim/optim/
lbfgs.rs

1#![allow(clippy::excessive_precision)]
2
3use burn_core as burn;
4
5use super::GradientsParams;
6use crate::LearningRate;
7use burn::config::Config;
8use burn::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
9use burn::prelude::ToElement;
10use burn::record::Record;
11use burn::tensor::backend::Backend;
12use burn::tensor::{Tensor, backend::AutodiffBackend};
13use serde::{Deserialize, Serialize};
14
15use alloc::vec;
16use alloc::vec::Vec;
17#[cfg(not(feature = "std"))]
18#[allow(unused_imports)]
19use num_traits::Float as _;
20
21/// Cubic Interpolate
22///
23/// Uses two points (x1, f1), (x2, f2) and their first derivatives g1,g2 to construct
24/// a cubic interpolant and return its minimum within the given bounds.
25fn cubic_interpolate(
26    x1: f64,
27    f1: f64,
28    g1: f64,
29    x2: f64,
30    f2: f64,
31    g2: f64,
32    bounds: Option<(f64, f64)>,
33) -> f64 {
34    // Compute bounds of interpolation area
35    let (min_bound, max_bound) = bounds.unwrap_or(if x1 <= x2 { (x1, x2) } else { (x2, x1) });
36    // Code for most common case: cubic interpolation of 2 points
37    // with function and derivative values for both
38    // Solution in this case (where x2 is the farthest point)
39    // d1 = g1 + g2 - 3*(f1 - f2) / (x1-x2);
40    // d2 = sqrt(d1^2 - g1 * g2);
41    // min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
42    // t_new = min(max(min_pos,min_bound), max_bound);
43    let d1 = g1 + g2 - 3.0 * (f1 - f2) / (x1 - x2);
44    let d2_square = d1 * d1 - g1 * g2;
45
46    if d2_square >= 0.0 {
47        let d2 = d2_square.sqrt();
48        let min_pos = if x1 <= x2 {
49            x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2.0 * d2))
50        } else {
51            x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2.0 * d2))
52        };
53        min_pos.max(min_bound).min(max_bound)
54    } else {
55        (min_bound + max_bound) / 2.0
56    }
57}
58/// Auxiliary Struct For Strong_Wolfe
59struct LineSearchSample<B: Backend> {
60    // step size
61    t: f64,
62    // loss
63    f: f64,
64    // gradient
65    g: Tensor<B, 1>,
66    // directional derivative
67    gtd: f64,
68}
69
70#[allow(clippy::too_many_arguments)]
71fn strong_wolfe<B: Backend, F>(
72    // obj_func(x,step size,direction) -> (loss,grad)
73    obj_func: &mut F,
74    x: &Tensor<B, 1>,
75    // initial step size
76    mut t: f64,
77    d: &Tensor<B, 1>,
78    f: f64,
79    g: Tensor<B, 1>,
80    gtd: f64,
81    c1: f64,
82    c2: f64,
83    tolerance_change: f64,
84    max_ls: usize,
85) -> (f64, Tensor<B, 1>, f64, usize)
86where
87    F: FnMut(&Tensor<B, 1>, f64, &Tensor<B, 1>) -> (f64, Tensor<B, 1>),
88{
89    let d_norm = d.clone().abs().max().into_scalar().to_f64();
90
91    // evaluate objective and gradient using initial step
92    let (mut f_new, mut g_new) = obj_func(x, t, d);
93    let mut ls_func_evals = 1;
94    let mut gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
95
96    // bracket an interval [t_prev,t] containing a point satisfying the Wolfe criteria
97    let (mut t_prev, mut f_prev, mut g_prev, mut gtd_prev) = (0.0, f, g.clone(), gtd);
98    let mut done = false;
99    let mut ls_iter = 0;
100
101    // the interval [low,high] using for Zoom phase
102    let mut bracket: Option<[LineSearchSample<B>; 2]> = None;
103    // point which satisfy the wolfe condition
104    let mut wolfe_bracket: Option<LineSearchSample<B>> = None;
105    while ls_iter < max_ls {
106        // Checking Conditions.
107
108        // Checking the Armijo Condition and function value increasing condition.
109        // Armijo: f(x+t*d) <= f(x) + c_1 t gtd
110        if f_new > (f + c1 * t * gtd) || (ls_iter > 1 && f_new >= f_prev) {
111            bracket = Some([
112                LineSearchSample {
113                    t: t_prev,
114                    f: f_prev,
115                    g: g_prev,
116                    gtd: gtd_prev,
117                },
118                LineSearchSample {
119                    t,
120                    f: f_new,
121                    g: g_new.clone(),
122                    gtd: gtd_new,
123                },
124            ]);
125            break;
126        }
127
128        // Checking Strong Wolfe Condition
129        // |gtd_new| <= -c_2 gtd
130        if gtd_new.abs() <= -c2 * gtd {
131            wolfe_bracket = Some(LineSearchSample {
132                t,
133                f: f_new,
134                g: g_new.clone(),
135                gtd: gtd_new,
136            });
137            done = true;
138            break;
139        }
140
141        // gtd_new >=0 , there must be a local minimum in the interval.
142        if gtd_new >= 0.0 {
143            bracket = Some([
144                LineSearchSample {
145                    t: t_prev,
146                    f: f_prev,
147                    g: g_prev,
148                    gtd: gtd_prev,
149                },
150                LineSearchSample {
151                    t,
152                    f: f_new,
153                    g: g_new.clone(),
154                    gtd: gtd_new,
155                },
156            ]);
157            break;
158        }
159
160        // interpolate
161        let min_step = t + 0.01 * (t - t_prev);
162        let max_step = t * 10.0;
163        let t_next = cubic_interpolate(
164            t_prev,
165            f_prev,
166            gtd_prev,
167            t,
168            f_new,
169            gtd_new,
170            Some((min_step, max_step)),
171        );
172        t_prev = t;
173        f_prev = f_new;
174        g_prev = g_new;
175        gtd_prev = gtd_new;
176
177        // next step
178        t = t_next;
179        (f_new, g_new) = obj_func(x, t, d);
180        ls_func_evals += 1;
181        gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
182        ls_iter += 1;
183    }
184    if let Some(sample) = wolfe_bracket {
185        return (sample.f, sample.g, sample.t, ls_func_evals);
186    }
187
188    let mut bracket = bracket.unwrap_or_else(|| {
189        [
190            LineSearchSample {
191                t: 0.0,
192                f,
193                g: g.clone(),
194                gtd,
195            },
196            LineSearchSample {
197                t,
198                f: f_new,
199                g: g_new.clone(),
200                gtd: gtd_new,
201            },
202        ]
203    });
204
205    // zoom phase
206    let mut insuf_progress = false;
207
208    // find high and low points in bracket
209    let (mut low_idx, mut high_idx) = if bracket[0].f <= bracket[1].f {
210        (0, 1)
211    } else {
212        (1, 0)
213    };
214
215    while !done && ls_iter < max_ls {
216        let diff = (bracket[1].t - bracket[0].t).abs();
217        // line-search bracket is so small
218        if diff * d_norm < tolerance_change {
219            break;
220        }
221
222        // compute new trial value
223        t = cubic_interpolate(
224            bracket[0].t,
225            bracket[0].f,
226            bracket[0].gtd,
227            bracket[1].t,
228            bracket[1].f,
229            bracket[1].gtd,
230            None,
231        );
232
233        let b_min = bracket[0].t.min(bracket[1].t);
234        let b_max = bracket[0].t.max(bracket[1].t);
235        let eps = 0.1 * (b_max - b_min);
236
237        if (b_max - t).min(t - b_min) < eps {
238            // interpolation close to boundary
239            if insuf_progress || t >= b_max || t <= b_min {
240                t = if (t - b_max).abs() < (t - b_min).abs() {
241                    b_max - eps
242                } else {
243                    b_min + eps
244                };
245                insuf_progress = false;
246            } else {
247                insuf_progress = true;
248            }
249        } else {
250            insuf_progress = false;
251        }
252
253        // Evaluate new point
254        (f_new, g_new) = obj_func(x, t, d);
255
256        ls_func_evals += 1;
257        gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
258        ls_iter += 1;
259
260        let armijo_holds = f_new <= (f + c1 * t * gtd) && f_new < bracket[low_idx].f;
261
262        if !armijo_holds {
263            bracket[high_idx] = LineSearchSample {
264                t,
265                f: f_new,
266                g: g_new,
267                gtd: gtd_new,
268            };
269        } else {
270            if gtd_new.abs() <= -c2 * gtd {
271                return (f_new, g_new, t, ls_func_evals);
272            }
273
274            if gtd_new * (bracket[high_idx].t - bracket[low_idx].t) >= 0.0 {
275                bracket[high_idx] = LineSearchSample {
276                    t: bracket[low_idx].t,
277                    f: bracket[low_idx].f,
278                    g: bracket[low_idx].g.clone(),
279                    gtd: bracket[low_idx].gtd,
280                };
281            }
282            bracket[low_idx] = LineSearchSample {
283                t,
284                f: f_new,
285                g: g_new,
286                gtd: gtd_new,
287            };
288        }
289
290        if bracket[0].f <= bracket[1].f {
291            low_idx = 0;
292            high_idx = 1;
293        } else {
294            low_idx = 1;
295            high_idx = 0;
296        }
297    }
298    // return stuff
299    (
300        bracket[low_idx].f,
301        bracket[low_idx].g.clone(),
302        bracket[low_idx].t,
303        ls_func_evals,
304    )
305}
306
307/// Strategy for the line search optimization phase
308#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
309pub enum LineSearchFn {
310    /// No line search performed
311    #[default]
312    None,
313    /// strong wolfe conditions
314    ///
315    /// See: <https://en.wikipedia.org/wiki/Wolfe_conditions>
316    StrongWolfe,
317}
318
319/// LBFGS Configuration.
320#[derive(Config, Debug)]
321pub struct LBFGSConfig {
322    /// Maximal number of iterations per optimization step (default: 20)
323    #[config(default = 20)]
324    pub max_iter: usize,
325    /// Update history size (default: 100).
326    #[config(default = 100)]
327    pub history_size: usize,
328    /// Termination tolerance on first order optimality (default: 1e-7).
329    #[config(default = 1e-7)]
330    pub tolerance_grad: f64,
331    /// Termination tolerance on function value/parameter changes (default: 1e-9).
332    #[config(default = 1e-9)]
333    pub tolerance_change: f64,
334    /// Maximal number of function evaluations per optimization step (default: max_iter * 1.25).
335    #[config(default = "None")]
336    pub max_eval: Option<usize>,
337    /// Either ‘strong_wolfe’ or None (default: None).
338    #[config(default = "LineSearchFn::None")]
339    pub line_search_fn: LineSearchFn,
340}
341
342impl LBFGSConfig {
343    /// Initialize AdamW optimizer
344    ///
345    /// # Returns
346    ///
347    /// Returns an optimizer that can be used to optimize a module
348    pub fn init<B: AutodiffBackend>(&self) -> LBFGS<B> {
349        // by default max_eval = max_iter * 5/4
350        let max_eval = self.max_eval.unwrap_or(self.max_iter * 5 / 4);
351        LBFGS {
352            config: LBFGSConfig {
353                max_iter: self.max_iter,
354                history_size: self.history_size,
355                tolerance_grad: self.tolerance_grad,
356                tolerance_change: self.tolerance_change,
357                max_eval: Some(max_eval),
358                line_search_fn: self.line_search_fn,
359            },
360            state: Default::default(),
361        }
362    }
363}
364
365/// Collects gradients in module visit order.
366struct FlattenGradsVisitorInner<'a, B: AutodiffBackend> {
367    grads: &'a GradientsParams,
368    tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
369}
370
371impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenGradsVisitorInner<'_, B> {
372    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
373        if let Some(g) = self.grads.get::<B::InnerBackend, D>(param.id) {
374            let numel = g.shape().num_elements();
375            self.tensors.push(g.reshape([numel]));
376        }
377    }
378}
379
380/// Flatten params to inner backend 1D tensor.
381fn flatten_params_inner<B: AutodiffBackend, M: Module<B>>(
382    module: &M,
383) -> Tensor<B::InnerBackend, 1> {
384    let mut tensors = Vec::new();
385    let mut visitor = FlattenParamsVisitorInner::<B> {
386        tensors: &mut tensors,
387    };
388    module.visit(&mut visitor);
389    if tensors.is_empty() {
390        return Tensor::empty([0], &module.devices()[0]);
391    }
392    Tensor::cat(tensors, 0)
393}
394
395struct FlattenParamsVisitorInner<'a, B: AutodiffBackend> {
396    tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
397}
398
399impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenParamsVisitorInner<'_, B> {
400    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
401        let t = param.val().inner();
402        let numel = t.shape().num_elements();
403        self.tensors.push(t.reshape([numel]));
404    }
405}
406
407/// Flatten gradients for a module.
408fn flatten_grads_inner<B: AutodiffBackend, M: Module<B>>(
409    module: &M,
410    grads: &GradientsParams,
411) -> Tensor<B::InnerBackend, 1> {
412    let mut tensors = Vec::new();
413    let mut visitor = FlattenGradsVisitorInner {
414        grads,
415        tensors: &mut tensors,
416    };
417    module.visit(&mut visitor);
418    if tensors.is_empty() {
419        return Tensor::empty([0], &module.devices()[0]);
420    }
421    Tensor::cat(tensors, 0)
422}
423
424/// Mapper that assigns each float param from a flat inner-backend 1D tensor.
425struct ParamsFromFlatMapperInner<'a, B: AutodiffBackend> {
426    flat: &'a Tensor<B::InnerBackend, 1>,
427    offset: &'a mut usize,
428}
429
430impl<B: AutodiffBackend> ParamsFromFlatMapperInner<'_, B> {
431    fn take_slice(&mut self, numel: usize) -> Tensor<B::InnerBackend, 1> {
432        let start = *self.offset;
433        *self.offset += numel;
434        self.flat.clone().slice(start..*self.offset)
435    }
436}
437
438impl<B: AutodiffBackend> ModuleMapper<B> for ParamsFromFlatMapperInner<'_, B> {
439    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
440        let (id, tensor, mapper) = param.consume();
441        let numel = tensor.shape().num_elements();
442        let slice_1d = self.take_slice(numel);
443        let new_inner = slice_1d.reshape(tensor.shape());
444        let new_tensor = Tensor::from_inner(new_inner).require_grad();
445        Param::from_mapped_value(id, new_tensor, mapper)
446    }
447}
448
449/// Overwrite module parameters from a flat inner-backend 1D tensor
450fn set_params_from_flat_inner<B: AutodiffBackend, M: Module<B>>(
451    module: M,
452    flat: Tensor<B::InnerBackend, 1>,
453) -> M {
454    let mut offset = 0;
455    let mut mapper = ParamsFromFlatMapperInner {
456        flat: &flat,
457        offset: &mut offset,
458    };
459    module.map(&mut mapper)
460}
461
462/// L-BFGS optimizer state
463#[derive(Clone, Record)]
464pub struct LBFGSState<B: Backend> {
465    /// Historical displacement vectors
466    pub history_s: Vec<Tensor<B, 1>>,
467    /// Historical gradient difference vectors
468    pub history_y: Vec<Tensor<B, 1>>,
469    /// Search direction
470    pub d: Option<Tensor<B, 1>>,
471    /// Step size from the previous iteration
472    pub t: Option<f64>,
473    /// Flattened gradient from the previous iteration
474    pub prev_flat_grad: Option<Tensor<B, 1>>,
475    /// Loss value from the previous iteration
476    pub prev_loss: Option<f64>,
477    /// Global iteration count
478    pub g_iter: usize,
479}
480
481impl<B: Backend> LBFGSState<B> {
482    /// Moves all historical tensors to the target device.
483    pub fn to_device(self, device: &B::Device) -> Self {
484        Self {
485            history_s: self
486                .history_s
487                .into_iter()
488                .map(|t| t.to_device(device))
489                .collect(),
490            history_y: self
491                .history_y
492                .into_iter()
493                .map(|t| t.to_device(device))
494                .collect(),
495            d: self.d.map(|t| t.to_device(device)),
496            t: self.t,
497            prev_flat_grad: self.prev_flat_grad.map(|t| t.to_device(device)),
498            prev_loss: self.prev_loss,
499            g_iter: self.g_iter,
500        }
501    }
502}
503impl<B: Backend> Default for LBFGSState<B> {
504    fn default() -> Self {
505        Self {
506            history_s: Vec::new(),
507            history_y: Vec::new(),
508            d: None,
509            t: Some(1.0),
510            prev_flat_grad: None,
511            prev_loss: None,
512            g_iter: 0,
513        }
514    }
515}
516
517/// L-BFGS optimizer.
518///
519/// Ported from [pytorch](https://github.com/pytorch/pytorch/torch/optim/lbfgs.py). Heavily inspired by [miniFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html)
520///
521/// See also:
522/// - [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS)
523///
524/// # Note
525/// This optimizer is memory intensive
526#[derive(Clone)]
527pub struct LBFGS<B: Backend + AutodiffBackend> {
528    config: LBFGSConfig,
529    state: LBFGSState<B::InnerBackend>,
530}
531
532impl<B: Backend + AutodiffBackend> LBFGS<B> {
533    /// A single optimization step for any tensor that represents the parameters of a model.
534    pub fn step<M, F>(&mut self, lr: LearningRate, mut module: M, mut closure: F) -> (M, f64)
535    where
536        M: AutodiffModule<B> + Clone,
537        F: FnMut(M) -> (f64, GradientsParams),
538    {
539        // evaluate initial f(x) and df/dx
540        let (mut loss, grads) = closure(module.clone());
541        let mut current_evals = 1;
542
543        let mut flat_grad = flatten_grads_inner::<B, M>(&module, &grads);
544        let mut x_flat = flatten_params_inner::<B, M>(&module);
545
546        let opt_cond =
547            flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad;
548        // optimal condition
549        if opt_cond {
550            return (module, loss);
551        }
552
553        // tensors cached in state
554        let mut d = self
555            .state
556            .d
557            .take()
558            .unwrap_or_else(|| flat_grad.clone().neg());
559        let mut t = self.state.t.unwrap_or(lr);
560        let mut prev_flat_grad = self.state.prev_flat_grad.take();
561
562        let mut n_iter = 0;
563
564        // optimize for a max of max_iter iterations
565        while n_iter < self.config.max_iter {
566            // keep track of nb of iterations
567            n_iter += 1;
568            self.state.g_iter += 1;
569
570            // compute gradient descent direction
571            if self.state.g_iter == 1 {
572                d = flat_grad.clone().neg();
573                self.state.history_s.clear();
574                self.state.history_y.clear();
575            } else {
576                // do lbfgs update (update memory)
577                if let Some(pg) = prev_flat_grad.as_ref() {
578                    let y = flat_grad.clone().sub(pg.clone());
579                    let s = d.clone().mul_scalar(t);
580
581                    let ys = y.clone().dot(s.clone()).into_scalar().to_f64();
582
583                    if ys > 1e-10 {
584                        // updating memory
585                        if self.state.history_s.len() >= self.config.history_size {
586                            // shift history by one (limited-memory)
587                            self.state.history_s.remove(0);
588                            self.state.history_y.remove(0);
589                        }
590                        self.state.history_s.push(s);
591                        self.state.history_y.push(y);
592                    }
593                }
594
595                // compute the approximate (L-BFGS) inverse Hessian
596                // multiplied by the gradient
597                let num_old = self.state.history_s.len();
598                let mut q = flat_grad.clone().neg();
599                let mut alphas: Vec<Tensor<B::InnerBackend, 1>> =
600                    vec![Tensor::zeros([1], &flat_grad.device()); num_old];
601
602                if num_old > 0 {
603                    // multiply by initial Hessian
604                    // r/d is the final direction
605                    for i in (0..num_old).rev() {
606                        let s = &self.state.history_s[i];
607                        let y = &self.state.history_y[i];
608                        let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
609                        let alpha = rho.clone().mul(s.clone().dot(q.clone()));
610                        alphas[i] = alpha.clone();
611                        q = q.sub(y.clone().mul(alpha));
612                    }
613
614                    let last_s = &self.state.history_s[num_old - 1];
615                    let last_y = &self.state.history_y[num_old - 1];
616                    let ys = last_y.clone().dot(last_s.clone());
617                    let yy = last_y.clone().dot(last_y.clone());
618                    let h_diag = ys.div(yy);
619
620                    let mut r = q.mul(h_diag);
621
622                    for ((s, y), alpha) in self
623                        .state
624                        .history_s
625                        .iter()
626                        .zip(self.state.history_y.iter())
627                        .zip(alphas)
628                        .take(num_old)
629                    {
630                        let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
631
632                        let beta = rho.mul(y.clone().dot(r.clone()));
633
634                        r = r.add(s.clone().mul(alpha.sub(beta)));
635                    }
636                    d = r;
637                } else {
638                    d = q;
639                }
640            }
641
642            prev_flat_grad = Some(flat_grad.clone());
643            let prev_loss_iter = loss;
644
645            // compute step len
646            if self.state.g_iter == 1 {
647                let grad_l1 = flat_grad.clone().abs().sum().into_scalar().to_f64();
648                t = (1.0f64 / grad_l1).min(1.0) * lr;
649            } else {
650                t = lr;
651            }
652
653            // directional derivative
654            let gtd = flat_grad.clone().dot(d.clone()).into_scalar().to_f64();
655
656            if gtd > -self.config.tolerance_change {
657                break;
658            }
659
660            let ls_func_evals;
661
662            if let LineSearchFn::StrongWolfe = self.config.line_search_fn {
663                // perform line search, using user function
664                let mut obj_func =
665                    |current_x: &Tensor<B::InnerBackend, 1>,
666                     step: f64,
667                     dir: &Tensor<B::InnerBackend, 1>| {
668                        let update = dir.clone().mul_scalar(step);
669                        let new_x = current_x.clone().add(update);
670                        let tmp_module = set_params_from_flat_inner::<B, M>(module.clone(), new_x);
671                        let (l, g) = closure(tmp_module);
672                        (l, flatten_grads_inner::<B, M>(&module, &g))
673                    };
674
675                let (ls_f, ls_g, ls_t, evals) = strong_wolfe(
676                    &mut obj_func,
677                    &x_flat,
678                    t,
679                    &d,
680                    loss,
681                    flat_grad.clone(),
682                    gtd,
683                    1e-4,
684                    0.9,
685                    self.config.tolerance_change,
686                    self.config.max_eval.unwrap() - current_evals,
687                );
688
689                loss = ls_f;
690                flat_grad = ls_g;
691                t = ls_t;
692                ls_func_evals = evals;
693
694                x_flat = x_flat.add(d.clone().mul_scalar(t));
695                module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
696            } else {
697                // no line search, simply move with fixed-step
698                let step_vec = d.clone().mul_scalar(t);
699                x_flat = x_flat.add(step_vec);
700                module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
701                // re-evaluate function only if not in last iteration
702                // the reason we do this: in a stochastic setting,
703                // no use to re-evaluate that function here
704                let (new_loss, new_grads) = closure(module.clone());
705                loss = new_loss;
706                flat_grad = flatten_grads_inner::<B, M>(&module, &new_grads);
707                ls_func_evals = 1;
708            }
709
710            // update func eval
711            current_evals += ls_func_evals;
712
713            // check conditions
714
715            if current_evals >= self.config.max_eval.unwrap() {
716                break;
717            }
718
719            if flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad {
720                break;
721            }
722
723            if d.clone().mul_scalar(t).abs().max().into_scalar().to_f64()
724                <= self.config.tolerance_change
725            {
726                break;
727            }
728
729            if (loss - prev_loss_iter).abs() < self.config.tolerance_change {
730                break;
731            }
732        }
733        self.state.d = Some(d);
734        self.state.t = Some(t);
735        self.state.prev_flat_grad = prev_flat_grad;
736        self.state.prev_loss = Some(loss);
737        (module, loss)
738    }
739    /// Moves the optimizer state to the specified device.
740    pub fn to_device(self, device: &B::Device) -> Self {
741        Self {
742            config: self.config,
743            // History tensors reside in InnerBackend, so we convert the device accordingly
744            state: self.state.to_device(device),
745        }
746    }
747}
748
749#[cfg(test)]
750mod tests {
751
752    use super::*;
753    use crate::GradientsParams;
754    use crate::TestAutodiffBackend;
755    use burn::module::{Module, Param};
756    use burn::tensor::{Tensor, TensorData};
757    use burn_nn::{Linear, LinearConfig, LinearRecord};
758
759    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
760        let device = Default::default();
761        let record = LinearRecord {
762            weight: Param::from_data(weight, &device),
763            bias: Some(Param::from_data(bias, &device)),
764        };
765
766        LinearConfig::new(6, 6).init(&device).load_record(record)
767    }
768    #[test]
769    fn test_cubic_interpolate() {
770        let tolerance = 1e-8;
771
772        // basic
773        let (x1, f1, g1, x2, f2, g2) = (-1.0, 1.0, -2.0, 1.0, 1.0, 2.0);
774        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
775        assert!(
776            (result - 0.00000).abs() < tolerance,
777            "Basic: Result {} should be close to 0.0",
778            result
779        );
780
781        // bound
782        let (x1, f1, g1, x2, f2, g2) = (0.0, 0.25, -1.0, 1.0, 0.25, 1.0);
783        let bounds = Some((0.6, 1.0));
784        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds);
785        assert!(
786            (result - 0.6000000000).abs() < tolerance,
787            "Bound: Result {} should be clamped to 0.6",
788            result
789        );
790
791        // d2_square < 0,should return mid value
792        let (x1, f1, g1, x2, f2, g2) = (0.0, 0.0, 10.0, 1.0, 5.0, 10.0);
793        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((0.0, 1.0)));
794        assert!(
795            (result - 0.5000000).abs() < tolerance,
796            "Fallback: Result {} should be midpoint 0.5",
797            result
798        );
799
800        // asymmetric
801        let (x1, f1, g1, x2, f2, g2) = (0.0, 1.0, -5.0, 1.0, 0.5, 1.0);
802        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
803        assert!(
804            (result - 0.4606553370833684).abs() < tolerance,
805            "Asymmetric: Result {} should be 0.4606553370833684",
806            result
807        );
808
809        // not good value
810        let (x1, f1, g1, x2, f2, g2) = (
811            1.231232145,
812            -0.12567458754,
813            9.1231243007,
814            8.239105015,
815            -100.9012398021,
816            123201321.0293982,
817        );
818        let result_1 = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
819        let result_2 = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((-4.4, 4.4)));
820        assert!(
821            (result_1 - 5.9031480234724434).abs() < tolerance,
822            "not good value 1: Result {} should be 5.9031480234724434",
823            result
824        );
825        assert!(
826            (result_2 - 4.4000000000000004).abs() < tolerance,
827            "not good value 2: Result {} should be 4.4000000000000004",
828            result
829        );
830    }
831    #[test]
832    fn test_strong_wolfe_direct_comparison() {
833        let device = Default::default();
834        let tol = 1e-6;
835
836        {
837            let x = Tensor::<TestAutodiffBackend, 1>::from_floats([2.1321912957_f64], &device);
838            let d = Tensor::<TestAutodiffBackend, 1>::from_floats([0.91312321_f64], &device);
839            let t_initial = 1.213132_f64;
840            fn func<B: Backend>(
841                x_base: &Tensor<B, 1>,
842                t_val: f64,
843                d_vec: &Tensor<B, 1>,
844            ) -> (f64, Tensor<B, 1>) {
845                let curr_x = x_base.clone().add(d_vec.clone().mul_scalar(t_val));
846                let x2 = curr_x.clone().mul(curr_x.clone());
847                let x3 = x2.clone().mul(curr_x.clone());
848                let x4 = x2.clone().mul(x2.clone());
849
850                // f(x) = x^4 - 2*x^2 + x
851                let f_elements = x4 - x2.mul_scalar(2.0) + curr_x.clone();
852
853                let f_val = f_elements.sum().into_scalar().to_f64();
854
855                // g(x) = 4*x^3 - 4*x + 1
856                let g = x3.mul_scalar(4.0) - curr_x.clone().mul_scalar(4.0)
857                    + Tensor::ones_like(&curr_x);
858
859                (f_val, g)
860            }
861            let (f_init, g_init) = func(&x, 0.0, &d);
862            let gtd_init = g_init.clone().dot(d.clone()).into_scalar().to_f64();
863            println!("Initial State: f={},gtd = {}", f_init, gtd_init);
864            assert!((f_init - 13.7080059052).abs() < tol);
865            assert!((gtd_init - 28.5305728912).abs() < tol);
866            let mut obj_func =
867                |xb: &Tensor<TestAutodiffBackend, 1>,
868                 tv: f64,
869                 dv: &Tensor<TestAutodiffBackend, 1>| func(xb, tv, dv);
870
871            let (f_final, _g_final, t_final, evals) = strong_wolfe(
872                &mut obj_func,
873                &x,
874                t_initial,
875                &d,
876                f_init,
877                g_init,
878                gtd_init,
879                1e-4, // c1
880                0.9,  // c2
881                1e-9, // tolerance_change
882                10,   // max_ls
883            );
884            let g_f = _g_final.into_scalar().to_f64();
885            println!(
886                "f_final:{:?},_g_final:{:?},t_final:{:?},evals:{:?}",
887                f_final, g_f, t_final, evals
888            );
889            assert!((f_final - 13.708005905151367).abs() < tol);
890            assert!((g_f - 31.2450428009).abs() < tol);
891            assert!((t_final.to_f64() - 0.0).abs() < tol);
892            assert!((evals == 11));
893        }
894    }
895    #[test]
896    fn test_lbfgs_strong_wolfe_comparison() {
897        let device = Default::default();
898        let tol = 1e-5;
899        let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
900        let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
901        let weight = TensorData::from([[0.5f64]]);
902        let bias = TensorData::from([0.1f64]);
903        let module = given_linear_layer(weight, bias);
904
905        let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
906            .with_line_search_fn(LineSearchFn::StrongWolfe)
907            .init();
908        let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
909            let output = mod_in.forward(x_data.clone());
910            let loss = burn_nn::loss::MseLoss::new().forward(
911                output,
912                y_true.clone(),
913                burn_nn::loss::Reduction::Sum,
914            );
915
916            let grads = loss.backward();
917            let grads_params = GradientsParams::from_grads(grads, &mod_in);
918
919            (loss.into_scalar().to_f64(), grads_params)
920        };
921        let initial_loss = closure(module.clone()).0;
922        assert!((initial_loss - 50.1300048828).abs() < tol);
923        let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
924        assert!((final_loss - 0.0234732367).abs() < tol);
925        let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
926        let optimized_bias: f64 = updated_module
927            .bias
928            .as_ref()
929            .unwrap()
930            .val()
931            .into_scalar()
932            .to_f64();
933        assert!((optimized_data - 2.0570652485).abs() < tol);
934        assert!((optimized_bias - 0.8106800914).abs() < tol);
935    }
936    #[test]
937    fn test_lbfgs_no_strong_wolfe_comparison() {
938        let device = Default::default();
939        let tol = 1e-5;
940        let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
941        let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
942        let weight = TensorData::from([[0.5f64]]);
943        let bias = TensorData::from([0.1f64]);
944        let module = given_linear_layer(weight, bias);
945
946        let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
947            .with_line_search_fn(LineSearchFn::None)
948            .init();
949        let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
950            let output = mod_in.forward(x_data.clone());
951            let loss = burn_nn::loss::MseLoss::new().forward(
952                output,
953                y_true.clone(),
954                burn_nn::loss::Reduction::Sum,
955            );
956
957            let grads = loss.backward();
958            let grads_params = GradientsParams::from_grads(grads, &mod_in);
959
960            (loss.into_scalar().to_f64(), grads_params)
961        };
962        let initial_loss = closure(module.clone()).0;
963        assert!((initial_loss - 50.1300048828).abs() < tol);
964        let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
965        assert!((final_loss - 48.2181930542).abs() < tol);
966        let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
967        let optimized_bias: f64 = updated_module
968            .bias
969            .as_ref()
970            .unwrap()
971            .val()
972            .into_scalar()
973            .to_f64();
974
975        assert!((optimized_data - 0.5302446192).abs() < tol);
976        assert!((optimized_bias - 0.1142520783).abs() < tol);
977    }
978}