1#[cfg(feature = "distributed")]
2use burn_core::tensor::backend::distributed::DistributedParamId;
3use burn_core::{self as burn, prelude::Backend, tensor::Device};
4
5use super::{SimpleOptimizer, record::AdaptorRecord};
6use crate::{
7 LearningRate, MultiGradientsParams,
8 grad_clipping::GradientClipping,
9 optim::{GradientsParams, Optimizer},
10};
11
12use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
13use burn::tensor::{Tensor, backend::AutodiffBackend};
14use core::marker::PhantomData;
15use hashbrown::HashMap;
16
17#[derive(Clone)]
20pub struct OptimizerAdaptor<O, M, B>
21where
22 O: SimpleOptimizer<B::InnerBackend>,
23 M: AutodiffModule<B>,
24 B: AutodiffBackend,
25{
26 optim: O,
27 records: HashMap<ParamId, AdaptorRecord<O, B>>,
28 module: PhantomData<M>,
29 grad_clipping: Option<GradientClipping>,
30}
31
32impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
33where
34 B: AutodiffBackend,
35 M: AutodiffModule<B>,
36 O: SimpleOptimizer<B::InnerBackend>,
37{
38 fn from(optim: O) -> Self {
39 Self {
40 optim,
41 records: HashMap::new(),
42 module: PhantomData,
43 grad_clipping: None,
44 }
45 }
46}
47
48impl<O, M, B> OptimizerAdaptor<O, M, B>
49where
50 O: SimpleOptimizer<B::InnerBackend>,
51 M: AutodiffModule<B>,
52 B: AutodiffBackend,
53{
54 pub fn optim(&self) -> &O {
56 &self.optim
57 }
58
59 pub fn has_gradient_clipping(&self) -> bool {
61 self.grad_clipping.is_some()
62 }
63
64 pub fn grad_clipping(&self) -> Option<&GradientClipping> {
66 self.grad_clipping.as_ref()
67 }
68
69 pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
79 self.grad_clipping = Some(gradient_clipping);
80 self
81 }
82
83 fn step_common(&mut self, lr: LearningRate, module: M, mut grads: GradAdaptor) -> M {
84 module.map(&mut SimpleOptimizerMapper::<B, O>::new(
85 &self.optim,
86 &mut self.records,
87 &mut grads,
88 lr,
89 self.grad_clipping.as_ref(),
90 ))
91 }
92}
93
94impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
95where
96 B: AutodiffBackend,
97 M: AutodiffModule<B>,
98 O: SimpleOptimizer<B::InnerBackend>,
99{
100 type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
101
102 fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {
103 self.step_common(lr, module, grads.into())
104 }
105
106 fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M {
107 self.step_common(lr, module, grads.into())
108 }
109
110 fn to_record(&self) -> Self::Record {
111 self.records.clone()
112 }
113
114 fn load_record(mut self, record: Self::Record) -> Self {
115 self.records = record;
116 self
117 }
118}
119
120pub enum GradAdaptor {
122 Single(GradientsParams),
124
125 Multi(MultiGradientsParams),
127}
128
129impl From<GradientsParams> for GradAdaptor {
130 fn from(grads: GradientsParams) -> Self {
131 Self::Single(grads)
132 }
133}
134
135impl From<MultiGradientsParams> for GradAdaptor {
136 fn from(grads: MultiGradientsParams) -> Self {
137 Self::Multi(grads)
138 }
139}
140
141impl GradAdaptor {
142 pub fn remove<B: Backend, const D: usize>(
147 &mut self,
148 id: ParamId,
149 ) -> Option<(Tensor<B, D>, Device<B>)> {
150 match self {
151 GradAdaptor::Single(grads) => grads.remove(id).map(|t| {
152 let device = t.device();
153 (t, device)
154 }),
155 GradAdaptor::Multi(grads) => grads.remove(id),
156 }
157 }
158}
159
160#[derive(new)]
161struct SimpleOptimizerMapper<'a, B, O>
162where
163 B: AutodiffBackend,
164 O: SimpleOptimizer<B::InnerBackend>,
165{
166 optimizer: &'a O,
167 records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
168 grads: &'a mut GradAdaptor,
169 lr: LearningRate,
170 grad_clipping: Option<&'a GradientClipping>,
171}
172
173impl<B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, B, O>
174where
175 B: AutodiffBackend,
176 O: SimpleOptimizer<B::InnerBackend>,
177{
178 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
179 let (id, tensor, mapper) = param.consume();
180 let grad = self.grads.remove(id);
181
182 let tensor = if let Some((grad, device)) = grad {
183 let is_require_grad = tensor.is_require_grad();
184 #[cfg(feature = "distributed")]
185 let is_distributed = tensor.is_distributed();
186
187 let (key, record) = self.records.remove_entry(&id).unzip();
188 let tensor = if tensor.device() != device {
189 tensor.to_device(&device)
190 } else {
191 tensor
192 };
193
194 debug_assert_eq!(
195 grad.device(),
196 device,
197 "The gradient is on the provided device"
198 );
199 let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
200 g_clipping.clip_gradient(grad)
201 } else {
202 grad
203 };
204
205 debug_assert_eq!(
206 tensor.device(),
207 device,
208 "Tensor and gradients are on the same device."
209 );
210
211 let (tensor, state) = self.optimizer.step(
212 self.lr,
213 tensor.inner(),
214 clipped_grad,
215 record.map(|record| O::to_device(record.into_state(), &device)),
216 );
217
218 if let Some(state) = state {
219 self.records
220 .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
221 }
222
223 let mut tensor = Tensor::from_inner(tensor);
224 if is_require_grad {
225 tensor = tensor.require_grad();
226 }
227 #[cfg(feature = "distributed")]
228 if is_distributed {
229 tensor = tensor.set_distributed(DistributedParamId::from(id.val()))
230 }
231
232 tensor
233 } else {
234 tensor
235 };
236
237 Param::from_mapped_value(id, tensor, mapper)
238 }
239}