burn_optim/optim/
base.rs

1use burn_core::{self as burn, Tensor};
2
3use burn_core::module::ParamId;
4use burn_core::prelude::{Backend, DeviceOps};
5use burn_core::tensor::Device;
6use burn_core::tensor::backend::DeviceId;
7
8use super::GradientsParams;
9use crate::LearningRate;
10use alloc::vec::Vec;
11use burn::module::AutodiffModule;
12use burn::record::Record;
13use burn::tensor::backend::AutodiffBackend;
14
15#[derive(Default)]
16/// Exposes multiple gradients for each parameter.
17pub struct MultiGradientsParams {
18    /// Each [GradientsParams] has its associated [DeviceId].
19    pub grads: Vec<(GradientsParams, DeviceId)>,
20}
21
22impl MultiGradientsParams {
23    /// Removes the gradients for the given [parameter id](ParamId).
24    ///
25    /// Potentially accumulates the gradients from multiple sources using a device associated with
26    /// a parameter id. The same parameter will be accumulated using the same device during
27    /// all training.
28    pub fn remove<B: Backend, const D: usize>(
29        &mut self,
30        id: ParamId,
31    ) -> Option<(Tensor<B, D>, Device<B>)> {
32        let (mut tensor, device, index) = self.select(id)?;
33
34        for (i, (grads, _)) in self.grads.iter_mut().enumerate() {
35            if i == index {
36                continue;
37            }
38
39            if let Some(grad) = grads.remove::<B, D>(id) {
40                tensor = tensor + grad.to_device(&device);
41            }
42        }
43
44        Some((tensor, device))
45    }
46
47    fn select<B: Backend, const D: usize>(
48        &mut self,
49        id: ParamId,
50    ) -> Option<(Tensor<B, D>, Device<B>, usize)> {
51        let id_val = id.val() as usize;
52        for i in 0..self.grads.len() {
53            let selected_device_index = (id_val + i) % self.grads.len();
54
55            if let Some(acc) = self.grads[selected_device_index].0.remove::<B, D>(id) {
56                let device_id = self.grads[selected_device_index].1;
57                let device = <B::Device as DeviceOps>::from_id(device_id);
58                return Some((acc.to_device(&device), device, selected_device_index));
59            }
60        }
61
62        None
63    }
64}
65
66/// General trait to optimize [module](AutodiffModule).
67pub trait Optimizer<M, B>: Send + Clone
68where
69    M: AutodiffModule<B>,
70    B: AutodiffBackend,
71{
72    /// Optimizer associative type to be used when saving and loading the state.
73    type Record: Record<B>;
74
75    /// Perform the optimizer step using the given learning rate and gradients.
76    /// The updated module is returned.
77    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;
78
79    /// Perform the optimizer step using the given learning rate and gradients.
80    /// The updated module is returned.
81    fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M;
82
83    /// Get the current state of the optimizer as a [record](Record).
84    fn to_record(&self) -> Self::Record;
85
86    /// Load the state of the optimizer as a [record](Record).
87    fn load_record(self, record: Self::Record) -> Self;
88}