use super::cache::get_or_create_client;
fn wgpu_err(e: super::device::WgpuError) -> crate::error::Error {
crate::error::Error::Backend(format!("WebGPU client error: {e}"))
}
use super::client::WgpuClient;
use super::device::WgpuDevice;
use super::shaders;
use crate::runtime::{Allocator, NoOpGraph, Runtime, RuntimeClient};
use std::time::Duration;
#[derive(Clone, Debug, Default)]
pub struct WgpuRuntime;
impl Runtime for WgpuRuntime {
type Device = WgpuDevice;
type Client = WgpuClient;
type Allocator = super::WgpuAllocator;
type Graph = NoOpGraph;
type RawHandle = super::WgpuRawHandle;
type DType = crate::dtype::DType;
fn name() -> &'static str {
"wgpu"
}
fn supports_graph_capture() -> bool {
false }
fn capture_graph<F, T>(client: &Self::Client, f: F) -> crate::error::Result<(Self::Graph, T)>
where
F: FnOnce(&Self::Client) -> crate::error::Result<T>,
{
let result = f(client)?;
Ok((NoOpGraph, result))
}
fn allocate(size_bytes: usize, device: &Self::Device) -> crate::error::Result<u64> {
if size_bytes == 0 {
return Ok(0);
}
let client = get_or_create_client(device).map_err(wgpu_err)?;
client.allocator.allocate(size_bytes)
}
fn deallocate(ptr: u64, size_bytes: usize, device: &Self::Device) {
if ptr == 0 {
return;
}
if let Ok(client) = get_or_create_client(device) {
client.allocator.deallocate(ptr, size_bytes);
}
}
fn copy_to_device(src: &[u8], dst: u64, device: &Self::Device) -> crate::error::Result<()> {
if src.is_empty() || dst == 0 {
return Ok(());
}
let client = get_or_create_client(device).map_err(wgpu_err)?;
let buffer = super::client::get_buffer(dst).ok_or_else(|| {
crate::error::Error::Backend("Buffer not found for copy_to_device".into())
})?;
client.queue.write_buffer(&buffer, 0, src);
client.synchronize();
Ok(())
}
fn copy_from_device(
src: u64,
dst: &mut [u8],
device: &Self::Device,
) -> crate::error::Result<()> {
if dst.is_empty() || src == 0 {
return Ok(());
}
let client = get_or_create_client(device).map_err(wgpu_err)?;
let buffer = super::client::get_buffer(src).ok_or_else(|| {
crate::error::Error::Backend("Buffer not found for copy_from_device".into())
})?;
let staging = client.create_staging_buffer("copy_staging", dst.len() as u64);
let mut encoder =
client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy_from_device"),
});
encoder.copy_buffer_to_buffer(&buffer, 0, &staging, 0, dst.len() as u64);
client.submit_and_wait(encoder);
let slice = staging.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
client
.wgpu_device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: Some(Duration::from_secs(60)),
})
.map_err(|e| {
crate::error::Error::Backend(format!(
"GPU poll failed during copy_from_device: {e}"
))
})?;
let map_result = receiver.recv().map_err(|_| {
crate::error::Error::Backend(
"map_async callback was not invoked during copy_from_device".into(),
)
})?;
map_result.map_err(|e| {
crate::error::Error::Backend(format!("map_async failed during copy_from_device: {e}"))
})?;
{
let data = slice.get_mapped_range();
dst.copy_from_slice(&data[..dst.len()]);
}
staging.unmap();
Ok(())
}
fn copy_within_device(
src: u64,
dst: u64,
size_bytes: usize,
device: &Self::Device,
) -> crate::error::Result<()> {
if size_bytes == 0 || src == 0 || dst == 0 {
return Ok(());
}
let client = get_or_create_client(device).map_err(wgpu_err)?;
let src_buffer = super::client::get_buffer(src)
.ok_or_else(|| crate::error::Error::Backend("Source buffer not found".into()))?;
let dst_buffer = super::client::get_buffer(dst)
.ok_or_else(|| crate::error::Error::Backend("Destination buffer not found".into()))?;
let mut encoder =
client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy_within_device"),
});
encoder.copy_buffer_to_buffer(&src_buffer, 0, &dst_buffer, 0, size_bytes as u64);
client.submit_and_wait(encoder);
Ok(())
}
fn copy_strided(
src_handle: u64,
src_byte_offset: usize,
dst_handle: u64,
shape: &[usize],
strides: &[isize],
elem_size: usize,
device: &Self::Device,
) -> crate::error::Result<()> {
if src_handle == 0 || dst_handle == 0 || shape.is_empty() {
return Ok(());
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(());
}
if shape.len() > 8 {
return Err(crate::error::Error::BackendLimitation {
backend: "wgpu",
operation: "copy_strided",
reason: format!("rank {} exceeds maximum of 8 dimensions", shape.len()),
});
}
if !elem_size.is_multiple_of(4) {
return Err(crate::error::Error::BackendLimitation {
backend: "wgpu",
operation: "copy_strided",
reason: format!(
"element size {} is not a multiple of 4 bytes; sub-4-byte dtypes (F16, Bool) require byte-level addressing not supported by the u32-based copy shader",
elem_size
),
});
}
if !src_byte_offset.is_multiple_of(4) {
return Err(crate::error::Error::BackendLimitation {
backend: "wgpu",
operation: "copy_strided",
reason: format!(
"source byte offset {} is not 4-byte aligned",
src_byte_offset
),
});
}
let client = get_or_create_client(device).map_err(wgpu_err)?;
let src_buffer = super::client::get_buffer(src_handle).ok_or_else(|| {
crate::error::Error::Backend("Source buffer not found for copy_strided".into())
})?;
let dst_buffer = super::client::get_buffer(dst_handle).ok_or_else(|| {
crate::error::Error::Backend("Destination buffer not found for copy_strided".into())
})?;
let module = client
.wgpu_device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("strided_copy"),
source: wgpu::ShaderSource::Wgsl(shaders::copy::STRIDED_COPY_SHADER.into()),
});
let bind_group_layout =
client
.wgpu_device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("strided_copy_layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout =
client
.wgpu_device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("strided_copy_pipeline_layout"),
bind_group_layouts: &[&bind_group_layout],
immediate_size: 0,
});
let pipeline =
client
.wgpu_device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("strided_copy_pipeline"),
layout: Some(&pipeline_layout),
module: &module,
entry_point: Some("strided_copy"),
compilation_options: Default::default(),
cache: None,
});
let elem_size_units = (elem_size / 4) as u32;
let src_offset_units = (src_byte_offset / 4) as u32;
let mut shape_arr = [0u32; 8];
let mut strides_arr = [0i32; 8];
for (i, &s) in shape.iter().enumerate().take(8) {
shape_arr[i] = s as u32;
}
for (i, &s) in strides.iter().enumerate().take(8) {
strides_arr[i] = s as i32;
}
let mut params_data = Vec::with_capacity(80);
params_data.extend_from_slice(&(numel as u32).to_le_bytes());
params_data.extend_from_slice(&(shape.len() as u32).to_le_bytes());
params_data.extend_from_slice(&elem_size_units.to_le_bytes());
params_data.extend_from_slice(&src_offset_units.to_le_bytes());
for s in shape_arr {
params_data.extend_from_slice(&s.to_le_bytes());
}
for s in strides_arr {
params_data.extend_from_slice(&s.to_le_bytes());
}
let params_buffer = client.wgpu_device.create_buffer(&wgpu::BufferDescriptor {
label: Some("strided_copy_params"),
size: params_data.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
client.queue.write_buffer(¶ms_buffer, 0, ¶ms_data);
let bind_group = client
.wgpu_device
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("strided_copy_bind_group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: src_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: dst_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buffer.as_entire_binding(),
},
],
});
let workgroups = shaders::workgroup_count(numel);
let mut encoder =
client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("strided_copy"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("strided_copy_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(workgroups, 1, 1);
}
client.submit_and_wait(encoder);
Ok(())
}
fn default_device() -> Self::Device {
WgpuDevice::new(0)
}
fn default_client(device: &Self::Device) -> Self::Client {
get_or_create_client(device).expect("Failed to create WebGPU client")
}
fn raw_handle(client: &Self::Client) -> &Self::RawHandle {
&client.raw_handle
}
}
pub fn wgpu_device() -> WgpuDevice {
WgpuDevice::new(0)
}
pub fn wgpu_device_id(index: usize) -> WgpuDevice {
WgpuDevice::new(index)
}
pub fn is_wgpu_available() -> bool {
super::device::query_adapter_info_blocking(0).is_ok()
}