1#![allow(clippy::excessive_precision)]
2
3use burn_core as burn;
4
5use super::GradientsParams;
6use crate::LearningRate;
7use burn::config::Config;
8use burn::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};
9use burn::prelude::ToElement;
10use burn::record::Record;
11use burn::tensor::backend::Backend;
12use burn::tensor::{Tensor, backend::AutodiffBackend};
13use serde::{Deserialize, Serialize};
14
15use alloc::vec;
16use alloc::vec::Vec;
17#[cfg(not(feature = "std"))]
18#[allow(unused_imports)]
19use num_traits::Float as _;
20
21fn cubic_interpolate(
26 x1: f64,
27 f1: f64,
28 g1: f64,
29 x2: f64,
30 f2: f64,
31 g2: f64,
32 bounds: Option<(f64, f64)>,
33) -> f64 {
34 let (min_bound, max_bound) = bounds.unwrap_or(if x1 <= x2 { (x1, x2) } else { (x2, x1) });
36 let d1 = g1 + g2 - 3.0 * (f1 - f2) / (x1 - x2);
44 let d2_square = d1 * d1 - g1 * g2;
45
46 if d2_square >= 0.0 {
47 let d2 = d2_square.sqrt();
48 let min_pos = if x1 <= x2 {
49 x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2.0 * d2))
50 } else {
51 x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2.0 * d2))
52 };
53 min_pos.max(min_bound).min(max_bound)
54 } else {
55 (min_bound + max_bound) / 2.0
56 }
57}
58struct LineSearchSample<B: Backend> {
60 t: f64,
62 f: f64,
64 g: Tensor<B, 1>,
66 gtd: f64,
68}
69
70#[allow(clippy::too_many_arguments)]
71fn strong_wolfe<B: Backend, F>(
72 obj_func: &mut F,
74 x: &Tensor<B, 1>,
75 mut t: f64,
77 d: &Tensor<B, 1>,
78 f: f64,
79 g: Tensor<B, 1>,
80 gtd: f64,
81 c1: f64,
82 c2: f64,
83 tolerance_change: f64,
84 max_ls: usize,
85) -> (f64, Tensor<B, 1>, f64, usize)
86where
87 F: FnMut(&Tensor<B, 1>, f64, &Tensor<B, 1>) -> (f64, Tensor<B, 1>),
88{
89 let d_norm = d.clone().abs().max().into_scalar().to_f64();
90
91 let (mut f_new, mut g_new) = obj_func(x, t, d);
93 let mut ls_func_evals = 1;
94 let mut gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
95
96 let (mut t_prev, mut f_prev, mut g_prev, mut gtd_prev) = (0.0, f, g.clone(), gtd);
98 let mut done = false;
99 let mut ls_iter = 0;
100
101 let mut bracket: Option<[LineSearchSample<B>; 2]> = None;
103 let mut wolfe_bracket: Option<LineSearchSample<B>> = None;
105 while ls_iter < max_ls {
106 if f_new > (f + c1 * t * gtd) || (ls_iter > 1 && f_new >= f_prev) {
111 bracket = Some([
112 LineSearchSample {
113 t: t_prev,
114 f: f_prev,
115 g: g_prev,
116 gtd: gtd_prev,
117 },
118 LineSearchSample {
119 t,
120 f: f_new,
121 g: g_new.clone(),
122 gtd: gtd_new,
123 },
124 ]);
125 break;
126 }
127
128 if gtd_new.abs() <= -c2 * gtd {
131 wolfe_bracket = Some(LineSearchSample {
132 t,
133 f: f_new,
134 g: g_new.clone(),
135 gtd: gtd_new,
136 });
137 done = true;
138 break;
139 }
140
141 if gtd_new >= 0.0 {
143 bracket = Some([
144 LineSearchSample {
145 t: t_prev,
146 f: f_prev,
147 g: g_prev,
148 gtd: gtd_prev,
149 },
150 LineSearchSample {
151 t,
152 f: f_new,
153 g: g_new.clone(),
154 gtd: gtd_new,
155 },
156 ]);
157 break;
158 }
159
160 let min_step = t + 0.01 * (t - t_prev);
162 let max_step = t * 10.0;
163 let t_next = cubic_interpolate(
164 t_prev,
165 f_prev,
166 gtd_prev,
167 t,
168 f_new,
169 gtd_new,
170 Some((min_step, max_step)),
171 );
172 t_prev = t;
173 f_prev = f_new;
174 g_prev = g_new;
175 gtd_prev = gtd_new;
176
177 t = t_next;
179 (f_new, g_new) = obj_func(x, t, d);
180 ls_func_evals += 1;
181 gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
182 ls_iter += 1;
183 }
184 if let Some(sample) = wolfe_bracket {
185 return (sample.f, sample.g, sample.t, ls_func_evals);
186 }
187
188 let mut bracket = bracket.unwrap_or_else(|| {
189 [
190 LineSearchSample {
191 t: 0.0,
192 f,
193 g: g.clone(),
194 gtd,
195 },
196 LineSearchSample {
197 t,
198 f: f_new,
199 g: g_new.clone(),
200 gtd: gtd_new,
201 },
202 ]
203 });
204
205 let mut insuf_progress = false;
207
208 let (mut low_idx, mut high_idx) = if bracket[0].f <= bracket[1].f {
210 (0, 1)
211 } else {
212 (1, 0)
213 };
214
215 while !done && ls_iter < max_ls {
216 let diff = (bracket[1].t - bracket[0].t).abs();
217 if diff * d_norm < tolerance_change {
219 break;
220 }
221
222 t = cubic_interpolate(
224 bracket[0].t,
225 bracket[0].f,
226 bracket[0].gtd,
227 bracket[1].t,
228 bracket[1].f,
229 bracket[1].gtd,
230 None,
231 );
232
233 let b_min = bracket[0].t.min(bracket[1].t);
234 let b_max = bracket[0].t.max(bracket[1].t);
235 let eps = 0.1 * (b_max - b_min);
236
237 if (b_max - t).min(t - b_min) < eps {
238 if insuf_progress || t >= b_max || t <= b_min {
240 t = if (t - b_max).abs() < (t - b_min).abs() {
241 b_max - eps
242 } else {
243 b_min + eps
244 };
245 insuf_progress = false;
246 } else {
247 insuf_progress = true;
248 }
249 } else {
250 insuf_progress = false;
251 }
252
253 (f_new, g_new) = obj_func(x, t, d);
255
256 ls_func_evals += 1;
257 gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();
258 ls_iter += 1;
259
260 let armijo_holds = f_new <= (f + c1 * t * gtd) && f_new < bracket[low_idx].f;
261
262 if !armijo_holds {
263 bracket[high_idx] = LineSearchSample {
264 t,
265 f: f_new,
266 g: g_new,
267 gtd: gtd_new,
268 };
269 } else {
270 if gtd_new.abs() <= -c2 * gtd {
271 return (f_new, g_new, t, ls_func_evals);
272 }
273
274 if gtd_new * (bracket[high_idx].t - bracket[low_idx].t) >= 0.0 {
275 bracket[high_idx] = LineSearchSample {
276 t: bracket[low_idx].t,
277 f: bracket[low_idx].f,
278 g: bracket[low_idx].g.clone(),
279 gtd: bracket[low_idx].gtd,
280 };
281 }
282 bracket[low_idx] = LineSearchSample {
283 t,
284 f: f_new,
285 g: g_new,
286 gtd: gtd_new,
287 };
288 }
289
290 if bracket[0].f <= bracket[1].f {
291 low_idx = 0;
292 high_idx = 1;
293 } else {
294 low_idx = 1;
295 high_idx = 0;
296 }
297 }
298 (
300 bracket[low_idx].f,
301 bracket[low_idx].g.clone(),
302 bracket[low_idx].t,
303 ls_func_evals,
304 )
305}
306
307#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
309pub enum LineSearchFn {
310 #[default]
312 None,
313 StrongWolfe,
317}
318
319#[derive(Config, Debug)]
321pub struct LBFGSConfig {
322 #[config(default = 20)]
324 pub max_iter: usize,
325 #[config(default = 100)]
327 pub history_size: usize,
328 #[config(default = 1e-7)]
330 pub tolerance_grad: f64,
331 #[config(default = 1e-9)]
333 pub tolerance_change: f64,
334 #[config(default = "None")]
336 pub max_eval: Option<usize>,
337 #[config(default = "LineSearchFn::None")]
339 pub line_search_fn: LineSearchFn,
340}
341
342impl LBFGSConfig {
343 pub fn init<B: AutodiffBackend>(&self) -> LBFGS<B> {
349 let max_eval = self.max_eval.unwrap_or(self.max_iter * 5 / 4);
351 LBFGS {
352 config: LBFGSConfig {
353 max_iter: self.max_iter,
354 history_size: self.history_size,
355 tolerance_grad: self.tolerance_grad,
356 tolerance_change: self.tolerance_change,
357 max_eval: Some(max_eval),
358 line_search_fn: self.line_search_fn,
359 },
360 state: Default::default(),
361 }
362 }
363}
364
365struct FlattenGradsVisitorInner<'a, B: AutodiffBackend> {
367 grads: &'a GradientsParams,
368 tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
369}
370
371impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenGradsVisitorInner<'_, B> {
372 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
373 if let Some(g) = self.grads.get::<B::InnerBackend, D>(param.id) {
374 let numel = g.shape().num_elements();
375 self.tensors.push(g.reshape([numel]));
376 }
377 }
378}
379
380fn flatten_params_inner<B: AutodiffBackend, M: Module<B>>(
382 module: &M,
383) -> Tensor<B::InnerBackend, 1> {
384 let mut tensors = Vec::new();
385 let mut visitor = FlattenParamsVisitorInner::<B> {
386 tensors: &mut tensors,
387 };
388 module.visit(&mut visitor);
389 if tensors.is_empty() {
390 return Tensor::empty([0], &module.devices()[0]);
391 }
392 Tensor::cat(tensors, 0)
393}
394
395struct FlattenParamsVisitorInner<'a, B: AutodiffBackend> {
396 tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,
397}
398
399impl<B: AutodiffBackend> ModuleVisitor<B> for FlattenParamsVisitorInner<'_, B> {
400 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
401 let t = param.val().inner();
402 let numel = t.shape().num_elements();
403 self.tensors.push(t.reshape([numel]));
404 }
405}
406
407fn flatten_grads_inner<B: AutodiffBackend, M: Module<B>>(
409 module: &M,
410 grads: &GradientsParams,
411) -> Tensor<B::InnerBackend, 1> {
412 let mut tensors = Vec::new();
413 let mut visitor = FlattenGradsVisitorInner {
414 grads,
415 tensors: &mut tensors,
416 };
417 module.visit(&mut visitor);
418 if tensors.is_empty() {
419 return Tensor::empty([0], &module.devices()[0]);
420 }
421 Tensor::cat(tensors, 0)
422}
423
424struct ParamsFromFlatMapperInner<'a, B: AutodiffBackend> {
426 flat: &'a Tensor<B::InnerBackend, 1>,
427 offset: &'a mut usize,
428}
429
430impl<B: AutodiffBackend> ParamsFromFlatMapperInner<'_, B> {
431 fn take_slice(&mut self, numel: usize) -> Tensor<B::InnerBackend, 1> {
432 let start = *self.offset;
433 *self.offset += numel;
434 self.flat.clone().slice(start..*self.offset)
435 }
436}
437
438impl<B: AutodiffBackend> ModuleMapper<B> for ParamsFromFlatMapperInner<'_, B> {
439 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
440 let (id, tensor, mapper) = param.consume();
441 let numel = tensor.shape().num_elements();
442 let slice_1d = self.take_slice(numel);
443 let new_inner = slice_1d.reshape(tensor.shape());
444 let new_tensor = Tensor::from_inner(new_inner).require_grad();
445 Param::from_mapped_value(id, new_tensor, mapper)
446 }
447}
448
449fn set_params_from_flat_inner<B: AutodiffBackend, M: Module<B>>(
451 module: M,
452 flat: Tensor<B::InnerBackend, 1>,
453) -> M {
454 let mut offset = 0;
455 let mut mapper = ParamsFromFlatMapperInner {
456 flat: &flat,
457 offset: &mut offset,
458 };
459 module.map(&mut mapper)
460}
461
462#[derive(Clone, Record)]
464pub struct LBFGSState<B: Backend> {
465 pub history_s: Vec<Tensor<B, 1>>,
467 pub history_y: Vec<Tensor<B, 1>>,
469 pub d: Option<Tensor<B, 1>>,
471 pub t: Option<f64>,
473 pub prev_flat_grad: Option<Tensor<B, 1>>,
475 pub prev_loss: Option<f64>,
477 pub g_iter: usize,
479}
480
481impl<B: Backend> LBFGSState<B> {
482 pub fn to_device(self, device: &B::Device) -> Self {
484 Self {
485 history_s: self
486 .history_s
487 .into_iter()
488 .map(|t| t.to_device(device))
489 .collect(),
490 history_y: self
491 .history_y
492 .into_iter()
493 .map(|t| t.to_device(device))
494 .collect(),
495 d: self.d.map(|t| t.to_device(device)),
496 t: self.t,
497 prev_flat_grad: self.prev_flat_grad.map(|t| t.to_device(device)),
498 prev_loss: self.prev_loss,
499 g_iter: self.g_iter,
500 }
501 }
502}
503impl<B: Backend> Default for LBFGSState<B> {
504 fn default() -> Self {
505 Self {
506 history_s: Vec::new(),
507 history_y: Vec::new(),
508 d: None,
509 t: Some(1.0),
510 prev_flat_grad: None,
511 prev_loss: None,
512 g_iter: 0,
513 }
514 }
515}
516
517#[derive(Clone)]
527pub struct LBFGS<B: Backend + AutodiffBackend> {
528 config: LBFGSConfig,
529 state: LBFGSState<B::InnerBackend>,
530}
531
532impl<B: Backend + AutodiffBackend> LBFGS<B> {
533 pub fn step<M, F>(&mut self, lr: LearningRate, mut module: M, mut closure: F) -> (M, f64)
535 where
536 M: AutodiffModule<B> + Clone,
537 F: FnMut(M) -> (f64, GradientsParams),
538 {
539 let (mut loss, grads) = closure(module.clone());
541 let mut current_evals = 1;
542
543 let mut flat_grad = flatten_grads_inner::<B, M>(&module, &grads);
544 let mut x_flat = flatten_params_inner::<B, M>(&module);
545
546 let opt_cond =
547 flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad;
548 if opt_cond {
550 return (module, loss);
551 }
552
553 let mut d = self
555 .state
556 .d
557 .take()
558 .unwrap_or_else(|| flat_grad.clone().neg());
559 let mut t = self.state.t.unwrap_or(lr);
560 let mut prev_flat_grad = self.state.prev_flat_grad.take();
561
562 let mut n_iter = 0;
563
564 while n_iter < self.config.max_iter {
566 n_iter += 1;
568 self.state.g_iter += 1;
569
570 if self.state.g_iter == 1 {
572 d = flat_grad.clone().neg();
573 self.state.history_s.clear();
574 self.state.history_y.clear();
575 } else {
576 if let Some(pg) = prev_flat_grad.as_ref() {
578 let y = flat_grad.clone().sub(pg.clone());
579 let s = d.clone().mul_scalar(t);
580
581 let ys = y.clone().dot(s.clone()).into_scalar().to_f64();
582
583 if ys > 1e-10 {
584 if self.state.history_s.len() >= self.config.history_size {
586 self.state.history_s.remove(0);
588 self.state.history_y.remove(0);
589 }
590 self.state.history_s.push(s);
591 self.state.history_y.push(y);
592 }
593 }
594
595 let num_old = self.state.history_s.len();
598 let mut q = flat_grad.clone().neg();
599 let mut alphas: Vec<Tensor<B::InnerBackend, 1>> =
600 vec![Tensor::zeros([1], &flat_grad.device()); num_old];
601
602 if num_old > 0 {
603 for i in (0..num_old).rev() {
606 let s = &self.state.history_s[i];
607 let y = &self.state.history_y[i];
608 let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
609 let alpha = rho.clone().mul(s.clone().dot(q.clone()));
610 alphas[i] = alpha.clone();
611 q = q.sub(y.clone().mul(alpha));
612 }
613
614 let last_s = &self.state.history_s[num_old - 1];
615 let last_y = &self.state.history_y[num_old - 1];
616 let ys = last_y.clone().dot(last_s.clone());
617 let yy = last_y.clone().dot(last_y.clone());
618 let h_diag = ys.div(yy);
619
620 let mut r = q.mul(h_diag);
621
622 for ((s, y), alpha) in self
623 .state
624 .history_s
625 .iter()
626 .zip(self.state.history_y.iter())
627 .zip(alphas)
628 .take(num_old)
629 {
630 let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);
631
632 let beta = rho.mul(y.clone().dot(r.clone()));
633
634 r = r.add(s.clone().mul(alpha.sub(beta)));
635 }
636 d = r;
637 } else {
638 d = q;
639 }
640 }
641
642 prev_flat_grad = Some(flat_grad.clone());
643 let prev_loss_iter = loss;
644
645 if self.state.g_iter == 1 {
647 let grad_l1 = flat_grad.clone().abs().sum().into_scalar().to_f64();
648 t = (1.0f64 / grad_l1).min(1.0) * lr;
649 } else {
650 t = lr;
651 }
652
653 let gtd = flat_grad.clone().dot(d.clone()).into_scalar().to_f64();
655
656 if gtd > -self.config.tolerance_change {
657 break;
658 }
659
660 let ls_func_evals;
661
662 if let LineSearchFn::StrongWolfe = self.config.line_search_fn {
663 let mut obj_func =
665 |current_x: &Tensor<B::InnerBackend, 1>,
666 step: f64,
667 dir: &Tensor<B::InnerBackend, 1>| {
668 let update = dir.clone().mul_scalar(step);
669 let new_x = current_x.clone().add(update);
670 let tmp_module = set_params_from_flat_inner::<B, M>(module.clone(), new_x);
671 let (l, g) = closure(tmp_module);
672 (l, flatten_grads_inner::<B, M>(&module, &g))
673 };
674
675 let (ls_f, ls_g, ls_t, evals) = strong_wolfe(
676 &mut obj_func,
677 &x_flat,
678 t,
679 &d,
680 loss,
681 flat_grad.clone(),
682 gtd,
683 1e-4,
684 0.9,
685 self.config.tolerance_change,
686 self.config.max_eval.unwrap() - current_evals,
687 );
688
689 loss = ls_f;
690 flat_grad = ls_g;
691 t = ls_t;
692 ls_func_evals = evals;
693
694 x_flat = x_flat.add(d.clone().mul_scalar(t));
695 module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
696 } else {
697 let step_vec = d.clone().mul_scalar(t);
699 x_flat = x_flat.add(step_vec);
700 module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());
701 let (new_loss, new_grads) = closure(module.clone());
705 loss = new_loss;
706 flat_grad = flatten_grads_inner::<B, M>(&module, &new_grads);
707 ls_func_evals = 1;
708 }
709
710 current_evals += ls_func_evals;
712
713 if current_evals >= self.config.max_eval.unwrap() {
716 break;
717 }
718
719 if flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad {
720 break;
721 }
722
723 if d.clone().mul_scalar(t).abs().max().into_scalar().to_f64()
724 <= self.config.tolerance_change
725 {
726 break;
727 }
728
729 if (loss - prev_loss_iter).abs() < self.config.tolerance_change {
730 break;
731 }
732 }
733 self.state.d = Some(d);
734 self.state.t = Some(t);
735 self.state.prev_flat_grad = prev_flat_grad;
736 self.state.prev_loss = Some(loss);
737 (module, loss)
738 }
739 pub fn to_device(self, device: &B::Device) -> Self {
741 Self {
742 config: self.config,
743 state: self.state.to_device(device),
745 }
746 }
747}
748
749#[cfg(test)]
750mod tests {
751
752 use super::*;
753 use crate::GradientsParams;
754 use crate::TestAutodiffBackend;
755 use burn::module::{Module, Param};
756 use burn::tensor::{Tensor, TensorData};
757 use burn_nn::{Linear, LinearConfig, LinearRecord};
758
759 fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {
760 let device = Default::default();
761 let record = LinearRecord {
762 weight: Param::from_data(weight, &device),
763 bias: Some(Param::from_data(bias, &device)),
764 };
765
766 LinearConfig::new(6, 6).init(&device).load_record(record)
767 }
768 #[test]
769 fn test_cubic_interpolate() {
770 let tolerance = 1e-8;
771
772 let (x1, f1, g1, x2, f2, g2) = (-1.0, 1.0, -2.0, 1.0, 1.0, 2.0);
774 let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
775 assert!(
776 (result - 0.00000).abs() < tolerance,
777 "Basic: Result {} should be close to 0.0",
778 result
779 );
780
781 let (x1, f1, g1, x2, f2, g2) = (0.0, 0.25, -1.0, 1.0, 0.25, 1.0);
783 let bounds = Some((0.6, 1.0));
784 let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds);
785 assert!(
786 (result - 0.6000000000).abs() < tolerance,
787 "Bound: Result {} should be clamped to 0.6",
788 result
789 );
790
791 let (x1, f1, g1, x2, f2, g2) = (0.0, 0.0, 10.0, 1.0, 5.0, 10.0);
793 let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((0.0, 1.0)));
794 assert!(
795 (result - 0.5000000).abs() < tolerance,
796 "Fallback: Result {} should be midpoint 0.5",
797 result
798 );
799
800 let (x1, f1, g1, x2, f2, g2) = (0.0, 1.0, -5.0, 1.0, 0.5, 1.0);
802 let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
803 assert!(
804 (result - 0.4606553370833684).abs() < tolerance,
805 "Asymmetric: Result {} should be 0.4606553370833684",
806 result
807 );
808
809 let (x1, f1, g1, x2, f2, g2) = (
811 1.231232145,
812 -0.12567458754,
813 9.1231243007,
814 8.239105015,
815 -100.9012398021,
816 123201321.0293982,
817 );
818 let result_1 = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);
819 let result_2 = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((-4.4, 4.4)));
820 assert!(
821 (result_1 - 5.9031480234724434).abs() < tolerance,
822 "not good value 1: Result {} should be 5.9031480234724434",
823 result
824 );
825 assert!(
826 (result_2 - 4.4000000000000004).abs() < tolerance,
827 "not good value 2: Result {} should be 4.4000000000000004",
828 result
829 );
830 }
831 #[test]
832 fn test_strong_wolfe_direct_comparison() {
833 let device = Default::default();
834 let tol = 1e-6;
835
836 {
837 let x = Tensor::<TestAutodiffBackend, 1>::from_floats([2.1321912957_f64], &device);
838 let d = Tensor::<TestAutodiffBackend, 1>::from_floats([0.91312321_f64], &device);
839 let t_initial = 1.213132_f64;
840 fn func<B: Backend>(
841 x_base: &Tensor<B, 1>,
842 t_val: f64,
843 d_vec: &Tensor<B, 1>,
844 ) -> (f64, Tensor<B, 1>) {
845 let curr_x = x_base.clone().add(d_vec.clone().mul_scalar(t_val));
846 let x2 = curr_x.clone().mul(curr_x.clone());
847 let x3 = x2.clone().mul(curr_x.clone());
848 let x4 = x2.clone().mul(x2.clone());
849
850 let f_elements = x4 - x2.mul_scalar(2.0) + curr_x.clone();
852
853 let f_val = f_elements.sum().into_scalar().to_f64();
854
855 let g = x3.mul_scalar(4.0) - curr_x.clone().mul_scalar(4.0)
857 + Tensor::ones_like(&curr_x);
858
859 (f_val, g)
860 }
861 let (f_init, g_init) = func(&x, 0.0, &d);
862 let gtd_init = g_init.clone().dot(d.clone()).into_scalar().to_f64();
863 println!("Initial State: f={},gtd = {}", f_init, gtd_init);
864 assert!((f_init - 13.7080059052).abs() < tol);
865 assert!((gtd_init - 28.5305728912).abs() < tol);
866 let mut obj_func =
867 |xb: &Tensor<TestAutodiffBackend, 1>,
868 tv: f64,
869 dv: &Tensor<TestAutodiffBackend, 1>| func(xb, tv, dv);
870
871 let (f_final, _g_final, t_final, evals) = strong_wolfe(
872 &mut obj_func,
873 &x,
874 t_initial,
875 &d,
876 f_init,
877 g_init,
878 gtd_init,
879 1e-4, 0.9, 1e-9, 10, );
884 let g_f = _g_final.into_scalar().to_f64();
885 println!(
886 "f_final:{:?},_g_final:{:?},t_final:{:?},evals:{:?}",
887 f_final, g_f, t_final, evals
888 );
889 assert!((f_final - 13.708005905151367).abs() < tol);
890 assert!((g_f - 31.2450428009).abs() < tol);
891 assert!((t_final.to_f64() - 0.0).abs() < tol);
892 assert!((evals == 11));
893 }
894 }
895 #[test]
896 fn test_lbfgs_strong_wolfe_comparison() {
897 let device = Default::default();
898 let tol = 1e-5;
899 let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
900 let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
901 let weight = TensorData::from([[0.5f64]]);
902 let bias = TensorData::from([0.1f64]);
903 let module = given_linear_layer(weight, bias);
904
905 let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
906 .with_line_search_fn(LineSearchFn::StrongWolfe)
907 .init();
908 let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
909 let output = mod_in.forward(x_data.clone());
910 let loss = burn_nn::loss::MseLoss::new().forward(
911 output,
912 y_true.clone(),
913 burn_nn::loss::Reduction::Sum,
914 );
915
916 let grads = loss.backward();
917 let grads_params = GradientsParams::from_grads(grads, &mod_in);
918
919 (loss.into_scalar().to_f64(), grads_params)
920 };
921 let initial_loss = closure(module.clone()).0;
922 assert!((initial_loss - 50.1300048828).abs() < tol);
923 let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
924 assert!((final_loss - 0.0234732367).abs() < tol);
925 let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
926 let optimized_bias: f64 = updated_module
927 .bias
928 .as_ref()
929 .unwrap()
930 .val()
931 .into_scalar()
932 .to_f64();
933 assert!((optimized_data - 2.0570652485).abs() < tol);
934 assert!((optimized_bias - 0.8106800914).abs() < tol);
935 }
936 #[test]
937 fn test_lbfgs_no_strong_wolfe_comparison() {
938 let device = Default::default();
939 let tol = 1e-5;
940 let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);
941 let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);
942 let weight = TensorData::from([[0.5f64]]);
943 let bias = TensorData::from([0.1f64]);
944 let module = given_linear_layer(weight, bias);
945
946 let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()
947 .with_line_search_fn(LineSearchFn::None)
948 .init();
949 let mut closure = |mod_in: Linear<TestAutodiffBackend>| {
950 let output = mod_in.forward(x_data.clone());
951 let loss = burn_nn::loss::MseLoss::new().forward(
952 output,
953 y_true.clone(),
954 burn_nn::loss::Reduction::Sum,
955 );
956
957 let grads = loss.backward();
958 let grads_params = GradientsParams::from_grads(grads, &mod_in);
959
960 (loss.into_scalar().to_f64(), grads_params)
961 };
962 let initial_loss = closure(module.clone()).0;
963 assert!((initial_loss - 50.1300048828).abs() < tol);
964 let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);
965 assert!((final_loss - 48.2181930542).abs() < tol);
966 let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();
967 let optimized_bias: f64 = updated_module
968 .bias
969 .as_ref()
970 .unwrap()
971 .val()
972 .into_scalar()
973 .to_f64();
974
975 assert!((optimized_data - 0.5302446192).abs() < tol);
976 assert!((optimized_bias - 0.1142520783).abs() < tol);
977 }
978}