1use burn_core as burn;
2
3use super::{SimpleOptimizer, record::AdaptorRecord};
4use crate::{
5 LearningRate,
6 grad_clipping::GradientClipping,
7 optim::{GradientsParams, Optimizer},
8};
9
10use burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};
11use burn::tensor::{Tensor, backend::AutodiffBackend};
12use core::marker::PhantomData;
13use hashbrown::HashMap;
14
15#[derive(Clone)]
18pub struct OptimizerAdaptor<O, M, B>
19where
20 O: SimpleOptimizer<B::InnerBackend>,
21 M: AutodiffModule<B>,
22 B: AutodiffBackend,
23{
24 optim: O,
25 records: HashMap<ParamId, AdaptorRecord<O, B>>,
26 module: PhantomData<M>,
27 grad_clipping: Option<GradientClipping>,
28}
29
30impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
31where
32 B: AutodiffBackend,
33 M: AutodiffModule<B>,
34 O: SimpleOptimizer<B::InnerBackend>,
35{
36 fn from(optim: O) -> Self {
37 Self {
38 optim,
39 records: HashMap::new(),
40 module: PhantomData,
41 grad_clipping: None,
42 }
43 }
44}
45
46impl<O, M, B> OptimizerAdaptor<O, M, B>
47where
48 O: SimpleOptimizer<B::InnerBackend>,
49 M: AutodiffModule<B>,
50 B: AutodiffBackend,
51{
52 pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
62 self.grad_clipping = Some(gradient_clipping);
63 self
64 }
65
66 #[cfg(test)]
67 pub(crate) fn has_gradient_clipping(&self) -> bool {
68 self.grad_clipping.is_some()
69 }
70}
71
72impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
73where
74 B: AutodiffBackend,
75 M: AutodiffModule<B>,
76 O: SimpleOptimizer<B::InnerBackend>,
77{
78 type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
79
80 fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M {
81 let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
82 &self.optim,
83 &mut self.records,
84 &mut grads,
85 lr,
86 self.grad_clipping.as_ref(),
87 );
88 module.map(&mut mapper)
89 }
90
91 fn to_record(&self) -> Self::Record {
92 self.records.clone()
93 }
94
95 fn load_record(mut self, record: Self::Record) -> Self {
96 self.records = record;
97 self
98 }
99}
100
101#[derive(new)]
102struct SimpleOptimizerMapper<'a, M, B, O>
103where
104 M: AutodiffModule<B>,
105 B: AutodiffBackend,
106 O: SimpleOptimizer<B::InnerBackend>,
107{
108 optimizer: &'a O,
109 records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
110 grads: &'a mut GradientsParams,
111 lr: LearningRate,
112 phantom: PhantomData<M>,
113 grad_clipping: Option<&'a GradientClipping>,
114}
115
116impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
117where
118 M: AutodiffModule<B>,
119 B: AutodiffBackend,
120 O: SimpleOptimizer<B::InnerBackend>,
121{
122 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
123 let (id, tensor, mapper) = param.consume();
124 let grad = self.grads.remove(id);
125
126 let tensor = if let Some(grad) = grad {
127 let device = grad.device();
128 let is_require_grad = tensor.is_require_grad();
129 let (key, record) = self.records.remove_entry(&id).unzip();
130
131 let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
132 g_clipping.clip_gradient(grad)
133 } else {
134 grad
135 };
136
137 let (tensor, state) = self.optimizer.step(
138 self.lr,
139 tensor.inner(),
140 clipped_grad,
141 record.map(|record| O::to_device(record.into_state(), &device)),
142 );
143
144 if let Some(state) = state {
145 self.records
146 .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
147 }
148
149 let mut tensor = Tensor::from_inner(tensor);
150 if is_require_grad {
151 tensor = tensor.require_grad();
152 }
153 tensor
154 } else {
155 tensor
156 };
157
158 Param::from_mapped_value(id, tensor, mapper)
159 }
160}