use std::sync::Arc;
use tracing::{debug, trace};
use wgpu::{CommandEncoder, Device, Queue};
use super::buffer::{GpuBuffer, GpuBufferPool};
use super::device::GpuDevice;
use super::error::{GpuError, GpuResult};
use super::pipeline::{ComputePipeline, PipelineCache};
#[derive(Debug, Clone)]
pub struct DispatchConfig {
pub label: Option<String>,
pub wait: bool,
pub timeout_ms: u64,
}
impl Default for DispatchConfig {
fn default() -> Self {
Self {
label: None,
wait: false,
timeout_ms: 0,
}
}
}
impl DispatchConfig {
pub fn wait() -> Self {
Self {
wait: true,
..Default::default()
}
}
pub fn with_label(label: impl Into<String>) -> Self {
Self {
label: Some(label.into()),
..Default::default()
}
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn with_wait(mut self, wait: bool) -> Self {
self.wait = wait;
self
}
}
pub struct GpuDispatcher {
device: Arc<GpuDevice>,
pipeline_cache: PipelineCache,
buffer_pool: GpuBufferPool,
}
impl GpuDispatcher {
pub fn new(device: Arc<GpuDevice>) -> Self {
let pipeline_cache = PipelineCache::new(device.device_arc());
let buffer_pool = GpuBufferPool::new(device.device_arc());
Self {
device,
pipeline_cache,
buffer_pool,
}
}
pub fn device(&self) -> &GpuDevice {
&self.device
}
pub fn pipeline_cache(&self) -> &PipelineCache {
&self.pipeline_cache
}
pub fn buffer_pool(&self) -> &GpuBufferPool {
&self.buffer_pool
}
pub async fn dispatch(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
) -> GpuResult<()> {
self.dispatch_with_config(pipeline, bind_group, workgroups, DispatchConfig::default())
.await
}
pub async fn dispatch_with_config(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
config: DispatchConfig,
) -> GpuResult<()> {
let limits = &self.device.info().max_workgroups;
if workgroups[0] > limits[0] || workgroups[1] > limits[1] || workgroups[2] > limits[2] {
return Err(GpuError::InvalidWorkgroupSize {
x: workgroups[0],
y: workgroups[1],
z: workgroups[2],
});
}
let label = config.label.as_deref().unwrap_or("dispatch");
debug!(
"Dispatching '{}' with workgroups [{}, {}, {}]",
label, workgroups[0], workgroups[1], workgroups[2]
);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(label),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
pub async fn dispatch_indirect(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
indirect_buffer: &GpuBuffer,
) -> GpuResult<()> {
self.dispatch_indirect_with_config(
pipeline,
bind_group,
indirect_buffer,
0,
DispatchConfig::default(),
)
.await
}
pub async fn dispatch_indirect_with_config(
&self,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
indirect_buffer: &GpuBuffer,
indirect_offset: u64,
config: DispatchConfig,
) -> GpuResult<()> {
let label = config.label.as_deref().unwrap_or("dispatch_indirect");
debug!("Dispatching indirect '{}'", label);
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(label),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups_indirect(indirect_buffer.buffer(), indirect_offset);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
pub async fn dispatch_chain(
&self,
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
) -> GpuResult<()> {
self.dispatch_chain_with_config(dispatches, DispatchConfig::default())
.await
}
pub async fn dispatch_chain_with_config(
&self,
dispatches: &[(&ComputePipeline, &wgpu::BindGroup, [u32; 3])],
config: DispatchConfig,
) -> GpuResult<()> {
if dispatches.is_empty() {
return Ok(());
}
let label = config.label.as_deref().unwrap_or("dispatch_chain");
debug!("Dispatching chain '{}' with {} kernels", label, dispatches.len());
let mut encoder = self
.device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(label),
});
for (i, (pipeline, bind_group, workgroups)) in dispatches.iter().enumerate() {
trace!(
"Chain dispatch {}: workgroups [{}, {}, {}]",
i,
workgroups[0],
workgroups[1],
workgroups[2]
);
let pass_label = format!("{}_pass_{}", label, i);
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&pass_label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(*bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
self.device.submit(encoder.finish());
if config.wait {
self.device.poll(true);
}
Ok(())
}
pub fn record_dispatch(
&self,
encoder: &mut CommandEncoder,
pipeline: &ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroups: [u32; 3],
label: Option<&str>,
) {
let pass_label = label.unwrap_or("recorded_dispatch");
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(pass_label),
timestamp_writes: None,
});
pass.set_pipeline(pipeline.pipeline());
pass.set_bind_group(0, Some(bind_group), &[]);
pass.dispatch_workgroups(workgroups[0], workgroups[1], workgroups[2]);
}
pub fn synchronize(&self) {
self.device.poll(true);
}
pub fn poll(&self) -> bool {
self.device.poll(false)
}
}
pub struct DispatchBuilder<'a> {
dispatcher: &'a GpuDispatcher,
dispatches: Vec<(Arc<ComputePipeline>, wgpu::BindGroup, [u32; 3])>,
config: DispatchConfig,
}
impl<'a> DispatchBuilder<'a> {
pub fn new(dispatcher: &'a GpuDispatcher) -> Self {
Self {
dispatcher,
dispatches: Vec::new(),
config: DispatchConfig::default(),
}
}
pub fn add(
mut self,
pipeline: Arc<ComputePipeline>,
bind_group: wgpu::BindGroup,
workgroups: [u32; 3],
) -> Self {
self.dispatches.push((pipeline, bind_group, workgroups));
self
}
pub fn config(mut self, config: DispatchConfig) -> Self {
self.config = config;
self
}
pub fn label(mut self, label: impl Into<String>) -> Self {
self.config.label = Some(label.into());
self
}
pub fn wait(mut self) -> Self {
self.config.wait = true;
self
}
pub async fn execute(self) -> GpuResult<()> {
if self.dispatches.is_empty() {
return Ok(());
}
let refs: Vec<(&ComputePipeline, &wgpu::BindGroup, [u32; 3])> = self
.dispatches
.iter()
.map(|(p, b, w)| (p.as_ref(), b, *w))
.collect();
self.dispatcher
.dispatch_chain_with_config(&refs, self.config)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dispatch_config_default() {
let config = DispatchConfig::default();
assert!(!config.wait);
assert!(config.label.is_none());
assert_eq!(config.timeout_ms, 0);
}
#[test]
fn test_dispatch_config_wait() {
let config = DispatchConfig::wait();
assert!(config.wait);
}
#[test]
fn test_dispatch_config_builder() {
let config = DispatchConfig::with_label("test")
.with_timeout(1000)
.with_wait(true);
assert_eq!(config.label.as_deref(), Some("test"));
assert_eq!(config.timeout_ms, 1000);
assert!(config.wait);
}
}