1use super::{SimpleOptimizer, record::AdaptorRecord};
2use crate::{
3 LearningRate,
4 grad_clipping::GradientClipping,
5 module::{AutodiffModule, ModuleMapper, ParamId},
6 optim::{GradientsParams, Optimizer},
7};
8use burn_tensor::{Tensor, backend::AutodiffBackend};
9use core::marker::PhantomData;
10use hashbrown::HashMap;
11
12#[derive(Clone)]
15pub struct OptimizerAdaptor<O, M, B>
16where
17 O: SimpleOptimizer<B::InnerBackend>,
18 M: AutodiffModule<B>,
19 B: AutodiffBackend,
20{
21 optim: O,
22 records: HashMap<ParamId, AdaptorRecord<O, B>>,
23 module: PhantomData<M>,
24 grad_clipping: Option<GradientClipping>,
25}
26
27impl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>
28where
29 B: AutodiffBackend,
30 M: AutodiffModule<B>,
31 O: SimpleOptimizer<B::InnerBackend>,
32{
33 fn from(optim: O) -> Self {
34 Self {
35 optim,
36 records: HashMap::new(),
37 module: PhantomData,
38 grad_clipping: None,
39 }
40 }
41}
42
43impl<O, M, B> OptimizerAdaptor<O, M, B>
44where
45 O: SimpleOptimizer<B::InnerBackend>,
46 M: AutodiffModule<B>,
47 B: AutodiffBackend,
48{
49 pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
59 self.grad_clipping = Some(gradient_clipping);
60 self
61 }
62
63 #[cfg(test)]
64 pub(crate) fn has_gradient_clipping(&self) -> bool {
65 self.grad_clipping.is_some()
66 }
67}
68
69impl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>
70where
71 B: AutodiffBackend,
72 M: AutodiffModule<B>,
73 O: SimpleOptimizer<B::InnerBackend>,
74{
75 type Record = HashMap<ParamId, AdaptorRecord<O, B>>;
76
77 fn step(&mut self, lr: LearningRate, module: M, mut grads: GradientsParams) -> M {
78 let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(
79 &self.optim,
80 &mut self.records,
81 &mut grads,
82 lr,
83 self.grad_clipping.as_ref(),
84 );
85 module.map(&mut mapper)
86 }
87
88 fn to_record(&self) -> Self::Record {
89 self.records.clone()
90 }
91
92 fn load_record(mut self, record: Self::Record) -> Self {
93 self.records = record;
94 self
95 }
96}
97
98#[derive(new)]
99struct SimpleOptimizerMapper<'a, M, B, O>
100where
101 M: AutodiffModule<B>,
102 B: AutodiffBackend,
103 O: SimpleOptimizer<B::InnerBackend>,
104{
105 optimizer: &'a O,
106 records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,
107 grads: &'a mut GradientsParams,
108 lr: LearningRate,
109 phantom: PhantomData<M>,
110 grad_clipping: Option<&'a GradientClipping>,
111}
112
113impl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>
114where
115 M: AutodiffModule<B>,
116 B: AutodiffBackend,
117 O: SimpleOptimizer<B::InnerBackend>,
118{
119 fn map_float<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
120 let grad = self.grads.remove(id);
121
122 if let Some(grad) = grad {
123 let device = grad.device();
124 let is_require_grad = tensor.is_require_grad();
125 let (key, record) = self.records.remove_entry(&id).unzip();
126
127 let clipped_grad = if let Some(g_clipping) = self.grad_clipping {
128 g_clipping.clip_gradient(grad)
129 } else {
130 grad
131 };
132
133 let (tensor, state) = self.optimizer.step(
134 self.lr,
135 tensor.inner(),
136 clipped_grad,
137 record.map(|record| O::to_device(record.into_state(), &device)),
138 );
139
140 if let Some(state) = state {
141 self.records
142 .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));
143 }
144
145 let mut tensor = Tensor::from_inner(tensor);
146 if is_require_grad {
147 tensor = tensor.require_grad();
148 }
149 return tensor;
150 }
151
152 tensor
153 }
154}