use crate::error::{NumRs2Error, Result};
use std::sync::Arc;
use wgpu::util::DeviceExt;
pub type GpuContextRef = Arc<GpuContext>;
#[repr(align(64))]
pub struct GpuContext {
device: wgpu::Device,
queue: wgpu::Queue,
shader_modules: ShaderModules,
}
struct ShaderModules {
element_wise_f32: wgpu::ShaderModule,
element_wise_f64: wgpu::ShaderModule,
reduction_f32: wgpu::ShaderModule,
reduction_f64: wgpu::ShaderModule,
matmul_f32: wgpu::ShaderModule,
matmul_f64: wgpu::ShaderModule,
}
impl GpuContext {
pub async fn new() -> Result<Self> {
let adapter = wgpu::Instance::default()
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.map_err(|e| {
NumRs2Error::RuntimeError(format!(
"Failed to find an appropriate GPU adapter: {}",
e
))
})?;
let info = adapter.get_info();
println!("Selected GPU: {} ({:?})", info.name, info.backend);
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some("NumRS2 GPU device"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
trace: wgpu::Trace::default(),
experimental_features: wgpu::ExperimentalFeatures::default(),
})
.await
.map_err(|e| {
NumRs2Error::RuntimeError(format!("Failed to create GPU device: {}", e))
})?;
let shader_modules = Self::create_shader_modules(&device)?;
Ok(Self {
device,
queue,
shader_modules,
})
}
fn create_shader_modules(device: &wgpu::Device) -> Result<ShaderModules> {
let element_wise_f32 = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Element-wise F32 Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/element_wise_f32.wgsl").into()),
});
let element_wise_f64 = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Element-wise F64 Shader stub"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/element_wise_f32.wgsl").into()),
});
let reduction_f32 = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Reduction F32 Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/reduction_f32.wgsl").into()),
});
let reduction_f64 = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Reduction F64 Shader stub"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/reduction_f32.wgsl").into()),
});
let matmul_f32 = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Matrix Multiplication F32 Shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/matmul_f32.wgsl").into()),
});
let matmul_f64 = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Matrix Multiplication F64 Shader stub"),
source: wgpu::ShaderSource::Wgsl(include_str!("shaders/matmul_f32.wgsl").into()),
});
Ok(ShaderModules {
element_wise_f32,
element_wise_f64,
reduction_f32,
reduction_f64,
matmul_f32,
matmul_f64,
})
}
pub fn device(&self) -> &wgpu::Device {
&self.device
}
pub fn queue(&self) -> &wgpu::Queue {
&self.queue
}
pub fn element_wise_f32_shader(&self) -> &wgpu::ShaderModule {
&self.shader_modules.element_wise_f32
}
pub fn element_wise_f64_shader(&self) -> &wgpu::ShaderModule {
&self.shader_modules.element_wise_f64
}
pub fn reduction_f32_shader(&self) -> &wgpu::ShaderModule {
&self.shader_modules.reduction_f32
}
pub fn reduction_f64_shader(&self) -> &wgpu::ShaderModule {
&self.shader_modules.reduction_f64
}
pub fn matmul_f32_shader(&self) -> &wgpu::ShaderModule {
&self.shader_modules.matmul_f32
}
pub fn matmul_f64_shader(&self) -> &wgpu::ShaderModule {
&self.shader_modules.matmul_f64
}
pub fn create_buffer<T: bytemuck::Pod + bytemuck::Zeroable>(
&self,
data: &[T],
usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
self.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("NumRS2 GPU Buffer"),
contents: bytemuck::cast_slice(data),
usage,
})
}
pub fn create_empty_buffer(&self, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("NumRS2 GPU Buffer"),
size,
usage,
mapped_at_creation: false,
})
}
pub fn run_compute(
&self,
compute_pipeline: &wgpu::ComputePipeline,
bind_groups: &[&wgpu::BindGroup],
workgroup_count: (u32, u32, u32),
) {
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("NumRS2 Compute Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("NumRS2 Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(compute_pipeline);
for (i, bind_group) in bind_groups.iter().enumerate() {
compute_pass.set_bind_group(i as u32, *bind_group, &[]);
}
compute_pass.dispatch_workgroups(
workgroup_count.0,
workgroup_count.1,
workgroup_count.2,
);
}
self.queue.submit(std::iter::once(encoder.finish()));
}
}
pub fn new_context() -> Result<GpuContextRef> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to create async runtime: {}", e)))?;
let context = rt.block_on(GpuContext::new())?;
Ok(Arc::new(context))
}