burn_optim/optim/simple/record/
v1.rs

1use burn_core as burn;
2
3use crate::optim::SimpleOptimizer;
4use burn::record::{PrecisionSettings, Record};
5use burn::tensor::backend::Backend;
6use core::any::Any;
7use serde::{Deserialize, Serialize};
8
9#[cfg(not(feature = "std"))]
10use alloc::boxed::Box;
11
12/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
13pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
14    /// Rank 0.
15    Rank0(O::State<0>),
16
17    /// Rank 1.
18    Rank1(O::State<1>),
19
20    /// Rank 2.
21    Rank2(O::State<2>),
22
23    /// Rank 3.
24    Rank3(O::State<3>),
25
26    /// Rank 4.
27    Rank4(O::State<4>),
28
29    /// Rank 5.
30    Rank5(O::State<5>),
31
32    /// Rank 6.
33    Rank6(O::State<6>),
34
35    /// Rank 7.
36    Rank7(O::State<7>),
37
38    /// Rank 8.
39    Rank8(O::State<8>),
40}
41
42impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
43    fn clone(&self) -> Self {
44        match self {
45            AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),
46            AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),
47            AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),
48            AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),
49            AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()),
50            AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()),
51            AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()),
52            AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()),
53            AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()),
54        }
55    }
56}
57
58/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
59#[derive(Serialize, Deserialize)]
60#[serde(bound = "")]
61pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
62    /// Rank 0.
63    Rank0(<O::State<0> as Record<B>>::Item<S>),
64
65    /// Rank 1.
66    Rank1(<O::State<1> as Record<B>>::Item<S>),
67
68    /// Rank 2.
69    Rank2(<O::State<2> as Record<B>>::Item<S>),
70
71    /// Rank 3.
72    Rank3(<O::State<3> as Record<B>>::Item<S>),
73
74    /// Rank 4.
75    Rank4(<O::State<4> as Record<B>>::Item<S>),
76
77    /// Rank 5.
78    Rank5(<O::State<5> as Record<B>>::Item<S>),
79
80    /// Rank 6.
81    Rank6(<O::State<6> as Record<B>>::Item<S>),
82
83    /// Rank 7.
84    Rank7(<O::State<7> as Record<B>>::Item<S>),
85
86    /// Rank 8.
87    Rank8(<O::State<8> as Record<B>>::Item<S>),
88}
89
90impl<O, B> AdaptorRecordV1<O, B>
91where
92    O: SimpleOptimizer<B>,
93    B: Backend,
94{
95    /// Convert the record into the state.
96    ///
97    /// # Returns
98    ///
99    /// The state.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the state dimension is not supported.
104    pub fn into_state<const D: usize>(self) -> O::State<D> {
105        let boxed_state: Box<dyn Any> = match self {
106            AdaptorRecordV1::Rank0(s) => Box::new(s),
107            AdaptorRecordV1::Rank1(s) => Box::new(s),
108            AdaptorRecordV1::Rank2(s) => Box::new(s),
109            AdaptorRecordV1::Rank3(s) => Box::new(s),
110            AdaptorRecordV1::Rank4(s) => Box::new(s),
111            AdaptorRecordV1::Rank5(s) => Box::new(s),
112            AdaptorRecordV1::Rank6(s) => Box::new(s),
113            AdaptorRecordV1::Rank7(s) => Box::new(s),
114            AdaptorRecordV1::Rank8(s) => Box::new(s),
115        };
116        let state = boxed_state
117            .downcast::<O::State<D>>()
118            .expect("Unsupported state dimension, dimension up to 8 are supported.");
119        *state
120    }
121
122    /// Convert the state into the record.
123    ///
124    /// # Arguments
125    ///
126    /// * `state`: The state.
127    ///
128    /// # Returns
129    ///
130    /// The record.
131    pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
132        let state: Box<dyn Any> = Box::new(state);
133
134        match D {
135            0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),
136            1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),
137            2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),
138            3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),
139            4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()),
140            5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()),
141            6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()),
142            7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()),
143            8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()),
144            _ => panic!("Unsupported state dimension, dimension up to 8 are supported."),
145        }
146    }
147}
148
149impl<O, B> Record<B> for AdaptorRecordV1<O, B>
150where
151    O: SimpleOptimizer<B>,
152    B: Backend,
153{
154    type Item<S: PrecisionSettings> = AdaptorRecordItemV1<O, B, S>;
155
156    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
157        match self {
158            AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),
159            AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),
160            AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),
161            AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),
162            AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()),
163            AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()),
164            AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()),
165            AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()),
166            AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()),
167        }
168    }
169
170    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
171        match item {
172            AdaptorRecordItemV1::Rank0(item) => {
173                AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))
174            }
175            AdaptorRecordItemV1::Rank1(item) => {
176                AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
177            }
178            AdaptorRecordItemV1::Rank2(item) => {
179                AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))
180            }
181            AdaptorRecordItemV1::Rank3(item) => {
182                AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))
183            }
184            AdaptorRecordItemV1::Rank4(item) => {
185                AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))
186            }
187            AdaptorRecordItemV1::Rank5(item) => {
188                AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))
189            }
190            AdaptorRecordItemV1::Rank6(item) => {
191                AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))
192            }
193            AdaptorRecordItemV1::Rank7(item) => {
194                AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))
195            }
196            AdaptorRecordItemV1::Rank8(item) => {
197                AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))
198            }
199        }
200    }
201}