use std::fmt::Debug;
use std::{mem::size_of, ops::Deref};
use futures::executor::block_on;
use crate::QUERYSET_BUFFER_USAGE;
use crate::{GpuError, ScopedBufferView};
pub struct QuerySet {
pub(crate) inner: wgpu::QuerySet,
pub(crate) buffer: wgpu::Buffer,
pub(crate) ty: wgpu::QueryType,
}
impl Deref for QuerySet {
type Target = wgpu::QuerySet;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl Debug for QuerySet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "agpu QuerySet")
}
}
impl QuerySet {
pub(crate) fn new_timestamp(device: &wgpu::Device, count: u32) -> Self {
let ty = wgpu::QueryType::Timestamp;
let label = Some("Timestamp QuerySet");
Self::new_impl(device, ty, label, count)
}
pub(crate) fn new_stats(device: &wgpu::Device, count: u32) -> Self {
let all = wgpu::PipelineStatisticsTypes::all();
let ty = wgpu::QueryType::PipelineStatistics(all);
let label = Some("PipelineStatistics QuerySet");
Self::new_impl(device, ty, label, count)
}
fn new_impl(
device: &wgpu::Device,
ty: wgpu::QueryType,
label: Option<&str>,
count: u32,
) -> Self {
let query_size = query_ty_size(ty);
let buffer_size = query_size * count * size_of::<u64>() as u32;
let inner = device.create_query_set(&wgpu::QuerySetDescriptor { label, ty, count });
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label,
size: buffer_size as u64,
usage: QUERYSET_BUFFER_USAGE,
mapped_at_creation: false,
});
QuerySet { inner, buffer, ty }
}
pub fn query_size(&self) -> u32 {
query_ty_size(self.ty)
}
pub fn resolve(&self, count: u32, encoder: &mut wgpu::CommandEncoder) {
encoder.resolve_query_set(self, 0..count * self.query_size(), &self.buffer, 0)
}
pub fn get(&self, device: &wgpu::Device, count: u32) -> Result<Vec<u64>, GpuError> {
if count == 0 {
return Ok(Vec::new());
}
let slice = self
.buffer
.slice(..size_of::<u64>() as u64 * (count * self.query_size()) as u64);
let mapping = slice.map_async(wgpu::MapMode::Read);
device.poll(wgpu::Maintain::Wait);
block_on(mapping).map_err(|_| GpuError::BufferAsyncError)?;
let view = slice.get_mapped_range();
let view = ScopedBufferView::new(&self.buffer, view);
let timestamps: &[u64] = bytemuck::cast_slice(&view);
Ok(timestamps.to_vec())
}
}
fn query_ty_size(ty: wgpu::QueryType) -> u32 {
match ty {
wgpu::QueryType::PipelineStatistics(ty) => num_bits_set(ty.bits()),
_ => 1,
}
}
fn num_bits_set<N>(n: N) -> u32
where
N: num_traits::PrimInt,
{
n.count_ones()
}