astrelis_render/query.rs
1//! GPU query and profiling support.
2//!
3//! This module provides wrappers for GPU queries (timestamps, occlusion) and
4//! a high-level profiler for measuring GPU execution times.
5//!
6//! # Features Required
7//!
8//! - `TIMESTAMP_QUERY` - Required for timestamp queries and GPU profiling
9//!
10//! # Example
11//!
12//! ```ignore
13//! use astrelis_render::{GpuProfiler, GraphicsContext};
14//!
15//! // Create profiler (requires TIMESTAMP_QUERY feature)
16//! let mut profiler = GpuProfiler::new(context.clone(), 256);
17//!
18//! // In render loop:
19//! profiler.begin_frame();
20//!
21//! {
22//! let region = profiler.begin_region(&mut encoder, "Shadow Pass");
23//! // ... render shadow pass ...
24//! profiler.end_region(&mut encoder, region);
25//! }
26//!
27//! profiler.resolve(&mut encoder);
28//!
29//! // Later, read results
30//! for (label, duration_ms) in profiler.read_results() {
31//! println!("{}: {:.2}ms", label, duration_ms);
32//! }
33//! ```
34
35use std::sync::Arc;
36
37use crate::capability::{GpuRequirements, RenderCapability};
38use crate::context::GraphicsContext;
39use crate::features::GpuFeatures;
40
41// =============================================================================
42// Query Types
43// =============================================================================
44
45/// Types of GPU queries.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum QueryType {
48 /// Timestamp query for measuring GPU execution time.
49 /// Requires `TIMESTAMP_QUERY` feature.
50 Timestamp,
51 /// Occlusion query for counting visible fragments.
52 Occlusion,
53}
54
55impl QueryType {
56 /// Convert to wgpu query type.
57 pub fn to_wgpu(self) -> wgpu::QueryType {
58 match self {
59 QueryType::Timestamp => wgpu::QueryType::Timestamp,
60 QueryType::Occlusion => wgpu::QueryType::Occlusion,
61 }
62 }
63}
64
65// =============================================================================
66// QuerySet
67// =============================================================================
68
69/// A wrapper around wgpu::QuerySet with metadata.
70pub struct QuerySet {
71 query_set: wgpu::QuerySet,
72 query_type: QueryType,
73 count: u32,
74}
75
76impl QuerySet {
77 /// Create a new query set.
78 ///
79 /// # Arguments
80 ///
81 /// * `device` - The wgpu device
82 /// * `label` - Optional debug label
83 /// * `query_type` - Type of queries in this set
84 /// * `count` - Number of queries in the set
85 pub fn new(
86 device: &wgpu::Device,
87 label: Option<&str>,
88 query_type: QueryType,
89 count: u32,
90 ) -> Self {
91 let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
92 label,
93 ty: query_type.to_wgpu(),
94 count,
95 });
96
97 Self {
98 query_set,
99 query_type,
100 count,
101 }
102 }
103
104 /// Get the underlying wgpu query set.
105 #[inline]
106 pub fn query_set(&self) -> &wgpu::QuerySet {
107 &self.query_set
108 }
109
110 /// Get the query type.
111 #[inline]
112 pub fn query_type(&self) -> QueryType {
113 self.query_type
114 }
115
116 /// Get the number of queries in the set.
117 #[inline]
118 pub fn count(&self) -> u32 {
119 self.count
120 }
121}
122
123// =============================================================================
124// QueryResultBuffer
125// =============================================================================
126
127/// Buffer for storing and reading query results.
128pub struct QueryResultBuffer {
129 resolve_buffer: wgpu::Buffer,
130 read_buffer: wgpu::Buffer,
131 count: u32,
132}
133
134impl QueryResultBuffer {
135 /// Create a new query result buffer.
136 ///
137 /// # Arguments
138 ///
139 /// * `device` - The wgpu device
140 /// * `label` - Optional debug label
141 /// * `count` - Number of query results to store
142 pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
143 let size = (count as u64) * std::mem::size_of::<u64>() as u64;
144
145 let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
146 label: label.map(|l| format!("{} Resolve", l)).as_deref(),
147 size,
148 usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
149 mapped_at_creation: false,
150 });
151
152 let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
153 label: label.map(|l| format!("{} Read", l)).as_deref(),
154 size,
155 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
156 mapped_at_creation: false,
157 });
158
159 Self {
160 resolve_buffer,
161 read_buffer,
162 count,
163 }
164 }
165
166 /// Get the resolve buffer (used for query resolution).
167 #[inline]
168 pub fn resolve_buffer(&self) -> &wgpu::Buffer {
169 &self.resolve_buffer
170 }
171
172 /// Get the read buffer (used for CPU readback).
173 #[inline]
174 pub fn read_buffer(&self) -> &wgpu::Buffer {
175 &self.read_buffer
176 }
177
178 /// Get the number of results this buffer can hold.
179 #[inline]
180 pub fn count(&self) -> u32 {
181 self.count
182 }
183
184 /// Resolve queries from a query set into this buffer.
185 pub fn resolve(
186 &self,
187 encoder: &mut wgpu::CommandEncoder,
188 query_set: &QuerySet,
189 query_range: std::ops::Range<u32>,
190 destination_offset: u32,
191 ) {
192 encoder.resolve_query_set(
193 query_set.query_set(),
194 query_range,
195 &self.resolve_buffer,
196 (destination_offset as u64) * std::mem::size_of::<u64>() as u64,
197 );
198 }
199
200 /// Copy resolved results to the readable buffer.
201 pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
202 let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
203 encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
204 }
205
206 /// Map the read buffer for CPU access.
207 ///
208 /// Returns a future that completes when the buffer is mapped.
209 pub fn map_async(&self) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
210 let slice = self.read_buffer.slice(..);
211 let (tx, rx) = std::sync::mpsc::channel();
212
213 slice.map_async(wgpu::MapMode::Read, move |result| {
214 let _ = tx.send(result);
215 });
216
217 async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
218 }
219
220 /// Read the query results (must be mapped first).
221 ///
222 /// Returns the raw u64 timestamps/occlusion counts.
223 pub fn read_results(&self) -> Vec<u64> {
224 let slice = self.read_buffer.slice(..);
225 let data = slice.get_mapped_range();
226 let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
227 drop(data);
228 self.read_buffer.unmap();
229 results
230 }
231}
232
233// =============================================================================
234// ProfileRegion
235// =============================================================================
236
237/// A handle to a profiling region.
238///
239/// Created by `GpuProfiler::begin_region` and consumed by `GpuProfiler::end_region`.
240#[derive(Debug)]
241pub struct ProfileRegion {
242 label: String,
243 start_query: u32,
244}
245
246// =============================================================================
247// GpuProfiler
248// =============================================================================
249
250/// High-level GPU profiler for measuring execution times.
251///
252/// This profiler uses timestamp queries to measure GPU execution time
253/// for different regions of your rendering code.
254///
255/// # Requirements
256///
257/// - Device must support `TIMESTAMP_QUERY` feature
258/// - Must call `begin_frame()` at the start of each frame
259/// - Must call `resolve()` before submitting commands
260///
261/// # Example
262///
263/// ```ignore
264/// let mut profiler = GpuProfiler::new(context.clone(), 256);
265///
266/// // Each frame:
267/// profiler.begin_frame();
268///
269/// let region = profiler.begin_region(&mut encoder, "My Pass");
270/// // ... do rendering ...
271/// profiler.end_region(&mut encoder, region);
272///
273/// profiler.resolve(&mut encoder);
274///
275/// // Read results (may be from previous frame)
276/// for (label, duration_ms) in profiler.read_results() {
277/// println!("{}: {:.2}ms", label, duration_ms);
278/// }
279/// ```
280impl RenderCapability for GpuProfiler {
281 fn requirements() -> GpuRequirements {
282 GpuRequirements::new().require_features(GpuFeatures::TIMESTAMP_QUERY)
283 }
284
285 fn name() -> &'static str {
286 "GpuProfiler"
287 }
288}
289
290pub struct GpuProfiler {
291 context: Arc<GraphicsContext>,
292 query_set: QuerySet,
293 result_buffer: QueryResultBuffer,
294 /// Current query index for the frame
295 current_query: u32,
296 /// Maximum queries per frame
297 max_queries: u32,
298 /// Regions from the current frame (label, start_query, end_query)
299 regions: Vec<(String, u32, u32)>,
300 /// Cached results from the previous frame
301 cached_results: Vec<(String, f64)>,
302 /// Timestamp period in nanoseconds per tick
303 timestamp_period: f32,
304}
305
306impl GpuProfiler {
307 /// Create a new GPU profiler.
308 ///
309 /// # Arguments
310 ///
311 /// * `context` - Graphics context (must support TIMESTAMP_QUERY)
312 /// * `max_queries` - Maximum number of timestamp queries per frame
313 ///
314 /// # Panics
315 ///
316 /// Panics if the device doesn't support timestamp queries.
317 pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
318 let timestamp_period = context.queue().get_timestamp_period();
319
320 let query_set = QuerySet::new(
321 context.device(),
322 Some("GPU Profiler Queries"),
323 QueryType::Timestamp,
324 max_queries,
325 );
326
327 let result_buffer = QueryResultBuffer::new(
328 context.device(),
329 Some("GPU Profiler Results"),
330 max_queries,
331 );
332
333 Self {
334 context,
335 query_set,
336 result_buffer,
337 current_query: 0,
338 max_queries,
339 regions: Vec::new(),
340 cached_results: Vec::new(),
341 timestamp_period,
342 }
343 }
344
345 /// Begin a new frame.
346 ///
347 /// Call this at the start of each frame before recording any regions.
348 pub fn begin_frame(&mut self) {
349 self.current_query = 0;
350 self.regions.clear();
351 }
352
353 /// Begin a profiling region.
354 ///
355 /// # Arguments
356 ///
357 /// * `encoder` - Command encoder to write the timestamp
358 /// * `label` - Human-readable label for this region
359 ///
360 /// # Returns
361 ///
362 /// A `ProfileRegion` handle that must be passed to `end_region`.
363 pub fn begin_region(
364 &mut self,
365 encoder: &mut wgpu::CommandEncoder,
366 label: &str,
367 ) -> Option<ProfileRegion> {
368 if self.current_query >= self.max_queries {
369 return None;
370 }
371
372 let start_query = self.current_query;
373 encoder.write_timestamp(&self.query_set.query_set, start_query);
374 self.current_query += 1;
375
376 Some(ProfileRegion {
377 label: label.to_string(),
378 start_query,
379 })
380 }
381
382 /// End a profiling region.
383 ///
384 /// # Arguments
385 ///
386 /// * `encoder` - Command encoder to write the timestamp
387 /// * `region` - The region handle from `begin_region`
388 pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
389 if self.current_query >= self.max_queries {
390 return;
391 }
392
393 let end_query = self.current_query;
394 encoder.write_timestamp(&self.query_set.query_set, end_query);
395 self.current_query += 1;
396
397 self.regions
398 .push((region.label, region.start_query, end_query));
399 }
400
401 /// Resolve all queries from this frame.
402 ///
403 /// Call this after all regions have been recorded, before submitting commands.
404 pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
405 if self.current_query == 0 {
406 return;
407 }
408
409 self.result_buffer.resolve(
410 encoder,
411 &self.query_set,
412 0..self.current_query,
413 0,
414 );
415 self.result_buffer.copy_to_readable(encoder);
416 }
417
418 /// Read profiling results synchronously.
419 ///
420 /// This blocks until the results are available from the GPU.
421 /// For non-blocking reads, consider using double-buffering or
422 /// reading results from the previous frame.
423 ///
424 /// # Returns
425 ///
426 /// A vector of (label, duration_ms) pairs for each completed region.
427 pub fn read_results(&mut self) -> &[(String, f64)] {
428 if self.regions.is_empty() {
429 return &self.cached_results;
430 }
431
432 let device = self.context.device();
433
434 // Map the buffer
435 let slice = self.result_buffer.read_buffer().slice(..);
436 let (tx, rx) = std::sync::mpsc::channel();
437
438 slice.map_async(wgpu::MapMode::Read, move |result| {
439 let _ = tx.send(result);
440 });
441
442 // Wait for the buffer to be mapped (blocking)
443 let _ = device.poll(wgpu::PollType::Wait {
444 submission_index: None,
445 timeout: None,
446 });
447
448 // Wait for the callback
449 if rx.recv().is_ok() {
450 let data = slice.get_mapped_range();
451 let timestamps: &[u64] = bytemuck::cast_slice(&data);
452
453 self.cached_results.clear();
454
455 for (label, start, end) in &self.regions {
456 let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
457 let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
458
459 // Convert ticks to milliseconds
460 let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
461 * self.timestamp_period as f64;
462 let duration_ms = duration_ns / 1_000_000.0;
463
464 self.cached_results.push((label.clone(), duration_ms));
465 }
466
467 drop(data);
468 self.result_buffer.read_buffer().unmap();
469 }
470
471 &self.cached_results
472 }
473
474 /// Try to read profiling results without blocking.
475 ///
476 /// Returns None if the results are not yet available.
477 /// This is useful when you want to display results from the previous frame.
478 ///
479 /// # Returns
480 ///
481 /// Some reference to the cached results if new data was read, or the existing cached results.
482 pub fn try_read_results(&mut self) -> &[(String, f64)] {
483 if self.regions.is_empty() {
484 return &self.cached_results;
485 }
486
487 let device = self.context.device();
488
489 // Try to map the buffer
490 let slice = self.result_buffer.read_buffer().slice(..);
491 let (tx, rx) = std::sync::mpsc::channel();
492
493 slice.map_async(wgpu::MapMode::Read, move |result| {
494 let _ = tx.send(result);
495 });
496
497 // Non-blocking poll
498 let _ = device.poll(wgpu::PollType::Poll);
499
500 // Check if mapping succeeded
501 if let Ok(Ok(())) = rx.try_recv() {
502 let data = slice.get_mapped_range();
503 let timestamps: &[u64] = bytemuck::cast_slice(&data);
504
505 self.cached_results.clear();
506
507 for (label, start, end) in &self.regions {
508 let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
509 let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
510
511 // Convert ticks to milliseconds
512 let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
513 * self.timestamp_period as f64;
514 let duration_ms = duration_ns / 1_000_000.0;
515
516 self.cached_results.push((label.clone(), duration_ms));
517 }
518
519 drop(data);
520 self.result_buffer.read_buffer().unmap();
521 }
522
523 &self.cached_results
524 }
525
526 /// Get the number of queries used this frame.
527 #[inline]
528 pub fn queries_used(&self) -> u32 {
529 self.current_query
530 }
531
532 /// Get the maximum queries per frame.
533 #[inline]
534 pub fn max_queries(&self) -> u32 {
535 self.max_queries
536 }
537
538 /// Get the timestamp period in nanoseconds per tick.
539 #[inline]
540 pub fn timestamp_period(&self) -> f32 {
541 self.timestamp_period
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_query_type_conversion() {
551 // Just verify conversion doesn't panic
552 let _ = QueryType::Timestamp.to_wgpu();
553 let _ = QueryType::Occlusion.to_wgpu();
554 }
555
556 #[test]
557 fn test_profile_region_debug() {
558 let region = ProfileRegion {
559 label: "Test".to_string(),
560 start_query: 0,
561 };
562 // Just ensure Debug is implemented
563 let _ = format!("{:?}", region);
564 }
565}