Skip to main content

astrelis_render/
quad_renderer.rs

1//! Fast instanced quad renderer with GPU-based coordinate transformation.
2//!
3//! Renders thousands of quads (rectangles) efficiently using GPU instancing.
4//! Quads are stored in data coordinates, and the GPU transforms
5//! them to screen coordinates using a transformation matrix.
6//!
7//! This is primarily used for bar charts but can be used for any axis-aligned
8//! rectangle rendering where data-to-screen transformation is needed.
9
10use astrelis_core::profiling::profile_scope;
11use crate::capability::{GpuRequirements, RenderCapability};
12use crate::transform::{DataTransform, TransformUniform};
13use crate::{Color, GraphicsContext, Viewport};
14use bytemuck::{Pod, Zeroable};
15use glam::Vec2;
16use std::sync::Arc;
17use wgpu::util::DeviceExt;
18
19/// A quad (axis-aligned rectangle) for batch rendering.
20///
21/// The quad is defined by two data coordinates (min and max) which define
22/// the corners. For bar charts, x_min/x_max define the bar width and
23/// y_min/y_max define the bar height (typically y_min = baseline).
24#[derive(Debug, Clone, Copy)]
25pub struct Quad {
26    /// Minimum corner (typically bottom-left in data coords)
27    pub min: Vec2,
28    /// Maximum corner (typically top-right in data coords)
29    pub max: Vec2,
30    /// Fill color
31    pub color: Color,
32}
33
34impl Quad {
35    pub fn new(min: Vec2, max: Vec2, color: Color) -> Self {
36        Self { min, max, color }
37    }
38
39    /// Create a quad from center, width, and height.
40    pub fn from_center(center: Vec2, width: f32, height: f32, color: Color) -> Self {
41        let half = Vec2::new(width * 0.5, height * 0.5);
42        Self {
43            min: center - half,
44            max: center + half,
45            color,
46        }
47    }
48
49    /// Create a bar from x center, width, y_bottom, and y_top.
50    pub fn bar(x_center: f32, width: f32, y_bottom: f32, y_top: f32, color: Color) -> Self {
51        Self {
52            min: Vec2::new(x_center - width * 0.5, y_bottom),
53            max: Vec2::new(x_center + width * 0.5, y_top),
54            color,
55        }
56    }
57}
58
59/// GPU instance data for a quad.
60#[repr(C)]
61#[derive(Debug, Clone, Copy, Pod, Zeroable)]
62struct QuadInstance {
63    min: [f32; 2],
64    max: [f32; 2],
65    color: [f32; 4],
66}
67
68impl QuadInstance {
69    fn new(quad: &Quad) -> Self {
70        Self {
71            min: [quad.min.x, quad.min.y],
72            max: [quad.max.x, quad.max.y],
73            color: [quad.color.r, quad.color.g, quad.color.b, quad.color.a],
74        }
75    }
76}
77
78impl RenderCapability for QuadRenderer {
79    fn requirements() -> GpuRequirements {
80        GpuRequirements::none()
81    }
82
83    fn name() -> &'static str {
84        "QuadRenderer"
85    }
86}
87
88/// Fast batched quad renderer using GPU instancing.
89///
90/// Optimized for bar charts with large datasets. Key features:
91/// - Quads stored in data coordinates
92/// - GPU transforms data → screen (pan/zoom is cheap)
93/// - Only rebuild instance buffer when data actually changes
94pub struct QuadRenderer {
95    context: Arc<GraphicsContext>,
96    pipeline: wgpu::RenderPipeline,
97    vertex_buffer: wgpu::Buffer,
98    transform_buffer: wgpu::Buffer,
99    transform_bind_group: wgpu::BindGroup,
100    instance_buffer: Option<wgpu::Buffer>,
101    instance_count: u32,
102    /// Pending quads
103    pending_quads: Vec<Quad>,
104    /// Whether quads need to be re-uploaded
105    data_dirty: bool,
106}
107
108impl QuadRenderer {
109    /// Create a new quad renderer with the given target texture format.
110    ///
111    /// The `target_format` must match the render target this renderer will draw into.
112    /// For window surfaces, use the format from `WindowContext::format()`.
113    pub fn new(context: Arc<GraphicsContext>, target_format: wgpu::TextureFormat) -> Self {
114        // Create transform uniform buffer
115        let transform_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
116            label: Some("Quad Renderer Transform Buffer"),
117            size: std::mem::size_of::<TransformUniform>() as u64,
118            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
119            mapped_at_creation: false,
120        });
121
122        // Bind group layout
123        let bind_group_layout =
124            context
125                .device()
126                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
127                    label: Some("Quad Renderer Bind Group Layout"),
128                    entries: &[wgpu::BindGroupLayoutEntry {
129                        binding: 0,
130                        visibility: wgpu::ShaderStages::VERTEX,
131                        ty: wgpu::BindingType::Buffer {
132                            ty: wgpu::BufferBindingType::Uniform,
133                            has_dynamic_offset: false,
134                            min_binding_size: None,
135                        },
136                        count: None,
137                    }],
138                });
139
140        let transform_bind_group = context.device().create_bind_group(&wgpu::BindGroupDescriptor {
141            label: Some("Quad Renderer Transform Bind Group"),
142            layout: &bind_group_layout,
143            entries: &[wgpu::BindGroupEntry {
144                binding: 0,
145                resource: transform_buffer.as_entire_binding(),
146            }],
147        });
148
149        // Shader
150        let shader = context
151            .device()
152            .create_shader_module(wgpu::ShaderModuleDescriptor {
153                label: Some("Quad Renderer Shader"),
154                source: wgpu::ShaderSource::Wgsl(QUAD_SHADER.into()),
155            });
156
157        // Pipeline
158        let pipeline_layout =
159            context
160                .device()
161                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
162                    label: Some("Quad Renderer Pipeline Layout"),
163                    bind_group_layouts: &[&bind_group_layout],
164                    push_constant_ranges: &[],
165                });
166
167        let pipeline = context
168            .device()
169            .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
170                label: Some("Quad Renderer Pipeline"),
171                layout: Some(&pipeline_layout),
172                vertex: wgpu::VertexState {
173                    module: &shader,
174                    entry_point: Some("vs_main"),
175                    buffers: &[
176                        // Unit quad vertices
177                        wgpu::VertexBufferLayout {
178                            array_stride: 8,
179                            step_mode: wgpu::VertexStepMode::Vertex,
180                            attributes: &[wgpu::VertexAttribute {
181                                format: wgpu::VertexFormat::Float32x2,
182                                offset: 0,
183                                shader_location: 0,
184                            }],
185                        },
186                        // Quad instances
187                        wgpu::VertexBufferLayout {
188                            array_stride: std::mem::size_of::<QuadInstance>() as u64,
189                            step_mode: wgpu::VertexStepMode::Instance,
190                            attributes: &[
191                                wgpu::VertexAttribute {
192                                    format: wgpu::VertexFormat::Float32x2,
193                                    offset: 0,
194                                    shader_location: 1,
195                                },
196                                wgpu::VertexAttribute {
197                                    format: wgpu::VertexFormat::Float32x2,
198                                    offset: 8,
199                                    shader_location: 2,
200                                },
201                                wgpu::VertexAttribute {
202                                    format: wgpu::VertexFormat::Float32x4,
203                                    offset: 16,
204                                    shader_location: 3,
205                                },
206                            ],
207                        },
208                    ],
209                    compilation_options: wgpu::PipelineCompilationOptions::default(),
210                },
211                fragment: Some(wgpu::FragmentState {
212                    module: &shader,
213                    entry_point: Some("fs_main"),
214                    targets: &[Some(wgpu::ColorTargetState {
215                        format: target_format,
216                        blend: Some(wgpu::BlendState::ALPHA_BLENDING),
217                        write_mask: wgpu::ColorWrites::ALL,
218                    })],
219                    compilation_options: wgpu::PipelineCompilationOptions::default(),
220                }),
221                primitive: wgpu::PrimitiveState {
222                    topology: wgpu::PrimitiveTopology::TriangleStrip,
223                    cull_mode: None,
224                    ..Default::default()
225                },
226                depth_stencil: None,
227                multisample: wgpu::MultisampleState::default(),
228                multiview: None,
229                cache: None,
230            });
231
232        // Unit quad (0,0 to 1,1)
233        let quad_vertices: [[f32; 2]; 4] = [
234            [0.0, 0.0],
235            [1.0, 0.0],
236            [0.0, 1.0],
237            [1.0, 1.0],
238        ];
239
240        let vertex_buffer = context
241            .device()
242            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
243                label: Some("Quad Renderer Vertex Buffer"),
244                contents: bytemuck::cast_slice(&quad_vertices),
245                usage: wgpu::BufferUsages::VERTEX,
246            });
247
248        Self {
249            context,
250            pipeline,
251            vertex_buffer,
252            transform_buffer,
253            transform_bind_group,
254            instance_buffer: None,
255            instance_count: 0,
256            pending_quads: Vec::with_capacity(1024),
257            data_dirty: false,
258        }
259    }
260
261    /// Clear all quads. Call this when data changes.
262    pub fn clear(&mut self) {
263        self.pending_quads.clear();
264        self.data_dirty = true;
265    }
266
267    /// Add a quad.
268    #[inline]
269    pub fn add_quad(&mut self, min: Vec2, max: Vec2, color: Color) {
270        self.pending_quads.push(Quad::new(min, max, color));
271        self.data_dirty = true;
272    }
273
274    /// Add a bar from center x, width, y range.
275    #[inline]
276    pub fn add_bar(&mut self, x_center: f32, width: f32, y_bottom: f32, y_top: f32, color: Color) {
277        self.pending_quads.push(Quad::bar(x_center, width, y_bottom, y_top, color));
278        self.data_dirty = true;
279    }
280
281    /// Add a quad.
282    #[inline]
283    pub fn add(&mut self, quad: Quad) {
284        self.pending_quads.push(quad);
285        self.data_dirty = true;
286    }
287
288    /// Get the number of quads.
289    pub fn quad_count(&self) -> usize {
290        self.pending_quads.len()
291    }
292
293    /// Prepare GPU buffers. Only uploads data if it changed.
294    pub fn prepare(&mut self) {
295        profile_scope!("quad_renderer_prepare");
296
297        if !self.data_dirty {
298            return; // No data change, skip upload
299        }
300
301        if self.pending_quads.is_empty() {
302            self.instance_buffer = None;
303            self.instance_count = 0;
304            self.data_dirty = false;
305            return;
306        }
307
308        tracing::trace!("Uploading {} quads to GPU", self.pending_quads.len());
309
310        // Convert to GPU format
311        let instances: Vec<QuadInstance> = {
312            profile_scope!("convert_instances");
313            self.pending_quads.iter().map(QuadInstance::new).collect()
314        };
315
316        // Create buffer
317        {
318            profile_scope!("create_instance_buffer");
319            self.instance_buffer = Some(
320                self.context
321                    .device()
322                    .create_buffer_init(&wgpu::util::BufferInitDescriptor {
323                        label: Some("Quad Renderer Instance Buffer"),
324                        contents: bytemuck::cast_slice(&instances),
325                        usage: wgpu::BufferUsages::VERTEX,
326                    }),
327            );
328        }
329
330        self.instance_count = self.pending_quads.len() as u32;
331        self.data_dirty = false;
332    }
333
334    /// Render quads with identity transform (data coords = screen coords).
335    pub fn render(&self, pass: &mut wgpu::RenderPass, viewport: Viewport) {
336        let transform = DataTransform::identity(viewport);
337        self.render_transformed(pass, &transform);
338    }
339
340    /// Render quads with a [`DataTransform`].
341    ///
342    /// This is the preferred method for rendering with data-to-screen mapping.
343    /// The transform is cheap to update (32 bytes), so pan/zoom only updates
344    /// the transform, not the quad data.
345    ///
346    /// # Example
347    ///
348    /// ```ignore
349    /// let transform = DataTransform::from_data_range(viewport, DataRangeParams {
350    ///     plot_x: 80.0, plot_y: 20.0,
351    ///     plot_width: 600.0, plot_height: 400.0,
352    ///     data_x_min: 0.0, data_x_max: 100.0,
353    ///     data_y_min: 0.0, data_y_max: 50.0,
354    /// });
355    /// quad_renderer.render_transformed(pass, &transform);
356    /// ```
357    pub fn render_transformed(&self, pass: &mut wgpu::RenderPass, transform: &DataTransform) {
358        self.render_with_uniform(pass, transform.uniform());
359    }
360
361    /// Render quads with a data-to-screen transformation.
362    ///
363    /// **Deprecated:** Use [`render_transformed`](Self::render_transformed) with a
364    /// [`DataTransform`] instead for a cleaner API.
365    ///
366    /// This is the fast path for charts: data doesn't change on pan/zoom,
367    /// only the transform does.
368    pub fn render_with_data_transform(
369        &self,
370        pass: &mut wgpu::RenderPass,
371        viewport: Viewport,
372        plot_x: f32,
373        plot_y: f32,
374        plot_width: f32,
375        plot_height: f32,
376        data_x_min: f64,
377        data_x_max: f64,
378        data_y_min: f64,
379        data_y_max: f64,
380    ) {
381        let transform = DataTransform::from_data_range(
382            viewport,
383            crate::transform::DataRangeParams::new(
384                plot_x,
385                plot_y,
386                plot_width,
387                plot_height,
388                data_x_min,
389                data_x_max,
390                data_y_min,
391                data_y_max,
392            ),
393        );
394        self.render_transformed(pass, &transform);
395    }
396
397    /// Render with a specific transform uniform.
398    fn render_with_uniform(&self, pass: &mut wgpu::RenderPass, transform: &TransformUniform) {
399        profile_scope!("quad_renderer_render");
400
401        if self.instance_count == 0 {
402            return;
403        }
404
405        let Some(instance_buffer) = &self.instance_buffer else {
406            return;
407        };
408
409        // Upload transform
410        self.context.queue().write_buffer(
411            &self.transform_buffer,
412            0,
413            bytemuck::cast_slice(&[*transform]),
414        );
415
416        // Draw
417        pass.push_debug_group("QuadRenderer::render");
418        pass.set_pipeline(&self.pipeline);
419        pass.set_bind_group(0, &self.transform_bind_group, &[]);
420        pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
421        pass.set_vertex_buffer(1, instance_buffer.slice(..));
422        pass.draw(0..4, 0..self.instance_count);
423        pass.pop_debug_group();
424    }
425}
426
427/// WGSL shader for quads with data coordinate transformation.
428const QUAD_SHADER: &str = r#"
429struct Transform {
430    projection: mat4x4<f32>,
431    scale: vec2<f32>,
432    offset: vec2<f32>,
433}
434
435@group(0) @binding(0)
436var<uniform> transform: Transform;
437
438struct VertexInput {
439    @location(0) quad_pos: vec2<f32>,  // 0-1 range unit quad
440    @location(1) rect_min: vec2<f32>,  // data coords
441    @location(2) rect_max: vec2<f32>,  // data coords
442    @location(3) color: vec4<f32>,
443}
444
445struct VertexOutput {
446    @builtin(position) position: vec4<f32>,
447    @location(0) color: vec4<f32>,
448}
449
450@vertex
451fn vs_main(input: VertexInput) -> VertexOutput {
452    var output: VertexOutput;
453
454    // Interpolate between min and max based on quad position (0-1)
455    let data_pos = mix(input.rect_min, input.rect_max, input.quad_pos);
456
457    // Transform data coordinates to screen coordinates
458    let screen_pos = data_pos * transform.scale + transform.offset;
459
460    output.position = transform.projection * vec4<f32>(screen_pos, 0.0, 1.0);
461    output.color = input.color;
462
463    return output;
464}
465
466@fragment
467fn fs_main(input: VertexOutput) -> @location(0) vec4<f32> {
468    return input.color;
469}
470"#;