1use crate::{LossOptimizer, Model, ModelOutcome};
34use candle_core::Result as CResult;
35use candle_core::{Tensor, Var};
36use log::info;
37use std::collections::VecDeque;
38mod strong_wolfe;
41
42#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
45#[non_exhaustive]
46pub enum LineSearch {
47 StrongWolfe(f64, f64, f64),
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
64#[non_exhaustive]
65pub enum GradConv {
66 MinForce(f64),
68 RMSForce(f64),
70}
71
72#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
74#[non_exhaustive]
75pub enum StepConv {
76 MinStep(f64),
78 RMSStep(f64),
80}
81
82#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
84pub struct ParamsLBFGS {
85 pub lr: f64,
88 pub history_size: usize,
90 pub line_search: Option<LineSearch>,
92 pub grad_conv: GradConv,
94 pub step_conv: StepConv,
96 pub weight_decay: Option<f64>,
98}
99
100impl Default for ParamsLBFGS {
101 fn default() -> Self {
102 Self {
103 lr: 1.,
104 history_size: 100,
107 line_search: None,
108 grad_conv: GradConv::MinForce(1e-7),
109 step_conv: StepConv::MinStep(1e-9),
110 weight_decay: None,
111 }
112 }
113}
114
115#[derive(Debug)]
123pub struct Lbfgs<'a, M: Model> {
124 vars: Vec<Var>,
125 model: &'a M,
126 s_hist: VecDeque<(Tensor, Tensor)>,
127 last_grad: Option<Var>,
128 next_grad: Option<Var>,
129 last_step: Option<Var>,
130 params: ParamsLBFGS,
131 first: bool,
132}
133
134impl<'a, M: Model> LossOptimizer<'a, M> for Lbfgs<'a, M> {
135 type Config = ParamsLBFGS;
136
137 fn new(vs: Vec<Var>, params: Self::Config, model: &'a M) -> CResult<Self> {
138 let hist_size = params.history_size;
139 Ok(Lbfgs {
140 vars: vs,
141 model,
142 s_hist: VecDeque::with_capacity(hist_size),
143 last_step: None,
144 last_grad: None,
145 next_grad: None,
146 params,
147 first: true,
148 })
149 }
150
151 #[allow(clippy::too_many_lines)]
152 fn backward_step(&mut self, loss: &Tensor) -> CResult<ModelOutcome> {
153 let mut evals = 1;
154
155 let grad = if let Some(this_grad) = &self.next_grad {
156 this_grad.as_tensor().copy()?
157 } else {
158 flat_grads(&self.vars, loss, self.params.weight_decay)?
159 };
160
161 match self.params.grad_conv {
162 GradConv::MinForce(tol) => {
163 if grad
164 .abs()?
165 .max(0)?
166 .to_dtype(candle_core::DType::F64)?
167 .to_scalar::<f64>()?
168 < tol
169 {
170 info!("grad converged");
171 return Ok(ModelOutcome::Converged(loss.clone(), evals));
172 }
173 }
174 GradConv::RMSForce(tol) => {
175 if grad
176 .sqr()?
177 .mean_all()?
178 .to_dtype(candle_core::DType::F64)?
179 .to_scalar::<f64>()?
180 .sqrt()
181 < tol
182 {
183 info!("grad converged");
184 return Ok(ModelOutcome::Converged(loss.clone(), evals));
185 }
186 }
187 }
188
189 let mut yk = None;
190
191 if let Some(last) = &self.last_grad {
192 yk = Some((&grad - last.as_tensor())?);
193 last.set(&grad)?;
194 } else {
195 self.last_grad = Some(Var::from_tensor(&grad)?);
196 }
197
198 let q = Var::from_tensor(&grad)?;
199
200 let hist_size = self.s_hist.len();
201
202 if hist_size == self.params.history_size {
203 self.s_hist.pop_front();
204 }
205 if let Some(yk) = yk {
206 if let Some(step) = &self.last_step {
207 self.s_hist.push_back((step.as_tensor().clone(), yk));
208 }
209 }
210
211 let gamma = if let Some((s, y)) = self.s_hist.back() {
212 let numr = y
213 .unsqueeze(0)?
214 .matmul(&(s.unsqueeze(1)?))?
215 .to_dtype(candle_core::DType::F64)?
216 .squeeze(1)?
217 .squeeze(0)?
218 .to_scalar::<f64>()?;
219
220 let denom = y
221 .unsqueeze(0)?
222 .matmul(&(y.unsqueeze(1)?))?
223 .to_dtype(candle_core::DType::F64)?
224 .squeeze(1)?
225 .squeeze(0)?
226 .to_scalar::<f64>()?
227 + 1e-10;
228
229 numr / denom
230 } else {
231 1.
232 };
233
234 let mut rhos = VecDeque::with_capacity(hist_size);
235 let mut alphas = VecDeque::with_capacity(hist_size);
236 for (s, y) in self.s_hist.iter().rev() {
237 let rho = (y
238 .unsqueeze(0)?
239 .matmul(&(s.unsqueeze(1)?))?
240 .to_dtype(candle_core::DType::F64)?
241 .squeeze(1)?
242 .squeeze(0)?
243 .to_scalar::<f64>()?
244 + 1e-10)
245 .powi(-1);
246
247 let alpha = rho
248 * s.unsqueeze(0)?
249 .matmul(&(q.unsqueeze(1)?))?
250 .to_dtype(candle_core::DType::F64)?
251 .squeeze(1)?
252 .squeeze(0)?
253 .to_scalar::<f64>()?;
254
255 q.set(&q.sub(&(y * alpha)?)?)?;
256 alphas.push_front(alpha);
258 rhos.push_front(rho);
259 }
260
261 q.set(&(q.as_tensor() * gamma)?)?;
263 for (((s, y), alpha), rho) in self.s_hist.iter().zip(alphas).zip(rhos) {
264 let beta = rho
265 * y.unsqueeze(0)?
266 .matmul(&(q.unsqueeze(1)?))?
267 .to_dtype(candle_core::DType::F64)?
268 .squeeze(1)?
269 .squeeze(0)?
270 .to_scalar::<f64>()?;
271
272 q.set(&q.add(&(s * (alpha - beta))?)?)?;
273 }
274
275 let dd = grad
277 .unsqueeze(0)?
278 .matmul(&(q.unsqueeze(1)?))?
279 .to_dtype(candle_core::DType::F64)?
280 .squeeze(1)?
281 .squeeze(0)?
282 .to_scalar::<f64>()?;
283
284 let mut lr = if self.first {
285 self.first = false;
286 -(1_f64.min(
287 1. / grad
288 .abs()?
289 .sum_all()?
290 .to_dtype(candle_core::DType::F64)?
291 .to_scalar::<f64>()?,
292 )) * self.params.lr
293 } else {
294 -self.params.lr
295 };
296
297 if let Some(ls) = &self.params.line_search {
298 match ls {
299 LineSearch::StrongWolfe(c1, c2, tol) => {
300 let (loss, grad, t, steps) =
301 self.strong_wolfe(lr, &q, loss, &grad, dd, *c1, *c2, *tol, 25)?;
302 if let Some(next_grad) = &self.next_grad {
303 next_grad.set(&grad)?;
304 } else {
305 self.next_grad = Some(Var::from_tensor(&grad)?);
306 }
307
308 evals += steps;
309 lr = t;
310 q.set(&(q.as_tensor() * lr)?)?;
311
312 if let Some(step) = &self.last_step {
313 step.set(&q)?;
314 } else {
315 self.last_step = Some(Var::from_tensor(&q)?);
316 }
317
318 match self.params.step_conv {
319 StepConv::MinStep(tol) => {
320 if q.abs()?
321 .max(0)?
322 .to_dtype(candle_core::DType::F64)?
323 .to_scalar::<f64>()?
324 < tol
325 {
326 add_grad(&mut self.vars, q.as_tensor())?;
327 info!("step converged");
328 Ok(ModelOutcome::Converged(loss, evals))
329 } else {
330 add_grad(&mut self.vars, q.as_tensor())?;
331 Ok(ModelOutcome::Stepped(loss, evals))
332 }
333 }
334 StepConv::RMSStep(tol) => {
335 if q.sqr()?
336 .mean_all()?
337 .to_dtype(candle_core::DType::F64)?
338 .to_scalar::<f64>()?
339 .sqrt()
340 < tol
341 {
342 add_grad(&mut self.vars, q.as_tensor())?;
343 info!("step converged");
344 Ok(ModelOutcome::Converged(loss, evals))
345 } else {
346 add_grad(&mut self.vars, q.as_tensor())?;
347 Ok(ModelOutcome::Stepped(loss, evals))
348 }
349 }
350 }
351 }
352 }
353 } else {
354 q.set(&(q.as_tensor() * lr)?)?;
355
356 if let Some(step) = &self.last_step {
357 step.set(&q)?;
358 } else {
359 self.last_step = Some(Var::from_tensor(&q)?);
360 }
361
362 match self.params.step_conv {
363 StepConv::MinStep(tol) => {
364 if q.abs()?
365 .max(0)?
366 .to_dtype(candle_core::DType::F64)?
367 .to_scalar::<f64>()?
368 < tol
369 {
370 add_grad(&mut self.vars, q.as_tensor())?;
371
372 let next_loss = self.model.loss()?;
373 evals += 1;
374 info!("step converged");
375 Ok(ModelOutcome::Converged(next_loss, evals))
376 } else {
377 add_grad(&mut self.vars, q.as_tensor())?;
378
379 let next_loss = self.model.loss()?;
380 evals += 1;
381 Ok(ModelOutcome::Stepped(next_loss, evals))
382 }
383 }
384 StepConv::RMSStep(tol) => {
385 if q.sqr()?
386 .mean_all()?
387 .to_dtype(candle_core::DType::F64)?
388 .to_scalar::<f64>()?
389 .sqrt()
390 < tol
391 {
392 add_grad(&mut self.vars, q.as_tensor())?;
393
394 let next_loss = self.model.loss()?;
395 evals += 1;
396 info!("step converged");
397 Ok(ModelOutcome::Converged(next_loss, evals))
398 } else {
399 add_grad(&mut self.vars, q.as_tensor())?;
400
401 let next_loss = self.model.loss()?;
402 evals += 1;
403 Ok(ModelOutcome::Stepped(next_loss, evals))
404 }
405 }
406 }
407 }
408 }
409
410 fn learning_rate(&self) -> f64 {
411 self.params.lr
412 }
413
414 fn set_learning_rate(&mut self, lr: f64) {
415 self.params.lr = lr;
416 }
417
418 fn into_inner(self) -> Vec<Var> {
419 self.vars
420 }
421}
422
423#[allow(clippy::inline_always)]
424#[inline(always)]
425fn flat_grads(vs: &Vec<Var>, loss: &Tensor, weight_decay: Option<f64>) -> CResult<Tensor> {
426 let grads = loss.backward()?;
427 let mut flat_grads = Vec::with_capacity(vs.len());
428 if let Some(wd) = weight_decay {
429 for v in vs {
430 if let Some(grad) = grads.get(v) {
431 let grad = &(grad + (wd * v.as_tensor())?)?;
432 flat_grads.push(grad.flatten_all()?);
433 } else {
434 let grad = (wd * v.as_tensor())?; flat_grads.push(grad.flatten_all()?);
436 }
437 }
438 } else {
439 for v in vs {
440 if let Some(grad) = grads.get(v) {
441 flat_grads.push(grad.flatten_all()?);
442 } else {
443 let n_elems = v.elem_count();
444 flat_grads.push(candle_core::Tensor::zeros(n_elems, v.dtype(), v.device())?);
445 }
446 }
447 }
448 candle_core::Tensor::cat(&flat_grads, 0)
449}
450
451fn add_grad(vs: &mut Vec<Var>, flat_tensor: &Tensor) -> CResult<()> {
452 let mut offset = 0;
453 for var in vs {
454 let n_elems = var.elem_count();
455 let tensor = flat_tensor
456 .narrow(0, offset, n_elems)?
457 .reshape(var.shape())?;
458 var.set(&var.add(&tensor)?)?;
459 offset += n_elems;
460 }
461 Ok(())
462}
463
464fn set_vs(vs: &mut [Var], vals: &Vec<Tensor>) -> CResult<()> {
465 for (var, t) in vs.iter().zip(vals) {
466 var.set(t)?;
467 }
468 Ok(())
469}
470
471#[cfg(test)]
472mod tests {
473 use crate::Model;
476 use anyhow::Result;
477 use assert_approx_eq::assert_approx_eq;
478 use candle_core::Device;
479 use candle_core::{Module, Result as CResult};
480 pub struct LinearModel {
481 linear: candle_nn::Linear,
482 xs: Tensor,
483 ys: Tensor,
484 }
485
486 impl Model for LinearModel {
487 fn loss(&self) -> CResult<Tensor> {
488 let preds = self.forward(&self.xs)?;
489 let loss = candle_nn::loss::mse(&preds, &self.ys)?;
490 Ok(loss)
491 }
492 }
493
494 impl LinearModel {
495 fn new() -> CResult<(Self, Vec<Var>)> {
496 let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?;
497 let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?;
498
499 let linear =
500 candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone()));
501
502 Ok((
503 Self {
504 linear,
505 xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?,
506 ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?,
507 },
508 vec![weight, bias],
509 ))
510 }
511
512 fn forward(&self, xs: &Tensor) -> CResult<Tensor> {
513 self.linear.forward(xs)
514 }
515 }
516
517 use super::*;
518 #[test]
519 fn lr_test() -> Result<()> {
520 let params = ParamsLBFGS {
521 lr: 0.004,
522 ..Default::default()
523 };
524 let (model, vars) = LinearModel::new()?;
525 let mut lbfgs = Lbfgs::new(vars, params, &model)?;
526 assert_approx_eq!(0.004, lbfgs.learning_rate());
527 lbfgs.set_learning_rate(0.002);
528 assert_approx_eq!(0.002, lbfgs.learning_rate());
529 Ok(())
530 }
531
532 #[test]
533 fn into_inner_test() -> Result<()> {
534 let params = ParamsLBFGS {
535 lr: 0.004,
536 ..Default::default()
537 };
538 let (model, vars) = LinearModel::new()?;
541 let slice: Vec<&Var> = vars.iter().collect();
542 let lbfgs = Lbfgs::from_slice(&slice, params, &model)?;
543 let inner = lbfgs.into_inner();
544
545 assert_eq!(inner[0].as_tensor().to_vec1::<f64>()?, &[3f64, 1.]);
546 println!("checked weights");
547 assert_approx_eq!(inner[1].as_tensor().to_vec0::<f64>()?, -2_f64);
548 Ok(())
549 }
550}