Skip to main content

burn_optim/optim/simple/
adaptor.rs

1#[cfg(feature = "distributed")]
2use burn_core::tensor::backend::distributed::DistributedParamId;
3use burn_core::{self as burn, prelude::Backend, tensor::Device};
4
5use super::{SimpleOptimizer, record::AdaptorRecord};
6use crate::{
7    LearningRate, MultiGradientsParams,
8    grad_clipping::GradientClipping,
9    optim::{GradientsParams, Optimizer},
10};
11
12use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
13use burn::tensor::{Tensor, backend::AutodiffBackend};
14use core::marker::PhantomData;
15use hashbrown::HashMap;
16
17/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
18/// an [optimizer](Optimizer).
19#[derive(Clone)]
20pub struct OptimizerAdaptor<O, M, B>
21where
22    O: SimpleOptimizer<B::InnerBackend>,
23    M: AutodiffModule<B>,
24    B: AutodiffBackend,
25{
26    optim: O,
27    records: HashMap<ParamId, AdaptorRecord<O, B>>,
28    module: PhantomData<M>,
29    grad_clipping: Option<GradientClipping>,
30}
31
32impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
33where
34    B: AutodiffBackend,
35    M: AutodiffModule<B>,
36    O: SimpleOptimizer<B::InnerBackend>,
37{
38    fn from(optim: O) -> Self {
39        Self {
40            optim,
41            records: HashMap::new(),
42            module: PhantomData,
43            grad_clipping: None,
44        }
45    }
46}
47
48impl<O, M, B> OptimizerAdaptor<O, M, B>
49where
50    O: SimpleOptimizer<B::InnerBackend>,
51    M: AutodiffModule<B>,
52    B: AutodiffBackend,
53{
54    /// Access the wrapped [`SimpleOptimizer`].
55    pub fn optim(&self) -> &O {
56        &self.optim
57    }
58
59    /// Check if the optimizer has gradient clipping.
60    pub fn has_gradient_clipping(&self) -> bool {
61        self.grad_clipping.is_some()
62    }
63
64    /// Access the gradient clipping.
65    pub fn grad_clipping(&self) -> Option<&GradientClipping> {
66        self.grad_clipping.as_ref()
67    }
68
69    /// Sets the gradient clipping.
70    ///
71    /// # Arguments
72    ///
73    /// * `gradient_clipping` - The gradient clipping.
74    ///
75    /// # Returns
76    ///
77    /// The optimizer.
78    pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
79        self.grad_clipping = Some(gradient_clipping);
80        self
81    }
82
83    fn step_common(&mut self, lr: LearningRate, module: M, mut grads: GradAdaptor) -> M {
84        module.map(&mut SimpleOptimizerMapper::<B, O>::new(
85            &self.optim,
86            &mut self.records,
87            &mut grads,
88            lr,
89            self.grad_clipping.as_ref(),
90        ))
91    }
92}
93
94impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
95where
96    B: AutodiffBackend,
97    M: AutodiffModule<B>,
98    O: SimpleOptimizer<B::InnerBackend>,
99{
100    type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
101
102    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {
103        self.step_common(lr, module, grads.into())
104    }
105
106    fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M {
107        self.step_common(lr, module, grads.into())
108    }
109
110    fn to_record(&self) -> Self::Record {
111        self.records.clone()
112    }
113
114    fn load_record(mut self, record: Self::Record) -> Self {
115        self.records = record;
116        self
117    }
118}
119
120/// Wrapper to unify the `remove` method for [GradientsParams] and [MultiGradientsParams].
121pub enum GradAdaptor {
122    /// Wrapper for [`GradientsParams`].
123    Single(GradientsParams),
124
125    /// Wrapper for [`MultiGradientsParams`].
126    Multi(MultiGradientsParams),
127}
128
129impl From<GradientsParams> for GradAdaptor {
130    fn from(grads: GradientsParams) -> Self {
131        Self::Single(grads)
132    }
133}
134
135impl From<MultiGradientsParams> for GradAdaptor {
136    fn from(grads: MultiGradientsParams) -> Self {
137        Self::Multi(grads)
138    }
139}
140
141impl GradAdaptor {
142    /// Remove a gradient parameter by ID.
143    ///
144    /// # Returns
145    /// Maybe the (tensor, device) pair.
146    pub fn remove<B: Backend, const D: usize>(
147        &mut self,
148        id: ParamId,
149    ) -> Option<(Tensor<B, D>, Device<B>)> {
150        match self {
151            GradAdaptor::Single(grads) => grads.remove(id).map(|t| {
152                let device = t.device();
153                (t, device)
154            }),
155            GradAdaptor::Multi(grads) => grads.remove(id),
156        }
157    }
158}
159
160#[derive(new)]
161struct SimpleOptimizerMapper<'a, B, O>
162where
163    B: AutodiffBackend,
164    O: SimpleOptimizer<B::InnerBackend>,
165{
166    optimizer: &'a O,
167    records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
168    grads: &'a mut GradAdaptor,
169    lr: LearningRate,
170    grad_clipping: Option<&'a GradientClipping>,
171}
172
173impl<B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, B, O>
174where
175    B: AutodiffBackend,
176    O: SimpleOptimizer<B::InnerBackend>,
177{
178    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
179        let (id, tensor, mapper) = param.consume();
180        let grad = self.grads.remove(id);
181
182        let tensor = if let Some((grad, device)) = grad {
183            let is_require_grad = tensor.is_require_grad();
184            #[cfg(feature = "distributed")]
185            let is_distributed = tensor.is_distributed();
186
187            let (key, record) = self.records.remove_entry(&id).unzip();
188            let tensor = if tensor.device() != device {
189                tensor.to_device(&device)
190            } else {
191                tensor
192            };
193
194            debug_assert_eq!(
195                grad.device(),
196                device,
197                "The gradient is on the provided device"
198            );
199            let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
200                g_clipping.clip_gradient(grad)
201            } else {
202                grad
203            };
204
205            debug_assert_eq!(
206                tensor.device(),
207                device,
208                "Tensor and gradients are on the same device."
209            );
210
211            let (tensor, state) = self.optimizer.step(
212                self.lr,
213                tensor.inner(),
214                clipped_grad,
215                record.map(|record| O::to_device(record.into_state(), &device)),
216            );
217
218            if let Some(state) = state {
219                self.records
220                    .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
221            }
222
223            let mut tensor = Tensor::from_inner(tensor);
224            if is_require_grad {
225                tensor = tensor.require_grad();
226            }
227            #[cfg(feature = "distributed")]
228            if is_distributed {
229                tensor = tensor.set_distributed(DistributedParamId::from(id.val()))
230            }
231
232            tensor
233        } else {
234            tensor
235        };
236
237        Param::from_mapped_value(id, tensor, mapper)
238    }
239}