border_tch_agent/
opt.rs

1//! Optimizers.
2use anyhow::Result;
3use core::f64;
4use serde::{Deserialize, Serialize};
5use tch::{
6    // nn,
7    nn::{Adam, AdamW, Optimizer as Optimizer_, OptimizerConfig as OptimizerConfig_, VarStore},
8    Tensor,
9};
10
11/// Configures an optimizer for training neural networks in an RL agent.
12#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
13pub enum OptimizerConfig {
14    /// Adam optimizer.
15    Adam {
16        /// Learning rate.
17        lr: f64,
18    },
19
20    AdamW {
21        lr: f64,
22        beta1: f64,
23        beta2: f64,
24        wd: f64,
25        eps: f64,
26        amsgrad: bool,
27    },
28}
29
30impl OptimizerConfig {
31    /// Constructs an optimizer.
32    pub fn build(&self, vs: &VarStore) -> Result<Optimizer> {
33        match &self {
34            OptimizerConfig::Adam { lr } => {
35                let opt = Adam::default().build(vs, *lr)?;
36                Ok(Optimizer::Adam(opt))
37            }
38            OptimizerConfig::AdamW {
39                lr,
40                beta1,
41                beta2,
42                wd,
43                eps,
44                amsgrad,
45            } => {
46                let opt = AdamW {
47                    beta1: *beta1,
48                    beta2: *beta2,
49                    wd: *wd,
50                    eps: *eps,
51                    amsgrad: *amsgrad,
52                }
53                .build(vs, *lr)?;
54                Ok(Optimizer::AdamW(opt))
55            }
56        }
57    }
58}
59
60/// Optimizers.
61///
62/// This is a thin wrapper of [tch::nn::Optimizer].
63///
64/// [tch::nn::Optimizer]: https://docs.rs/tch/0.16.0/tch/nn/struct.Optimizer.html
65pub enum Optimizer {
66    /// Adam optimizer.
67    Adam(Optimizer_),
68
69    AdamW(Optimizer_),
70}
71
72impl Optimizer {
73    /// Applies a backward step pass.
74    pub fn backward_step(&mut self, loss: &Tensor) {
75        match self {
76            Self::Adam(opt) => {
77                opt.backward_step(loss);
78            }
79            Self::AdamW(opt) => {
80                opt.backward_step(loss);
81            }
82        }
83    }
84}