flashlight 0.0.12

neural network library
Documentation
use flashlight_tensor::prelude::*;
use rand::prelude::*;
use async_trait::async_trait;

/// Basic trait for models
pub trait ModelCpu{
    /// Forward propagation for model, returns Tensor<f32> on output
    fn forward(&mut self, input: Tensor<f32>) -> Tensor<f32>;
    /// Backward propagation for model, uses gradient_output, taken from last_activation.grad_output(target: &Tensor<f32>)
    fn backward(&mut self, grad_output: Tensor<f32>);
}

/// Basic trait for models
#[async_trait]
pub trait ModelGpu{
    /// Forward propagation for model, returns Tensor<f32> on output
    async fn forward(&mut self, input: Tensor<f32>) -> Tensor<f32>;
    /// Backward propagation for model, uses gradient_output, taken from last_activation.grad_output(target: &Tensor<f32>)
    async fn backward(&mut self, grad_output: Tensor<f32>);

    fn clear_buffers(&mut self);
}

/// Returns xavier weights based on input and output neurons in layer
pub fn xavier_weights(input_neurons: u32, output_neurons: u32) -> f32{
    (6.0/((input_neurons + output_neurons) as f32)).sqrt()
}