burn 0.3.0

BURN: Burn Unstoppable Rusty Neurons
Documentation
use super::ParamId;
use crate::tensor::{DataSerialize, Element};
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::collections::HashMap;
use std::io::{Read, Write};

#[derive(Debug, PartialEq, Eq, Default)]
pub struct StateNamed<E> {
    pub values: HashMap<String, State<E>>,
}

#[derive(Debug, PartialEq, Eq)]
pub enum State<E> {
    StateNamed(StateNamed<E>),
    Data(DataSerialize<E>),
    ParamId(ParamId),
}

#[derive(Debug)]
pub enum StateError {
    InvalidFormat(String),
    FileNotFound(String),
}

impl std::fmt::Display for StateError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut message = "State error => ".to_string();

        match self {
            Self::InvalidFormat(err) => {
                message += format!("Invalid format: {}", err).as_str();
            }
            Self::FileNotFound(err) => {
                message += format!("File not found: {}", err).as_str();
            }
        };

        f.write_str(message.as_str())
    }
}
impl std::error::Error for StateError {}

impl<E: Element> From<State<E>> for serde_json::Value
where
    E: serde::de::DeserializeOwned,
    E: serde::Serialize,
{
    fn from(state: State<E>) -> serde_json::Value {
        match state {
            State::StateNamed(state) => state.into(),
            State::Data(data) => serde_json::to_value(data).unwrap(),
            State::ParamId(id) => serde_json::to_value(id.to_string()).unwrap(),
        }
    }
}

impl<E: Element> From<StateNamed<E>> for serde_json::Value
where
    E: serde::de::DeserializeOwned,
    E: serde::Serialize,
{
    fn from(state: StateNamed<E>) -> serde_json::Value {
        let mut map = serde_json::Map::new();

        for (key, state) in state.values {
            map.insert(key, state.into());
        }

        serde_json::Value::Object(map)
    }
}

impl<E> TryFrom<serde_json::Value> for State<E>
where
    E: serde::de::DeserializeOwned,
    E: serde::Serialize,
{
    type Error = StateError;

    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
        if let Ok(data) = serde_json::from_value(value.clone()) {
            return Ok(State::Data(data));
        };

        if let Ok(state) = StateNamed::<E>::try_from(value.clone()) {
            return Ok(State::StateNamed(state));
        };

        match serde_json::from_value::<String>(value.clone()) {
            Ok(id) => Ok(State::ParamId(ParamId::from(id.as_str()))),
            Err(_) => Err(StateError::InvalidFormat(format!(
                "Invalid value {:?}",
                value
            ))),
        }
    }
}

impl<E> TryFrom<serde_json::Value> for StateNamed<E>
where
    E: serde::de::DeserializeOwned,
    E: serde::Serialize,
{
    type Error = StateError;

    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
        let map = match value {
            serde_json::Value::Object(map) => map,
            _ => {
                return Err(StateError::InvalidFormat(format!(
                    "Invalid value {:?}",
                    value
                )))
            }
        };

        let mut values = HashMap::new();
        for (key, value) in map {
            values.insert(key, State::try_from(value)?);
        }

        Ok(Self { values })
    }
}

impl<E: Element> StateNamed<E> {
    pub fn new() -> Self {
        Self {
            values: HashMap::new(),
        }
    }

    pub fn register_state(&mut self, name: &str, state: State<E>) {
        self.values.insert(name.to_string(), state);
    }
}

impl<E: Element> StateNamed<E> {
    pub fn get(&self, name: &str) -> Option<&State<E>> {
        self.values.get(name)
    }

    pub fn is_empty(&self) -> bool {
        self.values.is_empty()
    }

    pub fn convert<O: Element>(self) -> StateNamed<O> {
        let mut values = HashMap::with_capacity(self.values.len());

        for (key, value) in self.values {
            values.insert(key, value.convert());
        }

        StateNamed { values }
    }
}

impl<E: Element> State<E> {
    pub fn get(&self, name: &str) -> Option<&Self> {
        match self {
            State::StateNamed(named) => named.get(name),
            _ => None,
        }
    }

    pub fn is_empty(&self) -> bool {
        match self {
            State::StateNamed(named) => named.is_empty(),
            State::Data(_) => false,
            State::ParamId(_) => false,
        }
    }

    pub fn convert<O: Element>(self) -> State<O> {
        match self {
            State::StateNamed(named) => State::StateNamed(named.convert()),
            State::Data(data) => State::Data(data.convert()),
            State::ParamId(id) => State::ParamId(id),
        }
    }
}

impl<E: Element> State<E>
where
    E: serde::de::DeserializeOwned,
    E: serde::Serialize,
{
    pub fn save(self, file: &str) -> std::io::Result<()> {
        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
        let value: serde_json::Value = self.into();

        let content = value.to_string();
        encoder.write_all(content.as_bytes()).unwrap();
        let content_compressed = encoder.finish().unwrap();

        std::fs::write(file, content_compressed)
    }

    pub fn load(file: &str) -> Result<Self, StateError> {
        let content_compressed =
            std::fs::read(file).map_err(|err| StateError::FileNotFound(format!("{:?}", err)))?;

        let mut decoder = GzDecoder::new(content_compressed.as_slice());
        let mut content = String::new();
        decoder.read_to_string(&mut content).unwrap();

        let value: serde_json::Value = serde_json::from_str(&content)
            .map_err(|err| StateError::InvalidFormat(format!("{:?}", err)))?;
        Self::try_from(value)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::module::Module;
    use crate::nn;
    use crate::tensor::backend::Backend;

    #[test]
    fn test_state_to_from_value() {
        let linear = nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
            d_input: 32,
            d_output: 32,
            bias: true,
        });

        let state = linear.state();
        let value: serde_json::Value = state.into();
        println!("{:?}", value);
        let state_from: State<<crate::TestBackend as Backend>::Elem> =
            State::try_from(value.clone()).unwrap();
        let value_from: serde_json::Value = state_from.into();

        assert_eq!(value, value_from);
    }

    #[test]
    fn test_can_save_and_load_from_file() {
        let mut linear = nn::Linear::<crate::TestBackend>::new(&nn::LinearConfig {
            d_input: 32,
            d_output: 32,
            bias: true,
        });
        linear.state().save("/tmp/test.json").unwrap();
        linear
            .load(&State::load("/tmp/test.json").unwrap())
            .unwrap();
    }
}