vkml 0.0.3

High-level Vulkan-based machine learning library
use crate::{tensor::TensorDesc, utils::OnnxAutoPad};

use super::{
    Layer,
    activations::{ActivationLayer, ActivationType},
    concat::ConcatLayer,
    conv::ConvLayer,
    element_wise::{ElementWiseLayer, ElementWiseOperation},
    input_buffer::InputLayer,
    linear::LinearLayer,
    reshape::ReshapeLayer,
};

pub struct Layers;

impl Layers {
    pub fn input_buffer(out_features: i64) -> Box<dyn Layer> {
        Box::new(InputLayer::new(out_features))
    }

    pub fn input_buffer_with(out_features: i64, track_gradients: bool) -> Box<dyn Layer> {
        Box::new(InputLayer::new_with(out_features, track_gradients))
    }

    pub fn linear(in_features: i64, out_features: i64) -> Box<dyn Layer> {
        Box::new(LinearLayer::new(in_features, out_features))
    }

    pub fn linear_with(in_features: i64, out_features: i64, bias: bool) -> Box<dyn Layer> {
        Box::new(LinearLayer::new_with(in_features, out_features, bias))
    }

    pub fn conv_with(
        in_features: i64,
        out_features: i64,
        auto_pad: OnnxAutoPad,
        dilations: Vec<i64>,
        kernel_shape: Vec<i64>,
        pads: Vec<i64>,
        strides: Vec<i64>,
        bias: bool,
    ) -> Box<dyn Layer> {
        Box::new(ConvLayer::new_with(
            in_features,
            out_features,
            auto_pad,
            dilations,
            kernel_shape,
            pads,
            strides,
            bias,
        ))
    }

    pub fn reshape(target_shape: TensorDesc) -> Box<dyn Layer> {
        Box::new(ReshapeLayer::new(target_shape))
    }

    pub fn flatten() -> Box<dyn Layer> {
        Box::new(ReshapeLayer::flatten())
    }

    pub fn concat(dim: usize) -> Box<dyn Layer> {
        Box::new(ConcatLayer::new(dim))
    }

    pub fn add() -> Box<dyn Layer> {
        Box::new(ElementWiseLayer::new(ElementWiseOperation::Add))
    }

    pub fn sub() -> Box<dyn Layer> {
        Box::new(ElementWiseLayer::new(ElementWiseOperation::Subtract))
    }

    pub fn mul() -> Box<dyn Layer> {
        Box::new(ElementWiseLayer::new(ElementWiseOperation::Multiply))
    }

    pub fn div() -> Box<dyn Layer> {
        Box::new(ElementWiseLayer::new(ElementWiseOperation::Divide))
    }

    pub fn max() -> Box<dyn Layer> {
        Box::new(ElementWiseLayer::new(ElementWiseOperation::Maximum))
    }

    pub fn min() -> Box<dyn Layer> {
        Box::new(ElementWiseLayer::new(ElementWiseOperation::Minimum))
    }

    pub fn relu() -> Box<dyn Layer> {
        Box::new(ActivationLayer::new(ActivationType::ReLU))
    }

    pub fn sigmoid() -> Box<dyn Layer> {
        Box::new(ActivationLayer::new(ActivationType::Sigmoid))
    }

    pub fn softmax(axis: Option<i64>) -> Box<dyn Layer> {
        Box::new(ActivationLayer::new(ActivationType::Softmax(axis)))
    }
}