use std::collections::HashMap;
#[cfg(feature = "wgpu_backend")]
use std::sync::{Arc, Mutex};
use crate::gpu::{GpuBufferImpl, GpuCompilerImpl, GpuContextImpl, GpuError, GpuKernelImpl};
#[cfg(feature = "wgpu_backend")]
#[allow(unused_imports)]
use wgpu::{
util::DeviceExt, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, Buffer,
BufferBindingType, BufferDescriptor, BufferUsages, ComputePipeline, Device, DeviceDescriptor,
Features, Instance, InstanceDescriptor, Limits, PowerPreference, Queue, RequestAdapterOptions,
ShaderModuleDescriptor, ShaderSource, ShaderStages, StorageTextureAccess, TextureFormat,
TextureSampleType, TextureViewDimension,
};
#[cfg(not(feature = "wgpu_backend"))]
type WgpuDevice = *mut std::ffi::c_void;
#[cfg(not(feature = "wgpu_backend"))]
type WgpuQueue = *mut std::ffi::c_void;
#[cfg(not(feature = "wgpu_backend"))]
type WgpuBuffer = *mut std::ffi::c_void;
#[cfg(not(feature = "wgpu_backend"))]
type WgpuComputePipeline = *mut std::ffi::c_void;
#[allow(dead_code)]
const ADAM_SHADER_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read_write> params: array<f32>;
@group(0) @binding(1) var<storage, read> grads: array<f32>;
@group(0) @binding(2) var<storage, read_write> m: array<f32>;
@group(0) @binding(3) var<storage, read_write> v: array<f32>;
struct AdamUniforms {
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
bias_correction1: f32,
bias_correction2: f32,
n: u32,
};
@group(0) @binding(4) var<uniform> uniforms: AdamUniforms;
@compute @workgroup_size(64)
#[allow(dead_code)]
fn adam_update(@builtin(global_invocation_id) global_id: vec3<u32>) {
let idx = global_id.x;
if (idx >= uniforms.n) {
return;
}
var grad = grads[idx];
// Apply weight decay
if (uniforms.weight_decay > 0.0) {
grad += uniforms.weight_decay * params[idx];
}
// Update biased first moment estimate
m[idx] = uniforms.beta1 * m[idx] + (1.0 - uniforms.beta1) * grad;
// Update biased second raw moment estimate
v[idx] = uniforms.beta2 * v[idx] + (1.0 - uniforms.beta2) * grad * grad;
// Compute bias-corrected moment estimates
let m_hat = m[idx] / uniforms.bias_correction1;
let v_hat = v[idx] / uniforms.bias_correction2;
// Update parameters
params[idx] -= uniforms.lr * m_hat / (sqrt(v_hat) + uniforms.eps);
}
"#;
#[allow(dead_code)]
const GEMM_SHADER_WGSL: &str = r#"
@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
@group(0) @binding(2) var<storage, read_write> matrix_c: array<f32>;
struct GemmUniforms {
M: u32,
N: u32,
K: u32,
alpha: f32,
beta: f32,
};
@group(0) @binding(3) var<uniform> uniforms: GemmUniforms;
@compute @workgroup_size(8, 8)
#[allow(dead_code)]
fn gemm(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row = global_id.x;
let col = global_id.y;
if (row >= uniforms.M || col >= uniforms.N) {
return;
}
var sum = 0.0;
for (var k = 0u; k < uniforms.K; k++) {
sum += matrix_a[row * uniforms.K + k] * matrix_b[k * uniforms.N + col];
}
let idx = row * uniforms.N + col;
matrix_c[idx] = uniforms.alpha * sum + uniforms.beta * matrix_c[idx];
}
"#;
pub struct WebGPUContext {
#[cfg(feature = "wgpu_backend")]
device: Arc<Device>,
#[cfg(feature = "wgpu_backend")]
queue: Arc<Queue>,
#[cfg(not(feature = "wgpu_backend"))]
device: Arc<WgpuDevice>,
#[cfg(not(feature = "wgpu_backend"))]
queue: Arc<WgpuQueue>,
compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
}
unsafe impl Send for WebGPUContext {}
unsafe impl Sync for WebGPUContext {}
impl WebGPUContext {
pub fn new() -> Result<Self, GpuError> {
#[cfg(feature = "wgpu_backend")]
{
let instance_desc = InstanceDescriptor {
backends: Backends::all(),
flags: wgpu::InstanceFlags::default(),
memory_budget_thresholds: Default::default(),
backend_options: Default::default(),
display: None,
};
let instance = Instance::new(instance_desc);
let adapter = pollster::block_on(instance.request_adapter(&RequestAdapterOptions {
power_preference: PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))
.map_err(|e| GpuError::Other(format!("Failed to find WebGPU adapter: {e}")))?;
let device_descriptor = DeviceDescriptor {
label: Some("SciRS2 WebGPU Device"),
required_features: Features::empty(),
required_limits: Limits::default(),
..Default::default()
};
let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor))
.map_err(|e| GpuError::Other(format!("{e}")))?;
Ok(Self {
device: Arc::new(device),
queue: Arc::new(queue),
compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
}
#[cfg(not(feature = "wgpu_backend"))]
{
let device = Self::initialize_webgpu()?;
let queue = Self::create_queue(device)?;
Ok(Self {
device,
queue,
compiled_shaders: Arc::new(Mutex::new(HashMap::new())),
memory_pool: Arc::new(Mutex::new(WebGPUMemoryPool::new(1024 * 1024 * 1024))), })
}
}
pub fn is_available() -> bool {
#[cfg(feature = "wgpu_backend")]
{
let instance_desc = InstanceDescriptor {
backends: Backends::all(),
flags: wgpu::InstanceFlags::default(),
memory_budget_thresholds: Default::default(),
backend_options: Default::default(),
display: None,
};
let instance = Instance::new(instance_desc);
pollster::block_on(async {
instance
.request_adapter(&RequestAdapterOptions {
power_preference: PowerPreference::default(),
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.is_ok()
})
}
#[cfg(not(feature = "wgpu_backend"))]
{
false
}
}
fn compile_shader_internal(&self, source: &str, name: &str) -> Result<WebGPUShader, GpuError> {
#[cfg(feature = "wgpu_backend")]
{
let shader_module = self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: ShaderSource::Wgsl(source.into()),
});
let entry_point = Self::extract_entry_point(source).unwrap_or("main");
let (bind_group_layout, binding_infos) =
self.create_bind_group_layout_from_source(source, name)?;
let pipeline_layout =
self.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("{}_layout", name)),
bind_group_layouts: &[Some(&bind_group_layout)],
..Default::default()
});
let compute_pipeline =
self.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(&format!("{}_pipeline", name)),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
});
Ok(WebGPUShader {
pipeline: compute_pipeline,
bind_group_layout,
name: name.to_string(),
binding_infos,
})
}
#[cfg(not(feature = "wgpu_backend"))]
{
let pipeline = Self::compile_wgsl_source(source, name)?;
Ok(WebGPUShader {
pipeline,
bind_group_layout: std::ptr::null_mut(),
name: name.to_string(),
binding_infos: Vec::new(),
})
}
}
#[cfg(feature = "wgpu_backend")]
fn create_bind_group_layout_from_source(
&self,
source: &str,
name: &str,
) -> Result<(BindGroupLayout, Vec<BindingInfo>), GpuError> {
#[derive(Default)]
struct PendingAttr {
group: Option<u32>,
binding: Option<u32>,
}
let mut pending = PendingAttr::default();
let mut entries: Vec<BindGroupLayoutEntry> = Vec::new();
let mut infos: Vec<BindingInfo> = Vec::new();
fn strip_comment(line: &str) -> &str {
line.split_once("//").map(|(a, _)| a).unwrap_or(line)
}
for raw_line in source.lines() {
let line = strip_comment(raw_line).trim();
if line.is_empty() {
continue;
}
if let Some(i) = line.find("@group(") {
if let Some(end) = line[i + 7..].find(')') {
if let Ok(g) = line[i + 7..i + 7 + end].parse::<u32>() {
pending.group = Some(g);
}
}
}
if let Some(i) = line.find("@binding(") {
if let Some(end) = line[i + 9..].find(')') {
if let Ok(b) = line[i + 9..i + 9 + end].parse::<u32>() {
pending.binding = Some(b);
}
}
}
if line.contains("var<") {
if pending.group.unwrap_or(0) == 0 {
let binding_num = pending.binding.unwrap_or_else(|| entries.len() as u32);
let name = extract_var_name(line).unwrap_or("");
let storage = line.contains("var<storage");
let uniform = line.contains("var<uniform");
let read_only = storage
&& (line.contains(", read>")
|| line.contains("var<storage, read>")
|| line.contains("var<storage, read,"));
if storage {
entries.push(BindGroupLayoutEntry {
binding: binding_num,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
infos.push(BindingInfo {
binding: binding_num,
name: name.to_string(),
kind: if read_only {
BindingKind::StorageRead
} else {
BindingKind::StorageRw
},
});
} else if uniform {
entries.push(BindGroupLayoutEntry {
binding: binding_num,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
infos.push(BindingInfo {
binding: binding_num,
name: name.to_string(),
kind: BindingKind::Uniform,
});
}
}
pending = PendingAttr::default();
}
}
if entries.is_empty() {
entries.push(BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
infos.push(BindingInfo {
binding: 0,
name: "_unnamed".into(),
kind: BindingKind::StorageRw,
});
}
let mut seen = std::collections::HashSet::new();
let mut dedup_entries = Vec::new();
let mut dedup_infos = Vec::new();
for (e, info) in entries.into_iter().zip(infos) {
if seen.insert(e.binding) {
dedup_entries.push(e);
dedup_infos.push(info);
}
}
let bind_group_layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some(&format!("{}_bind_group_layout", name)),
entries: &dedup_entries,
});
Ok((bind_group_layout, dedup_infos))
}
#[cfg(feature = "wgpu_backend")]
pub fn allocate_device_memory(&self, size: usize) -> Result<Buffer, GpuError> {
let buffer = self.device.create_buffer(&BufferDescriptor {
label: Some("SciRS2 Buffer"),
size: size as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Ok(buffer)
}
#[cfg(not(feature = "wgpu_backend"))]
pub fn allocate_device_memory_2(&self, size: usize) -> Result<WgpuBuffer, GpuError> {
Ok((0x1000 + size) as WgpuBuffer)
}
#[cfg(not(feature = "wgpu_backend"))]
fn initialize_webgpu() -> Result<WgpuDevice, GpuError> {
Ok(0x1 as WgpuDevice)
}
#[cfg(not(feature = "wgpu_backend"))]
fn create_queue(device: WgpuDevice) -> Result<WgpuQueue, GpuError> {
Ok(0x2 as WgpuQueue)
}
#[cfg(not(feature = "wgpu_backend"))]
fn compile_wgsl_source(source: &str, name: &str) -> Result<WgpuComputePipeline, GpuError> {
Ok(0x3 as WgpuComputePipeline)
}
fn extract_entry_point(source: &str) -> Option<&str> {
let lines: Vec<&str> = source.lines().collect();
for (i, line) in lines.iter().enumerate() {
let trimmed = line.trim();
if trimmed.contains("@compute") {
let mut search_line = trimmed;
let mut search_idx = 0;
if !search_line.contains("fn ") && search_idx + 1 < lines.len() {
search_idx += 1;
search_line = lines[search_idx].trim();
}
if let Some(start) = search_line.find("fn ") {
let remaining = &search_line[start + 3..];
if let Some(end) = remaining.find('(') {
let funcname = remaining[..end].trim();
return Some(funcname);
}
}
}
}
None
}
}
impl GpuContextImpl for WebGPUContext {
fn create_buffer(&self, size: usize) -> Arc<dyn GpuBufferImpl> {
if let Ok(mut pool) = self.memory_pool.lock() {
if let Some(device_buffer) = pool.allocate(size) {
return Arc::new(WebGPUBuffer {
device_buffer: Some(device_buffer),
#[cfg(feature = "wgpu_backend")]
queue: Arc::clone(&self.queue),
#[cfg(feature = "wgpu_backend")]
device: Arc::clone(&self.device),
#[cfg(not(feature = "wgpu_backend"))]
queue: self.queue,
size,
memory_pool: Arc::clone(&self.memory_pool),
});
}
}
let device_buffer = match self.allocate_device_memory(size) {
Ok(buffer) => buffer,
Err(e) => {
eprintln!(
"Warning: WebGPU buffer allocation failed ({}), creating CPU fallback buffer",
e
);
#[cfg(feature = "wgpu_backend")]
{
return Arc::new(WebGPUCpuFallbackBuffer {
data: vec![0u8; size],
size,
memory_pool: Arc::clone(&self.memory_pool),
});
}
#[cfg(not(feature = "wgpu_backend"))]
{
(0x2000 + size) as WgpuBuffer
}
}
};
Arc::new(WebGPUBuffer {
device_buffer: Some(device_buffer),
#[cfg(feature = "wgpu_backend")]
queue: Arc::clone(&self.queue),
#[cfg(feature = "wgpu_backend")]
device: Arc::clone(&self.device),
#[cfg(not(feature = "wgpu_backend"))]
queue: self.queue,
size,
memory_pool: Arc::clone(&self.memory_pool),
})
}
fn create_compiler(&self) -> Arc<dyn GpuCompilerImpl> {
Arc::new(WebGPUCompiler {
context: Arc::new(WebGPUContext {
memory_pool: Arc::clone(&self.memory_pool),
compiled_shaders: Arc::clone(&self.compiled_shaders),
#[cfg(feature = "wgpu_backend")]
device: Arc::clone(&self.device),
#[cfg(feature = "wgpu_backend")]
queue: Arc::clone(&self.queue),
#[cfg(not(feature = "wgpu_backend"))]
device: Arc::clone(&self.device),
#[cfg(not(feature = "wgpu_backend"))]
queue: Arc::clone(&self.queue),
}),
})
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
struct WebGPUShader {
#[cfg(feature = "wgpu_backend")]
pipeline: ComputePipeline,
#[cfg(not(feature = "wgpu_backend"))]
pipeline: WgpuComputePipeline,
#[cfg(feature = "wgpu_backend")]
#[allow(dead_code)]
bind_group_layout: BindGroupLayout,
#[cfg(not(feature = "wgpu_backend"))]
#[allow(dead_code)]
bind_group_layout: *mut std::ffi::c_void,
#[allow(dead_code)]
name: String,
#[allow(dead_code)]
binding_infos: Vec<BindingInfo>, }
unsafe impl Send for WebGPUShader {}
unsafe impl Sync for WebGPUShader {}
struct WebGPUCompiler {
context: Arc<WebGPUContext>,
}
impl GpuCompilerImpl for WebGPUCompiler {
fn compile(&self, source: &str) -> Result<Arc<dyn GpuKernelImpl>, GpuError> {
let shader = self.context.compile_shader_internal(source, "shader")?;
Ok(Arc::new(WebGPUKernelHandle {
shader_name: shader.name.clone(),
compiled_shaders: Arc::clone(&self.context.compiled_shaders),
params: Arc::new(Mutex::new(HashMap::new())),
#[cfg(feature = "wgpu_backend")]
device: Arc::clone(&self.context.device),
#[cfg(feature = "wgpu_backend")]
queue: Arc::clone(&self.context.queue),
#[cfg(feature = "wgpu_backend")]
ephemeral_uniforms: Mutex::new(Vec::new()),
#[cfg(not(feature = "wgpu_backend"))]
device: self.context.device,
#[cfg(not(feature = "wgpu_backend"))]
queue: self.context.queue,
}))
}
fn compile_typed(
&self,
name: &str,
_input_type: std::any::TypeId,
_output_type: std::any::TypeId,
) -> Arc<dyn GpuKernelImpl> {
Arc::new(WebGPUKernelHandle {
shader_name: name.to_string(),
compiled_shaders: Arc::clone(&self.context.compiled_shaders),
params: Arc::new(Mutex::new(HashMap::new())),
#[cfg(feature = "wgpu_backend")]
device: Arc::clone(&self.context.device),
#[cfg(feature = "wgpu_backend")]
queue: Arc::clone(&self.context.queue),
#[cfg(feature = "wgpu_backend")]
ephemeral_uniforms: Mutex::new(Vec::new()),
#[cfg(not(feature = "wgpu_backend"))]
device: self.context.device,
#[cfg(not(feature = "wgpu_backend"))]
queue: self.context.queue,
})
}
}
struct WebGPUKernelHandle {
shader_name: String,
compiled_shaders: Arc<Mutex<HashMap<String, WebGPUShader>>>,
params: Arc<Mutex<HashMap<String, KernelParam>>>,
#[cfg(feature = "wgpu_backend")]
device: Arc<Device>,
#[cfg(feature = "wgpu_backend")]
queue: Arc<Queue>,
#[cfg(feature = "wgpu_backend")]
ephemeral_uniforms: Mutex<Vec<wgpu::Buffer>>,
#[cfg(not(feature = "wgpu_backend"))]
device: WgpuDevice,
#[cfg(not(feature = "wgpu_backend"))]
queue: WgpuQueue,
}
enum KernelParam {
#[allow(dead_code)]
Buffer(Arc<dyn GpuBufferImpl>),
#[allow(dead_code)]
U32(u32),
#[allow(dead_code)]
I32(i32),
#[allow(dead_code)]
F32(f32),
#[allow(dead_code)]
F64(f64),
Bytes(Vec<u8>),
}
#[derive(Clone, Debug)]
enum BindingKind {
StorageRw,
StorageRead,
Uniform,
}
#[derive(Clone, Debug)]
struct BindingInfo {
binding: u32,
name: String,
kind: BindingKind,
}
fn extract_var_name(line: &str) -> Option<&str> {
if let Some(var_start) = line.find("var<") {
let after_var = &line[var_start..];
if let Some(close) = after_var.find('>') {
let after = &after_var[close + 1..];
let after = after.trim_start();
if let Some(colon) = after.find(':') {
let name_part = after[..colon].trim();
if !name_part.is_empty() {
return Some(name_part);
}
}
}
}
None
}
impl GpuKernelImpl for WebGPUKernelHandle {
fn set_buffer(&self, name: &str, buffer: &Arc<dyn GpuBufferImpl>) {
let mut params = self.params.lock().expect("Operation failed");
params.insert(name.to_string(), KernelParam::Buffer(Arc::clone(buffer)));
}
fn set_u32(&self, name: &str, value: u32) {
let mut params = self.params.lock().expect("Operation failed");
params.insert(name.to_string(), KernelParam::U32(value));
}
fn set_i32(&self, name: &str, value: i32) {
let mut params = self.params.lock().expect("Operation failed");
params.insert(name.to_string(), KernelParam::I32(value));
}
fn set_f32(&self, name: &str, value: f32) {
let mut params = self.params.lock().expect("Operation failed");
params.insert(name.to_string(), KernelParam::F32(value));
}
fn set_f64(&self, name: &str, value: f64) {
let mut params = self.params.lock().expect("Operation failed");
params.insert(name.to_string(), KernelParam::F64(value));
}
#[allow(dead_code)]
fn dispatch(&self, workgroups: [u32; 3]) {
#[cfg(feature = "wgpu_backend")]
{
let shaders = self.compiled_shaders.lock().expect("Operation failed");
if let Some(shader) = shaders.get(&self.shader_name) {
let params = self.params.lock().expect("Operation failed");
let mut encoder =
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Compute Command Encoder"),
});
{
let mut compute_pass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&shader.pipeline);
if let Ok(bind_group) = self.create_bind_group_from_params(shader, ¶ms) {
compute_pass.set_bind_group(0, &bind_group, &[]);
} else {
eprintln!(
"Warning: Failed to create bind group for shader {}",
self.shader_name
);
}
compute_pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
let command_buffer = encoder.finish();
self.queue.submit(std::iter::once(command_buffer));
eprintln!(
"WebGPU compute shader {} dispatched with workgroups: {:?}",
self.shader_name, workgroups
);
}
}
#[cfg(not(feature = "wgpu_backend"))]
{
eprintln!("Executing WebGPU shader {} (simulated)", self.shader_name);
eprintln!("Work groups: {:?}", workgroups);
}
}
}
struct WebGPUBuffer {
#[cfg(feature = "wgpu_backend")]
device_buffer: Option<Buffer>,
#[cfg(feature = "wgpu_backend")]
queue: Arc<Queue>,
#[cfg(feature = "wgpu_backend")]
device: Arc<Device>,
#[cfg(not(feature = "wgpu_backend"))]
device_buffer: Option<WgpuBuffer>,
#[cfg(not(feature = "wgpu_backend"))]
queue: WgpuQueue,
size: usize,
memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
}
unsafe impl Send for WebGPUBuffer {}
unsafe impl Sync for WebGPUBuffer {}
impl GpuBufferImpl for WebGPUBuffer {
fn size(&self) -> usize {
self.size
}
unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
#[cfg(feature = "wgpu_backend")]
{
if size > self.size {
eprintln!(
"Warning: Data size {} exceeds buffer size {}",
size, self.size
);
return;
}
let data_slice = std::slice::from_raw_parts(data, size);
if let Some(ref buffer) = self.device_buffer {
self.queue.write_buffer(buffer, 0, data_slice);
}
}
#[cfg(not(feature = "wgpu_backend"))]
{
if size > self.size {
eprintln!(
"Warning: Data size {} exceeds buffer size {}",
size, self.size
);
}
}
}
unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
#[cfg(feature = "wgpu_backend")]
{
if size > self.size {
eprintln!(
"Warning: Data size {} exceeds buffer size {}",
size, self.size
);
return;
}
if let Some(ref buffer) = self.device_buffer {
let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("scirs2-readback"),
size: size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder =
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scirs2-readback-enc"),
});
encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, size as u64);
self.queue.submit(Some(encoder.finish()));
let slice = staging.slice(0..size as u64);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx.send(r);
});
if let Ok(Ok(())) = rx.recv() {
let mapped = slice.get_mapped_range();
let dst = std::slice::from_raw_parts_mut(data, size);
dst.copy_from_slice(&mapped);
drop(mapped);
staging.unmap();
} else {
eprintln!("Warning: map_async failed for readback");
}
}
}
#[cfg(not(feature = "wgpu_backend"))]
{
if size > self.size {
eprintln!(
"Warning: Data size {} exceeds buffer size {}",
size, self.size
);
}
let data_slice = std::slice::from_raw_parts_mut(data, size);
data_slice.fill(0);
}
}
fn device_ptr(&self) -> u64 {
#[cfg(feature = "wgpu_backend")]
{
&self.device_buffer as *const _ as u64
}
#[cfg(not(feature = "wgpu_backend"))]
{
self.device_buffer as u64
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(feature = "wgpu_backend")]
impl WebGPUKernelHandle {
fn create_bind_group_from_params(
&self,
shader: &WebGPUShader,
params: &HashMap<String, KernelParam>,
) -> Result<wgpu::BindGroup, GpuError> {
let mut entries: Vec<wgpu::BindGroupEntry> = Vec::new();
let mut owned_uniform_buffers: Vec<wgpu::Buffer> = Vec::new();
let mut uniform_bytes: Vec<u8> = Vec::new();
for info in &shader.binding_infos {
match info.kind {
BindingKind::StorageRw | BindingKind::StorageRead => {
if let Some(KernelParam::Buffer(buf)) = params.get(&info.name) {
if let Some(wbuf) = buf.as_any().downcast_ref::<WebGPUBuffer>() {
if let Some(ref inner) = wbuf.device_buffer {
entries.push(wgpu::BindGroupEntry {
binding: info.binding,
resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
buffer: inner,
offset: 0,
size: None,
}),
});
}
}
} else {
return Err(GpuError::InvalidParameter(format!(
"Missing buffer param '{}'",
info.name
)));
}
}
BindingKind::Uniform => {
for (k, v) in params.iter() {
if k == &info.name || k.starts_with(&(info.name.clone() + ".")) {
match v {
KernelParam::U32(u) => {
uniform_bytes.extend_from_slice(&u.to_le_bytes())
}
KernelParam::I32(i) => {
uniform_bytes.extend_from_slice(&i.to_le_bytes())
}
KernelParam::F32(f) => {
uniform_bytes.extend_from_slice(&f.to_le_bytes())
}
KernelParam::F64(f) => {
uniform_bytes.extend_from_slice(&f.to_le_bytes())
}
KernelParam::Bytes(b) => uniform_bytes.extend_from_slice(b),
KernelParam::Buffer(_) => {}
}
}
}
}
}
}
if !uniform_bytes.is_empty() {
while uniform_bytes.len() % 16 != 0 {
uniform_bytes.push(0);
}
if let Some(uinfo) = shader
.binding_infos
.iter()
.find(|b| matches!(b.kind, BindingKind::Uniform))
{
if let Ok(mut list) = self.ephemeral_uniforms.lock() {
list.clear();
let ubuf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("scirs2-uniforms"),
contents: &uniform_bytes,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
});
list.push(ubuf.clone());
owned_uniform_buffers.push(ubuf.clone());
let idx = owned_uniform_buffers.len() - 1;
let buf_ref = &owned_uniform_buffers[idx];
entries.push(wgpu::BindGroupEntry {
binding: uinfo.binding,
resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
buffer: buf_ref,
offset: 0,
size: None,
}),
});
}
}
} else if let Ok(mut list) = self.ephemeral_uniforms.lock() {
list.clear();
}
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("scirs2-bind-group"),
layout: &shader.bind_group_layout,
entries: &entries,
});
Ok(bind_group)
}
}
impl Drop for WebGPUBuffer {
fn drop(&mut self) {
if let Ok(mut pool) = self.memory_pool.lock() {
#[cfg(feature = "wgpu_backend")]
{
if let Some(buffer) = self.device_buffer.take() {
pool.deallocate(buffer);
}
}
#[cfg(not(feature = "wgpu_backend"))]
{
if let Some(buffer) = self.device_buffer.take() {
pool.deallocate(buffer);
}
}
}
}
}
struct WebGPUCpuFallbackBuffer {
data: Vec<u8>,
size: usize,
#[allow(dead_code)]
memory_pool: Arc<Mutex<WebGPUMemoryPool>>,
}
impl GpuBufferImpl for WebGPUCpuFallbackBuffer {
fn size(&self) -> usize {
self.size
}
unsafe fn copy_from_host(&self, data: *const u8, size: usize) {
if size > self.size {
eprintln!("Warning: WebGPU CPU fallback buffer copy_from_host size mismatch");
return;
}
let data_slice = std::slice::from_raw_parts(data, size);
eprintln!(
"Warning: CPU fallback buffer copy_from_host called (size: {})",
size
);
}
unsafe fn copy_to_host(&self, data: *mut u8, size: usize) {
if size > self.size {
eprintln!("Warning: WebGPU CPU fallback buffer copy_to_host size mismatch");
return;
}
let data_slice = std::slice::from_raw_parts_mut(data, size);
let copy_size = size.min(self.data.len());
data_slice[..copy_size].copy_from_slice(&self.data[..copy_size]);
eprintln!(
"Warning: CPU fallback buffer copy_to_host called (size: {})",
size
);
}
fn device_ptr(&self) -> u64 {
self.data.as_ptr() as u64
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
unsafe impl Send for WebGPUCpuFallbackBuffer {}
unsafe impl Sync for WebGPUCpuFallbackBuffer {}
struct WebGPUMemoryPool {
#[cfg(feature = "wgpu_backend")]
available_buffers: HashMap<usize, Vec<Buffer>>,
#[cfg(not(feature = "wgpu_backend"))]
available_buffers: HashMap<usize, Vec<WgpuBuffer>>,
#[allow(dead_code)]
total_size: usize,
used_size: usize,
}
impl WebGPUMemoryPool {
fn new(totalsize: usize) -> Self {
Self {
available_buffers: HashMap::new(),
total_size: totalsize,
used_size: 0,
}
}
#[cfg(feature = "wgpu_backend")]
fn allocate(&mut self, size: usize) -> Option<Buffer> {
if let Some(buffers) = self.available_buffers.get_mut(&size) {
if let Some(buffer) = buffers.pop() {
self.used_size += size;
return Some(buffer);
}
}
None
}
#[cfg(not(feature = "wgpu_backend"))]
fn allocate(&mut self, size: usize) -> Option<WgpuBuffer> {
if let Some(buffers) = self.available_buffers.get_mut(&size) {
if let Some(buffer) = buffers.pop() {
self.used_size += size;
return Some(buffer);
}
}
None
}
#[cfg(feature = "wgpu_backend")]
fn deallocate(&mut self, buffer: Buffer) {
let size = buffer.size() as usize;
self.available_buffers
.entry(size)
.or_insert_with(Vec::new)
.push(buffer);
self.used_size = self.used_size.saturating_sub(size);
}
#[cfg(not(feature = "wgpu_backend"))]
fn deallocate(&mut self, buffer: WgpuBuffer) {
let size = 1024; self.available_buffers
.entry(size)
.or_insert_with(Vec::new)
.push(buffer);
self.used_size = self.used_size.saturating_sub(size);
}
#[allow(dead_code)]
fn get_memory_usage(&self) -> (usize, usize) {
(self.used_size, self.total_size)
}
}