pub(crate) mod dispatch;
mod operations;
use super::{BufferId, GpuCommandBatch};
use std::sync::Arc;
impl GpuCommandBatch {
pub async fn execute(&mut self) -> Result<(), String> {
contract_pre_single_encoder_batch!();
let mut local_cache = dispatch::PipelineCache::new();
let result = self.execute_inner(&mut local_cache);
contract_post_single_encoder_batch!(result);
result
}
pub async fn execute_with_cache(
&mut self,
cache: &mut dispatch::PipelineCache,
) -> Result<(), String> {
self.execute_inner(cache)
}
fn execute_inner(
&mut self,
pipeline_cache: &mut dispatch::PipelineCache,
) -> Result<(), String> {
for (buffer_id, buffer_info) in &mut self.buffers {
if buffer_info.gpu_buffer.is_some() {
continue;
}
let size_bytes = (buffer_info.size * std::mem::size_of::<f32>()) as u64;
let gpu_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("Buffer {:?}", buffer_id)),
size: size_bytes,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
buffer_info.gpu_buffer = Some(Arc::new(gpu_buffer));
}
for buffer_info in self.buffers.values() {
if let Some(data) = &buffer_info.data {
if let Some(gpu_buffer) = &buffer_info.gpu_buffer {
self.device.queue.write_buffer(gpu_buffer, 0, bytemuck::cast_slice(data));
}
}
}
let mut encoder =
self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Batch Encoder"),
});
for op in &self.operations {
self.encode_operation(op, &mut encoder, pipeline_cache)?;
}
self.device.queue.submit(Some(encoder.finish()));
Ok(())
}
pub async fn read(&self, buffer_id: BufferId) -> Result<Vec<f32>, String> {
contract_pre_read!();
let buffer_info = self.buffers.get(&buffer_id).ok_or("Invalid buffer ID")?;
let gpu_buffer = buffer_info
.gpu_buffer
.as_ref()
.ok_or("Buffer not executed yet - call execute() first")?;
let size_bytes = (buffer_info.size * std::mem::size_of::<f32>()) as u64;
let staging_buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: size_bytes,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder =
self.device.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Read Encoder"),
});
encoder.copy_buffer_to_buffer(gpu_buffer, 0, &staging_buffer, 0, size_bytes);
self.device.queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).ok();
});
self.device
.device
.poll(wgpu::PollType::Wait { submission_index: None, timeout: None })
.map_err(|e| format!("GPU poll failed: {:?}", e))?;
receiver
.receive()
.await
.ok_or("Failed to receive mapping result")?
.map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
let data = {
let mapped_range = buffer_slice.get_mapped_range();
let float_data: &[f32] = bytemuck::cast_slice(&mapped_range);
float_data.to_vec()
};
staging_buffer.unmap();
Ok(data)
}
}