use crate::error::{NumRs2Error, Result};
use crate::gpu::array::GpuArray;
use wgpu::util::DeviceExt;
const WORKGROUP_SIZE: u32 = 256;
enum ElementWiseOp {
Add = 0,
Subtract = 1,
Multiply = 2,
Divide = 3,
Exp = 4,
Log = 5,
Sin = 6,
Cos = 7,
Tan = 8,
Sqrt = 9,
Abs = 10,
Neg = 11,
Pow = 12,
}
pub fn add<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
element_wise_op(a, b, ElementWiseOp::Add)
}
pub fn subtract<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
element_wise_op(a, b, ElementWiseOp::Subtract)
}
pub fn multiply<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
element_wise_op(a, b, ElementWiseOp::Multiply)
}
pub fn divide<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
element_wise_op(a, b, ElementWiseOp::Divide)
}
pub fn exp<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Exp)
}
pub fn log<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Log)
}
pub fn sin<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Sin)
}
pub fn cos<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Cos)
}
pub fn tan<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Tan)
}
pub fn sqrt<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Sqrt)
}
pub fn abs<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Abs)
}
pub fn neg<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
unary_element_wise_op(a, ElementWiseOp::Neg)
}
pub fn pow<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
element_wise_op(a, b, ElementWiseOp::Pow)
}
pub fn matmul<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
if a.shape().len() != 2 || b.shape().len() != 2 {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![2],
actual: vec![a.shape().len(), b.shape().len()],
});
}
let a_shape = a.shape();
let b_shape = b.shape();
if a_shape[1] != b_shape[0] {
return Err(NumRs2Error::DimensionMismatch(format!(
"Cannot multiply matrices with shapes {:?} and {:?}",
a_shape, b_shape
)));
}
let out_shape = vec![a_shape[0], b_shape[1]];
let context = a.context().clone();
let result = GpuArray::<T>::new_with_shape(&out_shape, context.clone())?;
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("NumRS2 MatMul Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let shader = if std::mem::size_of::<T>() == 4 {
context.matmul_f32_shader()
} else {
context.matmul_f64_shader()
};
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("NumRS2 MatMul Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("NumRS2 MatMul Pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some("main"),
cache: None,
compilation_options: Default::default(),
});
let dims = [
a_shape[0] as u32, a_shape[1] as u32, b_shape[1] as u32, 0, ];
let dimensions_buffer =
context
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("MatMul Dimensions"),
contents: bytemuck::cast_slice(&dims),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("NumRS2 MatMul Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: b.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: result.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: dimensions_buffer.as_entire_binding(),
},
],
});
let workgroup_count_x = (out_shape[1] as f32 / 16.0).ceil() as u32;
let workgroup_count_y = (out_shape[0] as f32 / 16.0).ceil() as u32;
context.run_compute(
&pipeline,
&[&bind_group],
(workgroup_count_x, workgroup_count_y, 1),
);
Ok(result)
}
pub fn transpose<T: bytemuck::Pod + bytemuck::Zeroable>(a: &GpuArray<T>) -> Result<GpuArray<T>> {
if a.shape().len() < 2 {
return Err(NumRs2Error::InvalidOperation(format!(
"Cannot transpose array with less than 2 dimensions, got shape {:?}",
a.shape()
)));
}
if a.shape().len() == 2 {
let mut out_shape = a.shape().to_vec();
out_shape.swap(0, 1);
let context = a.context().clone();
let result = GpuArray::<T>::new_with_shape(&out_shape, context.clone())?;
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("NumRS2 Transpose Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let shader = if std::mem::size_of::<T>() == 4 {
context.element_wise_f32_shader()
} else {
context.element_wise_f64_shader()
};
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("NumRS2 Transpose Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("NumRS2 Transpose Pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some("transpose"),
cache: None,
compilation_options: Default::default(),
});
let dims = [a.shape()[0] as u32, a.shape()[1] as u32, 0, 0];
let dimensions_buffer =
context
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Transpose Dimensions"),
contents: bytemuck::cast_slice(&dims),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("NumRS2 Transpose Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: result.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: dimensions_buffer.as_entire_binding(),
},
],
});
let workgroup_count_x = (out_shape[1] as f32 / 16.0).ceil() as u32;
let workgroup_count_y = (out_shape[0] as f32 / 16.0).ceil() as u32;
context.run_compute(
&pipeline,
&[&bind_group],
(workgroup_count_x, workgroup_count_y, 1),
);
return Ok(result);
}
Err(NumRs2Error::NotImplemented(
"Transpose for arrays with more than 2 dimensions is not implemented yet".to_string(),
))
}
fn element_wise_op<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
op: ElementWiseOp,
) -> Result<GpuArray<T>> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape().to_vec(),
actual: b.shape().to_vec(),
});
}
let context = a.context().clone();
let result = GpuArray::<T>::new_with_shape(a.shape(), context.clone())?;
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("NumRS2 Element-wise Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let shader = if std::mem::size_of::<T>() == 4 {
context.element_wise_f32_shader()
} else {
context.element_wise_f64_shader()
};
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("NumRS2 Element-wise Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("NumRS2 Element-wise Pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some("binary_op"),
cache: None,
compilation_options: Default::default(),
});
let params = [op as u32, a.size() as u32, 0, 0];
let params_buffer = context
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Element-wise Op Params"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("NumRS2 Element-wise Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: b.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: result.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: params_buffer.as_entire_binding(),
},
],
});
let total_threads = a.size() as u32;
let workgroup_count = total_threads.div_ceil(WORKGROUP_SIZE);
context.run_compute(&pipeline, &[&bind_group], (workgroup_count, 1, 1));
Ok(result)
}
#[derive(Clone, Copy)]
enum ReductionOp {
Sum = 0,
Mean = 1,
Max = 2,
Min = 3,
}
pub fn sum_f32(a: &GpuArray<f32>) -> Result<f32> {
reduction_op_f32(a, ReductionOp::Sum)
}
pub fn sum_f64(a: &GpuArray<f64>) -> Result<f64> {
reduction_op_f64(a, ReductionOp::Sum)
}
pub fn mean_f32(a: &GpuArray<f32>) -> Result<f32> {
reduction_op_f32(a, ReductionOp::Mean)
}
pub fn mean_f64(a: &GpuArray<f64>) -> Result<f64> {
reduction_op_f64(a, ReductionOp::Mean)
}
pub fn max_f32(a: &GpuArray<f32>) -> Result<f32> {
reduction_op_f32(a, ReductionOp::Max)
}
pub fn max_f64(a: &GpuArray<f64>) -> Result<f64> {
reduction_op_f64(a, ReductionOp::Max)
}
pub fn min_f32(a: &GpuArray<f32>) -> Result<f32> {
reduction_op_f32(a, ReductionOp::Min)
}
pub fn min_f64(a: &GpuArray<f64>) -> Result<f64> {
reduction_op_f64(a, ReductionOp::Min)
}
fn reduction_op_f32(a: &GpuArray<f32>, op: ReductionOp) -> Result<f32> {
let context = a.context().clone();
let total_elements = a.size() as u32;
let workgroup_count = total_elements.div_ceil(WORKGROUP_SIZE);
let partial_results_size = workgroup_count as usize * std::mem::size_of::<f32>();
let partial_results_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("Reduction Partial Results"),
size: partial_results_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("NumRS2 Reduction Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let shader = context.reduction_f32_shader();
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("NumRS2 Reduction Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("NumRS2 Reduction Pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some("reduction"),
cache: None,
compilation_options: Default::default(),
});
let params = [op as u32, total_elements, 0, 0];
let params_buffer = context
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Reduction Op Params"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("NumRS2 Reduction Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: partial_results_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
context.run_compute(&pipeline, &[&bind_group], (workgroup_count, 1, 1));
let staging_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("Reduction Staging Buffer"),
size: partial_results_size as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut encoder = context
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Reduction Copy Encoder"),
});
encoder.copy_buffer_to_buffer(
&partial_results_buffer,
0,
&staging_buffer,
0,
partial_results_size as u64,
);
context.queue().submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
tx.send(result)
.expect("Failed to send f32 reduction buffer mapping result - receiver dropped");
});
context
.device()
.poll(wgpu::PollType::wait_indefinitely())
.expect("GPU device poll failed during f32 reduction buffer mapping");
rx.recv()
.map_err(|e| {
NumRs2Error::RuntimeError(format!(
"Failed to receive f32 reduction buffer mapping result: {:?}",
e
))
})?
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to map buffer: {:?}", e)))?;
let data = buffer_slice.get_mapped_range();
let partial_results: &[f32] = bytemuck::cast_slice(&data);
let final_result = match op {
ReductionOp::Sum => partial_results.iter().sum(),
ReductionOp::Mean => partial_results.iter().sum::<f32>() / total_elements as f32,
ReductionOp::Max => partial_results
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max),
ReductionOp::Min => partial_results
.iter()
.cloned()
.fold(f32::INFINITY, f32::min),
};
drop(data);
staging_buffer.unmap();
Ok(final_result)
}
fn reduction_op_f64(a: &GpuArray<f64>, op: ReductionOp) -> Result<f64> {
let context = a.context().clone();
let total_elements = a.size() as u32;
let workgroup_count = total_elements.div_ceil(WORKGROUP_SIZE);
let partial_results_size = workgroup_count as usize * std::mem::size_of::<f64>();
let partial_results_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("Reduction Partial Results"),
size: partial_results_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("NumRS2 Reduction Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let shader = context.reduction_f64_shader();
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("NumRS2 Reduction Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("NumRS2 Reduction Pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some("reduction"),
cache: None,
compilation_options: Default::default(),
});
let params = [op as u32, total_elements, 0, 0];
let params_buffer = context
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Reduction Op Params"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("NumRS2 Reduction Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: partial_results_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
context.run_compute(&pipeline, &[&bind_group], (workgroup_count, 1, 1));
let staging_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
label: Some("Reduction Staging Buffer"),
size: partial_results_size as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut encoder = context
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Reduction Copy Encoder"),
});
encoder.copy_buffer_to_buffer(
&partial_results_buffer,
0,
&staging_buffer,
0,
partial_results_size as u64,
);
context.queue().submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
tx.send(result)
.expect("Failed to send f64 reduction buffer mapping result - receiver dropped");
});
context
.device()
.poll(wgpu::PollType::wait_indefinitely())
.expect("GPU device poll failed during f64 reduction buffer mapping");
rx.recv()
.map_err(|e| {
NumRs2Error::RuntimeError(format!(
"Failed to receive f64 reduction buffer mapping result: {:?}",
e
))
})?
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to map buffer: {:?}", e)))?;
let data = buffer_slice.get_mapped_range();
let partial_results: &[f64] = bytemuck::cast_slice(&data);
let final_result = match op {
ReductionOp::Sum => partial_results.iter().sum(),
ReductionOp::Mean => partial_results.iter().sum::<f64>() / total_elements as f64,
ReductionOp::Max => partial_results
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max),
ReductionOp::Min => partial_results
.iter()
.cloned()
.fold(f64::INFINITY, f64::min),
};
drop(data);
staging_buffer.unmap();
Ok(final_result)
}
fn unary_element_wise_op<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
op: ElementWiseOp,
) -> Result<GpuArray<T>> {
let context = a.context().clone();
let result = GpuArray::<T>::new_with_shape(a.shape(), context.clone())?;
let bind_group_layout =
context
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("NumRS2 Unary Element-wise Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let shader = if std::mem::size_of::<T>() == 4 {
context.element_wise_f32_shader()
} else {
context.element_wise_f64_shader()
};
let pipeline_layout =
context
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("NumRS2 Unary Element-wise Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = context
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("NumRS2 Unary Element-wise Pipeline"),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some("unary_op"),
cache: None,
compilation_options: Default::default(),
});
let params = [op as u32, a.size() as u32, 0, 0];
let params_buffer = context
.device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Unary Element-wise Op Params"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = context
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("NumRS2 Unary Element-wise Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: result.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
let total_threads = a.size() as u32;
let workgroup_count = total_threads.div_ceil(WORKGROUP_SIZE);
context.run_compute(&pipeline, &[&bind_group], (workgroup_count, 1, 1));
Ok(result)
}
pub fn broadcast_add<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
let output_shape = broadcast_shapes(a.shape(), b.shape())?;
broadcast_binary_op(a, b, &output_shape, ElementWiseOp::Add)
}
pub fn broadcast_multiply<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
) -> Result<GpuArray<T>> {
let output_shape = broadcast_shapes(a.shape(), b.shape())?;
broadcast_binary_op(a, b, &output_shape, ElementWiseOp::Multiply)
}
fn broadcast_shapes(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
let max_dims = shape_a.len().max(shape_b.len());
let mut result = vec![1; max_dims];
for i in 0..max_dims {
let dim_a = if i < shape_a.len() {
shape_a[shape_a.len() - 1 - i]
} else {
1
};
let dim_b = if i < shape_b.len() {
shape_b[shape_b.len() - 1 - i]
} else {
1
};
if dim_a == dim_b {
result[max_dims - 1 - i] = dim_a;
} else if dim_a == 1 {
result[max_dims - 1 - i] = dim_b;
} else if dim_b == 1 {
result[max_dims - 1 - i] = dim_a;
} else {
return Err(NumRs2Error::ShapeMismatch {
expected: shape_a.to_vec(),
actual: shape_b.to_vec(),
});
}
}
Ok(result)
}
fn broadcast_binary_op<T: bytemuck::Pod + bytemuck::Zeroable>(
a: &GpuArray<T>,
b: &GpuArray<T>,
output_shape: &[usize],
op: ElementWiseOp,
) -> Result<GpuArray<T>> {
if a.shape() == b.shape() {
return element_wise_op(a, b, op);
}
Err(NumRs2Error::NotImplemented(
"Full broadcasting support is not yet implemented for GPU arrays".to_string(),
))
}
pub fn copy_with_format<T: bytemuck::Pod + bytemuck::Zeroable>(
src: &GpuArray<T>,
) -> Result<GpuArray<T>> {
let context = src.context().clone();
let result = GpuArray::<T>::new_with_shape(src.shape(), context.clone())?;
let mut encoder = context
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("NumRS2 Copy Encoder"),
});
encoder.copy_buffer_to_buffer(
src.buffer(),
0,
result.buffer(),
0,
(src.size() * src.element_size()) as u64,
);
context.queue().submit(std::iter::once(encoder.finish()));
Ok(result)
}
pub fn fill<T: bytemuck::Pod + bytemuck::Zeroable + Clone>(
array: &mut GpuArray<T>,
value: T,
) -> Result<()> {
let data = vec![value; array.size()];
array
.context()
.queue()
.write_buffer(array.buffer(), 0, bytemuck::cast_slice(&data));
Ok(())
}
pub fn slice<T: bytemuck::Pod + bytemuck::Zeroable>(
array: &GpuArray<T>,
ranges: &[(usize, usize)],
) -> Result<GpuArray<T>> {
if ranges.len() != array.shape().len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Number of slice ranges ({}) does not match array dimensions ({})",
ranges.len(),
array.shape().len()
)));
}
let mut new_shape = Vec::with_capacity(ranges.len());
for (i, (start, end)) in ranges.iter().enumerate() {
if *start >= *end || *end > array.shape()[i] {
return Err(NumRs2Error::IndexError(format!(
"Invalid range [{}..{}] for dimension {} with size {}",
start,
end,
i,
array.shape()[i]
)));
}
new_shape.push(*end - *start);
}
let cpu_array = array.to_array()?;
let mut slice_spec = String::new();
for (i, (start, end)) in ranges.iter().enumerate() {
if i > 0 {
slice_spec.push_str(", ");
}
slice_spec.push_str(&format!("{}..{}", start, end));
}
Err(NumRs2Error::NotImplemented(
"GPU array slicing is not yet fully implemented".to_string(),
))
}