1use burn_core::{self as burn, Tensor};
2
3use burn_core::module::ParamId;
4use burn_core::prelude::{Backend, DeviceOps};
5use burn_core::tensor::Device;
6use burn_core::tensor::backend::DeviceId;
7
8use super::GradientsParams;
9use crate::LearningRate;
10use alloc::vec::Vec;
11use burn::module::AutodiffModule;
12use burn::record::Record;
13use burn::tensor::backend::AutodiffBackend;
14
15#[derive(Default)]
16pub struct MultiGradientsParams {
18 pub grads: Vec<(GradientsParams, DeviceId)>,
20}
21
22impl MultiGradientsParams {
23 pub fn remove<B: Backend, const D: usize>(
29 &mut self,
30 id: ParamId,
31 ) -> Option<(Tensor<B, D>, Device<B>)> {
32 let (mut tensor, device, index) = self.select(id)?;
33
34 for (i, (grads, _)) in self.grads.iter_mut().enumerate() {
35 if i == index {
36 continue;
37 }
38
39 if let Some(grad) = grads.remove::<B, D>(id) {
40 tensor = tensor + grad.to_device(&device);
41 }
42 }
43
44 Some((tensor, device))
45 }
46
47 fn select<B: Backend, const D: usize>(
48 &mut self,
49 id: ParamId,
50 ) -> Option<(Tensor<B, D>, Device<B>, usize)> {
51 let id_val = id.val() as usize;
52 for i in 0..self.grads.len() {
53 let selected_device_index = (id_val + i) % self.grads.len();
54
55 if let Some(acc) = self.grads[selected_device_index].0.remove::<B, D>(id) {
56 let device_id = self.grads[selected_device_index].1;
57 let device = <B::Device as DeviceOps>::from_id(device_id);
58 return Some((acc.to_device(&device), device, selected_device_index));
59 }
60 }
61
62 None
63 }
64}
65
66pub trait Optimizer<M, B>: Send + Clone
68where
69 M: AutodiffModule<B>,
70 B: AutodiffBackend,
71{
72 type Record: Record<B>;
74
75 fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
78
79 fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M;
82
83 fn to_record(&self) -> Self::Record;
85
86 fn load_record(self, record: Self::Record) -> Self;
88}