Skip to main content

astrelis_render/
query.rs

1//! GPU query and profiling support.
2//!
3//! This module provides wrappers for GPU queries (timestamps, occlusion) and
4//! a high-level profiler for measuring GPU execution times.
5//!
6//! # Features Required
7//!
8//! - `TIMESTAMP_QUERY` - Required for timestamp queries and GPU profiling
9//!
10//! # Example
11//!
12//! ```ignore
13//! use astrelis_render::{GpuProfiler, GraphicsContext};
14//!
15//! // Create profiler (requires TIMESTAMP_QUERY feature)
16//! let mut profiler = GpuProfiler::new(context.clone(), 256);
17//!
18//! // In render loop:
19//! profiler.begin_frame();
20//!
21//! {
22//!     let region = profiler.begin_region(&mut encoder, "Shadow Pass");
23//!     // ... render shadow pass ...
24//!     profiler.end_region(&mut encoder, region);
25//! }
26//!
27//! profiler.resolve(&mut encoder);
28//!
29//! // Later, read results
30//! for (label, duration_ms) in profiler.read_results() {
31//!     println!("{}: {:.2}ms", label, duration_ms);
32//! }
33//! ```
34
35use std::sync::Arc;
36
37use crate::capability::{GpuRequirements, RenderCapability};
38use crate::context::GraphicsContext;
39use crate::features::GpuFeatures;
40
41// =============================================================================
42// Query Types
43// =============================================================================
44
45/// Types of GPU queries.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum QueryType {
48    /// Timestamp query for measuring GPU execution time.
49    /// Requires `TIMESTAMP_QUERY` feature.
50    Timestamp,
51    /// Occlusion query for counting visible fragments.
52    Occlusion,
53}
54
55impl QueryType {
56    /// Convert to wgpu query type.
57    pub fn to_wgpu(self) -> wgpu::QueryType {
58        match self {
59            QueryType::Timestamp => wgpu::QueryType::Timestamp,
60            QueryType::Occlusion => wgpu::QueryType::Occlusion,
61        }
62    }
63}
64
65// =============================================================================
66// QuerySet
67// =============================================================================
68
69/// A wrapper around wgpu::QuerySet with metadata.
70pub struct QuerySet {
71    query_set: wgpu::QuerySet,
72    query_type: QueryType,
73    count: u32,
74}
75
76impl QuerySet {
77    /// Create a new query set.
78    ///
79    /// # Arguments
80    ///
81    /// * `device` - The wgpu device
82    /// * `label` - Optional debug label
83    /// * `query_type` - Type of queries in this set
84    /// * `count` - Number of queries in the set
85    pub fn new(
86        device: &wgpu::Device,
87        label: Option<&str>,
88        query_type: QueryType,
89        count: u32,
90    ) -> Self {
91        let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
92            label,
93            ty: query_type.to_wgpu(),
94            count,
95        });
96
97        Self {
98            query_set,
99            query_type,
100            count,
101        }
102    }
103
104    /// Get the underlying wgpu query set.
105    #[inline]
106    pub fn query_set(&self) -> &wgpu::QuerySet {
107        &self.query_set
108    }
109
110    /// Get the query type.
111    #[inline]
112    pub fn query_type(&self) -> QueryType {
113        self.query_type
114    }
115
116    /// Get the number of queries in the set.
117    #[inline]
118    pub fn count(&self) -> u32 {
119        self.count
120    }
121}
122
123// =============================================================================
124// QueryResultBuffer
125// =============================================================================
126
127/// Buffer for storing and reading query results.
128pub struct QueryResultBuffer {
129    resolve_buffer: wgpu::Buffer,
130    read_buffer: wgpu::Buffer,
131    count: u32,
132}
133
134impl QueryResultBuffer {
135    /// Create a new query result buffer.
136    ///
137    /// # Arguments
138    ///
139    /// * `device` - The wgpu device
140    /// * `label` - Optional debug label
141    /// * `count` - Number of query results to store
142    pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
143        let size = (count as u64) * std::mem::size_of::<u64>() as u64;
144
145        let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
146            label: label.map(|l| format!("{} Resolve", l)).as_deref(),
147            size,
148            usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
149            mapped_at_creation: false,
150        });
151
152        let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
153            label: label.map(|l| format!("{} Read", l)).as_deref(),
154            size,
155            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
156            mapped_at_creation: false,
157        });
158
159        Self {
160            resolve_buffer,
161            read_buffer,
162            count,
163        }
164    }
165
166    /// Get the resolve buffer (used for query resolution).
167    #[inline]
168    pub fn resolve_buffer(&self) -> &wgpu::Buffer {
169        &self.resolve_buffer
170    }
171
172    /// Get the read buffer (used for CPU readback).
173    #[inline]
174    pub fn read_buffer(&self) -> &wgpu::Buffer {
175        &self.read_buffer
176    }
177
178    /// Get the number of results this buffer can hold.
179    #[inline]
180    pub fn count(&self) -> u32 {
181        self.count
182    }
183
184    /// Resolve queries from a query set into this buffer.
185    pub fn resolve(
186        &self,
187        encoder: &mut wgpu::CommandEncoder,
188        query_set: &QuerySet,
189        query_range: std::ops::Range<u32>,
190        destination_offset: u32,
191    ) {
192        encoder.resolve_query_set(
193            query_set.query_set(),
194            query_range,
195            &self.resolve_buffer,
196            (destination_offset as u64) * std::mem::size_of::<u64>() as u64,
197        );
198    }
199
200    /// Copy resolved results to the readable buffer.
201    pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
202        let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
203        encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
204    }
205
206    /// Map the read buffer for CPU access.
207    ///
208    /// Returns a future that completes when the buffer is mapped.
209    pub fn map_async(
210        &self,
211    ) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
212        let slice = self.read_buffer.slice(..);
213        let (tx, rx) = std::sync::mpsc::channel();
214
215        slice.map_async(wgpu::MapMode::Read, move |result| {
216            let _ = tx.send(result);
217        });
218
219        async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
220    }
221
222    /// Read the query results (must be mapped first).
223    ///
224    /// Returns the raw u64 timestamps/occlusion counts.
225    pub fn read_results(&self) -> Vec<u64> {
226        let slice = self.read_buffer.slice(..);
227        let data = slice.get_mapped_range();
228        let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
229        drop(data);
230        self.read_buffer.unmap();
231        results
232    }
233}
234
235// =============================================================================
236// ProfileRegion
237// =============================================================================
238
239/// A handle to a profiling region.
240///
241/// Created by `GpuProfiler::begin_region` and consumed by `GpuProfiler::end_region`.
242#[derive(Debug)]
243pub struct ProfileRegion {
244    label: String,
245    start_query: u32,
246}
247
248// =============================================================================
249// GpuProfiler
250// =============================================================================
251
252/// High-level GPU profiler for measuring execution times.
253///
254/// This profiler uses timestamp queries to measure GPU execution time
255/// for different regions of your rendering code.
256///
257/// # Requirements
258///
259/// - Device must support `TIMESTAMP_QUERY` feature
260/// - Must call `begin_frame()` at the start of each frame
261/// - Must call `resolve()` before submitting commands
262///
263/// # Example
264///
265/// ```ignore
266/// let mut profiler = GpuProfiler::new(context.clone(), 256);
267///
268/// // Each frame:
269/// profiler.begin_frame();
270///
271/// let region = profiler.begin_region(&mut encoder, "My Pass");
272/// // ... do rendering ...
273/// profiler.end_region(&mut encoder, region);
274///
275/// profiler.resolve(&mut encoder);
276///
277/// // Read results (may be from previous frame)
278/// for (label, duration_ms) in profiler.read_results() {
279///     println!("{}: {:.2}ms", label, duration_ms);
280/// }
281/// ```
282impl RenderCapability for GpuProfiler {
283    fn requirements() -> GpuRequirements {
284        GpuRequirements::new().require_features(GpuFeatures::TIMESTAMP_QUERY)
285    }
286
287    fn name() -> &'static str {
288        "GpuProfiler"
289    }
290}
291
292pub struct GpuProfiler {
293    context: Arc<GraphicsContext>,
294    query_set: QuerySet,
295    result_buffer: QueryResultBuffer,
296    /// Current query index for the frame
297    current_query: u32,
298    /// Maximum queries per frame
299    max_queries: u32,
300    /// Regions from the current frame (label, start_query, end_query)
301    regions: Vec<(String, u32, u32)>,
302    /// Cached results from the previous frame
303    cached_results: Vec<(String, f64)>,
304    /// Timestamp period in nanoseconds per tick
305    timestamp_period: f32,
306}
307
308impl GpuProfiler {
309    /// Create a new GPU profiler.
310    ///
311    /// # Arguments
312    ///
313    /// * `context` - Graphics context (must support TIMESTAMP_QUERY)
314    /// * `max_queries` - Maximum number of timestamp queries per frame
315    ///
316    /// # Panics
317    ///
318    /// Panics if the device doesn't support timestamp queries.
319    pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
320        let timestamp_period = context.queue().get_timestamp_period();
321
322        let query_set = QuerySet::new(
323            context.device(),
324            Some("GPU Profiler Queries"),
325            QueryType::Timestamp,
326            max_queries,
327        );
328
329        let result_buffer =
330            QueryResultBuffer::new(context.device(), Some("GPU Profiler Results"), max_queries);
331
332        Self {
333            context,
334            query_set,
335            result_buffer,
336            current_query: 0,
337            max_queries,
338            regions: Vec::new(),
339            cached_results: Vec::new(),
340            timestamp_period,
341        }
342    }
343
344    /// Begin a new frame.
345    ///
346    /// Call this at the start of each frame before recording any regions.
347    pub fn begin_frame(&mut self) {
348        self.current_query = 0;
349        self.regions.clear();
350    }
351
352    /// Begin a profiling region.
353    ///
354    /// # Arguments
355    ///
356    /// * `encoder` - Command encoder to write the timestamp
357    /// * `label` - Human-readable label for this region
358    ///
359    /// # Returns
360    ///
361    /// A `ProfileRegion` handle that must be passed to `end_region`.
362    pub fn begin_region(
363        &mut self,
364        encoder: &mut wgpu::CommandEncoder,
365        label: &str,
366    ) -> Option<ProfileRegion> {
367        if self.current_query >= self.max_queries {
368            return None;
369        }
370
371        let start_query = self.current_query;
372        encoder.write_timestamp(&self.query_set.query_set, start_query);
373        self.current_query += 1;
374
375        Some(ProfileRegion {
376            label: label.to_string(),
377            start_query,
378        })
379    }
380
381    /// End a profiling region.
382    ///
383    /// # Arguments
384    ///
385    /// * `encoder` - Command encoder to write the timestamp
386    /// * `region` - The region handle from `begin_region`
387    pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
388        if self.current_query >= self.max_queries {
389            return;
390        }
391
392        let end_query = self.current_query;
393        encoder.write_timestamp(&self.query_set.query_set, end_query);
394        self.current_query += 1;
395
396        self.regions
397            .push((region.label, region.start_query, end_query));
398    }
399
400    /// Resolve all queries from this frame.
401    ///
402    /// Call this after all regions have been recorded, before submitting commands.
403    pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
404        if self.current_query == 0 {
405            return;
406        }
407
408        self.result_buffer
409            .resolve(encoder, &self.query_set, 0..self.current_query, 0);
410        self.result_buffer.copy_to_readable(encoder);
411    }
412
413    /// Read profiling results synchronously.
414    ///
415    /// This blocks until the results are available from the GPU.
416    /// For non-blocking reads, consider using double-buffering or
417    /// reading results from the previous frame.
418    ///
419    /// # Returns
420    ///
421    /// A vector of (label, duration_ms) pairs for each completed region.
422    pub fn read_results(&mut self) -> &[(String, f64)] {
423        if self.regions.is_empty() {
424            return &self.cached_results;
425        }
426
427        let device = self.context.device();
428
429        // Map the buffer
430        let slice = self.result_buffer.read_buffer().slice(..);
431        let (tx, rx) = std::sync::mpsc::channel();
432
433        slice.map_async(wgpu::MapMode::Read, move |result| {
434            let _ = tx.send(result);
435        });
436
437        // Wait for the buffer to be mapped (blocking)
438        let _ = device.poll(wgpu::PollType::Wait {
439            submission_index: None,
440            timeout: None,
441        });
442
443        // Wait for the callback
444        if rx.recv().is_ok() {
445            let data = slice.get_mapped_range();
446            let timestamps: &[u64] = bytemuck::cast_slice(&data);
447
448            self.cached_results.clear();
449
450            for (label, start, end) in &self.regions {
451                let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
452                let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
453
454                // Convert ticks to milliseconds
455                let duration_ns =
456                    (end_ts.saturating_sub(start_ts)) as f64 * self.timestamp_period as f64;
457                let duration_ms = duration_ns / 1_000_000.0;
458
459                self.cached_results.push((label.clone(), duration_ms));
460            }
461
462            drop(data);
463            self.result_buffer.read_buffer().unmap();
464        }
465
466        &self.cached_results
467    }
468
469    /// Try to read profiling results without blocking.
470    ///
471    /// Returns None if the results are not yet available.
472    /// This is useful when you want to display results from the previous frame.
473    ///
474    /// # Returns
475    ///
476    /// Some reference to the cached results if new data was read, or the existing cached results.
477    pub fn try_read_results(&mut self) -> &[(String, f64)] {
478        if self.regions.is_empty() {
479            return &self.cached_results;
480        }
481
482        let device = self.context.device();
483
484        // Try to map the buffer
485        let slice = self.result_buffer.read_buffer().slice(..);
486        let (tx, rx) = std::sync::mpsc::channel();
487
488        slice.map_async(wgpu::MapMode::Read, move |result| {
489            let _ = tx.send(result);
490        });
491
492        // Non-blocking poll
493        let _ = device.poll(wgpu::PollType::Poll);
494
495        // Check if mapping succeeded
496        if let Ok(Ok(())) = rx.try_recv() {
497            let data = slice.get_mapped_range();
498            let timestamps: &[u64] = bytemuck::cast_slice(&data);
499
500            self.cached_results.clear();
501
502            for (label, start, end) in &self.regions {
503                let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
504                let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
505
506                // Convert ticks to milliseconds
507                let duration_ns =
508                    (end_ts.saturating_sub(start_ts)) as f64 * self.timestamp_period as f64;
509                let duration_ms = duration_ns / 1_000_000.0;
510
511                self.cached_results.push((label.clone(), duration_ms));
512            }
513
514            drop(data);
515            self.result_buffer.read_buffer().unmap();
516        }
517
518        &self.cached_results
519    }
520
521    /// Get the number of queries used this frame.
522    #[inline]
523    pub fn queries_used(&self) -> u32 {
524        self.current_query
525    }
526
527    /// Get the maximum queries per frame.
528    #[inline]
529    pub fn max_queries(&self) -> u32 {
530        self.max_queries
531    }
532
533    /// Get the timestamp period in nanoseconds per tick.
534    #[inline]
535    pub fn timestamp_period(&self) -> f32 {
536        self.timestamp_period
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn test_query_type_conversion() {
546        // Just verify conversion doesn't panic
547        let _ = QueryType::Timestamp.to_wgpu();
548        let _ = QueryType::Occlusion.to_wgpu();
549    }
550
551    #[test]
552    fn test_profile_region_debug() {
553        let region = ProfileRegion {
554            label: "Test".to_string(),
555            start_query: 0,
556        };
557        // Just ensure Debug is implemented
558        let _ = format!("{:?}", region);
559    }
560}