1use burn_core::{self as burn, prelude::Backend, tensor::Device};
2
3use super::{SimpleOptimizer, record::AdaptorRecord};
4use crate::{
5 LearningRate, MultiGradientsParams,
6 grad_clipping::GradientClipping,
7 optim::{GradientsParams, Optimizer},
8};
9
10use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
11use burn::tensor::{Tensor, backend::AutodiffBackend};
12use core::marker::PhantomData;
13use hashbrown::HashMap;
14
15#[derive(Clone)]
18pub struct OptimizerAdaptor<O, M, B>
19where
20 O: SimpleOptimizer<B::InnerBackend>,
21 M: AutodiffModule<B>,
22 B: AutodiffBackend,
23{
24 optim: O,
25 records: HashMap<ParamId, AdaptorRecord<O, B>>,
26 module: PhantomData<M>,
27 grad_clipping: Option<GradientClipping>,
28}
29
30impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
31where
32 B: AutodiffBackend,
33 M: AutodiffModule<B>,
34 O: SimpleOptimizer<B::InnerBackend>,
35{
36 fn from(optim: O) -> Self {
37 Self {
38 optim,
39 records: HashMap::new(),
40 module: PhantomData,
41 grad_clipping: None,
42 }
43 }
44}
45
46impl<O, M, B> OptimizerAdaptor<O, M, B>
47where
48 O: SimpleOptimizer<B::InnerBackend>,
49 M: AutodiffModule<B>,
50 B: AutodiffBackend,
51{
52 pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
62 self.grad_clipping = Some(gradient_clipping);
63 self
64 }
65
66 #[cfg(test)]
67 pub(crate) fn has_gradient_clipping(&self) -> bool {
68 self.grad_clipping.is_some()
69 }
70}
71
72impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
73where
74 B: AutodiffBackend,
75 M: AutodiffModule<B>,
76 O: SimpleOptimizer<B::InnerBackend>,
77{
78 type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
79
80 fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {
81 let mut grads = GradAdaptor::Single(grads);
82
83 let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
84 &self.optim,
85 &mut self.records,
86 &mut grads,
87 lr,
88 self.grad_clipping.as_ref(),
89 );
90 module.map(&mut mapper)
91 }
92
93 fn step_multi(&mut self, lr: LearningRate, module: M, grads: crate::MultiGradientsParams) -> M {
94 let mut grads = GradAdaptor::Multi(grads);
95
96 let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
97 &self.optim,
98 &mut self.records,
99 &mut grads,
100 lr,
101 self.grad_clipping.as_ref(),
102 );
103 module.map(&mut mapper)
104 }
105
106 fn to_record(&self) -> Self::Record {
107 self.records.clone()
108 }
109
110 fn load_record(mut self, record: Self::Record) -> Self {
111 self.records = record;
112 self
113 }
114}
115
116enum GradAdaptor {
117 Single(GradientsParams),
118 Multi(MultiGradientsParams),
119}
120
121impl GradAdaptor {
122 fn remove<B: Backend, const D: usize>(
123 &mut self,
124 id: ParamId,
125 ) -> Option<(Tensor<B, D>, Device<B>)> {
126 match self {
127 GradAdaptor::Single(grads) => grads.remove(id).map(|t| {
128 let device = t.device();
129 (t, device)
130 }),
131 GradAdaptor::Multi(grads) => grads.remove(id),
132 }
133 }
134}
135
136#[derive(new)]
137struct SimpleOptimizerMapper<'a, M, B, O>
138where
139 M: AutodiffModule<B>,
140 B: AutodiffBackend,
141 O: SimpleOptimizer<B::InnerBackend>,
142{
143 optimizer: &'a O,
144 records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
145 grads: &'a mut GradAdaptor,
146 lr: LearningRate,
147 phantom: PhantomData<M>,
148 grad_clipping: Option<&'a GradientClipping>,
149}
150
151impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
152where
153 M: AutodiffModule<B>,
154 B: AutodiffBackend,
155 O: SimpleOptimizer<B::InnerBackend>,
156{
157 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
158 let (id, tensor, mapper) = param.consume();
159 let grad = self.grads.remove(id);
160
161 let tensor = if let Some((grad, device)) = grad {
162 let is_require_grad = tensor.is_require_grad();
163 let (key, record) = self.records.remove_entry(&id).unzip();
164 let tensor = if tensor.device() != device {
165 tensor.to_device(&device)
166 } else {
167 tensor
168 };
169
170 debug_assert_eq!(
171 grad.device(),
172 device,
173 "The gradient is on the provided device"
174 );
175 let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
176 g_clipping.clip_gradient(grad)
177 } else {
178 grad
179 };
180
181 debug_assert_eq!(
182 tensor.device(),
183 device,
184 "Tensor and gradients are on the same device."
185 );
186
187 let (tensor, state) = self.optimizer.step(
188 self.lr,
189 tensor.inner(),
190 clipped_grad,
191 record.map(|record| O::to_device(record.into_state(), &device)),
192 );
193
194 if let Some(state) = state {
195 self.records
196 .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
197 }
198
199 let mut tensor = Tensor::from_inner(tensor);
200 if is_require_grad {
201 tensor = tensor.require_grad();
202 }
203 tensor
204 } else {
205 tensor
206 };
207
208 Param::from_mapped_value(id, tensor, mapper)
209 }
210}