use crate::gpu::{buffer::GpuBuffer, ops::BinaryOp};
use crate::{Result, TensorError};
use std::sync::Arc;
use wgpu::util::DeviceExt;
#[derive(Debug, Clone, Copy)]
pub enum CoalescingStrategy {
Linear,
Tiled2D,
Vectorized,
BankConflictFree,
Adaptive,
}
#[derive(Debug)]
pub struct TensorCharacteristics {
pub total_elements: usize,
pub width: usize,
pub height: usize,
pub depth: usize,
pub is_contiguous: bool,
pub dtype_size: usize,
}
impl TensorCharacteristics {
pub fn from_shape(shape: &[usize], dtype_size: usize) -> Self {
let total_elements = shape.iter().product();
let (width, height, depth) = match shape.len() {
1 => (shape[0], 1, 1),
2 => (shape[1], shape[0], 1),
3 => (shape[2], shape[1], shape[0]),
_ => {
let width = shape[shape.len() - 1];
let height = shape[shape.len() - 2];
let depth = shape[0..shape.len() - 2].iter().product();
(width, height, depth)
}
};
Self {
total_elements,
width,
height,
depth,
is_contiguous: true, dtype_size,
}
}
pub fn select_optimal_strategy(&self) -> CoalescingStrategy {
if self.total_elements > 1_000_000 {
return CoalescingStrategy::Vectorized;
}
if self.width > 16 && self.height > 16 {
return CoalescingStrategy::Tiled2D;
}
if self.total_elements < 1024 || self.height == 1 {
return CoalescingStrategy::Linear;
}
CoalescingStrategy::Adaptive
}
}
pub fn execute_binary_op_coalesced<T>(
input_a: &GpuBuffer<T>,
input_b: &GpuBuffer<T>,
operation: BinaryOp,
output_len: usize,
shape: &[usize],
strategy: CoalescingStrategy,
) -> Result<GpuBuffer<T>>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
let device = &input_a.device;
let queue = &input_a.queue;
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Binary Op Coalesced Output"),
size: (output_len * std::mem::size_of::<T>()) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let characteristics = TensorCharacteristics::from_shape(shape, std::mem::size_of::<T>());
let shape_metadata = [
characteristics.width as u32,
characteristics.height as u32,
characteristics.depth as u32,
1u32, ];
let shape_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Shape Metadata"),
contents: bytemuck::cast_slice(&shape_metadata),
usage: wgpu::BufferUsages::STORAGE,
});
let (shader_source, dispatch_config) =
select_shader_and_config::<T>(strategy, &characteristics, operation);
let compute_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Binary Op Coalesced Shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Binary Op Coalesced Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Binary Op Coalesced Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Binary Op Coalesced Pipeline"),
layout: Some(&compute_pipeline_layout),
module: &compute_shader,
entry_point: Some(dispatch_config.entry_point),
cache: None,
compilation_options: Default::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Binary Op Coalesced Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_a.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: input_b.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: output_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: shape_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Binary Op Coalesced Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Binary Op Coalesced Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&compute_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(
dispatch_config.workgroups_x,
dispatch_config.workgroups_y,
dispatch_config.workgroups_z,
);
}
queue.submit(Some(encoder.finish()));
Ok(GpuBuffer::from_wgpu_buffer(
output_buffer,
Arc::clone(&input_a.device),
Arc::clone(&input_a.queue),
input_a.device_enum().clone(),
output_len,
))
}
#[derive(Debug)]
struct DispatchConfig {
entry_point: &'static str,
workgroups_x: u32,
workgroups_y: u32,
workgroups_z: u32,
}
fn select_shader_and_config<T>(
strategy: CoalescingStrategy,
characteristics: &TensorCharacteristics,
operation: BinaryOp,
) -> (&'static str, DispatchConfig) {
let type_name = std::any::type_name::<T>();
let entry_point = match operation {
BinaryOp::Add => match strategy {
CoalescingStrategy::Linear => "add_op",
CoalescingStrategy::Tiled2D => match type_name {
"f32" => "add_op_coalesced",
"f64" => "add_op_coalesced_f64",
"i32" => "add_op_coalesced_i32",
"i64" => "add_op_coalesced_i64",
"u32" => "add_op_coalesced_u32",
"u64" => "add_op_coalesced_u64",
_ => "add_op_coalesced", },
CoalescingStrategy::Vectorized => match type_name {
"f32" => "add_op_vectorized",
"f64" => "add_op_vectorized_f64",
"i32" => "add_op_vectorized_i32",
"i64" => "add_op_vectorized_i64",
"u32" => "add_op_vectorized_u32",
"u64" => "add_op_vectorized_u64",
_ => "add_op_vectorized", },
CoalescingStrategy::BankConflictFree => "add_op_bank_conflict_free",
CoalescingStrategy::Adaptive => "add_op_adaptive",
},
BinaryOp::Sub => match strategy {
CoalescingStrategy::Tiled2D => match type_name {
"f32" => "sub_op_coalesced",
"f64" => "sub_op_coalesced_f64",
"i32" => "sub_op_coalesced_i32",
"i64" => "sub_op_coalesced_i64",
"u32" => "sub_op_coalesced_u32",
"u64" => "sub_op_coalesced_u64",
_ => "sub_op_coalesced", },
_ => "sub_op", },
BinaryOp::Mul => match strategy {
CoalescingStrategy::Tiled2D => match type_name {
"f32" => "mul_op_coalesced",
"f64" => "mul_op_coalesced_f64",
"i32" => "mul_op_coalesced_i32",
"i64" => "mul_op_coalesced_i64",
"u32" => "mul_op_coalesced_u32",
"u64" => "mul_op_coalesced_u64",
_ => "mul_op_coalesced", },
_ => "mul_op", },
BinaryOp::Div => match strategy {
CoalescingStrategy::Tiled2D => match type_name {
"f32" => "div_op_coalesced",
"f64" => "div_op_coalesced_f64",
"i32" => "div_op_coalesced_i32",
"i64" => "div_op_coalesced_i64",
"u32" => "div_op_coalesced_u32",
"u64" => "div_op_coalesced_u64",
_ => "div_op_coalesced", },
_ => "div_op", },
BinaryOp::Pow => match strategy {
CoalescingStrategy::Tiled2D => match type_name {
"f32" => "pow_op_coalesced",
"f64" => "pow_op_coalesced_f64",
"i32" => "pow_op_coalesced_i32",
"i64" => "pow_op_coalesced_i64",
"u32" => "pow_op_coalesced_u32",
"u64" => "pow_op_coalesced_u64",
_ => "pow_op_coalesced", },
_ => "pow_op", },
BinaryOp::PReLU => match strategy {
CoalescingStrategy::Tiled2D => match type_name {
"f32" => "prelu_op_coalesced",
"f64" => "prelu_op_coalesced_f64",
"i32" => "prelu_op_coalesced_i32",
"i64" => "prelu_op_coalesced_i64",
"u32" => "prelu_op_coalesced_u32",
"u64" => "prelu_op_coalesced_u64",
_ => "prelu_op_coalesced", },
_ => "prelu_op", },
_ => "add_op", };
let dispatch_config = match strategy {
CoalescingStrategy::Linear => DispatchConfig {
entry_point,
workgroups_x: ((characteristics.total_elements + 63) / 64) as u32,
workgroups_y: 1,
workgroups_z: 1,
},
CoalescingStrategy::Tiled2D => {
let (tile_x, tile_y) = match type_name {
"f64" | "i64" | "u64" => (16, 8), _ => (16, 16), };
DispatchConfig {
entry_point,
workgroups_x: ((characteristics.width + tile_x - 1) / tile_x) as u32,
workgroups_y: ((characteristics.height + tile_y - 1) / tile_y) as u32,
workgroups_z: characteristics.depth.max(1) as u32,
}
}
CoalescingStrategy::Vectorized => {
let vector_size = match type_name {
"f64" | "i64" | "u64" => 2, _ => 4, };
DispatchConfig {
entry_point,
workgroups_x: ((characteristics.total_elements + (64 * vector_size - 1))
/ (64 * vector_size)) as u32,
workgroups_y: 1,
workgroups_z: 1,
}
}
CoalescingStrategy::BankConflictFree => DispatchConfig {
entry_point,
workgroups_x: ((characteristics.width + 31) / 32) as u32,
workgroups_y: ((characteristics.height + 7) / 8) as u32,
workgroups_z: 1,
},
CoalescingStrategy::Adaptive => DispatchConfig {
entry_point,
workgroups_x: ((characteristics.total_elements + 255) / 256) as u32,
workgroups_y: 1,
workgroups_z: 1,
},
};
let shader_source = match type_name {
"f32" => include_str!("shaders/binary_ops_coalesced.wgsl"),
"f64" => include_str!("shaders/binary_ops_coalesced_f64.wgsl"),
"i32" => include_str!("shaders/binary_ops_coalesced_i32.wgsl"),
"i64" => include_str!("shaders/binary_ops_coalesced_i64.wgsl"),
"u32" => include_str!("shaders/binary_ops_coalesced_u32.wgsl"),
"u64" => include_str!("shaders/binary_ops_coalesced_u64.wgsl"),
_ => include_str!("shaders/binary_ops_coalesced.wgsl"), };
(shader_source, dispatch_config)
}
pub fn benchmark_coalescing_strategies<T>(
input_a: &GpuBuffer<T>,
input_b: &GpuBuffer<T>,
operation: BinaryOp,
output_len: usize,
shape: &[usize],
) -> Result<CoalescingStrategy>
where
T: bytemuck::Pod + bytemuck::Zeroable + Clone + Send + Sync + 'static,
{
let characteristics = TensorCharacteristics::from_shape(shape, std::mem::size_of::<T>());
let strategies = [
CoalescingStrategy::Linear,
CoalescingStrategy::Tiled2D,
CoalescingStrategy::Vectorized,
CoalescingStrategy::BankConflictFree,
CoalescingStrategy::Adaptive,
];
let mut best_strategy = CoalescingStrategy::Linear;
let mut best_time = std::time::Duration::MAX;
for strategy in strategies.iter() {
let start = std::time::Instant::now();
for _ in 0..10 {
let _result = execute_binary_op_coalesced(
input_a, input_b, operation, output_len, shape, *strategy,
)?;
}
let elapsed = start.elapsed();
if elapsed < best_time {
best_time = elapsed;
best_strategy = *strategy;
}
}
Ok(best_strategy)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_characteristics_strategy_selection() {
let small_chars = TensorCharacteristics::from_shape(&[100], 4);
assert!(matches!(
small_chars.select_optimal_strategy(),
CoalescingStrategy::Linear
));
let large_chars = TensorCharacteristics::from_shape(&[2000, 2000], 4);
assert!(matches!(
large_chars.select_optimal_strategy(),
CoalescingStrategy::Vectorized
));
let tiled_chars = TensorCharacteristics::from_shape(&[100, 100], 4);
assert!(matches!(
tiled_chars.select_optimal_strategy(),
CoalescingStrategy::Tiled2D
));
}
#[test]
fn test_dispatch_config_calculation() {
let chars = TensorCharacteristics::from_shape(&[64, 64], 4);
let (_, config) =
select_shader_and_config::<f32>(CoalescingStrategy::Tiled2D, &chars, BinaryOp::Add);
assert_eq!(config.workgroups_x, 4); assert_eq!(config.workgroups_y, 4); assert_eq!(config.workgroups_z, 1);
}
}