use crate::buffer::generate_buffer_id;
use crate::memory::MemoryPoolConfig;
use crate::profiler::SimpleProfiler;
#[cfg(feature = "webgpu")]
use crate::webgpu::wgpu;
use crate::webgpu::{
WebGpuBackendConfig, WebGpuDevice, WebGpuError, WebGpuKernelExecutor, WebGpuMemoryManager,
};
use crate::{
BackendCore, BackendResult, Buffer, BufferDescriptor, BufferHandle, Device, Kernel,
KernelDescriptor, KernelHandle, MemoryManager, MemoryStats, Profiler,
};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use torsh_core::{device::DeviceType, error::TorshError};
#[derive(Debug)]
pub struct WebGpuBackend {
config: WebGpuBackendConfig,
devices: RwLock<HashMap<usize, Arc<WebGpuDevice>>>,
memory_managers: RwLock<HashMap<usize, Arc<RwLock<WebGpuMemoryManager>>>>,
kernel_executors: RwLock<HashMap<usize, Arc<WebGpuKernelExecutor>>>,
profiler: Arc<SimpleProfiler>,
initialized: RwLock<bool>,
}
impl WebGpuBackend {
pub fn new(config: WebGpuBackendConfig) -> Self {
Self {
config,
devices: RwLock::new(HashMap::new()),
memory_managers: RwLock::new(HashMap::new()),
kernel_executors: RwLock::new(HashMap::new()),
profiler: Arc::new(SimpleProfiler::new()),
initialized: RwLock::new(false),
}
}
pub fn with_default_config() -> Self {
Self::new(WebGpuBackendConfig::default())
}
pub fn builder() -> WebGpuBackendBuilder {
WebGpuBackendBuilder::new()
}
pub fn config(&self) -> &WebGpuBackendConfig {
&self.config
}
pub fn get_device(&self, device_id: usize) -> BackendResult<Arc<WebGpuDevice>> {
let devices = self.devices.read();
devices
.get(&device_id)
.cloned()
.ok_or_else(|| TorshError::BackendError(format!("Device {} not found", device_id)))
}
pub fn get_memory_manager(
&self,
device_id: usize,
) -> BackendResult<Arc<RwLock<WebGpuMemoryManager>>> {
let managers = self.memory_managers.read();
managers.get(&device_id).cloned().ok_or_else(|| {
TorshError::BackendError(format!("Memory manager for device {} not found", device_id))
})
}
pub fn get_kernel_executor(
&self,
device_id: usize,
) -> BackendResult<Arc<WebGpuKernelExecutor>> {
let executors = self.kernel_executors.read();
executors.get(&device_id).cloned().ok_or_else(|| {
TorshError::BackendError(format!(
"Kernel executor for device {} not found",
device_id
))
})
}
async fn initialize_device(&self, device_id: usize) -> BackendResult<Arc<WebGpuDevice>> {
let device = if let Some(adapter_index) = self.config.adapter_index {
WebGpuDevice::from_adapter_index(adapter_index, device_id).await
} else {
WebGpuDevice::from_best_adapter(device_id).await
}
.map_err(|e| TorshError::BackendError(e.to_string()))?;
let device = Arc::new(device);
let memory_config = MemoryPoolConfig::default();
let memory_manager = Arc::new(RwLock::new(WebGpuMemoryManager::new(
Arc::clone(&device),
memory_config,
)));
let kernel_executor = Arc::new(WebGpuKernelExecutor::new(Arc::clone(&device)));
{
let mut devices = self.devices.write();
devices.insert(device_id, Arc::clone(&device));
}
{
let mut managers = self.memory_managers.write();
managers.insert(device_id, memory_manager);
}
{
let mut executors = self.kernel_executors.write();
executors.insert(device_id, kernel_executor);
}
Ok(device)
}
fn convert_error(error: WebGpuError) -> TorshError {
TorshError::BackendError(error.to_string())
}
fn extract_webgpu_buffer(&self, buffer: &Buffer) -> BackendResult<&wgpu::Buffer> {
match &buffer.handle {
BufferHandle::WebGpu {
buffer_ptr,
size: _,
} => {
unsafe {
let wgpu_buffer_ptr = *buffer_ptr as *const wgpu::Buffer;
Ok(&*wgpu_buffer_ptr)
}
}
_ => Err(TorshError::BackendError(
"Buffer is not a WebGPU buffer".to_string(),
)),
}
}
fn extract_webgpu_buffers(
&self,
src: &Buffer,
dst: &Buffer,
) -> BackendResult<(&wgpu::Buffer, &wgpu::Buffer)> {
let src_buf = self.extract_webgpu_buffer(src)?;
let dst_buf = self.extract_webgpu_buffer(dst)?;
Ok((src_buf, dst_buf))
}
}
impl BackendCore for WebGpuBackend {
fn device_type(&self) -> DeviceType {
DeviceType::Wgpu(0)
}
fn name(&self) -> &str {
"WebGPU"
}
fn is_available(&self) -> BackendResult<bool> {
Ok(crate::webgpu::is_available())
}
fn capabilities(&self) -> crate::backend::BackendCapabilities {
crate::backend::BackendCapabilities {
max_buffer_size: 2_147_483_648, max_compute_units: 8,
max_workgroup_size: (256, 256, 64),
supported_dtypes: vec![
torsh_core::dtype::DType::F32,
torsh_core::dtype::DType::I32,
torsh_core::dtype::DType::U32,
],
supports_async: true,
supports_unified_memory: false,
supports_sub_buffers: true,
supports_kernel_caching: true,
memory_bandwidth_gbps: 100.0, compute_throughput_gflops: 50.0, extended_capabilities: crate::backend::ExtendedCapabilities::default(),
}
}
fn performance_hints(&self) -> crate::backend::PerformanceHints {
crate::backend::PerformanceHints {
preferred_workgroup_size: (64, 1, 1),
memory_alignment: 256, prefer_vectorized: true,
prefer_async: true,
optimal_batch_size: 256,
cache_kernels: true,
}
}
}
#[async_trait::async_trait]
impl crate::backend::BackendLifecycle for WebGpuBackend {
async fn initialize(&mut self) -> BackendResult<()> {
if *self.initialized.read() {
return Ok(());
}
crate::webgpu::init().await.map_err(Self::convert_error)?;
self.initialize_device(0).await?;
*self.initialized.write() = true;
Ok(())
}
async fn shutdown(&mut self) -> BackendResult<()> {
self.devices.write().clear();
self.memory_managers.write().clear();
self.kernel_executors.write().clear();
*self.initialized.write() = false;
Ok(())
}
fn is_initialized(&self) -> bool {
*self.initialized.read()
}
}
impl crate::backend::BackendDeviceManager for WebGpuBackend {
fn devices(&self) -> BackendResult<Vec<Device>> {
let devices = self.devices.read();
Ok(devices
.values()
.map(|d| {
let webgpu_device = d.as_ref();
Device::new(
0, webgpu_device.device_type(),
webgpu_device.name().to_string(),
webgpu_device.info().clone(),
)
})
.collect())
}
fn default_device(&self) -> BackendResult<Device> {
let webgpu_device = self.get_device(0)?;
Ok(Device::new(
0, webgpu_device.device_type(),
webgpu_device.name().to_string(),
webgpu_device.info().clone(),
))
}
fn create_device(&self, device_id: usize) -> BackendResult<Device> {
if let Ok(webgpu_device) = self.get_device(device_id) {
return Ok(Device::new(
device_id, webgpu_device.device_type(),
webgpu_device.name().to_string(),
webgpu_device.info().clone(),
));
}
let runtime = tokio::runtime::Handle::try_current().or_else(|_| {
tokio::runtime::Runtime::new()
.map(|rt| rt.handle().clone())
.map_err(|e| {
TorshError::BackendError(format!("Failed to create async runtime: {}", e))
})
})?;
let webgpu_device = runtime.block_on(async { self.initialize_device(device_id).await })?;
Ok(Device::new(
device_id, webgpu_device.device_type(),
webgpu_device.name().to_string(),
webgpu_device.info().clone(),
))
}
fn device_count(&self) -> BackendResult<usize> {
Ok(self.devices.read().len())
}
fn is_device_available(&self, device_id: usize) -> bool {
self.devices.read().contains_key(&device_id)
}
}
impl crate::backend::BackendResourceManager for WebGpuBackend {
fn create_buffer(
&self,
device: &Device,
descriptor: &BufferDescriptor,
) -> BackendResult<Buffer> {
let memory_manager = self.get_memory_manager(device.id())?;
let buffer = memory_manager.write().allocate(descriptor)?;
Ok(buffer)
}
fn create_kernel(
&self,
device: &Device,
descriptor: &KernelDescriptor,
) -> BackendResult<Kernel> {
let kernel_executor = self.get_kernel_executor(device.id())?;
let _webgpu_kernel = kernel_executor
.create_kernel(descriptor.clone())
.map_err(Self::convert_error)?;
let kernel_handle = KernelHandle::WebGpu {
shader_module_id: format!("webgpu_shader_{}", descriptor.name),
entry_point: "main".to_string(), };
let kernel_metadata = crate::kernel::KernelMetadata {
compile_time_ms: 0.0,
binary_size: 0,
registers_per_thread: None,
shared_memory_usage: None,
max_workgroup_size: descriptor.workgroup_size_hint,
compiler_version: "wgpu".to_string(),
warnings: Vec::new(),
performance_hints: Vec::new(),
};
Ok(Kernel::new(
0, device.clone(),
descriptor.name.clone(),
descriptor.clone(),
kernel_handle,
kernel_metadata,
))
}
fn memory_manager(
&self,
device: &Device,
) -> BackendResult<Box<dyn MemoryManager + Send + Sync>> {
let manager = self.get_memory_manager(device.id())?;
Ok(Box::new(WebGpuMemoryManagerWrapper { inner: manager })
as Box<dyn MemoryManager + Send + Sync>)
}
fn profiler(&self) -> BackendResult<Box<dyn Profiler + Send + Sync>> {
Ok(Box::new((*self.profiler).clone()) as Box<dyn Profiler + Send + Sync>)
}
fn create_scoped_buffer(
&self,
device: &Device,
descriptor: &BufferDescriptor,
) -> BackendResult<Buffer> {
self.create_buffer(device, descriptor)
}
}
#[async_trait::async_trait]
impl crate::backend::BackendExecutor for WebGpuBackend {
async fn synchronize(&self, device: &Device) -> BackendResult<()> {
let webgpu_device = self.get_device(device.id())?;
webgpu_device
.wait_for_completion()
.await
.map_err(Self::convert_error)
}
async fn copy_buffer(
&self,
_src: &Buffer,
_dst: &Buffer,
_src_offset: usize,
_dst_offset: usize,
_size: usize,
) -> BackendResult<()> {
let device_id = 0; let webgpu_device = self.get_device(device_id)?;
let _encoder = webgpu_device.create_command_encoder(Some("Buffer Copy"));
Ok(())
}
async fn copy_to_device(
&self,
_src: &[u8],
_dst: &Buffer,
_dst_offset: usize,
) -> BackendResult<()> {
let device_id = 0; let _webgpu_device = self.get_device(device_id)?;
Ok(())
}
async fn copy_from_device(
&self,
_src: &Buffer,
_dst: &mut [u8],
_src_offset: usize,
) -> BackendResult<()> {
let device_id = 0; let _webgpu_device = self.get_device(device_id)?;
Ok(())
}
async fn execute_kernel(
&self,
kernel: &Kernel,
_buffers: &[&Buffer],
uniform_data: &[u8],
workgroup_size: (u32, u32, u32),
workgroup_count: (u32, u32, u32),
) -> BackendResult<()> {
let device_id = 0; let kernel_executor = self.get_kernel_executor(device_id)?;
match &kernel.handle {
KernelHandle::WebGpu {
shader_module_id: _,
entry_point: _,
} => {
kernel_executor
.execute_simple_kernel(
&kernel.name,
&[], uniform_data,
workgroup_size,
workgroup_count,
)
.await
.map_err(Self::convert_error)
}
_ => Err(TorshError::BackendError(
"Invalid kernel handle for WebGPU backend".to_string(),
)),
}
}
}
impl crate::backend::BackendOperations for WebGpuBackend {
fn fft_ops(&self) -> Box<dyn crate::fft::FftOps> {
Box::new(crate::cpu::fft::CpuFftOps::new(None))
}
fn convolution_ops(&self) -> Box<dyn crate::convolution::ConvolutionOps> {
Box::new(crate::cpu::convolution::CpuConvolutionOps::new(None))
}
fn rnn_ops(&self) -> Box<dyn crate::rnn::RnnOps> {
Box::new(crate::cpu::rnn::CpuRnnOps::new(None))
}
fn sparse_ops(&self) -> Box<dyn crate::sparse_ops::SparseOps<f32>> {
Box::new(crate::sparse_ops::DefaultSparseOps::new(
crate::Device::new(
0,
torsh_core::device::DeviceType::Wgpu(0),
"WebGPU Device".to_string(),
crate::DeviceInfo::default(),
),
))
}
fn quantization_ops(&self) -> Box<dyn crate::quantization::QuantizationOps> {
Box::new(crate::quantization::CpuQuantizationOps::new())
}
fn operations_bundle(&self) -> crate::backend::OperationsBundle {
crate::backend::OperationsBundle {
fft: self.fft_ops(),
convolution: self.convolution_ops(),
rnn: self.rnn_ops(),
quantization: self.quantization_ops(),
sparse: self.sparse_ops(),
}
}
}
impl crate::backend::BackendOps for WebGpuBackend {
fn backend_type(&self) -> crate::backend::BackendType {
crate::backend::BackendType::WebGpu
}
fn available_ops(&self) -> Vec<&str> {
vec![
"elementwise_add",
"elementwise_mul",
"elementwise_sub",
"elementwise_div",
"matmul",
"conv2d",
"relu",
"softmax",
"batch_norm",
"reduction",
]
}
fn supports_op(&self, op_name: &str) -> bool {
self.available_ops().contains(&op_name)
}
fn supports_fft(&self) -> bool {
true
}
fn supports_convolution(&self) -> bool {
true
}
fn supports_rnn(&self) -> bool {
true
}
fn supports_sparse(&self) -> bool {
false
}
fn supports_quantization(&self) -> bool {
true
}
fn operation_capabilities(
&self,
_op_name: &str,
) -> Option<std::collections::HashMap<String, crate::backend::CapabilityValue>> {
None
}
}
impl crate::backend::Backend for WebGpuBackend {
fn as_core(&self) -> &dyn crate::backend::BackendCore {
self
}
fn as_lifecycle(&mut self) -> &mut dyn crate::backend::BackendLifecycle {
self
}
fn as_device_manager(&self) -> &dyn crate::backend::BackendDeviceManager {
self
}
fn as_resource_manager(&self) -> &dyn crate::backend::BackendResourceManager {
self
}
fn as_executor(&self) -> &dyn crate::backend::BackendExecutor {
self
}
fn as_operations(&self) -> &dyn crate::backend::BackendOperations {
self
}
}
#[derive(Debug)]
pub struct WebGpuBackendBuilder {
config: WebGpuBackendConfig,
}
impl WebGpuBackendBuilder {
pub fn new() -> Self {
Self {
config: WebGpuBackendConfig::default(),
}
}
pub fn adapter_index(mut self, index: usize) -> Self {
self.config.adapter_index = Some(index);
self
}
pub fn device_id(mut self, id: usize) -> Self {
self.config.adapter_index = Some(id);
self
}
pub fn power_preference(mut self, preference: wgpu::PowerPreference) -> Self {
self.config.power_preference = preference;
self
}
pub fn debug_mode(mut self, enable: bool) -> Self {
self.config.debug_mode = enable;
self
}
pub fn max_buffer_size(mut self, size: u64) -> Self {
self.config.max_buffer_size = size;
self
}
pub fn enable_pipeline_cache(mut self, enable: bool) -> Self {
self.config.enable_pipeline_cache = enable;
self
}
pub fn preferred_workgroup_size(mut self, size: (u32, u32, u32)) -> Self {
self.config.preferred_workgroup_size = size;
self
}
pub fn build(self) -> WebGpuBackend {
WebGpuBackend::new(self.config)
}
}
#[derive(Debug)]
pub struct WebGpuMemoryManagerWrapper {
inner: Arc<RwLock<WebGpuMemoryManager>>,
}
impl MemoryManager for WebGpuMemoryManagerWrapper {
fn allocate(
&mut self,
descriptor: &BufferDescriptor,
) -> torsh_core::error::Result<crate::Buffer> {
let webgpu_buffer = self
.inner
.read()
.buffer_pool()
.get_buffer(descriptor.clone())
.map_err(|e| TorshError::BackendError(e.to_string()))?;
let handle = webgpu_buffer.handle().clone();
let buffer = crate::Buffer::new(
generate_buffer_id(),
crate::Device::new(
0,
torsh_core::device::DeviceType::Wgpu(0),
"WebGPU Device".to_string(),
crate::DeviceInfo::default(),
),
webgpu_buffer.descriptor().size as usize,
descriptor.usage.clone(),
descriptor.clone(),
handle,
);
Ok(buffer)
}
fn deallocate(&mut self, _buffer: &crate::Buffer) -> torsh_core::error::Result<()> {
Ok(())
}
fn stats(&self) -> MemoryStats {
self.inner.read().stats()
}
fn garbage_collect(&mut self) -> torsh_core::error::Result<usize> {
Ok(0)
}
fn set_pool(
&mut self,
_pool: Box<dyn crate::memory::MemoryPool>,
) -> torsh_core::error::Result<()> {
Err(TorshError::BackendError(
"WebGPU memory pool cannot be replaced".to_string(),
))
}
fn device(&self) -> &crate::Device {
static WEBGPU_DEVICE: std::sync::OnceLock<crate::Device> = std::sync::OnceLock::new();
WEBGPU_DEVICE.get_or_init(|| {
crate::Device::new(
0,
torsh_core::device::DeviceType::Wgpu(0),
"WebGPU Device".to_string(),
crate::DeviceInfo::default(),
)
})
}
fn allocate_raw(
&mut self,
_size: usize,
_alignment: usize,
) -> torsh_core::error::Result<*mut u8> {
Err(TorshError::BackendError(
"WebGPU doesn't support raw memory allocation".to_string(),
))
}
fn deallocate_raw(&mut self, _ptr: *mut u8, _size: usize) -> torsh_core::error::Result<()> {
Err(TorshError::BackendError(
"WebGPU doesn't support raw memory deallocation".to_string(),
))
}
fn supports_unified_memory(&self) -> bool {
false
}
fn allocate_unified(&mut self, _size: usize) -> torsh_core::error::Result<*mut u8> {
Err(TorshError::BackendError(
"WebGPU doesn't support unified memory allocation".to_string(),
))
}
fn deallocate_unified(&mut self, _ptr: *mut u8, _size: usize) -> torsh_core::error::Result<()> {
Err(TorshError::BackendError(
"WebGPU doesn't support unified memory deallocation".to_string(),
))
}
fn prefetch_to_device(&self, _ptr: *mut u8, _size: usize) -> torsh_core::error::Result<()> {
Ok(())
}
fn prefetch_to_host(&self, _ptr: *mut u8, _size: usize) -> torsh_core::error::Result<()> {
Ok(())
}
fn set_memory_advice(
&self,
_ptr: *mut u8,
_size: usize,
_advice: crate::memory::MemoryAdvice,
) -> torsh_core::error::Result<()> {
Ok(())
}
fn available_memory(&self) -> torsh_core::error::Result<usize> {
Ok(1024 * 1024 * 1024)
}
fn total_memory(&self) -> torsh_core::error::Result<usize> {
Ok(4 * 1024 * 1024 * 1024)
}
fn synchronize(&self) -> torsh_core::error::Result<()> {
Ok(())
}
fn defragment(&mut self) -> torsh_core::error::Result<crate::memory::DefragmentationResult> {
Ok(crate::memory::DefragmentationResult {
blocks_moved: 0,
memory_compacted: 0,
duration_ms: 0.0,
fragmentation_before: 0.0,
fragmentation_after: 0.0,
efficiency_improvement: 0.0,
success: true,
})
}
fn needs_defragmentation(&self) -> bool {
false
}
fn fragmentation_info(&self) -> crate::memory::FragmentationInfo {
crate::memory::FragmentationInfo {
overall_fragmentation: 0.0,
external_fragmentation: 0.0,
internal_fragmentation: 0.0,
free_blocks: 1,
allocated_blocks: 0,
largest_free_block: 1024 * 1024 * 1024,
smallest_free_block: 1024 * 1024 * 1024,
average_free_block: 1024 * 1024 * 1024,
total_free_memory: 1024 * 1024 * 1024,
total_allocated_memory: 0,
utilization_efficiency: 1.0,
allocation_efficiency: 1.0,
}
}
fn compact_memory(&mut self) -> torsh_core::error::Result<crate::memory::CompactionResult> {
Ok(crate::memory::CompactionResult {
allocations_moved: 0,
bytes_moved: 0,
duration_ms: 0.0,
largest_free_before: 1024 * 1024 * 1024,
largest_free_after: 1024 * 1024 * 1024,
free_blocks_before: 1,
free_blocks_after: 1,
success: true,
})
}
fn set_defragmentation_policy(&mut self, _policy: crate::memory::DefragmentationPolicy) {
}
}
pub struct WebGpuRnnOps;
impl WebGpuRnnOps {
pub fn new() -> Self {
Self
}
}
pub struct WebGpuQuantizationOps;
impl WebGpuQuantizationOps {
pub fn new() -> Self {
Self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::{BackendCore, BackendDeviceManager, BackendLifecycle, BackendOps};
use crate::BackendType;
use torsh_core::DType;
#[test]
fn test_backend_creation() {
let backend = WebGpuBackend::with_default_config();
assert_eq!(backend.name(), "WebGPU");
assert_eq!(backend.device_type(), DeviceType::Wgpu(0));
}
#[test]
fn test_backend_builder() {
let backend = WebGpuBackendBuilder::new()
.adapter_index(0)
.power_preference(wgpu::PowerPreference::HighPerformance)
.debug_mode(true)
.max_buffer_size(2 * 1024 * 1024 * 1024) .enable_pipeline_cache(true)
.preferred_workgroup_size((128, 1, 1))
.build();
assert_eq!(backend.config().adapter_index, Some(0));
assert_eq!(
backend.config().power_preference,
wgpu::PowerPreference::HighPerformance
);
assert!(backend.config().debug_mode);
assert_eq!(backend.config().max_buffer_size, 2 * 1024 * 1024 * 1024);
assert!(backend.config().enable_pipeline_cache);
assert_eq!(backend.config().preferred_workgroup_size, (128, 1, 1));
}
#[tokio::test]
async fn test_backend_availability() {
let backend = WebGpuBackend::with_default_config();
match backend.is_available() {
Ok(available) => {
if available {
println!("WebGPU backend is available");
} else {
println!("WebGPU backend is not available");
}
}
Err(e) => {
println!("Error checking WebGPU availability: {}", e);
}
}
}
#[tokio::test]
async fn test_backend_initialization() {
if cfg!(feature = "webgpu") && crate::webgpu::is_available() {
let mut backend = WebGpuBackend::with_default_config();
let result = backend.initialize().await;
if result.is_ok() {
assert!(*backend.initialized.read());
let device_result = backend.default_device();
if device_result.is_ok() {
let device = device_result.expect("operation should succeed");
assert_eq!(device.device_type(), DeviceType::Wgpu(0));
}
let shutdown_result = backend.shutdown().await;
assert!(shutdown_result.is_ok());
assert!(!*backend.initialized.read());
}
}
}
#[test]
fn test_backend_ops() {
let backend = WebGpuBackend::with_default_config();
assert_eq!(backend.backend_type(), BackendType::WebGpu);
assert!(backend.supports_op("elementwise_add"));
assert!(backend.supports_op("matmul"));
assert!(backend.supports_op("conv2d"));
assert!(!backend.supports_op("nonexistent_op"));
let ops = backend.available_ops();
assert!(!ops.is_empty());
assert!(ops.contains(&"elementwise_add"));
}
#[test]
fn test_capabilities() {
let backend = WebGpuBackend::with_default_config();
let capabilities = backend.capabilities();
assert!(capabilities.supported_dtypes.contains(&DType::F32));
assert!(capabilities.supports_async);
assert!(capabilities.supports_kernel_caching);
}
#[test]
fn test_performance_hints() {
let backend = WebGpuBackend::with_default_config();
let hints = backend.performance_hints();
assert_eq!(hints.preferred_workgroup_size, (64, 1, 1));
assert_eq!(hints.memory_alignment, 256);
assert!(hints.prefer_vectorized);
assert!(hints.prefer_async);
assert!(hints.cache_kernels);
}
}