1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use super::{AdaptorRecordItemV1, AdaptorRecordV1};
use crate::{
    optim::SimpleOptimizer,
    record::{PrecisionSettings, Record},
};
use burn_tensor::backend::AutodiffBackend;
use serde::{Deserialize, Serialize};

/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.
///
/// Records are versioned for backward compatibility, so old records can be loaded.
pub enum AdaptorRecord<O, B>
where
    O: SimpleOptimizer<B::InnerBackend>,
    B: AutodiffBackend,
{
    /// Version 1.
    V1(AdaptorRecordV1<O, B::InnerBackend>),
}

/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub enum AdaptorRecordItem<
    O: SimpleOptimizer<B::InnerBackend>,
    B: AutodiffBackend,
    S: PrecisionSettings,
> {
    /// Version 1.
    V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),
}

impl<O, B> Record<B> for AdaptorRecord<O, B>
where
    O: SimpleOptimizer<B::InnerBackend>,
    B: AutodiffBackend,
{
    type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;

    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
        match self {
            AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()),
        }
    }

    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
        match item {
            AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),
        }
    }
}

impl<O, B> Clone for AdaptorRecord<O, B>
where
    O: SimpleOptimizer<B::InnerBackend>,
    B: AutodiffBackend,
{
    fn clone(&self) -> Self {
        match self {
            AdaptorRecord::V1(record) => Self::V1(record.clone()),
        }
    }
}

impl<O, B> AdaptorRecord<O, B>
where
    O: SimpleOptimizer<B::InnerBackend>,
    B: AutodiffBackend,
{
    /// Converts the record into the optimizer state.
    ///
    /// # Returns
    ///
    /// The optimizer state.
    pub fn into_state<const D: usize>(self) -> O::State<D> {
        match self {
            AdaptorRecord::V1(record) => record.into_state(),
        }
    }

    /// Converts the optimizer state into the record.
    ///
    /// # Arguments
    ///
    /// * `state`: The optimizer state.
    ///
    /// # Returns
    ///
    /// The record.
    pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
        Self::V1(AdaptorRecordV1::from_state(state))
    }
}