runn 0.1.1

Runn is a feature-rich, easy-to-use library for building, training, and evaluating feed-forward neural networks in Rust
Documentation
pub mod cross_entropy;
pub mod mean_squared_error;

use crate::{common::matrix::DMat, Metrics};
use typetag;

#[typetag::serde]
pub trait LossFunction: LossFunctionClone + Send + Sync {
    fn forward(&self, predicted: &DMat, target: &DMat) -> f32;
    fn backward(&self, predicted: &DMat, target: &DMat) -> DMat;
    fn calculate_metrics(&self, targets: &DMat, predictions: &DMat) -> Metrics;
}

pub trait LossFunctionClone {
    fn clone_box(&self) -> Box<dyn LossFunction>;
}

impl LossFunctionClone for Box<dyn LossFunction> {
    fn clone_box(&self) -> Box<dyn LossFunction> {
        (**self).clone_box()
    }
}

impl Clone for Box<dyn LossFunction> {
    fn clone(&self) -> Box<dyn LossFunction> {
        self.clone_box()
    }
}

#[typetag::serde]
impl LossFunction for Box<dyn LossFunction> {
    fn forward(&self, predicted: &DMat, target: &DMat) -> f32 {
        (**self).forward(predicted, target)
    }

    fn backward(&self, predicted: &DMat, target: &DMat) -> DMat {
        (**self).backward(predicted, target)
    }

    fn calculate_metrics(&self, targets: &DMat, predictions: &DMat) -> Metrics {
        (**self).calculate_metrics(targets, predictions)
    }
}