border_candle_agent/
opt.rs1use anyhow::Result;
3use candle_core::{Tensor, Var};
4use candle_nn::{AdamW, Optimizer as _, ParamsAdamW};
5use candle_optimisers::adam::{Adam, ParamsAdam};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
10pub enum OptimizerConfig {
11 AdamW {
13 lr: f64,
14 #[serde(default = "default_beta1")]
15 beta1: f64,
16 #[serde(default = "default_beta2")]
17 beta2: f64,
18 #[serde(default = "default_eps")]
19 eps: f64,
20 #[serde(default = "default_weight_decay")]
21 weight_decay: f64,
22 },
23
24 Adam {
26 lr: f64,
28 },
29}
30
31fn default_beta1() -> f64 {
32 ParamsAdamW::default().beta1
33}
34
35fn default_beta2() -> f64 {
36 ParamsAdamW::default().beta2
37}
38
39fn default_eps() -> f64 {
40 ParamsAdamW::default().eps
41}
42
43fn default_weight_decay() -> f64 {
44 ParamsAdamW::default().weight_decay
45}
46
47impl OptimizerConfig {
48 pub fn build(&self, vars: Vec<Var>) -> Result<Optimizer> {
50 match &self {
51 OptimizerConfig::AdamW {
52 lr,
53 beta1,
54 beta2,
55 eps,
56 weight_decay,
57 } => {
58 let params = ParamsAdamW {
59 lr: *lr,
60 beta1: *beta1,
61 beta2: *beta2,
62 eps: *eps,
63 weight_decay: *weight_decay,
64 };
65 let opt = AdamW::new(vars, params)?;
66 Ok(Optimizer::AdamW(opt))
67 }
68 OptimizerConfig::Adam { lr } => {
69 let params = ParamsAdam {
70 lr: *lr,
71 ..ParamsAdam::default()
72 };
73 let opt = Adam::new(vars, params)?;
74 Ok(Optimizer::Adam(opt))
75 }
76 }
77 }
78
79 pub fn learning_rate(self, lr: f64) -> Self {
81 match self {
82 Self::AdamW {
83 lr: _,
84 beta1,
85 beta2,
86 eps,
87 weight_decay,
88 } => Self::AdamW {
89 lr,
90 beta1,
91 beta2,
92 eps,
93 weight_decay,
94 },
95 Self::Adam { lr: _ } => Self::Adam { lr },
96 }
97 }
98}
99
100impl Default for OptimizerConfig {
101 fn default() -> Self {
102 let params = ParamsAdamW::default();
103 Self::AdamW {
104 lr: params.lr,
105 beta1: params.beta1,
106 beta2: params.beta2,
107 eps: params.eps,
108 weight_decay: params.weight_decay,
109 }
110 }
111}
112
113pub enum Optimizer {
119 AdamW(AdamW),
121
122 Adam(Adam),
123}
124
125impl Optimizer {
126 pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
128 match self {
129 Self::AdamW(opt) => Ok(opt.backward_step(loss)?),
130 Self::Adam(opt) => Ok(opt.backward_step(loss)?),
131 }
132 }
133
134 pub fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> {
135 match self {
136 Self::AdamW(opt) => Ok(opt.step(grads)?),
137 Self::Adam(opt) => Ok(opt.step(grads)?),
138 }
139 }
140}