use super::{Param, ParamId};
use crate::module::{ADModule, Module, ModuleMapper, ModuleVisitor};
use crate::tensor::{
    backend::{ADBackend, Backend},
    Tensor,
};
impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
    fn from(value: Tensor<B, D>) -> Self {
        Param::new(ParamId::new(), value.require_grad())
    }
}
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
    type Record = Param<Tensor<B, D>>;
    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
        visitor.visit(&self.id, &self.value)
    }
    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
        let value = mapper.map(&self.id, self.value);
        Self::new(self.id, value)
    }
    fn into_record(self) -> Self::Record {
        self
    }
    fn load_record(self, record: Self::Record) -> Self {
        let mut tensor = record.value.detach();
        let device = self.device();
        if tensor.device() != device {
            tensor = tensor.to_device(&device).detach();
        }
        if self.is_require_grad() {
            tensor = tensor.require_grad();
        }
        Self::new(record.id, tensor)
    }
}
impl<const D: usize, B: ADBackend> ADModule<B> for Param<Tensor<B, D>> {
    type InnerModule = Param<Tensor<B::InnerBackend, D>>;
    fn valid(&self) -> Self::InnerModule {
        Param::new(
            self.id.clone(),
            self.value.clone().inner().set_require_grad(false),
        )
    }
}
#[cfg(all(test, feature = "std"))]
mod tests {
    use super::*;
    use crate::{
        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
        TestADBackend,
    };
    #[test]
    fn test_load_record_setting() {
        let tensor = Tensor::<TestADBackend, 2>::ones([3, 3]);
        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();
        let bytes = byte_recorder
            .record(Param::from(tensor.clone()).into_record(), ())
            .unwrap();
        let no_grad_is_require_grad = Param::from(tensor.clone())
            .no_grad()
            .load_record(byte_recorder.load(bytes.clone()).unwrap())
            .value
            .is_require_grad();
        let with_default_is_require_grad = Param::from(tensor)
            .load_record(byte_recorder.load(bytes).unwrap())
            .value
            .is_require_grad();
        assert!(!no_grad_is_require_grad);
        assert!(with_default_is_require_grad);
    }
}