use crate::error::{NumRs2Error, Result};
use crate::gpu::array::GpuArray;
use crate::gpu::context::GpuContextRef;
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
struct MatMulParams {
a_rows: u32,
a_cols: u32,
b_cols: u32,
_padding: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
struct ReductionParams {
op_type: u32,
array_size: u32,
workgroup_size: u32,
_padding: u32,
}
pub fn matmul<T: Pod + Zeroable + 'static>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
if a.shape().len() != 2 || b.shape().len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"Matrix multiplication requires 2D arrays".to_string(),
));
}
let a_rows = a.shape()[0];
let a_cols = a.shape()[1];
let b_rows = b.shape()[0];
let b_cols = b.shape()[1];
if a_cols != b_rows {
return Err(NumRs2Error::DimensionMismatch(format!(
"Incompatible matrix dimensions for multiplication: ({}, {}) * ({}, {})",
a_rows, a_cols, b_rows, b_cols
)));
}
let context = a.context();
let output_shape = vec![a_rows, b_cols];
let output = GpuArray::<T>::new_with_shape(&output_shape, context.clone())?;
let params = MatMulParams {
a_rows: a_rows as u32,
a_cols: a_cols as u32,
b_cols: b_cols as u32,
_padding: 0,
};
let params_buffer = context.create_buffer(
&[params],
wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
);
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Matrix Multiplication Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Matrix Multiplication Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: b.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: params_buffer.as_entire_binding(),
},
],
});
let shader_module = if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
context.matmul_f32_shader()
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
context.matmul_f64_shader()
} else {
return Err(NumRs2Error::TypeCastError(
"Matrix multiplication only supports f32 and f64 types".to_string(),
));
};
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Matrix Multiplication Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let compute_pipeline =
context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Matrix Multiplication Pipeline"),
layout: Some(&pipeline_layout),
module: shader_module,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let workgroup_count_x = (b_cols as u32).div_ceil(16);
let workgroup_count_y = (a_rows as u32).div_ceil(16);
context.run_compute(
&compute_pipeline,
&[&bind_group],
(workgroup_count_x, workgroup_count_y, 1),
);
Ok(output)
}
pub fn dot<T: Pod + Zeroable + num_traits::Zero + Clone + 'static>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<T> {
if a.shape().len() != 1 || b.shape().len() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"Dot product requires 1D arrays".to_string(),
));
}
if a.size() != b.size() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Vectors must have same length: {} != {}",
a.size(),
b.size()
)));
}
let a_2d = a.reshape(&[1, a.size()])?;
let b_2d = b.reshape(&[b.size(), 1])?;
let result_2d = matmul(&a_2d, &b_2d)?;
let result_array = result_2d.to_array()?;
result_array.get(&[0, 0]).map_err(|e| {
NumRs2Error::IndexError(format!("Failed to extract dot product result: {}", e))
})
}
pub fn norm_l2<T: Pod + Zeroable + num_traits::Float + 'static>(a: &GpuArray<T>) -> Result<T> {
if a.shape().len() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"Norm requires a 1D array".to_string(),
));
}
let squared_norm = dot(a, a)?;
Ok(squared_norm.sqrt())
}
pub fn norm_l1<T: Pod + Zeroable + 'static>(_a: &GpuArray<T>) -> Result<T> {
Err(NumRs2Error::NotImplemented(
"L1 norm on GPU requires additional shader support".to_string(),
))
}
pub fn matvec<T: Pod + Zeroable + 'static>(
a: &GpuArray<T>,
x: &GpuArray<T>,
) -> Result<GpuArray<T>> {
if a.shape().len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"Matrix-vector multiplication requires a 2D matrix".to_string(),
));
}
if x.shape().len() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"Matrix-vector multiplication requires a 1D vector".to_string(),
));
}
let n = a.shape()[1];
if x.size() != n {
return Err(NumRs2Error::DimensionMismatch(format!(
"Incompatible dimensions: matrix has {} columns but vector has {} elements",
n,
x.size()
)));
}
let x_col = x.reshape(&[n, 1])?;
let result = matmul(a, &x_col)?;
result.reshape(&[a.shape()[0]])
}
pub fn vecmat<T: Pod + Zeroable + 'static>(
x: &GpuArray<T>,
a: &GpuArray<T>,
) -> Result<GpuArray<T>> {
if x.shape().len() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"Vector-matrix multiplication requires a 1D vector".to_string(),
));
}
if a.shape().len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"Vector-matrix multiplication requires a 2D matrix".to_string(),
));
}
let m = a.shape()[0];
if x.size() != m {
return Err(NumRs2Error::DimensionMismatch(format!(
"Incompatible dimensions: vector has {} elements but matrix has {} rows",
x.size(),
m
)));
}
let x_row = x.reshape(&[1, m])?;
let result = matmul(&x_row, a)?;
result.reshape(&[a.shape()[1]])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul_params_size() {
assert_eq!(std::mem::size_of::<MatMulParams>(), 16);
assert_eq!(std::mem::align_of::<MatMulParams>(), 4);
}
#[test]
fn test_reduction_params_size() {
assert_eq!(std::mem::size_of::<ReductionParams>(), 16);
assert_eq!(std::mem::align_of::<ReductionParams>(), 4);
}
}