use std::sync::Arc;
use std::time::Duration;
use wgpu::{
Adapter, Buffer, BufferDescriptor, BufferUsages, CommandEncoder, ComputePipeline, Device,
DeviceDescriptor, Instance, Limits, PowerPreference, Queue, RequestAdapterOptions,
};
pub struct GpuDevice {
pub device: Arc<Device>,
pub queue: Arc<Queue>,
pub adapter: Adapter,
pub limits: Limits,
pub subgroups_supported: bool,
pub min_subgroup_size: u32,
pub max_subgroup_size: u32,
}
impl GpuDevice {
pub fn new() -> Option<Self> {
pollster::block_on(Self::new_async())
}
async fn new_async() -> Option<Self> {
let instance = Instance::default();
let adapter = instance
.request_adapter(&RequestAdapterOptions {
power_preference: PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok()?;
let _info = adapter.get_info();
let limits = adapter.limits();
let adapter_features = adapter.features();
let subgroups_supported = adapter_features.contains(wgpu::Features::SUBGROUP);
let required_features = if subgroups_supported {
wgpu::Features::SUBGROUP
} else {
wgpu::Features::empty()
};
let (device, queue) = match adapter
.request_device(&DeviceDescriptor {
label: Some("TreeBoost GPU Device"),
required_features,
required_limits: Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
trace: wgpu::Trace::Off,
experimental_features: wgpu::ExperimentalFeatures::default(),
})
.await
{
Ok(result) => result,
Err(_) => return None,
};
let (min_subgroup_size, max_subgroup_size) = if subgroups_supported {
(limits.min_subgroup_size, limits.max_subgroup_size)
} else {
(0, 0)
};
Some(Self {
device: Arc::new(device),
queue: Arc::new(queue),
adapter,
limits,
subgroups_supported,
min_subgroup_size,
max_subgroup_size,
})
}
pub fn name(&self) -> String {
self.adapter.get_info().name
}
pub fn backend(&self) -> wgpu::Backend {
self.adapter.get_info().backend
}
pub fn create_storage_buffer(&self, label: &str, size: u64, read_write: bool) -> Buffer {
let usage = if read_write {
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC
} else {
BufferUsages::STORAGE | BufferUsages::COPY_DST
};
self.device.create_buffer(&BufferDescriptor {
label: Some(label),
size,
usage,
mapped_at_creation: false,
})
}
pub fn create_uniform_buffer(&self, label: &str, size: u64) -> Buffer {
self.device.create_buffer(&BufferDescriptor {
label: Some(label),
size,
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
pub fn create_staging_buffer(&self, label: &str, size: u64) -> Buffer {
self.device.create_buffer(&BufferDescriptor {
label: Some(label),
size,
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
pub fn write_buffer<T: bytemuck::Pod>(&self, buffer: &Buffer, data: &[T]) {
self.queue
.write_buffer(buffer, 0, bytemuck::cast_slice(data));
}
pub fn create_encoder(&self, label: &str) -> CommandEncoder {
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some(label) })
}
pub fn submit_and_wait(&self, encoder: CommandEncoder) {
let submission = self.queue.submit(std::iter::once(encoder.finish()));
let _ = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(submission),
timeout: Some(Duration::from_secs(60)),
});
}
pub fn submit_async(&self, encoder: CommandEncoder) -> wgpu::SubmissionIndex {
self.queue.submit(std::iter::once(encoder.finish()))
}
pub fn wait_for_submission(&self, submission: wgpu::SubmissionIndex) {
let _ = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(submission),
timeout: Some(Duration::from_secs(60)),
});
}
pub fn poll(&self) -> bool {
self.device
.poll(wgpu::PollType::Poll)
.map(|status| status.is_queue_empty())
.unwrap_or(false)
}
pub fn read_buffer<T: bytemuck::Pod>(&self, staging: &Buffer, output: &mut [T]) {
let slice = staging.slice(..);
slice.map_async(wgpu::MapMode::Read, |_| {});
let _ = self.device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: Some(Duration::from_secs(60)),
});
{
let data = slice.get_mapped_range();
let src: &[T] = bytemuck::cast_slice(&data);
output.copy_from_slice(&src[..output.len()]);
}
staging.unmap();
}
pub fn read_buffer_partial<T: bytemuck::Pod>(&self, staging: &Buffer, output: &mut [T]) {
let byte_size = (output.len() * std::mem::size_of::<T>()) as u64;
let slice = staging.slice(..byte_size);
slice.map_async(wgpu::MapMode::Read, |_| {});
let _ = self.device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: Some(Duration::from_secs(60)),
});
{
let data = slice.get_mapped_range();
let src: &[T] = bytemuck::cast_slice(&data);
output.copy_from_slice(src);
}
staging.unmap();
}
pub fn create_compute_pipeline(
&self,
label: &str,
shader_source: &str,
entry_point: &str,
) -> ComputePipeline {
let shader = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(label),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
self.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: None, module: &shader,
entry_point: Some(entry_point),
compilation_options: Default::default(),
cache: None,
})
}
pub fn try_create_compute_pipeline(
&self,
label: &str,
shader_source: &str,
entry_point: &str,
) -> Option<ComputePipeline> {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
self.create_compute_pipeline(label, shader_source, entry_point)
}))
.ok()
}
pub fn max_workgroup_size(&self) -> u32 {
self.limits.max_compute_workgroup_size_x
}
pub fn max_storage_buffer_size(&self) -> u64 {
self.limits.max_storage_buffer_binding_size as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_device_creation() {
if let Some(device) = GpuDevice::new() {
println!(
"GPU device created: {} ({:?})",
device.name(),
device.backend()
);
assert!(device.max_workgroup_size() >= 256);
} else {
println!("No GPU available, skipping test");
}
}
}