1use crate::{LossOptimizer, Model, ModelOutcome};
34use candle_core::Result as CResult;
35use candle_core::{Tensor, Var};
36use log::info;
37use std::collections::VecDeque;
38mod strong_wolfe;
41
42#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
45#[non_exhaustive]
46pub enum LineSearch {
47 StrongWolfe(f64, f64, f64),
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
64#[non_exhaustive]
65pub enum GradConv {
66 MinForce(f64),
68 RMSForce(f64),
70}
71
72#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
74#[non_exhaustive]
75pub enum StepConv {
76 MinStep(f64),
78 RMSStep(f64),
80}
81
82#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
84pub struct ParamsLBFGS {
85 pub lr: f64,
88 pub history_size: usize,
90 pub line_search: Option<LineSearch>,
92 pub grad_conv: GradConv,
94 pub step_conv: StepConv,
96 pub weight_decay: Option<f64>,
98}
99
100impl Default for ParamsLBFGS {
101 fn default() -> Self {
102 Self {
103 lr: 1.,
104 history_size: 100,
107 line_search: None,
108 grad_conv: GradConv::MinForce(1e-7),
109 step_conv: StepConv::MinStep(1e-9),
110 weight_decay: None,
111 }
112 }
113}
114
115#[derive(Debug)]
123pub struct Lbfgs<M: Model> {
124 vars: Vec<Var>,
125 model: M,
126 s_hist: VecDeque<(Tensor, Tensor)>,
127 last_grad: Option<Var>,
128 next_grad: Option<Var>,
129 last_step: Option<Var>,
130 params: ParamsLBFGS,
131 first: bool,
132}
133
134impl<M: Model> LossOptimizer<M> for Lbfgs<M> {
135 type Config = ParamsLBFGS;
136
137 fn new(vs: Vec<Var>, params: Self::Config, model: M) -> CResult<Self> {
138 let hist_size = params.history_size;
139 Ok(Lbfgs {
140 vars: vs,
141 model,
142 s_hist: VecDeque::with_capacity(hist_size),
143 last_step: None,
144 last_grad: None,
145 next_grad: None,
146 params,
147 first: true,
148 })
149 }
150
151 #[allow(clippy::too_many_lines)]
152 fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome> {
153 let mut evals = 1;
154
155 let grad = if let Some(this_grad) = &self.next_grad {
156 this_grad.as_tensor().copy()?
157 } else {
158 flat_grads(&self.vars, loss, self.params.weight_decay)?
159 };
160
161 match self.params.grad_conv {
162 GradConv::MinForce(tol) => {
163 if grad
164 .abs()?
165 .max(0)?
166 .to_dtype(candle_core::DType::F64)?
167 .to_scalar::<f64>()?
168 < tol
169 {
170 info!("grad converged");
171 return Ok(ModelOutcome::Converged(loss.clone(), evals));
172 }
173 }
174 GradConv::RMSForce(tol) => {
175 if grad
176 .sqr()?
177 .mean_all()?
178 .to_dtype(candle_core::DType::F64)?
179 .to_scalar::<f64>()?
180 .sqrt()
181 < tol
182 {
183 info!("grad converged");
184 return Ok(ModelOutcome::Converged(loss.clone(), evals));
185 }
186 }
187 }
188
189 let mut yk = None;
190
191 if let Some(last) = &self.last_grad {
192 yk = Some((&grad - last.as_tensor())?);
193 last.set(&grad)?;
194 } else {
195 self.last_grad = Some(Var::from_tensor(&grad)?);
196 }
197
198 let q = Var::from_tensor(&grad)?;
199
200 let hist_size = self.s_hist.len();
201
202 if hist_size == self.params.history_size {
203 self.s_hist.pop_front();
204 }
205 if let Some(yk) = yk {
206 if let Some(step) = &self.last_step {
207 self.s_hist.push_back((step.as_tensor().clone(), yk));
208 }
209 }
210
211 let gamma = if let Some((s, y)) = self.s_hist.back() {
212 let numr = y
213 .unsqueeze(0)?
214 .matmul(&(s.unsqueeze(1)?))?
215 .to_dtype(candle_core::DType::F64)?
216 .squeeze(1)?
217 .squeeze(0)?
218 .to_scalar::<f64>()?;
219
220 let denom = y
221 .unsqueeze(0)?
222 .matmul(&(y.unsqueeze(1)?))?
223 .to_dtype(candle_core::DType::F64)?
224 .squeeze(1)?
225 .squeeze(0)?
226 .to_scalar::<f64>()?
227 + 1e-10;
228
229 numr / denom
230 } else {
231 1.
232 };
233
234 let mut rhos = VecDeque::with_capacity(hist_size);
235 let mut alphas = VecDeque::with_capacity(hist_size);
236 for (s, y) in self.s_hist.iter().rev() {
237 let rho = (y
238 .unsqueeze(0)?
239 .matmul(&(s.unsqueeze(1)?))?
240 .to_dtype(candle_core::DType::F64)?
241 .squeeze(1)?
242 .squeeze(0)?
243 .to_scalar::<f64>()?
244 + 1e-10)
245 .powi(-1);
246
247 let alpha = rho
248 * s.unsqueeze(0)?
249 .matmul(&(q.unsqueeze(1)?))?
250 .to_dtype(candle_core::DType::F64)?
251 .squeeze(1)?
252 .squeeze(0)?
253 .to_scalar::<f64>()?;
254
255 q.set(&q.sub(&(y * alpha)?)?)?;
256 alphas.push_front(alpha);
258 rhos.push_front(rho);
259 }
260
261 q.set(&(q.as_tensor() * gamma)?)?;
263 for (((s, y), alpha), rho) in self
264 .s_hist
265 .iter()
266 .zip(alphas.into_iter())
267 .zip(rhos.into_iter())
268 {
269 let beta = rho
270 * y.unsqueeze(0)?
271 .matmul(&(q.unsqueeze(1)?))?
272 .to_dtype(candle_core::DType::F64)?
273 .squeeze(1)?
274 .squeeze(0)?
275 .to_scalar::<f64>()?;
276
277 q.set(&q.add(&(s * (alpha - beta))?)?)?;
278 }
279
280 let dd = grad
282 .unsqueeze(0)?
283 .matmul(&(q.unsqueeze(1)?))?
284 .to_dtype(candle_core::DType::F64)?
285 .squeeze(1)?
286 .squeeze(0)?
287 .to_scalar::<f64>()?;
288
289 let mut lr = if self.first {
290 self.first = false;
291 -(1_f64.min(
292 1. / grad
293 .abs()?
294 .sum_all()?
295 .to_dtype(candle_core::DType::F64)?
296 .to_scalar::<f64>()?,
297 )) * self.params.lr
298 } else {
299 -self.params.lr
300 };
301
302 if let Some(ls) = &self.params.line_search {
303 match ls {
304 LineSearch::StrongWolfe(c1, c2, tol) => {
305 let (loss, grad, t, steps) =
306 self.strong_wolfe(lr, &q, loss, &grad, dd, *c1, *c2, *tol, 25)?;
307 if let Some(next_grad) = &self.next_grad {
308 next_grad.set(&grad)?;
309 } else {
310 self.next_grad = Some(Var::from_tensor(&grad)?);
311 }
312
313 evals += steps;
314 lr = t;
315 q.set(&(q.as_tensor() * lr)?)?;
316
317 if let Some(step) = &self.last_step {
318 step.set(&q)?;
319 } else {
320 self.last_step = Some(Var::from_tensor(&q)?);
321 }
322
323 match self.params.step_conv {
324 StepConv::MinStep(tol) => {
325 if q.abs()?
326 .max(0)?
327 .to_dtype(candle_core::DType::F64)?
328 .to_scalar::<f64>()?
329 < tol
330 {
331 add_grad(&mut self.vars, q.as_tensor())?;
332 info!("step converged");
333 Ok(ModelOutcome::Converged(loss, evals))
334 } else {
335 add_grad(&mut self.vars, q.as_tensor())?;
336 Ok(ModelOutcome::Stepped(loss, evals))
337 }
338 }
339 StepConv::RMSStep(tol) => {
340 if q.sqr()?
341 .mean_all()?
342 .to_dtype(candle_core::DType::F64)?
343 .to_scalar::<f64>()?
344 .sqrt()
345 < tol
346 {
347 add_grad(&mut self.vars, q.as_tensor())?;
348 info!("step converged");
349 Ok(ModelOutcome::Converged(loss, evals))
350 } else {
351 add_grad(&mut self.vars, q.as_tensor())?;
352 Ok(ModelOutcome::Stepped(loss, evals))
353 }
354 }
355 }
356 }
357 }
358 } else {
359 q.set(&(q.as_tensor() * lr)?)?;
360
361 if let Some(step) = &self.last_step {
362 step.set(&q)?;
363 } else {
364 self.last_step = Some(Var::from_tensor(&q)?);
365 }
366
367 match self.params.step_conv {
368 StepConv::MinStep(tol) => {
369 if q.abs()?
370 .max(0)?
371 .to_dtype(candle_core::DType::F64)?
372 .to_scalar::<f64>()?
373 < tol
374 {
375 add_grad(&mut self.vars, q.as_tensor())?;
376
377 let next_loss = self.model.loss()?;
378 evals += 1;
379 info!("step converged");
380 Ok(ModelOutcome::Converged(next_loss, evals))
381 } else {
382 add_grad(&mut self.vars, q.as_tensor())?;
383
384 let next_loss = self.model.loss()?;
385 evals += 1;
386 Ok(ModelOutcome::Stepped(next_loss, evals))
387 }
388 }
389 StepConv::RMSStep(tol) => {
390 if q.sqr()?
391 .mean_all()?
392 .to_dtype(candle_core::DType::F64)?
393 .to_scalar::<f64>()?
394 .sqrt()
395 < tol
396 {
397 add_grad(&mut self.vars, q.as_tensor())?;
398
399 let next_loss = self.model.loss()?;
400 evals += 1;
401 info!("step converged");
402 Ok(ModelOutcome::Converged(next_loss, evals))
403 } else {
404 add_grad(&mut self.vars, q.as_tensor())?;
405
406 let next_loss = self.model.loss()?;
407 evals += 1;
408 Ok(ModelOutcome::Stepped(next_loss, evals))
409 }
410 }
411 }
412 }
413 }
414
415 fn learning_rate(&self) -> f64 {
416 self.params.lr
417 }
418
419 fn set_learning_rate(&mut self, lr: f64) {
420 self.params.lr = lr;
421 }
422
423 #[must_use]
424 fn into_inner(self) -> Vec<Var> {
425 self.vars
426 }
427}
428
429#[allow(clippy::inline_always)]
430#[inline(always)]
431fn flat_grads(vs: &Vec<Var>, loss: &Tensor, weight_decay: Option<f64>) -> CResult<Tensor> {
432 let grads = loss.backward()?;
433 let mut flat_grads = Vec::with_capacity(vs.len());
434 if let Some(wd) = weight_decay {
435 for v in vs {
436 if let Some(grad) = grads.get(v) {
437 let grad = &(grad + (wd * v.as_tensor())?)?;
438 flat_grads.push(grad.flatten_all()?);
439 } else {
440 let grad = (wd * v.as_tensor())?; flat_grads.push(grad.flatten_all()?);
442 }
443 }
444 } else {
445 for v in vs {
446 if let Some(grad) = grads.get(v) {
447 flat_grads.push(grad.flatten_all()?);
448 } else {
449 let n_elems = v.elem_count();
450 flat_grads.push(candle_core::Tensor::zeros(n_elems, v.dtype(), v.device())?);
451 }
452 }
453 }
454 candle_core::Tensor::cat(&flat_grads, 0)
455}
456
457fn add_grad(vs: &mut Vec<Var>, flat_tensor: &Tensor) -> CResult<()> {
458 let mut offset = 0;
459 for var in vs {
460 let n_elems = var.elem_count();
461 let tensor = flat_tensor
462 .narrow(0, offset, n_elems)?
463 .reshape(var.shape())?;
464 var.set(&var.add(&tensor)?)?;
465 offset += n_elems;
466 }
467 Ok(())
468}
469
470fn set_vs(vs: &mut [Var], vals: &Vec<Tensor>) -> CResult<()> {
471 for (var, t) in vs.iter().zip(vals) {
472 var.set(t)?;
473 }
474 Ok(())
475}
476
477#[cfg(test)]
478mod tests {
479 use crate::Model;
482 use anyhow::Result;
483 use assert_approx_eq::assert_approx_eq;
484 use candle_core::Device;
485 use candle_core::{Module, Result as CResult};
486 pub struct LinearModel {
487 linear: candle_nn::Linear,
488 xs: Tensor,
489 ys: Tensor,
490 }
491
492 impl Model for LinearModel {
493 fn loss(&self) -> CResult<Tensor> {
494 let preds = self.forward(&self.xs)?;
495 let loss = candle_nn::loss::mse(&preds, &self.ys)?;
496 Ok(loss)
497 }
498 }
499
500 impl LinearModel {
501 fn new() -> CResult<(Self, Vec<Var>)> {
502 let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?;
503 let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?;
504
505 let linear =
506 candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone()));
507
508 Ok((
509 Self {
510 linear,
511 xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?,
512 ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?,
513 },
514 vec![weight, bias],
515 ))
516 }
517
518 fn forward(&self, xs: &Tensor) -> CResult<Tensor> {
519 self.linear.forward(xs)
520 }
521 }
522
523 use super::*;
524 #[test]
525 fn lr_test() -> Result<()> {
526 let params = ParamsLBFGS {
527 lr: 0.004,
528 ..Default::default()
529 };
530 let (model, vars) = LinearModel::new()?;
531 let mut lbfgs = Lbfgs::new(vars, params, model)?;
532 assert_approx_eq!(0.004, lbfgs.learning_rate());
533 lbfgs.set_learning_rate(0.002);
534 assert_approx_eq!(0.002, lbfgs.learning_rate());
535 Ok(())
536 }
537
538 #[test]
539 fn into_inner_test() -> Result<()> {
540 let params = ParamsLBFGS {
541 lr: 0.004,
542 ..Default::default()
543 };
544 let (model, vars) = LinearModel::new()?;
547 let slice: Vec<&Var> = vars.iter().collect();
548 let lbfgs = Lbfgs::from_slice(&slice, params, model)?;
549 let inner = lbfgs.into_inner();
550
551 assert_eq!(inner[0].as_tensor().to_vec1::<f64>()?, &[3f64, 1.]);
552 println!("checked weights");
553 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f64>()?, -2_f64);
554 Ok(())
555 }
556}