astrelis_render/
query.rs

1//! GPU query and profiling support.
2//!
3//! This module provides wrappers for GPU queries (timestamps, occlusion) and
4//! a high-level profiler for measuring GPU execution times.
5//!
6//! # Features Required
7//!
8//! - `TIMESTAMP_QUERY` - Required for timestamp queries and GPU profiling
9//!
10//! # Example
11//!
12//! ```ignore
13//! use astrelis_render::{GpuProfiler, GraphicsContext, GraphicsContextExt};
14//!
15//! // Create profiler (requires TIMESTAMP_QUERY feature)
16//! let mut profiler = GpuProfiler::new(context.clone(), 256);
17//!
18//! // In render loop:
19//! profiler.begin_frame();
20//!
21//! {
22//!     let region = profiler.begin_region(&mut encoder, "Shadow Pass");
23//!     // ... render shadow pass ...
24//!     profiler.end_region(&mut encoder, region);
25//! }
26//!
27//! profiler.resolve(&mut encoder);
28//!
29//! // Later, read results
30//! for (label, duration_ms) in profiler.read_results() {
31//!     println!("{}: {:.2}ms", label, duration_ms);
32//! }
33//! ```
34
35use std::sync::Arc;
36
37use crate::context::GraphicsContext;
38use crate::extension::GraphicsContextExt;
39
40// =============================================================================
41// Query Types
42// =============================================================================
43
44/// Types of GPU queries.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum QueryType {
47    /// Timestamp query for measuring GPU execution time.
48    /// Requires `TIMESTAMP_QUERY` feature.
49    Timestamp,
50    /// Occlusion query for counting visible fragments.
51    Occlusion,
52}
53
54impl QueryType {
55    /// Convert to wgpu query type.
56    pub fn to_wgpu(self) -> wgpu::QueryType {
57        match self {
58            QueryType::Timestamp => wgpu::QueryType::Timestamp,
59            QueryType::Occlusion => wgpu::QueryType::Occlusion,
60        }
61    }
62}
63
64// =============================================================================
65// QuerySet
66// =============================================================================
67
68/// A wrapper around wgpu::QuerySet with metadata.
69pub struct QuerySet {
70    query_set: wgpu::QuerySet,
71    query_type: QueryType,
72    count: u32,
73}
74
75impl QuerySet {
76    /// Create a new query set.
77    ///
78    /// # Arguments
79    ///
80    /// * `device` - The wgpu device
81    /// * `label` - Optional debug label
82    /// * `query_type` - Type of queries in this set
83    /// * `count` - Number of queries in the set
84    pub fn new(
85        device: &wgpu::Device,
86        label: Option<&str>,
87        query_type: QueryType,
88        count: u32,
89    ) -> Self {
90        let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
91            label,
92            ty: query_type.to_wgpu(),
93            count,
94        });
95
96        Self {
97            query_set,
98            query_type,
99            count,
100        }
101    }
102
103    /// Get the underlying wgpu query set.
104    #[inline]
105    pub fn query_set(&self) -> &wgpu::QuerySet {
106        &self.query_set
107    }
108
109    /// Get the query type.
110    #[inline]
111    pub fn query_type(&self) -> QueryType {
112        self.query_type
113    }
114
115    /// Get the number of queries in the set.
116    #[inline]
117    pub fn count(&self) -> u32 {
118        self.count
119    }
120}
121
122// =============================================================================
123// QueryResultBuffer
124// =============================================================================
125
126/// Buffer for storing and reading query results.
127pub struct QueryResultBuffer {
128    resolve_buffer: wgpu::Buffer,
129    read_buffer: wgpu::Buffer,
130    count: u32,
131}
132
133impl QueryResultBuffer {
134    /// Create a new query result buffer.
135    ///
136    /// # Arguments
137    ///
138    /// * `device` - The wgpu device
139    /// * `label` - Optional debug label
140    /// * `count` - Number of query results to store
141    pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
142        let size = (count as u64) * std::mem::size_of::<u64>() as u64;
143
144        let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
145            label: label.map(|l| format!("{} Resolve", l)).as_deref(),
146            size,
147            usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
148            mapped_at_creation: false,
149        });
150
151        let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
152            label: label.map(|l| format!("{} Read", l)).as_deref(),
153            size,
154            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
155            mapped_at_creation: false,
156        });
157
158        Self {
159            resolve_buffer,
160            read_buffer,
161            count,
162        }
163    }
164
165    /// Get the resolve buffer (used for query resolution).
166    #[inline]
167    pub fn resolve_buffer(&self) -> &wgpu::Buffer {
168        &self.resolve_buffer
169    }
170
171    /// Get the read buffer (used for CPU readback).
172    #[inline]
173    pub fn read_buffer(&self) -> &wgpu::Buffer {
174        &self.read_buffer
175    }
176
177    /// Get the number of results this buffer can hold.
178    #[inline]
179    pub fn count(&self) -> u32 {
180        self.count
181    }
182
183    /// Resolve queries from a query set into this buffer.
184    pub fn resolve(
185        &self,
186        encoder: &mut wgpu::CommandEncoder,
187        query_set: &QuerySet,
188        query_range: std::ops::Range<u32>,
189        destination_offset: u32,
190    ) {
191        encoder.resolve_query_set(
192            query_set.query_set(),
193            query_range,
194            &self.resolve_buffer,
195            (destination_offset as u64) * std::mem::size_of::<u64>() as u64,
196        );
197    }
198
199    /// Copy resolved results to the readable buffer.
200    pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
201        let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
202        encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
203    }
204
205    /// Map the read buffer for CPU access.
206    ///
207    /// Returns a future that completes when the buffer is mapped.
208    pub fn map_async(&self) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
209        let slice = self.read_buffer.slice(..);
210        let (tx, rx) = std::sync::mpsc::channel();
211
212        slice.map_async(wgpu::MapMode::Read, move |result| {
213            let _ = tx.send(result);
214        });
215
216        async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
217    }
218
219    /// Read the query results (must be mapped first).
220    ///
221    /// Returns the raw u64 timestamps/occlusion counts.
222    pub fn read_results(&self) -> Vec<u64> {
223        let slice = self.read_buffer.slice(..);
224        let data = slice.get_mapped_range();
225        let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
226        drop(data);
227        self.read_buffer.unmap();
228        results
229    }
230}
231
232// =============================================================================
233// ProfileRegion
234// =============================================================================
235
236/// A handle to a profiling region.
237///
238/// Created by `GpuProfiler::begin_region` and consumed by `GpuProfiler::end_region`.
239#[derive(Debug)]
240pub struct ProfileRegion {
241    label: String,
242    start_query: u32,
243}
244
245// =============================================================================
246// GpuProfiler
247// =============================================================================
248
249/// High-level GPU profiler for measuring execution times.
250///
251/// This profiler uses timestamp queries to measure GPU execution time
252/// for different regions of your rendering code.
253///
254/// # Requirements
255///
256/// - Device must support `TIMESTAMP_QUERY` feature
257/// - Must call `begin_frame()` at the start of each frame
258/// - Must call `resolve()` before submitting commands
259///
260/// # Example
261///
262/// ```ignore
263/// let mut profiler = GpuProfiler::new(context.clone(), 256);
264///
265/// // Each frame:
266/// profiler.begin_frame();
267///
268/// let region = profiler.begin_region(&mut encoder, "My Pass");
269/// // ... do rendering ...
270/// profiler.end_region(&mut encoder, region);
271///
272/// profiler.resolve(&mut encoder);
273///
274/// // Read results (may be from previous frame)
275/// for (label, duration_ms) in profiler.read_results() {
276///     println!("{}: {:.2}ms", label, duration_ms);
277/// }
278/// ```
279pub struct GpuProfiler {
280    context: Arc<GraphicsContext>,
281    query_set: QuerySet,
282    result_buffer: QueryResultBuffer,
283    /// Current query index for the frame
284    current_query: u32,
285    /// Maximum queries per frame
286    max_queries: u32,
287    /// Regions from the current frame (label, start_query, end_query)
288    regions: Vec<(String, u32, u32)>,
289    /// Cached results from the previous frame
290    cached_results: Vec<(String, f64)>,
291    /// Timestamp period in nanoseconds per tick
292    timestamp_period: f32,
293}
294
295impl GpuProfiler {
296    /// Create a new GPU profiler.
297    ///
298    /// # Arguments
299    ///
300    /// * `context` - Graphics context (must support TIMESTAMP_QUERY)
301    /// * `max_queries` - Maximum number of timestamp queries per frame
302    ///
303    /// # Panics
304    ///
305    /// Panics if the device doesn't support timestamp queries.
306    pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
307        let timestamp_period = context.queue().get_timestamp_period();
308
309        let query_set = QuerySet::new(
310            context.device(),
311            Some("GPU Profiler Queries"),
312            QueryType::Timestamp,
313            max_queries,
314        );
315
316        let result_buffer = QueryResultBuffer::new(
317            context.device(),
318            Some("GPU Profiler Results"),
319            max_queries,
320        );
321
322        Self {
323            context,
324            query_set,
325            result_buffer,
326            current_query: 0,
327            max_queries,
328            regions: Vec::new(),
329            cached_results: Vec::new(),
330            timestamp_period,
331        }
332    }
333
334    /// Begin a new frame.
335    ///
336    /// Call this at the start of each frame before recording any regions.
337    pub fn begin_frame(&mut self) {
338        self.current_query = 0;
339        self.regions.clear();
340    }
341
342    /// Begin a profiling region.
343    ///
344    /// # Arguments
345    ///
346    /// * `encoder` - Command encoder to write the timestamp
347    /// * `label` - Human-readable label for this region
348    ///
349    /// # Returns
350    ///
351    /// A `ProfileRegion` handle that must be passed to `end_region`.
352    pub fn begin_region(
353        &mut self,
354        encoder: &mut wgpu::CommandEncoder,
355        label: &str,
356    ) -> Option<ProfileRegion> {
357        if self.current_query >= self.max_queries {
358            return None;
359        }
360
361        let start_query = self.current_query;
362        encoder.write_timestamp(&self.query_set.query_set, start_query);
363        self.current_query += 1;
364
365        Some(ProfileRegion {
366            label: label.to_string(),
367            start_query,
368        })
369    }
370
371    /// End a profiling region.
372    ///
373    /// # Arguments
374    ///
375    /// * `encoder` - Command encoder to write the timestamp
376    /// * `region` - The region handle from `begin_region`
377    pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
378        if self.current_query >= self.max_queries {
379            return;
380        }
381
382        let end_query = self.current_query;
383        encoder.write_timestamp(&self.query_set.query_set, end_query);
384        self.current_query += 1;
385
386        self.regions
387            .push((region.label, region.start_query, end_query));
388    }
389
390    /// Resolve all queries from this frame.
391    ///
392    /// Call this after all regions have been recorded, before submitting commands.
393    pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
394        if self.current_query == 0 {
395            return;
396        }
397
398        self.result_buffer.resolve(
399            encoder,
400            &self.query_set,
401            0..self.current_query,
402            0,
403        );
404        self.result_buffer.copy_to_readable(encoder);
405    }
406
407    /// Read profiling results synchronously.
408    ///
409    /// This blocks until the results are available from the GPU.
410    /// For non-blocking reads, consider using double-buffering or
411    /// reading results from the previous frame.
412    ///
413    /// # Returns
414    ///
415    /// A vector of (label, duration_ms) pairs for each completed region.
416    pub fn read_results(&mut self) -> &[(String, f64)] {
417        if self.regions.is_empty() {
418            return &self.cached_results;
419        }
420
421        let device = self.context.device();
422
423        // Map the buffer
424        let slice = self.result_buffer.read_buffer().slice(..);
425        let (tx, rx) = std::sync::mpsc::channel();
426
427        slice.map_async(wgpu::MapMode::Read, move |result| {
428            let _ = tx.send(result);
429        });
430
431        // Wait for the buffer to be mapped (blocking)
432        let _ = device.poll(wgpu::PollType::Wait {
433            submission_index: None,
434            timeout: None,
435        });
436
437        // Wait for the callback
438        if rx.recv().is_ok() {
439            let data = slice.get_mapped_range();
440            let timestamps: &[u64] = bytemuck::cast_slice(&data);
441
442            self.cached_results.clear();
443
444            for (label, start, end) in &self.regions {
445                let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
446                let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
447
448                // Convert ticks to milliseconds
449                let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
450                    * self.timestamp_period as f64;
451                let duration_ms = duration_ns / 1_000_000.0;
452
453                self.cached_results.push((label.clone(), duration_ms));
454            }
455
456            drop(data);
457            self.result_buffer.read_buffer().unmap();
458        }
459
460        &self.cached_results
461    }
462
463    /// Try to read profiling results without blocking.
464    ///
465    /// Returns None if the results are not yet available.
466    /// This is useful when you want to display results from the previous frame.
467    ///
468    /// # Returns
469    ///
470    /// Some reference to the cached results if new data was read, or the existing cached results.
471    pub fn try_read_results(&mut self) -> &[(String, f64)] {
472        if self.regions.is_empty() {
473            return &self.cached_results;
474        }
475
476        let device = self.context.device();
477
478        // Try to map the buffer
479        let slice = self.result_buffer.read_buffer().slice(..);
480        let (tx, rx) = std::sync::mpsc::channel();
481
482        slice.map_async(wgpu::MapMode::Read, move |result| {
483            let _ = tx.send(result);
484        });
485
486        // Non-blocking poll
487        let _ = device.poll(wgpu::PollType::Poll);
488
489        // Check if mapping succeeded
490        if let Ok(Ok(())) = rx.try_recv() {
491            let data = slice.get_mapped_range();
492            let timestamps: &[u64] = bytemuck::cast_slice(&data);
493
494            self.cached_results.clear();
495
496            for (label, start, end) in &self.regions {
497                let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
498                let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
499
500                // Convert ticks to milliseconds
501                let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
502                    * self.timestamp_period as f64;
503                let duration_ms = duration_ns / 1_000_000.0;
504
505                self.cached_results.push((label.clone(), duration_ms));
506            }
507
508            drop(data);
509            self.result_buffer.read_buffer().unmap();
510        }
511
512        &self.cached_results
513    }
514
515    /// Get the number of queries used this frame.
516    #[inline]
517    pub fn queries_used(&self) -> u32 {
518        self.current_query
519    }
520
521    /// Get the maximum queries per frame.
522    #[inline]
523    pub fn max_queries(&self) -> u32 {
524        self.max_queries
525    }
526
527    /// Get the timestamp period in nanoseconds per tick.
528    #[inline]
529    pub fn timestamp_period(&self) -> f32 {
530        self.timestamp_period
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn test_query_type_conversion() {
540        // Just verify conversion doesn't panic
541        let _ = QueryType::Timestamp.to_wgpu();
542        let _ = QueryType::Occlusion.to_wgpu();
543    }
544
545    #[test]
546    fn test_profile_region_debug() {
547        let region = ProfileRegion {
548            label: "Test".to_string(),
549            start_query: 0,
550        };
551        // Just ensure Debug is implemented
552        let _ = format!("{:?}", region);
553    }
554}