use std::sync::Arc;
use crate::capability::{GpuRequirements, RenderCapability};
use crate::context::GraphicsContext;
use crate::features::GpuFeatures;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QueryType {
Timestamp,
Occlusion,
}
impl QueryType {
pub fn to_wgpu(self) -> wgpu::QueryType {
match self {
QueryType::Timestamp => wgpu::QueryType::Timestamp,
QueryType::Occlusion => wgpu::QueryType::Occlusion,
}
}
}
pub struct QuerySet {
query_set: wgpu::QuerySet,
query_type: QueryType,
count: u32,
}
impl QuerySet {
pub fn new(
device: &wgpu::Device,
label: Option<&str>,
query_type: QueryType,
count: u32,
) -> Self {
let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
label,
ty: query_type.to_wgpu(),
count,
});
Self {
query_set,
query_type,
count,
}
}
#[inline]
pub fn query_set(&self) -> &wgpu::QuerySet {
&self.query_set
}
#[inline]
pub fn query_type(&self) -> QueryType {
self.query_type
}
#[inline]
pub fn count(&self) -> u32 {
self.count
}
}
pub struct QueryResultBuffer {
resolve_buffer: wgpu::Buffer,
read_buffer: wgpu::Buffer,
count: u32,
}
impl QueryResultBuffer {
pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
let size = (count as u64) * std::mem::size_of::<u64>() as u64;
let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: label.map(|l| format!("{} Resolve", l)).as_deref(),
size,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: label.map(|l| format!("{} Read", l)).as_deref(),
size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
Self {
resolve_buffer,
read_buffer,
count,
}
}
#[inline]
pub fn resolve_buffer(&self) -> &wgpu::Buffer {
&self.resolve_buffer
}
#[inline]
pub fn read_buffer(&self) -> &wgpu::Buffer {
&self.read_buffer
}
#[inline]
pub fn count(&self) -> u32 {
self.count
}
pub fn resolve(
&self,
encoder: &mut wgpu::CommandEncoder,
query_set: &QuerySet,
query_range: std::ops::Range<u32>,
destination_offset: u32,
) {
encoder.resolve_query_set(
query_set.query_set(),
query_range,
&self.resolve_buffer,
(destination_offset as u64) * std::mem::size_of::<u64>() as u64,
);
}
pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
}
pub fn map_async(
&self,
) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
let slice = self.read_buffer.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
}
pub fn read_results(&self) -> Vec<u64> {
let slice = self.read_buffer.slice(..);
let data = slice.get_mapped_range();
let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
drop(data);
self.read_buffer.unmap();
results
}
}
#[derive(Debug)]
pub struct ProfileRegion {
label: String,
start_query: u32,
}
impl RenderCapability for GpuProfiler {
fn requirements() -> GpuRequirements {
GpuRequirements::new().require_features(GpuFeatures::TIMESTAMP_QUERY)
}
fn name() -> &'static str {
"GpuProfiler"
}
}
pub struct GpuProfiler {
context: Arc<GraphicsContext>,
query_set: QuerySet,
result_buffer: QueryResultBuffer,
current_query: u32,
max_queries: u32,
regions: Vec<(String, u32, u32)>,
cached_results: Vec<(String, f64)>,
timestamp_period: f32,
}
impl GpuProfiler {
pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
let timestamp_period = context.queue().get_timestamp_period();
let query_set = QuerySet::new(
context.device(),
Some("GPU Profiler Queries"),
QueryType::Timestamp,
max_queries,
);
let result_buffer =
QueryResultBuffer::new(context.device(), Some("GPU Profiler Results"), max_queries);
Self {
context,
query_set,
result_buffer,
current_query: 0,
max_queries,
regions: Vec::new(),
cached_results: Vec::new(),
timestamp_period,
}
}
pub fn begin_frame(&mut self) {
self.current_query = 0;
self.regions.clear();
}
pub fn begin_region(
&mut self,
encoder: &mut wgpu::CommandEncoder,
label: &str,
) -> Option<ProfileRegion> {
if self.current_query >= self.max_queries {
return None;
}
let start_query = self.current_query;
encoder.write_timestamp(&self.query_set.query_set, start_query);
self.current_query += 1;
Some(ProfileRegion {
label: label.to_string(),
start_query,
})
}
pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
if self.current_query >= self.max_queries {
return;
}
let end_query = self.current_query;
encoder.write_timestamp(&self.query_set.query_set, end_query);
self.current_query += 1;
self.regions
.push((region.label, region.start_query, end_query));
}
pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
if self.current_query == 0 {
return;
}
self.result_buffer
.resolve(encoder, &self.query_set, 0..self.current_query, 0);
self.result_buffer.copy_to_readable(encoder);
}
pub fn read_results(&mut self) -> &[(String, f64)] {
if self.regions.is_empty() {
return &self.cached_results;
}
let device = self.context.device();
let slice = self.result_buffer.read_buffer().slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
let _ = device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
});
if rx.recv().is_ok() {
let data = slice.get_mapped_range();
let timestamps: &[u64] = bytemuck::cast_slice(&data);
self.cached_results.clear();
for (label, start, end) in &self.regions {
let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
let duration_ns =
(end_ts.saturating_sub(start_ts)) as f64 * self.timestamp_period as f64;
let duration_ms = duration_ns / 1_000_000.0;
self.cached_results.push((label.clone(), duration_ms));
}
drop(data);
self.result_buffer.read_buffer().unmap();
}
&self.cached_results
}
pub fn try_read_results(&mut self) -> &[(String, f64)] {
if self.regions.is_empty() {
return &self.cached_results;
}
let device = self.context.device();
let slice = self.result_buffer.read_buffer().slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
let _ = device.poll(wgpu::PollType::Poll);
if let Ok(Ok(())) = rx.try_recv() {
let data = slice.get_mapped_range();
let timestamps: &[u64] = bytemuck::cast_slice(&data);
self.cached_results.clear();
for (label, start, end) in &self.regions {
let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
let duration_ns =
(end_ts.saturating_sub(start_ts)) as f64 * self.timestamp_period as f64;
let duration_ms = duration_ns / 1_000_000.0;
self.cached_results.push((label.clone(), duration_ms));
}
drop(data);
self.result_buffer.read_buffer().unmap();
}
&self.cached_results
}
#[inline]
pub fn queries_used(&self) -> u32 {
self.current_query
}
#[inline]
pub fn max_queries(&self) -> u32 {
self.max_queries
}
#[inline]
pub fn timestamp_period(&self) -> f32 {
self.timestamp_period
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_type_conversion() {
let _ = QueryType::Timestamp.to_wgpu();
let _ = QueryType::Occlusion.to_wgpu();
}
#[test]
fn test_profile_region_debug() {
let region = ProfileRegion {
label: "Test".to_string(),
start_query: 0,
};
let _ = format!("{:?}", region);
}
}