burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::ParamId;
use crate::module::{LoadingError, State, StateNamed};
use crate::tensor::Element;

#[derive(Debug)]
pub struct Param<T> {
    pub(super) id: ParamId,
    pub(super) value: T,
}

impl<T> std::fmt::Display for Param<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(format!("Param: {}", self.id).as_str())
    }
}

impl<T> Param<T> {
    pub fn new(value: T) -> Self {
        Self {
            id: ParamId::new(),
            value,
        }
    }
}

impl<T> std::ops::Deref for Param<T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &self.value
    }
}

pub(super) fn state_with_id<E: Element>(id: ParamId, state: State<E>) -> State<E> {
    let mut state_wrapper = StateNamed::new();

    state_wrapper.register_state("data", state);
    state_wrapper.register_state("id", State::ParamId(id));

    State::StateNamed(state_wrapper)
}

pub(super) fn load_with_id<E: Element>(
    state: &State<E>,
) -> Result<(&ParamId, &State<E>), LoadingError> {
    let state_wrapper = match state {
        State::StateNamed(state) => state,
        _ => {
            return Err(LoadingError::new(
                "Can't load state wrapper to fetch id and data".to_string(),
            ))
        }
    };

    let state = match state_wrapper.get("data") {
        Some(state) => state,
        None => {
            return Err(LoadingError::new(
                "Can't load state data from state wrapper".to_string(),
            ))
        }
    };

    let id = match state_wrapper.get("id") {
        Some(State::ParamId(id)) => id,
        _ => {
            return Err(LoadingError::new(
                "Can't load state id from state wrapper".to_string(),
            ))
        }
    };

    Ok((id, state))
}