burn_optim/optim/simple/record/
base.rs

1use burn_core as burn;
2
3use super::{AdaptorRecordItemV1, AdaptorRecordV1};
4use crate::optim::SimpleOptimizer;
5use burn::record::{PrecisionSettings, Record};
6use burn::tensor::backend::AutodiffBackend;
7use serde::{Deserialize, Serialize};
8
9/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.
10///
11/// Records are versioned for backward compatibility, so old records can be loaded.
12pub enum AdaptorRecord<O, B>
13where
14    O: SimpleOptimizer<B::InnerBackend>,
15    B: AutodiffBackend,
16{
17    /// Version 1.
18    V1(AdaptorRecordV1<O, B::InnerBackend>),
19}
20
21/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
22#[derive(Serialize, Deserialize)]
23#[serde(bound = "")]
24pub enum AdaptorRecordItem<
25    O: SimpleOptimizer<B::InnerBackend>,
26    B: AutodiffBackend,
27    S: PrecisionSettings,
28> {
29    /// Version 1.
30    V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
31}
32
33impl<O, B> Record<B> for AdaptorRecord<O, B>
34where
35    O: SimpleOptimizer<B::InnerBackend>,
36    B: AutodiffBackend,
37{
38    type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;
39
40    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
41        match self {
42            AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()),
43        }
44    }
45
46    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
47        match item {
48            AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
49        }
50    }
51}
52
53impl<O, B> Clone for AdaptorRecord<O, B>
54where
55    O: SimpleOptimizer<B::InnerBackend>,
56    B: AutodiffBackend,
57{
58    fn clone(&self) -> Self {
59        match self {
60            AdaptorRecord::V1(record) => Self::V1(record.clone()),
61        }
62    }
63}
64
65impl<O, B> AdaptorRecord<O, B>
66where
67    O: SimpleOptimizer<B::InnerBackend>,
68    B: AutodiffBackend,
69{
70    /// Converts the record into the optimizer state.
71    ///
72    /// # Returns
73    ///
74    /// The optimizer state.
75    pub fn into_state<const D: usize>(self) -> O::State<D> {
76        match self {
77            AdaptorRecord::V1(record) => record.into_state(),
78        }
79    }
80
81    /// Converts the optimizer state into the record.
82    ///
83    /// # Arguments
84    ///
85    /// * `state`: The optimizer state.
86    ///
87    /// # Returns
88    ///
89    /// The record.
90    pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
91        Self::V1(AdaptorRecordV1::from_state(state))
92    }
93}