#[cfg(feature = "webgpu")]
use crate::webgpu::wgpu;
use crate::{
Backend, BackendCapabilities, BackendCore, BackendDeviceManager, BackendExecutor,
BackendLifecycle, BackendOperations, BackendOps, BackendResourceManager, BackendResult,
BackendType, Buffer, BufferDescriptor, BufferHandle, CapabilityValue, Device, Kernel,
KernelDescriptor, KernelHandle, MemoryManager, MemoryPool, MemoryStats, OperationsBundle,
PerformanceHints, Profiler,
};
use crate::webgpu::{WebGpuError, WebGpuResult};
use async_trait::async_trait;
use futures::future::join_all;
use parking_lot::{RwLock, Mutex};
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use torsh_core::{device::DeviceType, dtype::DType, error::TorshError};
use wgpu::util::DeviceExt;
#[derive(Debug, Clone)]
pub struct ModernWebGpuConfig {
pub device_descriptor: wgpu::DeviceDescriptor<'static>,
pub instance_descriptor: wgpu::InstanceDescriptor,
pub adapter_options: wgpu::RequestAdapterOptions<'static>,
pub max_buffer_size: u64,
pub enable_validation: bool,
pub enable_spirv: bool,
pub preferred_backend: Option<wgpu::Backends>,
}
impl Default for ModernWebGpuConfig {
fn default() -> Self {
Self {
device_descriptor: wgpu::DeviceDescriptor {
label: Some("ToRSh WebGPU Device"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: wgpu::MemoryHints::default(),
trace: wgpu::Trace::Off,
experimental_features: wgpu::ExperimentalFeatures::default(),
},
instance_descriptor: wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
flags: wgpu::InstanceFlags::default(),
memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
backend_options: wgpu::BackendOptions {
dx12: wgpu::Dx12BackendOptions {
shader_compiler: wgpu::Dx12Compiler::Fxc,
..Default::default()
},
gl: wgpu::GlBackendOptions {
gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
..Default::default()
},
..Default::default()
},
display: None,
},
adapter_options: wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
},
max_buffer_size: 256 * 1024 * 1024, enable_validation: true,
enable_spirv: false, preferred_backend: None,
}
}
}
#[derive(Debug)]
pub struct ModernWebGpuDevice {
pub instance: wgpu::Instance,
pub adapter: wgpu::Adapter,
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub info: wgpu::AdapterInfo,
pub limits: wgpu::Limits,
pub features: wgpu::Features,
config: ModernWebGpuConfig,
shader_cache: RwLock<HashMap<String, Arc<wgpu::ShaderModule>>>,
pipeline_cache: RwLock<HashMap<String, Arc<wgpu::ComputePipeline>>>,
bind_group_layout_cache: RwLock<HashMap<String, Arc<wgpu::BindGroupLayout>>>,
}
impl ModernWebGpuDevice {
pub async fn new(config: ModernWebGpuConfig) -> WebGpuResult<Arc<Self>> {
let instance = wgpu::Instance::new(config.instance_descriptor.clone());
let adapter = instance
.request_adapter(&config.adapter_options)
.await
.ok_or_else(|| {
WebGpuError::InitializationError("Failed to request WebGPU adapter".to_string())
})?;
let info = adapter.get_info();
log::info!("WebGPU Adapter: {} ({:?})", info.name, info.backend);
let mut required_features = config.device_descriptor.required_features;
let mut required_limits = config.device_descriptor.required_limits.clone();
let adapter_limits = adapter.limits();
required_limits.max_buffer_size = required_limits
.max_buffer_size
.min(adapter_limits.max_buffer_size)
.min(config.max_buffer_size);
let supported_features = adapter.features();
if supported_features.contains(wgpu::Features::TIMESTAMP_QUERY) {
required_features |= wgpu::Features::TIMESTAMP_QUERY;
}
if supported_features.contains(wgpu::Features::BUFFER_BINDING_ARRAY) {
required_features |= wgpu::Features::BUFFER_BINDING_ARRAY;
}
if supported_features.contains(wgpu::Features::STORAGE_RESOURCE_BINDING_ARRAY) {
required_features |= wgpu::Features::STORAGE_RESOURCE_BINDING_ARRAY;
}
let device_descriptor = wgpu::DeviceDescriptor {
label: config.device_descriptor.label,
required_features,
required_limits: required_limits.clone(),
memory_hints: config.device_descriptor.memory_hints.clone(),
trace: wgpu::Trace::Off,
experimental_features: wgpu::ExperimentalFeatures::default(),
};
let (device, queue) = adapter
.request_device(&device_descriptor)
.await
.map_err(|e| {
WebGpuError::InitializationError(format!("Failed to request WebGPU device: {}", e))
})?;
Ok(Arc::new(Self {
instance,
adapter,
device,
queue,
info,
limits: required_limits,
features: required_features,
config,
shader_cache: RwLock::new(HashMap::new()),
pipeline_cache: RwLock::new(HashMap::new()),
bind_group_layout_cache: RwLock::new(HashMap::new()),
}))
}
pub fn create_shader_module(&self, label: &str, source: &str) -> Arc<wgpu::ShaderModule> {
let key = format!("{}_{}", label, md5::compute(source.as_bytes()));
{
let cache = self.shader_cache.read();
if let Some(shader) = cache.get(&key) {
return shader.clone();
}
}
let shader = Arc::new(self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(source.into()),
}));
self.shader_cache.write().insert(key, shader.clone());
shader
}
pub fn create_compute_pipeline(
&self,
label: &str,
shader: &wgpu::ShaderModule,
entry_point: &str,
bind_group_layouts: &[&wgpu::BindGroupLayout],
) -> WebGpuResult<Arc<wgpu::ComputePipeline>> {
let key = format!("{}_{}", label, entry_point);
{
let cache = self.pipeline_cache.read();
if let Some(pipeline) = cache.get(&key) {
return Ok(pipeline.clone());
}
}
let bind_group_layouts_opt: Vec<_> = bind_group_layouts.iter().map(|&l| Some(l)).collect();
let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("{}_layout", label)),
bind_group_layouts: &bind_group_layouts_opt,
immediate_size: 0,
});
let pipeline = Arc::new(
self.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&pipeline_layout),
module: shader,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
}),
);
self.pipeline_cache.write().insert(key, pipeline.clone());
Ok(pipeline)
}
pub fn create_buffer_with_data<T: bytemuck::Pod>(&self, label: &str, data: &[T], usage: wgpu::BufferUsages) -> wgpu::Buffer {
self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: bytemuck::cast_slice(data),
usage,
})
}
pub async fn submit_and_wait(&self, commands: impl IntoIterator<Item = wgpu::CommandBuffer>) -> WebGpuResult<()> {
let submission_index = self.queue.submit(commands);
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging_wait_buffer"),
size: 4,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
staging_buffer.slice(..).map_async(wgpu::MapMode::Read, |result| {
if let Err(e) = result {
log::error!("Buffer mapping failed during wait: {:?}", e);
}
});
let _ = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(submission_index),
timeout: None,
});
Ok(())
}
}
#[derive(Debug)]
pub struct ModernWebGpuMemoryManager {
device: Arc<ModernWebGpuDevice>,
buffer_pools: RwLock<HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>>,
allocation_stats: RwLock<MemoryStats>,
next_buffer_id: Mutex<u64>,
}
impl ModernWebGpuMemoryManager {
pub fn new(device: Arc<ModernWebGpuDevice>) -> Arc<Self> {
Arc::new(Self {
device,
buffer_pools: RwLock::new(HashMap::new()),
allocation_stats: RwLock::new(MemoryStats {
total_allocated: 0,
total_deallocated: 0,
current_allocated: 0,
peak_allocated: 0,
allocation_count: 0,
deallocation_count: 0,
}),
next_buffer_id: Mutex::new(0),
})
}
pub fn allocate_buffer(&self, descriptor: &BufferDescriptor) -> WebGpuResult<ModernWebGpuBuffer> {
let usage = self.convert_buffer_usage(&descriptor.usage)?;
let size = descriptor.size as u64;
let mut pools = self.buffer_pools.write();
let pool_key = (size, usage);
let buffer = if let Some(pool) = pools.get_mut(&pool_key) {
pool.pop().unwrap_or_else(|| {
self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("ToRSh Buffer {}", descriptor.name.as_deref().unwrap_or("unnamed"))),
size,
usage,
mapped_at_creation: false,
})
})
} else {
self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("ToRSh Buffer {}", descriptor.name.as_deref().unwrap_or("unnamed"))),
size,
usage,
mapped_at_creation: false,
})
};
let buffer_id = {
let mut id = self.next_buffer_id.lock();
let current_id = *id;
*id += 1;
current_id
};
{
let mut stats = self.allocation_stats.write();
stats.total_allocated += size as usize;
stats.current_allocated += size as usize;
stats.peak_allocated = stats.peak_allocated.max(stats.current_allocated);
stats.allocation_count += 1;
}
Ok(ModernWebGpuBuffer {
buffer,
device: self.device.clone(),
handle: BufferHandle::WebGpu {
buffer_ptr: buffer_id,
size: descriptor.size,
},
size: descriptor.size,
usage,
label: descriptor.name.clone(),
})
}
fn convert_buffer_usage(&self, usage: &crate::buffer::BufferUsage) -> WebGpuResult<wgpu::BufferUsages> {
use crate::buffer::BufferUsage;
let mut wgpu_usage = wgpu::BufferUsages::empty();
match usage {
BufferUsage::Uniform => wgpu_usage |= wgpu::BufferUsages::UNIFORM,
BufferUsage::Storage => wgpu_usage |= wgpu::BufferUsages::STORAGE,
BufferUsage::Vertex => wgpu_usage |= wgpu::BufferUsages::VERTEX,
BufferUsage::Index => wgpu_usage |= wgpu::BufferUsages::INDEX,
}
wgpu_usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
Ok(wgpu_usage)
}
pub fn return_buffer_to_pool(&self, buffer: ModernWebGpuBuffer) {
let pool_key = (buffer.size as u64, buffer.usage);
let mut pools = self.buffer_pools.write();
pools.entry(pool_key).or_insert_with(Vec::new).push(buffer.buffer);
let mut stats = self.allocation_stats.write();
stats.total_deallocated += buffer.size;
stats.current_allocated = stats.current_allocated.saturating_sub(buffer.size);
stats.deallocation_count += 1;
}
}
#[derive(Debug)]
pub struct ModernWebGpuBuffer {
pub buffer: wgpu::Buffer,
pub device: Arc<ModernWebGpuDevice>,
pub handle: BufferHandle,
pub size: usize,
pub usage: wgpu::BufferUsages,
pub label: Option<String>,
}
impl ModernWebGpuBuffer {
pub async fn write_data<T: bytemuck::Pod>(&self, data: &[T]) -> WebGpuResult<()> {
let bytes = bytemuck::cast_slice(data);
if bytes.len() > self.size {
return Err(WebGpuError::InvalidArgument(format!(
"Data size {} exceeds buffer size {}",
bytes.len(),
self.size
)));
}
self.device.queue.write_buffer(&self.buffer, 0, bytes);
Ok(())
}
pub async fn read_data<T: bytemuck::Pod + Clone>(&self) -> WebGpuResult<Vec<T>> {
let staging_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging_read_buffer"),
size: self.size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("buffer_read_copy"),
});
encoder.copy_buffer_to_buffer(
&self.buffer,
0,
&staging_buffer,
0,
self.size as u64,
);
self.device.queue.submit([encoder.finish()]);
let buffer_slice = staging_buffer.slice(..);
let (tx, rx) = futures::channel::oneshot::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
let _ = self.device.device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
});
rx.await.expect("buffer mapping channel should not be dropped").map_err(|e| WebGpuError::RuntimeError(format!("Buffer mapping failed: {:?}", e)))?;
let data = buffer_slice.get_mapped_range();
let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging_buffer.unmap();
Ok(result)
}
}
#[derive(Debug)]
pub struct ModernWebGpuCompute {
device: Arc<ModernWebGpuDevice>,
memory_manager: Arc<ModernWebGpuMemoryManager>,
}
impl ModernWebGpuCompute {
pub fn new(device: Arc<ModernWebGpuDevice>) -> Self {
let memory_manager = ModernWebGpuMemoryManager::new(device.clone());
Self {
device,
memory_manager,
}
}
pub async fn execute_compute(
&self,
shader_source: &str,
entry_point: &str,
buffers: &[&ModernWebGpuBuffer],
workgroup_size: (u32, u32, u32),
) -> WebGpuResult<()> {
let shader = self.device.create_shader_module("compute_shader", shader_source);
let mut bind_group_layout_entries = Vec::new();
for (i, _buffer) in buffers.iter().enumerate() {
bind_group_layout_entries.push(wgpu::BindGroupLayoutEntry {
binding: i as u32,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
}
let bind_group_layout = self.device.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("compute_bind_group_layout"),
entries: &bind_group_layout_entries,
});
let pipeline = self.device.create_compute_pipeline(
"compute_pipeline",
&shader,
entry_point,
&[&bind_group_layout],
)?;
let mut bind_group_entries = Vec::new();
for (i, buffer) in buffers.iter().enumerate() {
bind_group_entries.push(wgpu::BindGroupEntry {
binding: i as u32,
resource: buffer.buffer.as_entire_binding(),
});
}
let bind_group = self.device.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("compute_bind_group"),
layout: &bind_group_layout,
entries: &bind_group_entries,
});
let mut encoder = self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("compute_encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(workgroup_size.0, workgroup_size.1, workgroup_size.2);
}
self.device.submit_and_wait([encoder.finish()]).await?;
Ok(())
}
}
#[derive(Debug)]
pub struct ModernWebGpuBackend {
device: Arc<ModernWebGpuDevice>,
memory_manager: Arc<ModernWebGpuMemoryManager>,
compute: ModernWebGpuCompute,
config: ModernWebGpuConfig,
}
impl ModernWebGpuBackend {
pub async fn new(config: ModernWebGpuConfig) -> WebGpuResult<Self> {
let device = ModernWebGpuDevice::new(config.clone()).await?;
let memory_manager = ModernWebGpuMemoryManager::new(device.clone());
let compute = ModernWebGpuCompute::new(device.clone());
Ok(Self {
device,
memory_manager,
compute,
config,
})
}
pub fn with_default() -> impl std::future::Future<Output = WebGpuResult<Self>> {
Self::new(ModernWebGpuConfig::default())
}
}