astrelis_render/query.rs
1//! GPU query and profiling support.
2//!
3//! This module provides wrappers for GPU queries (timestamps, occlusion) and
4//! a high-level profiler for measuring GPU execution times.
5//!
6//! # Features Required
7//!
8//! - `TIMESTAMP_QUERY` - Required for timestamp queries and GPU profiling
9//!
10//! # Example
11//!
12//! ```ignore
13//! use astrelis_render::{GpuProfiler, GraphicsContext, GraphicsContextExt};
14//!
15//! // Create profiler (requires TIMESTAMP_QUERY feature)
16//! let mut profiler = GpuProfiler::new(context.clone(), 256);
17//!
18//! // In render loop:
19//! profiler.begin_frame();
20//!
21//! {
22//! let region = profiler.begin_region(&mut encoder, "Shadow Pass");
23//! // ... render shadow pass ...
24//! profiler.end_region(&mut encoder, region);
25//! }
26//!
27//! profiler.resolve(&mut encoder);
28//!
29//! // Later, read results
30//! for (label, duration_ms) in profiler.read_results() {
31//! println!("{}: {:.2}ms", label, duration_ms);
32//! }
33//! ```
34
35use std::sync::Arc;
36
37use crate::context::GraphicsContext;
38use crate::extension::GraphicsContextExt;
39
40// =============================================================================
41// Query Types
42// =============================================================================
43
44/// Types of GPU queries.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum QueryType {
47 /// Timestamp query for measuring GPU execution time.
48 /// Requires `TIMESTAMP_QUERY` feature.
49 Timestamp,
50 /// Occlusion query for counting visible fragments.
51 Occlusion,
52}
53
54impl QueryType {
55 /// Convert to wgpu query type.
56 pub fn to_wgpu(self) -> wgpu::QueryType {
57 match self {
58 QueryType::Timestamp => wgpu::QueryType::Timestamp,
59 QueryType::Occlusion => wgpu::QueryType::Occlusion,
60 }
61 }
62}
63
64// =============================================================================
65// QuerySet
66// =============================================================================
67
68/// A wrapper around wgpu::QuerySet with metadata.
69pub struct QuerySet {
70 query_set: wgpu::QuerySet,
71 query_type: QueryType,
72 count: u32,
73}
74
75impl QuerySet {
76 /// Create a new query set.
77 ///
78 /// # Arguments
79 ///
80 /// * `device` - The wgpu device
81 /// * `label` - Optional debug label
82 /// * `query_type` - Type of queries in this set
83 /// * `count` - Number of queries in the set
84 pub fn new(
85 device: &wgpu::Device,
86 label: Option<&str>,
87 query_type: QueryType,
88 count: u32,
89 ) -> Self {
90 let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
91 label,
92 ty: query_type.to_wgpu(),
93 count,
94 });
95
96 Self {
97 query_set,
98 query_type,
99 count,
100 }
101 }
102
103 /// Get the underlying wgpu query set.
104 #[inline]
105 pub fn query_set(&self) -> &wgpu::QuerySet {
106 &self.query_set
107 }
108
109 /// Get the query type.
110 #[inline]
111 pub fn query_type(&self) -> QueryType {
112 self.query_type
113 }
114
115 /// Get the number of queries in the set.
116 #[inline]
117 pub fn count(&self) -> u32 {
118 self.count
119 }
120}
121
122// =============================================================================
123// QueryResultBuffer
124// =============================================================================
125
126/// Buffer for storing and reading query results.
127pub struct QueryResultBuffer {
128 resolve_buffer: wgpu::Buffer,
129 read_buffer: wgpu::Buffer,
130 count: u32,
131}
132
133impl QueryResultBuffer {
134 /// Create a new query result buffer.
135 ///
136 /// # Arguments
137 ///
138 /// * `device` - The wgpu device
139 /// * `label` - Optional debug label
140 /// * `count` - Number of query results to store
141 pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
142 let size = (count as u64) * std::mem::size_of::<u64>() as u64;
143
144 let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
145 label: label.map(|l| format!("{} Resolve", l)).as_deref(),
146 size,
147 usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
148 mapped_at_creation: false,
149 });
150
151 let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
152 label: label.map(|l| format!("{} Read", l)).as_deref(),
153 size,
154 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
155 mapped_at_creation: false,
156 });
157
158 Self {
159 resolve_buffer,
160 read_buffer,
161 count,
162 }
163 }
164
165 /// Get the resolve buffer (used for query resolution).
166 #[inline]
167 pub fn resolve_buffer(&self) -> &wgpu::Buffer {
168 &self.resolve_buffer
169 }
170
171 /// Get the read buffer (used for CPU readback).
172 #[inline]
173 pub fn read_buffer(&self) -> &wgpu::Buffer {
174 &self.read_buffer
175 }
176
177 /// Get the number of results this buffer can hold.
178 #[inline]
179 pub fn count(&self) -> u32 {
180 self.count
181 }
182
183 /// Resolve queries from a query set into this buffer.
184 pub fn resolve(
185 &self,
186 encoder: &mut wgpu::CommandEncoder,
187 query_set: &QuerySet,
188 query_range: std::ops::Range<u32>,
189 destination_offset: u32,
190 ) {
191 encoder.resolve_query_set(
192 query_set.query_set(),
193 query_range,
194 &self.resolve_buffer,
195 (destination_offset as u64) * std::mem::size_of::<u64>() as u64,
196 );
197 }
198
199 /// Copy resolved results to the readable buffer.
200 pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
201 let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
202 encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
203 }
204
205 /// Map the read buffer for CPU access.
206 ///
207 /// Returns a future that completes when the buffer is mapped.
208 pub fn map_async(&self) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
209 let slice = self.read_buffer.slice(..);
210 let (tx, rx) = std::sync::mpsc::channel();
211
212 slice.map_async(wgpu::MapMode::Read, move |result| {
213 let _ = tx.send(result);
214 });
215
216 async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
217 }
218
219 /// Read the query results (must be mapped first).
220 ///
221 /// Returns the raw u64 timestamps/occlusion counts.
222 pub fn read_results(&self) -> Vec<u64> {
223 let slice = self.read_buffer.slice(..);
224 let data = slice.get_mapped_range();
225 let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
226 drop(data);
227 self.read_buffer.unmap();
228 results
229 }
230}
231
232// =============================================================================
233// ProfileRegion
234// =============================================================================
235
236/// A handle to a profiling region.
237///
238/// Created by `GpuProfiler::begin_region` and consumed by `GpuProfiler::end_region`.
239#[derive(Debug)]
240pub struct ProfileRegion {
241 label: String,
242 start_query: u32,
243}
244
245// =============================================================================
246// GpuProfiler
247// =============================================================================
248
249/// High-level GPU profiler for measuring execution times.
250///
251/// This profiler uses timestamp queries to measure GPU execution time
252/// for different regions of your rendering code.
253///
254/// # Requirements
255///
256/// - Device must support `TIMESTAMP_QUERY` feature
257/// - Must call `begin_frame()` at the start of each frame
258/// - Must call `resolve()` before submitting commands
259///
260/// # Example
261///
262/// ```ignore
263/// let mut profiler = GpuProfiler::new(context.clone(), 256);
264///
265/// // Each frame:
266/// profiler.begin_frame();
267///
268/// let region = profiler.begin_region(&mut encoder, "My Pass");
269/// // ... do rendering ...
270/// profiler.end_region(&mut encoder, region);
271///
272/// profiler.resolve(&mut encoder);
273///
274/// // Read results (may be from previous frame)
275/// for (label, duration_ms) in profiler.read_results() {
276/// println!("{}: {:.2}ms", label, duration_ms);
277/// }
278/// ```
279pub struct GpuProfiler {
280 context: Arc<GraphicsContext>,
281 query_set: QuerySet,
282 result_buffer: QueryResultBuffer,
283 /// Current query index for the frame
284 current_query: u32,
285 /// Maximum queries per frame
286 max_queries: u32,
287 /// Regions from the current frame (label, start_query, end_query)
288 regions: Vec<(String, u32, u32)>,
289 /// Cached results from the previous frame
290 cached_results: Vec<(String, f64)>,
291 /// Timestamp period in nanoseconds per tick
292 timestamp_period: f32,
293}
294
295impl GpuProfiler {
296 /// Create a new GPU profiler.
297 ///
298 /// # Arguments
299 ///
300 /// * `context` - Graphics context (must support TIMESTAMP_QUERY)
301 /// * `max_queries` - Maximum number of timestamp queries per frame
302 ///
303 /// # Panics
304 ///
305 /// Panics if the device doesn't support timestamp queries.
306 pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
307 let timestamp_period = context.queue().get_timestamp_period();
308
309 let query_set = QuerySet::new(
310 context.device(),
311 Some("GPU Profiler Queries"),
312 QueryType::Timestamp,
313 max_queries,
314 );
315
316 let result_buffer = QueryResultBuffer::new(
317 context.device(),
318 Some("GPU Profiler Results"),
319 max_queries,
320 );
321
322 Self {
323 context,
324 query_set,
325 result_buffer,
326 current_query: 0,
327 max_queries,
328 regions: Vec::new(),
329 cached_results: Vec::new(),
330 timestamp_period,
331 }
332 }
333
334 /// Begin a new frame.
335 ///
336 /// Call this at the start of each frame before recording any regions.
337 pub fn begin_frame(&mut self) {
338 self.current_query = 0;
339 self.regions.clear();
340 }
341
342 /// Begin a profiling region.
343 ///
344 /// # Arguments
345 ///
346 /// * `encoder` - Command encoder to write the timestamp
347 /// * `label` - Human-readable label for this region
348 ///
349 /// # Returns
350 ///
351 /// A `ProfileRegion` handle that must be passed to `end_region`.
352 pub fn begin_region(
353 &mut self,
354 encoder: &mut wgpu::CommandEncoder,
355 label: &str,
356 ) -> Option<ProfileRegion> {
357 if self.current_query >= self.max_queries {
358 return None;
359 }
360
361 let start_query = self.current_query;
362 encoder.write_timestamp(&self.query_set.query_set, start_query);
363 self.current_query += 1;
364
365 Some(ProfileRegion {
366 label: label.to_string(),
367 start_query,
368 })
369 }
370
371 /// End a profiling region.
372 ///
373 /// # Arguments
374 ///
375 /// * `encoder` - Command encoder to write the timestamp
376 /// * `region` - The region handle from `begin_region`
377 pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
378 if self.current_query >= self.max_queries {
379 return;
380 }
381
382 let end_query = self.current_query;
383 encoder.write_timestamp(&self.query_set.query_set, end_query);
384 self.current_query += 1;
385
386 self.regions
387 .push((region.label, region.start_query, end_query));
388 }
389
390 /// Resolve all queries from this frame.
391 ///
392 /// Call this after all regions have been recorded, before submitting commands.
393 pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
394 if self.current_query == 0 {
395 return;
396 }
397
398 self.result_buffer.resolve(
399 encoder,
400 &self.query_set,
401 0..self.current_query,
402 0,
403 );
404 self.result_buffer.copy_to_readable(encoder);
405 }
406
407 /// Read profiling results synchronously.
408 ///
409 /// This blocks until the results are available from the GPU.
410 /// For non-blocking reads, consider using double-buffering or
411 /// reading results from the previous frame.
412 ///
413 /// # Returns
414 ///
415 /// A vector of (label, duration_ms) pairs for each completed region.
416 pub fn read_results(&mut self) -> &[(String, f64)] {
417 if self.regions.is_empty() {
418 return &self.cached_results;
419 }
420
421 let device = self.context.device();
422
423 // Map the buffer
424 let slice = self.result_buffer.read_buffer().slice(..);
425 let (tx, rx) = std::sync::mpsc::channel();
426
427 slice.map_async(wgpu::MapMode::Read, move |result| {
428 let _ = tx.send(result);
429 });
430
431 // Wait for the buffer to be mapped (blocking)
432 let _ = device.poll(wgpu::PollType::Wait {
433 submission_index: None,
434 timeout: None,
435 });
436
437 // Wait for the callback
438 if rx.recv().is_ok() {
439 let data = slice.get_mapped_range();
440 let timestamps: &[u64] = bytemuck::cast_slice(&data);
441
442 self.cached_results.clear();
443
444 for (label, start, end) in &self.regions {
445 let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
446 let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
447
448 // Convert ticks to milliseconds
449 let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
450 * self.timestamp_period as f64;
451 let duration_ms = duration_ns / 1_000_000.0;
452
453 self.cached_results.push((label.clone(), duration_ms));
454 }
455
456 drop(data);
457 self.result_buffer.read_buffer().unmap();
458 }
459
460 &self.cached_results
461 }
462
463 /// Try to read profiling results without blocking.
464 ///
465 /// Returns None if the results are not yet available.
466 /// This is useful when you want to display results from the previous frame.
467 ///
468 /// # Returns
469 ///
470 /// Some reference to the cached results if new data was read, or the existing cached results.
471 pub fn try_read_results(&mut self) -> &[(String, f64)] {
472 if self.regions.is_empty() {
473 return &self.cached_results;
474 }
475
476 let device = self.context.device();
477
478 // Try to map the buffer
479 let slice = self.result_buffer.read_buffer().slice(..);
480 let (tx, rx) = std::sync::mpsc::channel();
481
482 slice.map_async(wgpu::MapMode::Read, move |result| {
483 let _ = tx.send(result);
484 });
485
486 // Non-blocking poll
487 let _ = device.poll(wgpu::PollType::Poll);
488
489 // Check if mapping succeeded
490 if let Ok(Ok(())) = rx.try_recv() {
491 let data = slice.get_mapped_range();
492 let timestamps: &[u64] = bytemuck::cast_slice(&data);
493
494 self.cached_results.clear();
495
496 for (label, start, end) in &self.regions {
497 let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
498 let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
499
500 // Convert ticks to milliseconds
501 let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
502 * self.timestamp_period as f64;
503 let duration_ms = duration_ns / 1_000_000.0;
504
505 self.cached_results.push((label.clone(), duration_ms));
506 }
507
508 drop(data);
509 self.result_buffer.read_buffer().unmap();
510 }
511
512 &self.cached_results
513 }
514
515 /// Get the number of queries used this frame.
516 #[inline]
517 pub fn queries_used(&self) -> u32 {
518 self.current_query
519 }
520
521 /// Get the maximum queries per frame.
522 #[inline]
523 pub fn max_queries(&self) -> u32 {
524 self.max_queries
525 }
526
527 /// Get the timestamp period in nanoseconds per tick.
528 #[inline]
529 pub fn timestamp_period(&self) -> f32 {
530 self.timestamp_period
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_query_type_conversion() {
540 // Just verify conversion doesn't panic
541 let _ = QueryType::Timestamp.to_wgpu();
542 let _ = QueryType::Occlusion.to_wgpu();
543 }
544
545 #[test]
546 fn test_profile_region_debug() {
547 let region = ProfileRegion {
548 label: "Test".to_string(),
549 start_query: 0,
550 };
551 // Just ensure Debug is implemented
552 let _ = format!("{:?}", region);
553 }
554}