burn_core/optim/simple/
adaptor.rs

1use super::{SimpleOptimizer, record::AdaptorRecord};
2use crate::{
3    LearningRate,
4    grad_clipping::GradientClipping,
5    module::{AutodiffModule, ModuleMapper, ParamId},
6    optim::{GradientsParams, Optimizer},
7};
8use burn_tensor::{Tensor, backend::AutodiffBackend};
9use core::marker::PhantomData;
10use hashbrown::HashMap;
11
12/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
13/// an [optimizer](Optimizer).
14#[derive(Clone)]
15pub struct OptimizerAdaptor<O, M, B>
16where
17    O: SimpleOptimizer<B::InnerBackend>,
18    M: AutodiffModule<B>,
19    B: AutodiffBackend,
20{
21    optim: O,
22    records: HashMap<ParamId, AdaptorRecord<O, B>>,
23    module: PhantomData<M>,
24    grad_clipping: Option<GradientClipping>,
25}
26
27impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
28where
29    B: AutodiffBackend,
30    M: AutodiffModule<B>,
31    O: SimpleOptimizer<B::InnerBackend>,
32{
33    fn from(optim: O) -> Self {
34        Self {
35            optim,
36            records: HashMap::new(),
37            module: PhantomData,
38            grad_clipping: None,
39        }
40    }
41}
42
43impl<O, M, B> OptimizerAdaptor<O, M, B>
44where
45    O: SimpleOptimizer<B::InnerBackend>,
46    M: AutodiffModule<B>,
47    B: AutodiffBackend,
48{
49    /// Sets the gradient clipping.
50    ///
51    /// # Arguments
52    ///
53    /// * `gradient_clipping` - The gradient clipping.
54    ///
55    /// # Returns
56    ///
57    /// The optimizer.
58    pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
59        self.grad_clipping = Some(gradient_clipping);
60        self
61    }
62
63    #[cfg(test)]
64    pub(crate) fn has_gradient_clipping(&self) -> bool {
65        self.grad_clipping.is_some()
66    }
67}
68
69impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
70where
71    B: AutodiffBackend,
72    M: AutodiffModule<B>,
73    O: SimpleOptimizer<B::InnerBackend>,
74{
75    type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
76
77    fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M {
78        let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
79            &self.optim,
80            &mut self.records,
81            &mut grads,
82            lr,
83            self.grad_clipping.as_ref(),
84        );
85        module.map(&mut mapper)
86    }
87
88    fn to_record(&self) -> Self::Record {
89        self.records.clone()
90    }
91
92    fn load_record(mut self, record: Self::Record) -> Self {
93        self.records = record;
94        self
95    }
96}
97
98#[derive(new)]
99struct SimpleOptimizerMapper<'a, M, B, O>
100where
101    M: AutodiffModule<B>,
102    B: AutodiffBackend,
103    O: SimpleOptimizer<B::InnerBackend>,
104{
105    optimizer: &'a O,
106    records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
107    grads: &'a mut GradientsParams,
108    lr: LearningRate,
109    phantom: PhantomData<M>,
110    grad_clipping: Option<&'a GradientClipping>,
111}
112
113impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
114where
115    M: AutodiffModule<B>,
116    B: AutodiffBackend,
117    O: SimpleOptimizer<B::InnerBackend>,
118{
119    fn map_float<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
120        let grad = self.grads.remove(id);
121
122        if let Some(grad) = grad {
123            let device = grad.device();
124            let is_require_grad = tensor.is_require_grad();
125            let (key, record) = self.records.remove_entry(&id).unzip();
126
127            let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
128                g_clipping.clip_gradient(grad)
129            } else {
130                grad
131            };
132
133            let (tensor, state) = self.optimizer.step(
134                self.lr,
135                tensor.inner(),
136                clipped_grad,
137                record.map(|record| O::to_device(record.into_state(), &device)),
138            );
139
140            if let Some(state) = state {
141                self.records
142                    .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
143            }
144
145            let mut tensor = Tensor::from_inner(tensor);
146            if is_require_grad {
147                tensor = tensor.require_grad();
148            }
149            return tensor;
150        }
151
152        tensor
153    }
154}