1use anyhow::Result;
3use core::f64;
4use serde::{Deserialize, Serialize};
5use tch::{
6 nn::{Adam, AdamW, Optimizer as Optimizer_, OptimizerConfig as OptimizerConfig_, VarStore},
8 Tensor,
9};
10
11#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
13pub enum OptimizerConfig {
14 Adam {
16 lr: f64,
18 },
19
20 AdamW {
21 lr: f64,
22 beta1: f64,
23 beta2: f64,
24 wd: f64,
25 eps: f64,
26 amsgrad: bool,
27 },
28}
29
30impl OptimizerConfig {
31 pub fn build(&self, vs: &VarStore) -> Result<Optimizer> {
33 match &self {
34 OptimizerConfig::Adam { lr } => {
35 let opt = Adam::default().build(vs, *lr)?;
36 Ok(Optimizer::Adam(opt))
37 }
38 OptimizerConfig::AdamW {
39 lr,
40 beta1,
41 beta2,
42 wd,
43 eps,
44 amsgrad,
45 } => {
46 let opt = AdamW {
47 beta1: *beta1,
48 beta2: *beta2,
49 wd: *wd,
50 eps: *eps,
51 amsgrad: *amsgrad,
52 }
53 .build(vs, *lr)?;
54 Ok(Optimizer::AdamW(opt))
55 }
56 }
57 }
58}
59
60pub enum Optimizer {
66 Adam(Optimizer_),
68
69 AdamW(Optimizer_),
70}
71
72impl Optimizer {
73 pub fn backward_step(&mut self, loss: &Tensor) {
75 match self {
76 Self::Adam(opt) => {
77 opt.backward_step(loss);
78 }
79 Self::AdamW(opt) => {
80 opt.backward_step(loss);
81 }
82 }
83 }
84}