burn_optim/optim/simple/record/
base.rs1use burn_core as burn;
2
3use super::{AdaptorRecordItemV1, AdaptorRecordV1};
4use crate::optim::SimpleOptimizer;
5use burn::record::{PrecisionSettings, Record};
6use burn::tensor::backend::AutodiffBackend;
7use serde::{Deserialize, Serialize};
8
9pub enum AdaptorRecord<O, B>
13where
14 O: SimpleOptimizer<B::InnerBackend>,
15 B: AutodiffBackend,
16{
17 V1(AdaptorRecordV1<O, B::InnerBackend>),
19}
20
21#[derive(Serialize, Deserialize)]
23#[serde(bound = "")]
24pub enum AdaptorRecordItem<
25 O: SimpleOptimizer<B::InnerBackend>,
26 B: AutodiffBackend,
27 S: PrecisionSettings,
28> {
29 V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
31}
32
33impl<O, B> Record<B> for AdaptorRecord<O, B>
34where
35 O: SimpleOptimizer<B::InnerBackend>,
36 B: AutodiffBackend,
37{
38 type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;
39
40 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
41 match self {
42 AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()),
43 }
44 }
45
46 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
47 match item {
48 AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
49 }
50 }
51}
52
53impl<O, B> Clone for AdaptorRecord<O, B>
54where
55 O: SimpleOptimizer<B::InnerBackend>,
56 B: AutodiffBackend,
57{
58 fn clone(&self) -> Self {
59 match self {
60 AdaptorRecord::V1(record) => Self::V1(record.clone()),
61 }
62 }
63}
64
65impl<O, B> AdaptorRecord<O, B>
66where
67 O: SimpleOptimizer<B::InnerBackend>,
68 B: AutodiffBackend,
69{
70 pub fn into_state<const D: usize>(self) -> O::State<D> {
76 match self {
77 AdaptorRecord::V1(record) => record.into_state(),
78 }
79 }
80
81 pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
91 Self::V1(AdaptorRecordV1::from_state(state))
92 }
93}