astrelis_render/query.rs
1//! GPU query and profiling support.
2//!
3//! This module provides wrappers for GPU queries (timestamps, occlusion) and
4//! a high-level profiler for measuring GPU execution times.
5//!
6//! # Features Required
7//!
8//! - `TIMESTAMP_QUERY` - Required for timestamp queries and GPU profiling
9//!
10//! # Example
11//!
12//! ```ignore
13//! use astrelis_render::{GpuProfiler, GraphicsContext};
14//!
15//! // Create profiler (requires TIMESTAMP_QUERY feature)
16//! let mut profiler = GpuProfiler::new(context.clone(), 256);
17//!
18//! // In render loop:
19//! profiler.begin_frame();
20//!
21//! {
22//! let region = profiler.begin_region(&mut encoder, "Shadow Pass");
23//! // ... render shadow pass ...
24//! profiler.end_region(&mut encoder, region);
25//! }
26//!
27//! profiler.resolve(&mut encoder);
28//!
29//! // Later, read results
30//! for (label, duration_ms) in profiler.read_results() {
31//! println!("{}: {:.2}ms", label, duration_ms);
32//! }
33//! ```
34
35use std::sync::Arc;
36
37use crate::capability::{GpuRequirements, RenderCapability};
38use crate::context::GraphicsContext;
39use crate::features::GpuFeatures;
40
41// =============================================================================
42// Query Types
43// =============================================================================
44
45/// Types of GPU queries.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum QueryType {
48 /// Timestamp query for measuring GPU execution time.
49 /// Requires `TIMESTAMP_QUERY` feature.
50 Timestamp,
51 /// Occlusion query for counting visible fragments.
52 Occlusion,
53}
54
55impl QueryType {
56 /// Convert to wgpu query type.
57 pub fn to_wgpu(self) -> wgpu::QueryType {
58 match self {
59 QueryType::Timestamp => wgpu::QueryType::Timestamp,
60 QueryType::Occlusion => wgpu::QueryType::Occlusion,
61 }
62 }
63}
64
65// =============================================================================
66// QuerySet
67// =============================================================================
68
69/// A wrapper around wgpu::QuerySet with metadata.
70pub struct QuerySet {
71 query_set: wgpu::QuerySet,
72 query_type: QueryType,
73 count: u32,
74}
75
76impl QuerySet {
77 /// Create a new query set.
78 ///
79 /// # Arguments
80 ///
81 /// * `device` - The wgpu device
82 /// * `label` - Optional debug label
83 /// * `query_type` - Type of queries in this set
84 /// * `count` - Number of queries in the set
85 pub fn new(
86 device: &wgpu::Device,
87 label: Option<&str>,
88 query_type: QueryType,
89 count: u32,
90 ) -> Self {
91 let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
92 label,
93 ty: query_type.to_wgpu(),
94 count,
95 });
96
97 Self {
98 query_set,
99 query_type,
100 count,
101 }
102 }
103
104 /// Get the underlying wgpu query set.
105 #[inline]
106 pub fn query_set(&self) -> &wgpu::QuerySet {
107 &self.query_set
108 }
109
110 /// Get the query type.
111 #[inline]
112 pub fn query_type(&self) -> QueryType {
113 self.query_type
114 }
115
116 /// Get the number of queries in the set.
117 #[inline]
118 pub fn count(&self) -> u32 {
119 self.count
120 }
121}
122
123// =============================================================================
124// QueryResultBuffer
125// =============================================================================
126
127/// Buffer for storing and reading query results.
128pub struct QueryResultBuffer {
129 resolve_buffer: wgpu::Buffer,
130 read_buffer: wgpu::Buffer,
131 count: u32,
132}
133
134impl QueryResultBuffer {
135 /// Create a new query result buffer.
136 ///
137 /// # Arguments
138 ///
139 /// * `device` - The wgpu device
140 /// * `label` - Optional debug label
141 /// * `count` - Number of query results to store
142 pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
143 let size = (count as u64) * std::mem::size_of::<u64>() as u64;
144
145 let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
146 label: label.map(|l| format!("{} Resolve", l)).as_deref(),
147 size,
148 usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
149 mapped_at_creation: false,
150 });
151
152 let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
153 label: label.map(|l| format!("{} Read", l)).as_deref(),
154 size,
155 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
156 mapped_at_creation: false,
157 });
158
159 Self {
160 resolve_buffer,
161 read_buffer,
162 count,
163 }
164 }
165
166 /// Get the resolve buffer (used for query resolution).
167 #[inline]
168 pub fn resolve_buffer(&self) -> &wgpu::Buffer {
169 &self.resolve_buffer
170 }
171
172 /// Get the read buffer (used for CPU readback).
173 #[inline]
174 pub fn read_buffer(&self) -> &wgpu::Buffer {
175 &self.read_buffer
176 }
177
178 /// Get the number of results this buffer can hold.
179 #[inline]
180 pub fn count(&self) -> u32 {
181 self.count
182 }
183
184 /// Resolve queries from a query set into this buffer.
185 pub fn resolve(
186 &self,
187 encoder: &mut wgpu::CommandEncoder,
188 query_set: &QuerySet,
189 query_range: std::ops::Range<u32>,
190 destination_offset: u32,
191 ) {
192 encoder.resolve_query_set(
193 query_set.query_set(),
194 query_range,
195 &self.resolve_buffer,
196 (destination_offset as u64) * std::mem::size_of::<u64>() as u64,
197 );
198 }
199
200 /// Copy resolved results to the readable buffer.
201 pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
202 let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
203 encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
204 }
205
206 /// Map the read buffer for CPU access.
207 ///
208 /// Returns a future that completes when the buffer is mapped.
209 pub fn map_async(
210 &self,
211 ) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
212 let slice = self.read_buffer.slice(..);
213 let (tx, rx) = std::sync::mpsc::channel();
214
215 slice.map_async(wgpu::MapMode::Read, move |result| {
216 let _ = tx.send(result);
217 });
218
219 async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
220 }
221
222 /// Read the query results (must be mapped first).
223 ///
224 /// Returns the raw u64 timestamps/occlusion counts.
225 pub fn read_results(&self) -> Vec<u64> {
226 let slice = self.read_buffer.slice(..);
227 let data = slice.get_mapped_range();
228 let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
229 drop(data);
230 self.read_buffer.unmap();
231 results
232 }
233}
234
235// =============================================================================
236// ProfileRegion
237// =============================================================================
238
239/// A handle to a profiling region.
240///
241/// Created by `GpuProfiler::begin_region` and consumed by `GpuProfiler::end_region`.
242#[derive(Debug)]
243pub struct ProfileRegion {
244 label: String,
245 start_query: u32,
246}
247
248// =============================================================================
249// GpuProfiler
250// =============================================================================
251
252/// High-level GPU profiler for measuring execution times.
253///
254/// This profiler uses timestamp queries to measure GPU execution time
255/// for different regions of your rendering code.
256///
257/// # Requirements
258///
259/// - Device must support `TIMESTAMP_QUERY` feature
260/// - Must call `begin_frame()` at the start of each frame
261/// - Must call `resolve()` before submitting commands
262///
263/// # Example
264///
265/// ```ignore
266/// let mut profiler = GpuProfiler::new(context.clone(), 256);
267///
268/// // Each frame:
269/// profiler.begin_frame();
270///
271/// let region = profiler.begin_region(&mut encoder, "My Pass");
272/// // ... do rendering ...
273/// profiler.end_region(&mut encoder, region);
274///
275/// profiler.resolve(&mut encoder);
276///
277/// // Read results (may be from previous frame)
278/// for (label, duration_ms) in profiler.read_results() {
279/// println!("{}: {:.2}ms", label, duration_ms);
280/// }
281/// ```
282impl RenderCapability for GpuProfiler {
283 fn requirements() -> GpuRequirements {
284 GpuRequirements::new().require_features(GpuFeatures::TIMESTAMP_QUERY)
285 }
286
287 fn name() -> &'static str {
288 "GpuProfiler"
289 }
290}
291
292pub struct GpuProfiler {
293 context: Arc<GraphicsContext>,
294 query_set: QuerySet,
295 result_buffer: QueryResultBuffer,
296 /// Current query index for the frame
297 current_query: u32,
298 /// Maximum queries per frame
299 max_queries: u32,
300 /// Regions from the current frame (label, start_query, end_query)
301 regions: Vec<(String, u32, u32)>,
302 /// Cached results from the previous frame
303 cached_results: Vec<(String, f64)>,
304 /// Timestamp period in nanoseconds per tick
305 timestamp_period: f32,
306}
307
308impl GpuProfiler {
309 /// Create a new GPU profiler.
310 ///
311 /// # Arguments
312 ///
313 /// * `context` - Graphics context (must support TIMESTAMP_QUERY)
314 /// * `max_queries` - Maximum number of timestamp queries per frame
315 ///
316 /// # Panics
317 ///
318 /// Panics if the device doesn't support timestamp queries.
319 pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
320 let timestamp_period = context.queue().get_timestamp_period();
321
322 let query_set = QuerySet::new(
323 context.device(),
324 Some("GPU Profiler Queries"),
325 QueryType::Timestamp,
326 max_queries,
327 );
328
329 let result_buffer =
330 QueryResultBuffer::new(context.device(), Some("GPU Profiler Results"), max_queries);
331
332 Self {
333 context,
334 query_set,
335 result_buffer,
336 current_query: 0,
337 max_queries,
338 regions: Vec::new(),
339 cached_results: Vec::new(),
340 timestamp_period,
341 }
342 }
343
344 /// Begin a new frame.
345 ///
346 /// Call this at the start of each frame before recording any regions.
347 pub fn begin_frame(&mut self) {
348 self.current_query = 0;
349 self.regions.clear();
350 }
351
352 /// Begin a profiling region.
353 ///
354 /// # Arguments
355 ///
356 /// * `encoder` - Command encoder to write the timestamp
357 /// * `label` - Human-readable label for this region
358 ///
359 /// # Returns
360 ///
361 /// A `ProfileRegion` handle that must be passed to `end_region`.
362 pub fn begin_region(
363 &mut self,
364 encoder: &mut wgpu::CommandEncoder,
365 label: &str,
366 ) -> Option<ProfileRegion> {
367 if self.current_query >= self.max_queries {
368 return None;
369 }
370
371 let start_query = self.current_query;
372 encoder.write_timestamp(&self.query_set.query_set, start_query);
373 self.current_query += 1;
374
375 Some(ProfileRegion {
376 label: label.to_string(),
377 start_query,
378 })
379 }
380
381 /// End a profiling region.
382 ///
383 /// # Arguments
384 ///
385 /// * `encoder` - Command encoder to write the timestamp
386 /// * `region` - The region handle from `begin_region`
387 pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
388 if self.current_query >= self.max_queries {
389 return;
390 }
391
392 let end_query = self.current_query;
393 encoder.write_timestamp(&self.query_set.query_set, end_query);
394 self.current_query += 1;
395
396 self.regions
397 .push((region.label, region.start_query, end_query));
398 }
399
400 /// Resolve all queries from this frame.
401 ///
402 /// Call this after all regions have been recorded, before submitting commands.
403 pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
404 if self.current_query == 0 {
405 return;
406 }
407
408 self.result_buffer
409 .resolve(encoder, &self.query_set, 0..self.current_query, 0);
410 self.result_buffer.copy_to_readable(encoder);
411 }
412
413 /// Read profiling results synchronously.
414 ///
415 /// This blocks until the results are available from the GPU.
416 /// For non-blocking reads, consider using double-buffering or
417 /// reading results from the previous frame.
418 ///
419 /// # Returns
420 ///
421 /// A vector of (label, duration_ms) pairs for each completed region.
422 pub fn read_results(&mut self) -> &[(String, f64)] {
423 if self.regions.is_empty() {
424 return &self.cached_results;
425 }
426
427 let device = self.context.device();
428
429 // Map the buffer
430 let slice = self.result_buffer.read_buffer().slice(..);
431 let (tx, rx) = std::sync::mpsc::channel();
432
433 slice.map_async(wgpu::MapMode::Read, move |result| {
434 let _ = tx.send(result);
435 });
436
437 // Wait for the buffer to be mapped (blocking)
438 let _ = device.poll(wgpu::PollType::Wait {
439 submission_index: None,
440 timeout: None,
441 });
442
443 // Wait for the callback
444 if rx.recv().is_ok() {
445 let data = slice.get_mapped_range();
446 let timestamps: &[u64] = bytemuck::cast_slice(&data);
447
448 self.cached_results.clear();
449
450 for (label, start, end) in &self.regions {
451 let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
452 let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
453
454 // Convert ticks to milliseconds
455 let duration_ns =
456 (end_ts.saturating_sub(start_ts)) as f64 * self.timestamp_period as f64;
457 let duration_ms = duration_ns / 1_000_000.0;
458
459 self.cached_results.push((label.clone(), duration_ms));
460 }
461
462 drop(data);
463 self.result_buffer.read_buffer().unmap();
464 }
465
466 &self.cached_results
467 }
468
469 /// Try to read profiling results without blocking.
470 ///
471 /// Returns None if the results are not yet available.
472 /// This is useful when you want to display results from the previous frame.
473 ///
474 /// # Returns
475 ///
476 /// Some reference to the cached results if new data was read, or the existing cached results.
477 pub fn try_read_results(&mut self) -> &[(String, f64)] {
478 if self.regions.is_empty() {
479 return &self.cached_results;
480 }
481
482 let device = self.context.device();
483
484 // Try to map the buffer
485 let slice = self.result_buffer.read_buffer().slice(..);
486 let (tx, rx) = std::sync::mpsc::channel();
487
488 slice.map_async(wgpu::MapMode::Read, move |result| {
489 let _ = tx.send(result);
490 });
491
492 // Non-blocking poll
493 let _ = device.poll(wgpu::PollType::Poll);
494
495 // Check if mapping succeeded
496 if let Ok(Ok(())) = rx.try_recv() {
497 let data = slice.get_mapped_range();
498 let timestamps: &[u64] = bytemuck::cast_slice(&data);
499
500 self.cached_results.clear();
501
502 for (label, start, end) in &self.regions {
503 let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
504 let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
505
506 // Convert ticks to milliseconds
507 let duration_ns =
508 (end_ts.saturating_sub(start_ts)) as f64 * self.timestamp_period as f64;
509 let duration_ms = duration_ns / 1_000_000.0;
510
511 self.cached_results.push((label.clone(), duration_ms));
512 }
513
514 drop(data);
515 self.result_buffer.read_buffer().unmap();
516 }
517
518 &self.cached_results
519 }
520
521 /// Get the number of queries used this frame.
522 #[inline]
523 pub fn queries_used(&self) -> u32 {
524 self.current_query
525 }
526
527 /// Get the maximum queries per frame.
528 #[inline]
529 pub fn max_queries(&self) -> u32 {
530 self.max_queries
531 }
532
533 /// Get the timestamp period in nanoseconds per tick.
534 #[inline]
535 pub fn timestamp_period(&self) -> f32 {
536 self.timestamp_period
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 #[test]
545 fn test_query_type_conversion() {
546 // Just verify conversion doesn't panic
547 let _ = QueryType::Timestamp.to_wgpu();
548 let _ = QueryType::Occlusion.to_wgpu();
549 }
550
551 #[test]
552 fn test_profile_region_debug() {
553 let region = ProfileRegion {
554 label: "Test".to_string(),
555 start_query: 0,
556 };
557 // Just ensure Debug is implemented
558 let _ = format!("{:?}", region);
559 }
560}