Skip to main content

astrelis_render/
query.rs

1//! GPU query and profiling support.
2//!
3//! This module provides wrappers for GPU queries (timestamps, occlusion) and
4//! a high-level profiler for measuring GPU execution times.
5//!
6//! # Features Required
7//!
8//! - `TIMESTAMP_QUERY` - Required for timestamp queries and GPU profiling
9//!
10//! # Example
11//!
12//! ```ignore
13//! use astrelis_render::{GpuProfiler, GraphicsContext};
14//!
15//! // Create profiler (requires TIMESTAMP_QUERY feature)
16//! let mut profiler = GpuProfiler::new(context.clone(), 256);
17//!
18//! // In render loop:
19//! profiler.begin_frame();
20//!
21//! {
22//!     let region = profiler.begin_region(&mut encoder, "Shadow Pass");
23//!     // ... render shadow pass ...
24//!     profiler.end_region(&mut encoder, region);
25//! }
26//!
27//! profiler.resolve(&mut encoder);
28//!
29//! // Later, read results
30//! for (label, duration_ms) in profiler.read_results() {
31//!     println!("{}: {:.2}ms", label, duration_ms);
32//! }
33//! ```
34
35use std::sync::Arc;
36
37use crate::capability::{GpuRequirements, RenderCapability};
38use crate::context::GraphicsContext;
39use crate::features::GpuFeatures;
40
41// =============================================================================
42// Query Types
43// =============================================================================
44
45/// Types of GPU queries.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum QueryType {
48    /// Timestamp query for measuring GPU execution time.
49    /// Requires `TIMESTAMP_QUERY` feature.
50    Timestamp,
51    /// Occlusion query for counting visible fragments.
52    Occlusion,
53}
54
55impl QueryType {
56    /// Convert to wgpu query type.
57    pub fn to_wgpu(self) -> wgpu::QueryType {
58        match self {
59            QueryType::Timestamp => wgpu::QueryType::Timestamp,
60            QueryType::Occlusion => wgpu::QueryType::Occlusion,
61        }
62    }
63}
64
65// =============================================================================
66// QuerySet
67// =============================================================================
68
69/// A wrapper around wgpu::QuerySet with metadata.
70pub struct QuerySet {
71    query_set: wgpu::QuerySet,
72    query_type: QueryType,
73    count: u32,
74}
75
76impl QuerySet {
77    /// Create a new query set.
78    ///
79    /// # Arguments
80    ///
81    /// * `device` - The wgpu device
82    /// * `label` - Optional debug label
83    /// * `query_type` - Type of queries in this set
84    /// * `count` - Number of queries in the set
85    pub fn new(
86        device: &wgpu::Device,
87        label: Option<&str>,
88        query_type: QueryType,
89        count: u32,
90    ) -> Self {
91        let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
92            label,
93            ty: query_type.to_wgpu(),
94            count,
95        });
96
97        Self {
98            query_set,
99            query_type,
100            count,
101        }
102    }
103
104    /// Get the underlying wgpu query set.
105    #[inline]
106    pub fn query_set(&self) -> &wgpu::QuerySet {
107        &self.query_set
108    }
109
110    /// Get the query type.
111    #[inline]
112    pub fn query_type(&self) -> QueryType {
113        self.query_type
114    }
115
116    /// Get the number of queries in the set.
117    #[inline]
118    pub fn count(&self) -> u32 {
119        self.count
120    }
121}
122
123// =============================================================================
124// QueryResultBuffer
125// =============================================================================
126
127/// Buffer for storing and reading query results.
128pub struct QueryResultBuffer {
129    resolve_buffer: wgpu::Buffer,
130    read_buffer: wgpu::Buffer,
131    count: u32,
132}
133
134impl QueryResultBuffer {
135    /// Create a new query result buffer.
136    ///
137    /// # Arguments
138    ///
139    /// * `device` - The wgpu device
140    /// * `label` - Optional debug label
141    /// * `count` - Number of query results to store
142    pub fn new(device: &wgpu::Device, label: Option<&str>, count: u32) -> Self {
143        let size = (count as u64) * std::mem::size_of::<u64>() as u64;
144
145        let resolve_buffer = device.create_buffer(&wgpu::BufferDescriptor {
146            label: label.map(|l| format!("{} Resolve", l)).as_deref(),
147            size,
148            usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
149            mapped_at_creation: false,
150        });
151
152        let read_buffer = device.create_buffer(&wgpu::BufferDescriptor {
153            label: label.map(|l| format!("{} Read", l)).as_deref(),
154            size,
155            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
156            mapped_at_creation: false,
157        });
158
159        Self {
160            resolve_buffer,
161            read_buffer,
162            count,
163        }
164    }
165
166    /// Get the resolve buffer (used for query resolution).
167    #[inline]
168    pub fn resolve_buffer(&self) -> &wgpu::Buffer {
169        &self.resolve_buffer
170    }
171
172    /// Get the read buffer (used for CPU readback).
173    #[inline]
174    pub fn read_buffer(&self) -> &wgpu::Buffer {
175        &self.read_buffer
176    }
177
178    /// Get the number of results this buffer can hold.
179    #[inline]
180    pub fn count(&self) -> u32 {
181        self.count
182    }
183
184    /// Resolve queries from a query set into this buffer.
185    pub fn resolve(
186        &self,
187        encoder: &mut wgpu::CommandEncoder,
188        query_set: &QuerySet,
189        query_range: std::ops::Range<u32>,
190        destination_offset: u32,
191    ) {
192        encoder.resolve_query_set(
193            query_set.query_set(),
194            query_range,
195            &self.resolve_buffer,
196            (destination_offset as u64) * std::mem::size_of::<u64>() as u64,
197        );
198    }
199
200    /// Copy resolved results to the readable buffer.
201    pub fn copy_to_readable(&self, encoder: &mut wgpu::CommandEncoder) {
202        let size = (self.count as u64) * std::mem::size_of::<u64>() as u64;
203        encoder.copy_buffer_to_buffer(&self.resolve_buffer, 0, &self.read_buffer, 0, size);
204    }
205
206    /// Map the read buffer for CPU access.
207    ///
208    /// Returns a future that completes when the buffer is mapped.
209    pub fn map_async(&self) -> impl std::future::Future<Output = Result<(), wgpu::BufferAsyncError>> {
210        let slice = self.read_buffer.slice(..);
211        let (tx, rx) = std::sync::mpsc::channel();
212
213        slice.map_async(wgpu::MapMode::Read, move |result| {
214            let _ = tx.send(result);
215        });
216
217        async move { rx.recv().map_err(|_| wgpu::BufferAsyncError)? }
218    }
219
220    /// Read the query results (must be mapped first).
221    ///
222    /// Returns the raw u64 timestamps/occlusion counts.
223    pub fn read_results(&self) -> Vec<u64> {
224        let slice = self.read_buffer.slice(..);
225        let data = slice.get_mapped_range();
226        let results: Vec<u64> = bytemuck::cast_slice(&data).to_vec();
227        drop(data);
228        self.read_buffer.unmap();
229        results
230    }
231}
232
233// =============================================================================
234// ProfileRegion
235// =============================================================================
236
237/// A handle to a profiling region.
238///
239/// Created by `GpuProfiler::begin_region` and consumed by `GpuProfiler::end_region`.
240#[derive(Debug)]
241pub struct ProfileRegion {
242    label: String,
243    start_query: u32,
244}
245
246// =============================================================================
247// GpuProfiler
248// =============================================================================
249
250/// High-level GPU profiler for measuring execution times.
251///
252/// This profiler uses timestamp queries to measure GPU execution time
253/// for different regions of your rendering code.
254///
255/// # Requirements
256///
257/// - Device must support `TIMESTAMP_QUERY` feature
258/// - Must call `begin_frame()` at the start of each frame
259/// - Must call `resolve()` before submitting commands
260///
261/// # Example
262///
263/// ```ignore
264/// let mut profiler = GpuProfiler::new(context.clone(), 256);
265///
266/// // Each frame:
267/// profiler.begin_frame();
268///
269/// let region = profiler.begin_region(&mut encoder, "My Pass");
270/// // ... do rendering ...
271/// profiler.end_region(&mut encoder, region);
272///
273/// profiler.resolve(&mut encoder);
274///
275/// // Read results (may be from previous frame)
276/// for (label, duration_ms) in profiler.read_results() {
277///     println!("{}: {:.2}ms", label, duration_ms);
278/// }
279/// ```
280impl RenderCapability for GpuProfiler {
281    fn requirements() -> GpuRequirements {
282        GpuRequirements::new().require_features(GpuFeatures::TIMESTAMP_QUERY)
283    }
284
285    fn name() -> &'static str {
286        "GpuProfiler"
287    }
288}
289
290pub struct GpuProfiler {
291    context: Arc<GraphicsContext>,
292    query_set: QuerySet,
293    result_buffer: QueryResultBuffer,
294    /// Current query index for the frame
295    current_query: u32,
296    /// Maximum queries per frame
297    max_queries: u32,
298    /// Regions from the current frame (label, start_query, end_query)
299    regions: Vec<(String, u32, u32)>,
300    /// Cached results from the previous frame
301    cached_results: Vec<(String, f64)>,
302    /// Timestamp period in nanoseconds per tick
303    timestamp_period: f32,
304}
305
306impl GpuProfiler {
307    /// Create a new GPU profiler.
308    ///
309    /// # Arguments
310    ///
311    /// * `context` - Graphics context (must support TIMESTAMP_QUERY)
312    /// * `max_queries` - Maximum number of timestamp queries per frame
313    ///
314    /// # Panics
315    ///
316    /// Panics if the device doesn't support timestamp queries.
317    pub fn new(context: Arc<GraphicsContext>, max_queries: u32) -> Self {
318        let timestamp_period = context.queue().get_timestamp_period();
319
320        let query_set = QuerySet::new(
321            context.device(),
322            Some("GPU Profiler Queries"),
323            QueryType::Timestamp,
324            max_queries,
325        );
326
327        let result_buffer = QueryResultBuffer::new(
328            context.device(),
329            Some("GPU Profiler Results"),
330            max_queries,
331        );
332
333        Self {
334            context,
335            query_set,
336            result_buffer,
337            current_query: 0,
338            max_queries,
339            regions: Vec::new(),
340            cached_results: Vec::new(),
341            timestamp_period,
342        }
343    }
344
345    /// Begin a new frame.
346    ///
347    /// Call this at the start of each frame before recording any regions.
348    pub fn begin_frame(&mut self) {
349        self.current_query = 0;
350        self.regions.clear();
351    }
352
353    /// Begin a profiling region.
354    ///
355    /// # Arguments
356    ///
357    /// * `encoder` - Command encoder to write the timestamp
358    /// * `label` - Human-readable label for this region
359    ///
360    /// # Returns
361    ///
362    /// A `ProfileRegion` handle that must be passed to `end_region`.
363    pub fn begin_region(
364        &mut self,
365        encoder: &mut wgpu::CommandEncoder,
366        label: &str,
367    ) -> Option<ProfileRegion> {
368        if self.current_query >= self.max_queries {
369            return None;
370        }
371
372        let start_query = self.current_query;
373        encoder.write_timestamp(&self.query_set.query_set, start_query);
374        self.current_query += 1;
375
376        Some(ProfileRegion {
377            label: label.to_string(),
378            start_query,
379        })
380    }
381
382    /// End a profiling region.
383    ///
384    /// # Arguments
385    ///
386    /// * `encoder` - Command encoder to write the timestamp
387    /// * `region` - The region handle from `begin_region`
388    pub fn end_region(&mut self, encoder: &mut wgpu::CommandEncoder, region: ProfileRegion) {
389        if self.current_query >= self.max_queries {
390            return;
391        }
392
393        let end_query = self.current_query;
394        encoder.write_timestamp(&self.query_set.query_set, end_query);
395        self.current_query += 1;
396
397        self.regions
398            .push((region.label, region.start_query, end_query));
399    }
400
401    /// Resolve all queries from this frame.
402    ///
403    /// Call this after all regions have been recorded, before submitting commands.
404    pub fn resolve(&self, encoder: &mut wgpu::CommandEncoder) {
405        if self.current_query == 0 {
406            return;
407        }
408
409        self.result_buffer.resolve(
410            encoder,
411            &self.query_set,
412            0..self.current_query,
413            0,
414        );
415        self.result_buffer.copy_to_readable(encoder);
416    }
417
418    /// Read profiling results synchronously.
419    ///
420    /// This blocks until the results are available from the GPU.
421    /// For non-blocking reads, consider using double-buffering or
422    /// reading results from the previous frame.
423    ///
424    /// # Returns
425    ///
426    /// A vector of (label, duration_ms) pairs for each completed region.
427    pub fn read_results(&mut self) -> &[(String, f64)] {
428        if self.regions.is_empty() {
429            return &self.cached_results;
430        }
431
432        let device = self.context.device();
433
434        // Map the buffer
435        let slice = self.result_buffer.read_buffer().slice(..);
436        let (tx, rx) = std::sync::mpsc::channel();
437
438        slice.map_async(wgpu::MapMode::Read, move |result| {
439            let _ = tx.send(result);
440        });
441
442        // Wait for the buffer to be mapped (blocking)
443        let _ = device.poll(wgpu::PollType::Wait {
444            submission_index: None,
445            timeout: None,
446        });
447
448        // Wait for the callback
449        if rx.recv().is_ok() {
450            let data = slice.get_mapped_range();
451            let timestamps: &[u64] = bytemuck::cast_slice(&data);
452
453            self.cached_results.clear();
454
455            for (label, start, end) in &self.regions {
456                let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
457                let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
458
459                // Convert ticks to milliseconds
460                let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
461                    * self.timestamp_period as f64;
462                let duration_ms = duration_ns / 1_000_000.0;
463
464                self.cached_results.push((label.clone(), duration_ms));
465            }
466
467            drop(data);
468            self.result_buffer.read_buffer().unmap();
469        }
470
471        &self.cached_results
472    }
473
474    /// Try to read profiling results without blocking.
475    ///
476    /// Returns None if the results are not yet available.
477    /// This is useful when you want to display results from the previous frame.
478    ///
479    /// # Returns
480    ///
481    /// Some reference to the cached results if new data was read, or the existing cached results.
482    pub fn try_read_results(&mut self) -> &[(String, f64)] {
483        if self.regions.is_empty() {
484            return &self.cached_results;
485        }
486
487        let device = self.context.device();
488
489        // Try to map the buffer
490        let slice = self.result_buffer.read_buffer().slice(..);
491        let (tx, rx) = std::sync::mpsc::channel();
492
493        slice.map_async(wgpu::MapMode::Read, move |result| {
494            let _ = tx.send(result);
495        });
496
497        // Non-blocking poll
498        let _ = device.poll(wgpu::PollType::Poll);
499
500        // Check if mapping succeeded
501        if let Ok(Ok(())) = rx.try_recv() {
502            let data = slice.get_mapped_range();
503            let timestamps: &[u64] = bytemuck::cast_slice(&data);
504
505            self.cached_results.clear();
506
507            for (label, start, end) in &self.regions {
508                let start_ts = timestamps.get(*start as usize).copied().unwrap_or(0);
509                let end_ts = timestamps.get(*end as usize).copied().unwrap_or(0);
510
511                // Convert ticks to milliseconds
512                let duration_ns = (end_ts.saturating_sub(start_ts)) as f64
513                    * self.timestamp_period as f64;
514                let duration_ms = duration_ns / 1_000_000.0;
515
516                self.cached_results.push((label.clone(), duration_ms));
517            }
518
519            drop(data);
520            self.result_buffer.read_buffer().unmap();
521        }
522
523        &self.cached_results
524    }
525
526    /// Get the number of queries used this frame.
527    #[inline]
528    pub fn queries_used(&self) -> u32 {
529        self.current_query
530    }
531
532    /// Get the maximum queries per frame.
533    #[inline]
534    pub fn max_queries(&self) -> u32 {
535        self.max_queries
536    }
537
538    /// Get the timestamp period in nanoseconds per tick.
539    #[inline]
540    pub fn timestamp_period(&self) -> f32 {
541        self.timestamp_period
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_query_type_conversion() {
551        // Just verify conversion doesn't panic
552        let _ = QueryType::Timestamp.to_wgpu();
553        let _ = QueryType::Occlusion.to_wgpu();
554    }
555
556    #[test]
557    fn test_profile_region_debug() {
558        let region = ProfileRegion {
559            label: "Test".to_string(),
560            start_query: 0,
561        };
562        // Just ensure Debug is implemented
563        let _ = format!("{:?}", region);
564    }
565}