1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use anyhow::Result;
use core::f64;
use serde::{Deserialize, Serialize};
use tch::{
nn,
nn::{Adam, Optimizer as Optimizer_, OptimizerConfig as OptimizerConfig_, VarStore},
Tensor,
};
#[cfg(not(feature = "adam_eps"))]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub enum OptimizerConfig {
Adam {
lr: f64,
},
}
#[cfg(feature = "adam_eps")]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub enum OptimizerConfig {
Adam {
lr: f64,
},
AdamEps {
lr: f64,
eps: f64,
},
}
#[cfg(not(feature = "adam_eps"))]
impl OptimizerConfig {
pub fn build(&self, vs: &VarStore) -> Result<Optimizer> {
match &self {
OptimizerConfig::Adam { lr } => {
let opt = Adam::default().build(vs, *lr)?;
Ok(Optimizer::Adam(opt))
}
}
}
}
#[cfg(feature = "adam_eps")]
impl OptimizerConfig {
pub fn build(&self, vs: &VarStore) -> Result<Optimizer> {
match &self {
OptimizerConfig::Adam { lr } => {
let opt = Adam::default().build(vs, *lr)?;
Ok(Optimizer::Adam(opt))
}
OptimizerConfig::AdamEps { lr, eps } => {
let mut opt = Adam::default();
opt.eps = *eps;
let opt = opt.build(vs, *lr)?;
Ok(Optimizer::Adam(opt))
}
}
}
}
pub enum Optimizer {
Adam(Optimizer_<nn::Adam>),
}
impl Optimizer {
pub fn backward_step(&mut self, loss: &Tensor) {
match self {
Self::Adam(opt) => {
opt.backward_step(loss);
}
}
}
}