burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::{load_with_id, state_with_id, Param};
use crate::module::{LoadingError, Module, State, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::{
    backend::{ADBackend, Backend},
    Data, Tensor,
};

impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
    type Backend = B;

    fn num_params(&self) -> usize {
        self.value.shape().num_elements()
    }

    fn update_params<O: Optimizer<Backend = B>>(
        &mut self,
        grads: &<B as ADBackend>::Gradients,
        optim: &mut O,
    ) where
        B: ADBackend,
    {
        optim.update(&self.id, &mut self.value, grads);
    }

    fn load_optim_state<O: Optimizer<Backend = B>>(
        &self,
        optim: &mut O,
        state_optim: &StateNamed<B::Elem>,
    ) where
        B: ADBackend,
    {
        optim.load_param_state::<D>(&self.id, state_optim, &self.value.device());
    }

    fn register_optim_state<O: Optimizer<Backend = B>>(
        &self,
        optim: &O,
        state_optim: &mut StateNamed<B::Elem>,
    ) where
        B: ADBackend,
    {
        optim.register_param_state::<D>(&self.id, state_optim);
    }

    fn devices(&self) -> Vec<B::Device> {
        vec![self.value.device()]
    }

    fn to_device(&mut self, device: B::Device) {
        self.value = self.value.to_device(device);
    }

    fn state(&self) -> State<B::Elem> {
        let state = State::Data(self.value.to_data().serialize());

        state_with_id(self.id.clone(), state)
    }

    fn load(&mut self, state: &State<B::Elem>) -> Result<(), LoadingError> {
        let (id, state) = load_with_id(state)?;
        self.id = id.clone();

        match state {
            State::Data(data) => {
                self.value = Tensor::from_data_device(Data::from(data), self.value.device());
            }
            _ => return Err(LoadingError::new("Can't load tensor".to_string())),
        };

        Ok(())
    }

    fn detach(&mut self) {
        self.value = self.value.clone().detach()
    }
}

impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
    type Backend = B;

    fn num_params(&self) -> usize {
        if let Some(value) = &self.value {
            return value.shape().num_elements();
        }

        0
    }

    fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &B::Gradients, optim: &mut O)
    where
        B: ADBackend,
    {
        if let Some(value) = &mut self.value {
            optim.update(&self.id, value, grads);
        }
    }

    fn load_optim_state<O: Optimizer<Backend = B>>(
        &self,
        optim: &mut O,
        state_optim: &StateNamed<B::Elem>,
    ) where
        B: ADBackend,
    {
        if let Some(value) = &self.value {
            optim.load_param_state::<D>(&self.id, state_optim, &value.device());
        }
    }

    fn register_optim_state<O: Optimizer<Backend = B>>(
        &self,
        optim: &O,
        state_optim: &mut StateNamed<B::Elem>,
    ) where
        B: ADBackend,
    {
        if self.value.is_some() {
            optim.register_param_state::<D>(&self.id, state_optim);
        }
    }

    fn devices(&self) -> Vec<B::Device> {
        if let Some(value) = &self.value {
            return vec![value.device()];
        }

        vec![]
    }

    fn to_device(&mut self, device: B::Device) {
        if let Some(value) = &self.value {
            self.value = Some(value.to_device(device));
        }
    }

    fn state(&self) -> State<B::Elem> {
        let state = match &self.value {
            Some(value) => State::Data(value.to_data().serialize()),
            None => State::StateNamed(StateNamed::new()),
        };

        state_with_id(self.id.clone(), state)
    }

    fn load(&mut self, state: &State<B::Elem>) -> Result<(), LoadingError> {
        let (id, state) = load_with_id(state)?;
        self.id = id.clone();

        let data = match state {
            State::Data(data) => data,
            _ => {
                return Err(LoadingError::new(
                    "Can't load Option<Tensor> from NamedState".to_string(),
                ))
            }
        };

        if let Some(value) = &self.value {
            self.value = Some(Tensor::from_data_device(Data::from(data), value.device()));
        }

        Ok(())
    }

    fn detach(&mut self) {
        self.value = self.value.clone().map(|tensor| tensor.detach());
    }
}

impl<const D: usize, B: Backend> Param<Tensor<B, D>> {
    pub fn inner(&self) -> Param<Tensor<B::InnerBackend, D>>
    where
        B: ADBackend,
    {
        Param::new(self.value.inner())
    }
}

impl<const D: usize, B: Backend> Param<Option<Tensor<B, D>>> {
    pub fn inner(&self) -> Param<Option<Tensor<B::InnerBackend, D>>>
    where
        B: ADBackend,
    {
        match &self.value {
            Some(tensor) => Param::new(Some(tensor.inner())),
            None => Param::new(None),
        }
    }
}