zenu-layer 0.1.1

A simple neural network layer library.
Documentation
use std::{collections::HashMap, hash::BuildHasher};

use zenu_autograd::Variable;
use zenu_matrix::{device::Device, num::Num};

pub mod layers;

pub trait ModuleParameters<T: Num, D: Device> {}

impl<T: Num, D: Device> ModuleParameters<T, D> for () {}

impl<T: Num, D: Device> ModuleParameters<T, D> for Variable<T, D> {}

impl<T: Num, D: Device> ModuleParameters<T, D> for Vec<Variable<T, D>> {}

impl<T: Num, D: Device, K, S: BuildHasher> ModuleParameters<T, D>
    for HashMap<K, Variable<T, D>, S>
{
}

pub trait Module<T: Num, D: Device> {
    type Input: ModuleParameters<T, D>;
    type Output: ModuleParameters<T, D>;
    fn call(&self, input: Self::Input) -> Self::Output;
}

pub trait Parameters<T: Num, D: Device> {
    fn weights(&self) -> HashMap<String, Variable<T, D>>;
    fn biases(&self) -> HashMap<String, Variable<T, D>>;
    fn parameters(&self) -> HashMap<String, Variable<T, D>> {
        let weights = self.weights();
        let biases = self.biases();
        let mut parameters = HashMap::new();
        for (key, value) in weights {
            parameters.insert(key.clone(), value.clone());
        }
        for (key, value) in biases {
            parameters.insert(key.clone(), value.clone());
        }
        parameters
    }
    fn load_parameters(&mut self, parameters: HashMap<String, Variable<T, D>>) {
        for (self_key, self_value) in self.parameters() {
            if let Some(value) = parameters.get(&self_key) {
                self_value.get_as_mut().copy_from(&value.get_as_ref());
            } else {
                panic!("Failed to load model missing key: {self_key}");
            }
        }
    }
}

impl<T: Num, D: Device> Parameters<T, D> for () {
    fn weights(&self) -> HashMap<String, Variable<T, D>> {
        HashMap::new()
    }

    fn biases(&self) -> HashMap<String, Variable<T, D>> {
        HashMap::new()
    }
}
impl<T: Num, D: Device, P: Parameters<T, D>> Parameters<T, D> for Vec<P> {
    fn weights(&self) -> HashMap<String, Variable<T, D>> {
        let mut weights = HashMap::new();
        for (idx, param) in self.iter().enumerate() {
            for (key, value) in param.weights() {
                weights.insert(format!("{idx}.{key}"), value.clone());
            }
        }
        weights
    }

    fn biases(&self) -> HashMap<String, Variable<T, D>> {
        let mut biases = HashMap::new();
        for (idx, param) in self.iter().enumerate() {
            for (key, value) in param.biases() {
                biases.insert(format!("{idx}.{key}",), value.clone());
            }
        }
        biases
    }
}

impl<T: Num, D: Device> Parameters<T, D> for Box<dyn Parameters<T, D>> {
    fn weights(&self) -> HashMap<String, Variable<T, D>> {
        self.as_ref().weights()
    }

    fn biases(&self) -> HashMap<String, Variable<T, D>> {
        self.as_ref().biases()
    }
}

impl<T: Num, D: Device, P: Parameters<T, D>, S: ::std::hash::BuildHasher> Parameters<T, D>
    for HashMap<String, P, S>
{
    fn weights(&self) -> HashMap<String, Variable<T, D>> {
        let mut weights = HashMap::new();
        for (key, param) in self {
            for (sub_key, value) in param.weights() {
                weights.insert(format!("{key}.{sub_key}"), value.clone());
            }
        }
        weights
    }

    fn biases(&self) -> HashMap<String, Variable<T, D>> {
        let mut biases = HashMap::new();
        for (key, param) in self {
            for (sub_key, value) in param.biases() {
                biases.insert(format!("{key}.{sub_key}"), value.clone());
            }
        }
        biases
    }
}