burn_optim/optim/simple/record/
v1.rs1use burn_core as burn;
2
3use crate::optim::SimpleOptimizer;
4use burn::record::{PrecisionSettings, Record};
5use burn::tensor::backend::Backend;
6use core::any::Any;
7use serde::{Deserialize, Serialize};
8
9#[cfg(not(feature = "std"))]
10use alloc::boxed::Box;
11
12pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
14 Rank0(O::State<0>),
16
17 Rank1(O::State<1>),
19
20 Rank2(O::State<2>),
22
23 Rank3(O::State<3>),
25
26 Rank4(O::State<4>),
28
29 Rank5(O::State<5>),
31
32 Rank6(O::State<6>),
34
35 Rank7(O::State<7>),
37
38 Rank8(O::State<8>),
40}
41
42impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
43 fn clone(&self) -> Self {
44 match self {
45 AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),
46 AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),
47 AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),
48 AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),
49 AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()),
50 AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()),
51 AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()),
52 AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()),
53 AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()),
54 }
55 }
56}
57
58#[derive(Serialize, Deserialize)]
60#[serde(bound = "")]
61pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
62 Rank0(<O::State<0> as Record<B>>::Item<S>),
64
65 Rank1(<O::State<1> as Record<B>>::Item<S>),
67
68 Rank2(<O::State<2> as Record<B>>::Item<S>),
70
71 Rank3(<O::State<3> as Record<B>>::Item<S>),
73
74 Rank4(<O::State<4> as Record<B>>::Item<S>),
76
77 Rank5(<O::State<5> as Record<B>>::Item<S>),
79
80 Rank6(<O::State<6> as Record<B>>::Item<S>),
82
83 Rank7(<O::State<7> as Record<B>>::Item<S>),
85
86 Rank8(<O::State<8> as Record<B>>::Item<S>),
88}
89
90impl<O, B> AdaptorRecordV1<O, B>
91where
92 O: SimpleOptimizer<B>,
93 B: Backend,
94{
95 pub fn into_state<const D: usize>(self) -> O::State<D> {
105 let boxed_state: Box<dyn Any> = match self {
106 AdaptorRecordV1::Rank0(s) => Box::new(s),
107 AdaptorRecordV1::Rank1(s) => Box::new(s),
108 AdaptorRecordV1::Rank2(s) => Box::new(s),
109 AdaptorRecordV1::Rank3(s) => Box::new(s),
110 AdaptorRecordV1::Rank4(s) => Box::new(s),
111 AdaptorRecordV1::Rank5(s) => Box::new(s),
112 AdaptorRecordV1::Rank6(s) => Box::new(s),
113 AdaptorRecordV1::Rank7(s) => Box::new(s),
114 AdaptorRecordV1::Rank8(s) => Box::new(s),
115 };
116 let state = boxed_state
117 .downcast::<O::State<D>>()
118 .expect("Unsupported state dimension, dimension up to 8 are supported.");
119 *state
120 }
121
122 pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
132 let state: Box<dyn Any> = Box::new(state);
133
134 match D {
135 0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),
136 1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),
137 2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),
138 3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),
139 4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()),
140 5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()),
141 6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()),
142 7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()),
143 8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()),
144 _ => panic!("Unsupported state dimension, dimension up to 8 are supported."),
145 }
146 }
147}
148
149impl<O, B> Record<B> for AdaptorRecordV1<O, B>
150where
151 O: SimpleOptimizer<B>,
152 B: Backend,
153{
154 type Item<S: PrecisionSettings> = AdaptorRecordItemV1<O, B, S>;
155
156 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
157 match self {
158 AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),
159 AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),
160 AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),
161 AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),
162 AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()),
163 AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()),
164 AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()),
165 AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()),
166 AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()),
167 }
168 }
169
170 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
171 match item {
172 AdaptorRecordItemV1::Rank0(item) => {
173 AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))
174 }
175 AdaptorRecordItemV1::Rank1(item) => {
176 AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))
177 }
178 AdaptorRecordItemV1::Rank2(item) => {
179 AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))
180 }
181 AdaptorRecordItemV1::Rank3(item) => {
182 AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))
183 }
184 AdaptorRecordItemV1::Rank4(item) => {
185 AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))
186 }
187 AdaptorRecordItemV1::Rank5(item) => {
188 AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))
189 }
190 AdaptorRecordItemV1::Rank6(item) => {
191 AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))
192 }
193 AdaptorRecordItemV1::Rank7(item) => {
194 AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))
195 }
196 AdaptorRecordItemV1::Rank8(item) => {
197 AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))
198 }
199 }
200 }
201}