vkml 0.0.2

High-level Vulkan-based machine learning library
pub mod activations;
pub mod concat;
pub mod conv;
pub mod element_wise;
pub mod execution;
pub mod factory;
pub mod input_buffer;
pub mod linear;
pub mod reshape;

use std::collections::HashMap;

use crate::{tensor::TensorDesc, tensor_graph::TensorId, utils::error::VKMLError};

use self::execution::LayerExecution;

pub trait Layer {
    // Calculate the output shapes for all outputs of this layer
    fn output_shapes(
        &self,
        batch_size: i64,
        input_shapes: &[&TensorDesc],
    ) -> Result<Vec<TensorDesc>, VKMLError>;

    // Whether this layer requires trainable parameters
    fn requires_parameters(&self) -> bool {
        self.parameter_count(0, &[]) > 0
    }

    // For parameterised layers, describes the required weight and bias tensors
    fn parameter_shapes(&self, _input_shapes: &[&TensorDesc]) -> Option<(TensorDesc, TensorDesc)> {
        None
    }

    // Return the total number of parameters in this layer
    fn parameter_count(&self, _batch_size: i64, _input_shapes: &[&TensorDesc]) -> i64 {
        0
    }

    // For graph verification, how many inputs this layer requires (min and max)
    fn input_requirements(&self) -> (usize, Option<usize>);

    // Return a string representation of the layers name
    fn name(&self) -> String;

    // Return optional configuration details for the layer
    fn config_string(&self) -> Option<String> {
        None
    }

    // Get input features
    fn in_features(&self) -> i64 {
        0
    }

    // Get output features
    fn out_features(&self) -> i64 {
        0
    }

    fn map_input_tensors(&self, num_inputs: usize) -> HashMap<TensorId, (usize, TensorId)> {
        let mut mappings = HashMap::new();
        // Default implementation: first N tensors map directly to inputs
        for i in 0..num_inputs {
            mappings.insert(i, (i, 0)); // Local tensor i maps to input connection i, output 0
        }
        mappings
    }

    // Generate tensor descriptions, instructions, and outputs for this layer
    fn build_layer_exec(
        &self,
        batch_size: i64,
        input_shapes: &[&TensorDesc],
    ) -> Result<LayerExecution, VKMLError>;
}