use crate::{instruction, tensor::TensorDesc, utils::error::VKMLError};
use super::{Layer, execution::LayerExecution};
#[derive(Clone)]
pub enum ElementWiseOperation {
Add,
Subtract,
Multiply,
Divide,
Maximum,
Minimum,
}
impl ElementWiseOperation {
fn name(&self) -> String {
match self {
ElementWiseOperation::Add => "Add".to_string(),
ElementWiseOperation::Subtract => "Subtract".to_string(),
ElementWiseOperation::Multiply => "Multiply".to_string(),
ElementWiseOperation::Divide => "Divide".to_string(),
ElementWiseOperation::Maximum => "Maximum".to_string(),
ElementWiseOperation::Minimum => "Minimum".to_string(),
}
}
}
#[derive(Clone)]
pub struct ElementWiseLayer {
pub operation: ElementWiseOperation,
}
impl ElementWiseLayer {
pub fn new(operation: ElementWiseOperation) -> Self {
Self { operation }
}
}
impl Layer for ElementWiseLayer {
fn output_shapes(
&self,
_batch_size: i64,
input_shapes: &[&TensorDesc],
) -> Result<Vec<TensorDesc>, VKMLError> {
if input_shapes.len() < 2 {
return Err(VKMLError::Layer(format!(
"Element-wise operation requires at least 2 inputs, got {}",
input_shapes.len()
)));
}
let first_shape = input_shapes[0];
for shape in &input_shapes[1..] {
if shape.dims() != first_shape.dims() {
return Err(VKMLError::Layer(format!(
"Element-wise operations require matching dimensions: {:?} vs {:?}",
first_shape.dims(),
shape.dims()
)));
}
}
Ok(vec![first_shape.clone()])
}
fn input_requirements(&self) -> (usize, Option<usize>) {
(2, None) }
fn name(&self) -> String {
self.operation.name()
}
fn build_layer_exec(
&self,
_batch_size: i64,
input_shapes: &[&TensorDesc],
) -> Result<LayerExecution, VKMLError> {
if input_shapes.len() < 2 {
return Err(VKMLError::Layer(format!(
"Element-wise operation requires at least 2 inputs, got {}",
input_shapes.len()
)));
}
let first_shape = input_shapes[0];
for shape in &input_shapes[1..] {
if shape.dims() != first_shape.dims() {
return Err(VKMLError::Layer(format!(
"Element-wise operations require matching dimensions: {:?} vs {:?}",
first_shape.dims(),
shape.dims()
)));
}
}
let tensors = vec![
first_shape.clone(),
first_shape.clone(),
first_shape.clone(),
];
let element_wise_instruction = match self.operation {
ElementWiseOperation::Add => instruction::add(0, 1, 2),
ElementWiseOperation::Subtract => instruction::sub(0, 1, 2),
ElementWiseOperation::Multiply => instruction::mul(0, 1, 2),
ElementWiseOperation::Divide => instruction::div(0, 1, 2),
ElementWiseOperation::Maximum => instruction::max(0, 1, 2),
ElementWiseOperation::Minimum => instruction::min(0, 1, 2),
};
let input_mappings = self.map_input_tensors(input_shapes.len().min(2));
Ok(LayerExecution {
tensors,
instructions: vec![element_wise_instruction],
outputs: vec![2],
input_mappings,
})
}
}