1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use super::{Param, ParamId};
use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor};
use crate::tensor::{
    backend::{ADBackend, Backend},
    Tensor,
};

impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
    fn from(value: Tensor<B, D>) -> Self {
        Param::new(ParamId::new(), value.require_grad())
    }
}

impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
    type Record = Param<Tensor<B, D>>;

    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
        visitor.visit(&self.id, &self.value)
    }

    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
        let value = mapper.map(&self.id, self.value);
        Self::new(self.id, value)
    }

    fn into_record(self) -> Self::Record {
        self
    }

    fn load_record(self, record: Self::Record) -> Self {
        let mut tensor = record.value.detach();
        let device = self.device();

        // Make sure we load the record into the same module device.
        if tensor.device() != device {
            tensor = tensor.to_device(&device).detach();
        }

        // Make sure we load the record with the same autodiff setting.
        if self.is_require_grad() {
            tensor = tensor.require_grad();
        }

        Self::new(record.id, tensor)
    }
}

impl<const D: usize, B: ADBackend> ADModule<B> for Param<Tensor<B, D>> {
    type InnerModule = Param<Tensor<B::InnerBackend, D>>;

    fn valid(&self) -> Self::InnerModule {
        Param::new(
            self.id.clone(),
            self.value.clone().inner().set_require_grad(false),
        )
    }
}

#[cfg(all(test, feature = "std"))]
mod tests {
    use super::*;
    use crate::{
        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
        TestADBackend,
    };

    #[test]
    fn test_load_record_setting() {
        let tensor = Tensor::<TestADBackend, 2>::ones([3, 3]);

        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
        let bytes = byte_recorder
            .record(Param::from(tensor.clone()).into_record(), ())
            .unwrap();

        let no_grad_is_require_grad = Param::from(tensor.clone())
            .no_grad()
            .load_record(byte_recorder.load(bytes.clone()).unwrap())
            .value
            .is_require_grad();

        let with_default_is_require_grad = Param::from(tensor)
            .load_record(byte_recorder.load(bytes).unwrap())
            .value
            .is_require_grad();

        assert!(!no_grad_is_require_grad);
        assert!(with_default_is_require_grad);
    }
}