burn_optim/optim/simple/
adaptor.rs

1use burn_core::{self as burn, prelude::Backend, tensor::Device};
2
3use super::{SimpleOptimizer, record::AdaptorRecord};
4use crate::{
5    LearningRate, MultiGradientsParams,
6    grad_clipping::GradientClipping,
7    optim::{GradientsParams, Optimizer},
8};
9
10use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
11use burn::tensor::{Tensor, backend::AutodiffBackend};
12use core::marker::PhantomData;
13use hashbrown::HashMap;
14
15/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
16/// an [optimizer](Optimizer).
17#[derive(Clone)]
18pub struct OptimizerAdaptor<O, M, B>
19where
20    O: SimpleOptimizer<B::InnerBackend>,
21    M: AutodiffModule<B>,
22    B: AutodiffBackend,
23{
24    optim: O,
25    records: HashMap<ParamId, AdaptorRecord<O, B>>,
26    module: PhantomData<M>,
27    grad_clipping: Option<GradientClipping>,
28}
29
30impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
31where
32    B: AutodiffBackend,
33    M: AutodiffModule<B>,
34    O: SimpleOptimizer<B::InnerBackend>,
35{
36    fn from(optim: O) -> Self {
37        Self {
38            optim,
39            records: HashMap::new(),
40            module: PhantomData,
41            grad_clipping: None,
42        }
43    }
44}
45
46impl<O, M, B> OptimizerAdaptor<O, M, B>
47where
48    O: SimpleOptimizer<B::InnerBackend>,
49    M: AutodiffModule<B>,
50    B: AutodiffBackend,
51{
52    /// Sets the gradient clipping.
53    ///
54    /// # Arguments
55    ///
56    /// * `gradient_clipping` - The gradient clipping.
57    ///
58    /// # Returns
59    ///
60    /// The optimizer.
61    pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
62        self.grad_clipping = Some(gradient_clipping);
63        self
64    }
65
66    #[cfg(test)]
67    pub(crate) fn has_gradient_clipping(&self) -> bool {
68        self.grad_clipping.is_some()
69    }
70}
71
72impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
73where
74    B: AutodiffBackend,
75    M: AutodiffModule<B>,
76    O: SimpleOptimizer<B::InnerBackend>,
77{
78    type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
79
80    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {
81        let mut grads = GradAdaptor::Single(grads);
82
83        let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
84            &self.optim,
85            &mut self.records,
86            &mut grads,
87            lr,
88            self.grad_clipping.as_ref(),
89        );
90        module.map(&mut mapper)
91    }
92
93    fn step_multi(&mut self, lr: LearningRate, module: M, grads: crate::MultiGradientsParams) -> M {
94        let mut grads = GradAdaptor::Multi(grads);
95
96        let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
97            &self.optim,
98            &mut self.records,
99            &mut grads,
100            lr,
101            self.grad_clipping.as_ref(),
102        );
103        module.map(&mut mapper)
104    }
105
106    fn to_record(&self) -> Self::Record {
107        self.records.clone()
108    }
109
110    fn load_record(mut self, record: Self::Record) -> Self {
111        self.records = record;
112        self
113    }
114}
115
116enum GradAdaptor {
117    Single(GradientsParams),
118    Multi(MultiGradientsParams),
119}
120
121impl GradAdaptor {
122    fn remove<B: Backend, const D: usize>(
123        &mut self,
124        id: ParamId,
125    ) -> Option<(Tensor<B, D>, Device<B>)> {
126        match self {
127            GradAdaptor::Single(grads) => grads.remove(id).map(|t| {
128                let device = t.device();
129                (t, device)
130            }),
131            GradAdaptor::Multi(grads) => grads.remove(id),
132        }
133    }
134}
135
136#[derive(new)]
137struct SimpleOptimizerMapper<'a, M, B, O>
138where
139    M: AutodiffModule<B>,
140    B: AutodiffBackend,
141    O: SimpleOptimizer<B::InnerBackend>,
142{
143    optimizer: &'a O,
144    records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
145    grads: &'a mut GradAdaptor,
146    lr: LearningRate,
147    phantom: PhantomData<M>,
148    grad_clipping: Option<&'a GradientClipping>,
149}
150
151impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
152where
153    M: AutodiffModule<B>,
154    B: AutodiffBackend,
155    O: SimpleOptimizer<B::InnerBackend>,
156{
157    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
158        let (id, tensor, mapper) = param.consume();
159        let grad = self.grads.remove(id);
160
161        let tensor = if let Some((grad, device)) = grad {
162            let is_require_grad = tensor.is_require_grad();
163            let (key, record) = self.records.remove_entry(&id).unzip();
164            let tensor = if tensor.device() != device {
165                tensor.to_device(&device)
166            } else {
167                tensor
168            };
169
170            debug_assert_eq!(
171                grad.device(),
172                device,
173                "The gradient is on the provided device"
174            );
175            let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
176                g_clipping.clip_gradient(grad)
177            } else {
178                grad
179            };
180
181            debug_assert_eq!(
182                tensor.device(),
183                device,
184                "Tensor and gradients are on the same device."
185            );
186
187            let (tensor, state) = self.optimizer.step(
188                self.lr,
189                tensor.inner(),
190                clipped_grad,
191                record.map(|record| O::to_device(record.into_state(), &device)),
192            );
193
194            if let Some(state) = state {
195                self.records
196                    .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
197            }
198
199            let mut tensor = Tensor::from_inner(tensor);
200            if is_require_grad {
201                tensor = tensor.require_grad();
202            }
203            tensor
204        } else {
205            tensor
206        };
207
208        Param::from_mapped_value(id, tensor, mapper)
209    }
210}