burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::{load_with_id, state_with_id, Param};
use crate::module::{ADModule, LoadingError, Module, State, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::backend::{ADBackend, Backend};

impl<M: Module> Module for Param<M> {
    type Backend = M::Backend;

    fn num_params(&self) -> usize {
        self.value.num_params()
    }

    fn update_params<O: Optimizer<Backend = M::Backend>>(
        &mut self,
        grads: &<M::Backend as ADBackend>::Gradients,
        optim: &mut O,
    ) where
        M::Backend: ADBackend,
    {
        self.value.update_params(grads, optim);
    }

    fn load_optim_state<O: Optimizer<Backend = M::Backend>>(
        &self,
        optim: &mut O,
        state_optim: &StateNamed<<M::Backend as Backend>::Elem>,
    ) where
        M::Backend: ADBackend,
    {
        self.value.load_optim_state(optim, state_optim);
    }

    fn register_optim_state<O: Optimizer<Backend = M::Backend>>(
        &self,
        optim: &O,
        state_optim: &mut StateNamed<<M::Backend as Backend>::Elem>,
    ) where
        M::Backend: ADBackend,
    {
        self.value.register_optim_state(optim, state_optim);
    }

    fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
        self.value.devices()
    }

    fn to_device(&mut self, device: <Self::Backend as Backend>::Device) {
        self.value.to_device(device)
    }

    fn state(&self) -> State<<M::Backend as Backend>::Elem> {
        let state = self.value.state();

        state_with_id(self.id.clone(), state)
    }

    fn load(&mut self, state: &State<<M::Backend as Backend>::Elem>) -> Result<(), LoadingError> {
        let (id, state) = load_with_id(state)?;
        self.id = id.clone();

        self.value.load(state)
    }

    fn detach(&mut self) {
        self.value.detach()
    }
}

impl<M: Module> Module for Param<Vec<M>> {
    type Backend = M::Backend;

    fn num_params(&self) -> usize {
        let mut num_params = 0;
        for module in self.value.iter() {
            num_params += module.num_params();
        }

        num_params
    }

    fn update_params<O: Optimizer<Backend = M::Backend>>(
        &mut self,
        grads: &<M::Backend as ADBackend>::Gradients,
        optim: &mut O,
    ) where
        M::Backend: ADBackend,
    {
        for module in self.value.iter_mut() {
            module.update_params(grads, optim);
        }
    }

    fn load_optim_state<O: Optimizer<Backend = M::Backend>>(
        &self,
        optim: &mut O,
        state_optim: &StateNamed<<M::Backend as Backend>::Elem>,
    ) where
        M::Backend: ADBackend,
    {
        for module in self.value.iter() {
            module.load_optim_state(optim, state_optim);
        }
    }
    fn register_optim_state<O: Optimizer<Backend = M::Backend>>(
        &self,
        optim: &O,
        state_optim: &mut StateNamed<<M::Backend as Backend>::Elem>,
    ) where
        M::Backend: ADBackend,
    {
        for module in self.value.iter() {
            module.register_optim_state(optim, state_optim);
        }
    }

    fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
        let mut devices = Vec::new();
        for module in self.value.iter() {
            devices.append(&mut module.devices());
        }
        devices
    }

    fn to_device(&mut self, device: <M::Backend as Backend>::Device) {
        for module in self.value.iter_mut() {
            module.to_device(device);
        }
    }

    fn state(&self) -> State<<M::Backend as Backend>::Elem> {
        let mut state = StateNamed::new();

        for (i, module) in self.value.iter().enumerate() {
            state.register_state(format!("mod-{}", i).as_str(), module.state());
        }

        let state = State::StateNamed(state);

        state_with_id(self.id.clone(), state)
    }

    fn load(&mut self, state: &State<<M::Backend as Backend>::Elem>) -> Result<(), LoadingError> {
        let (id, state) = load_with_id(state)?;
        self.id = id.clone();

        let num = self.value.len();
        for (i, module) in self.value.iter_mut().enumerate() {
            module
                .load(state.get(format!("mod-{}", i).as_str()).ok_or_else(|| {
                    LoadingError::new(format!(
                        "Invalid number of modules, expected {} modules missing #{}",
                        num, i
                    ))
                })?)
                .map_err(|err| {
                    LoadingError::new(format!("Can't load modules mod-{}: {}", i, err))
                })?;
        }

        Ok(())
    }

    fn detach(&mut self) {
        for value in self.value.iter_mut() {
            value.detach();
        }
    }
}

impl<M: Module> Param<Vec<M>> {
    pub fn inner(&self) -> Param<Vec<M::InnerModule>>
    where
        M: ADModule,
        M::Backend: ADBackend,
    {
        Param::new(self.value.iter().map(|v| v.inner()).collect())
    }
}

impl<M: Module> Param<M> {
    pub fn inner(&self) -> Param<M::InnerModule>
    where
        M: ADModule,
        M::Backend: ADBackend,
    {
        Param::new(self.value.inner())
    }
}