use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const WHERE_COND_F32_F32: &str = include_str!("where_cond_f32_f32.wgsl");
const WHERE_COND_F32_I32: &str = include_str!("where_cond_f32_i32.wgsl");
const WHERE_COND_F32_U32: &str = include_str!("where_cond_f32_u32.wgsl");
const WHERE_COND_I32_F32: &str = include_str!("where_cond_i32_f32.wgsl");
const WHERE_COND_I32_I32: &str = include_str!("where_cond_i32_i32.wgsl");
const WHERE_COND_I32_U32: &str = include_str!("where_cond_i32_u32.wgsl");
const WHERE_COND_U32_F32: &str = include_str!("where_cond_u32_f32.wgsl");
const WHERE_COND_U32_I32: &str = include_str!("where_cond_u32_i32.wgsl");
const WHERE_COND_U32_U32: &str = include_str!("where_cond_u32_u32.wgsl");
const WHERE_BC_F32_F32: &str = include_str!("where_broadcast_cond_f32_f32.wgsl");
const WHERE_BC_F32_I32: &str = include_str!("where_broadcast_cond_f32_i32.wgsl");
const WHERE_BC_F32_U32: &str = include_str!("where_broadcast_cond_f32_u32.wgsl");
const WHERE_BC_I32_F32: &str = include_str!("where_broadcast_cond_i32_f32.wgsl");
const WHERE_BC_I32_I32: &str = include_str!("where_broadcast_cond_i32_i32.wgsl");
const WHERE_BC_I32_U32: &str = include_str!("where_broadcast_cond_i32_u32.wgsl");
const WHERE_BC_U32_F32: &str = include_str!("where_broadcast_cond_u32_f32.wgsl");
const WHERE_BC_U32_I32: &str = include_str!("where_broadcast_cond_u32_i32.wgsl");
const WHERE_BC_U32_U32: &str = include_str!("where_broadcast_cond_u32_u32.wgsl");
fn where_shader_info(
cond_dtype: DType,
out_dtype: DType,
) -> Result<(&'static str, &'static str, &'static str)> {
Ok(match (cond_dtype, out_dtype) {
(DType::F32, DType::F32) => (
WHERE_COND_F32_F32,
"where_cond_f32_f32",
"where_cond_f32_f32",
),
(DType::F32, DType::I32) => (
WHERE_COND_F32_I32,
"where_cond_f32_i32",
"where_cond_f32_i32",
),
(DType::F32, DType::U32) => (
WHERE_COND_F32_U32,
"where_cond_f32_u32",
"where_cond_f32_u32",
),
(DType::I32, DType::F32) => (
WHERE_COND_I32_F32,
"where_cond_i32_f32",
"where_cond_i32_f32",
),
(DType::I32, DType::I32) => (
WHERE_COND_I32_I32,
"where_cond_i32_i32",
"where_cond_i32_i32",
),
(DType::I32, DType::U32) => (
WHERE_COND_I32_U32,
"where_cond_i32_u32",
"where_cond_i32_u32",
),
(DType::U32, DType::F32) => (
WHERE_COND_U32_F32,
"where_cond_u32_f32",
"where_cond_u32_f32",
),
(DType::U32, DType::I32) => (
WHERE_COND_U32_I32,
"where_cond_u32_i32",
"where_cond_u32_i32",
),
(DType::U32, DType::U32) => (
WHERE_COND_U32_U32,
"where_cond_u32_u32",
"where_cond_u32_u32",
),
_ => {
return Err(Error::UnsupportedDType {
dtype: cond_dtype,
op: "where_cond (WebGPU)",
});
}
})
}
fn where_broadcast_shader_info(
cond_dtype: DType,
out_dtype: DType,
) -> Result<(&'static str, &'static str, &'static str)> {
Ok(match (cond_dtype, out_dtype) {
(DType::F32, DType::F32) => (
WHERE_BC_F32_F32,
"where_broadcast_cond_f32_f32",
"where_broadcast_cond_f32_f32",
),
(DType::F32, DType::I32) => (
WHERE_BC_F32_I32,
"where_broadcast_cond_f32_i32",
"where_broadcast_cond_f32_i32",
),
(DType::F32, DType::U32) => (
WHERE_BC_F32_U32,
"where_broadcast_cond_f32_u32",
"where_broadcast_cond_f32_u32",
),
(DType::I32, DType::F32) => (
WHERE_BC_I32_F32,
"where_broadcast_cond_i32_f32",
"where_broadcast_cond_i32_f32",
),
(DType::I32, DType::I32) => (
WHERE_BC_I32_I32,
"where_broadcast_cond_i32_i32",
"where_broadcast_cond_i32_i32",
),
(DType::I32, DType::U32) => (
WHERE_BC_I32_U32,
"where_broadcast_cond_i32_u32",
"where_broadcast_cond_i32_u32",
),
(DType::U32, DType::F32) => (
WHERE_BC_U32_F32,
"where_broadcast_cond_u32_f32",
"where_broadcast_cond_u32_f32",
),
(DType::U32, DType::I32) => (
WHERE_BC_U32_I32,
"where_broadcast_cond_u32_i32",
"where_broadcast_cond_u32_i32",
),
(DType::U32, DType::U32) => (
WHERE_BC_U32_U32,
"where_broadcast_cond_u32_u32",
"where_broadcast_cond_u32_u32",
),
_ => {
return Err(Error::UnsupportedDType {
dtype: cond_dtype,
op: "where_broadcast_cond (WebGPU)",
});
}
})
}
#[allow(clippy::too_many_arguments)]
pub fn launch_where_op(
cache: &PipelineCache,
queue: &Queue,
cond: &Buffer,
x: &Buffer,
y: &Buffer,
out: &Buffer,
params_buffer: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
launch_where_generic_op(
cache,
queue,
cond,
x,
y,
out,
params_buffer,
numel,
DType::F32,
dtype,
)
}
#[allow(clippy::too_many_arguments)]
pub fn launch_where_generic_op(
cache: &PipelineCache,
queue: &Queue,
cond: &Buffer,
x: &Buffer,
y: &Buffer,
out: &Buffer,
params_buffer: &Buffer,
numel: usize,
cond_dtype: DType,
out_dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = where_shader_info(cond_dtype, out_dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[cond, x, y, out, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("where_cond"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("where_cond"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn launch_where_broadcast_op(
cache: &PipelineCache,
queue: &Queue,
cond: &Buffer,
x: &Buffer,
y: &Buffer,
out: &Buffer,
cond_strides: &Buffer,
x_strides: &Buffer,
y_strides: &Buffer,
out_shape: &Buffer,
params_buffer: &Buffer,
numel: usize,
cond_dtype: DType,
out_dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = where_broadcast_shader_info(cond_dtype, out_dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 8,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
cond,
x,
y,
out,
cond_strides,
x_strides,
y_strides,
out_shape,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("where_broadcast"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("where_broadcast"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}