Skip to main content

astrelis_render/
point_renderer.rs

1//! Fast instanced point renderer with GPU-based coordinate transformation.
2//!
3//! Renders thousands of points efficiently using GPU instancing.
4//! Points are stored in data coordinates, and the GPU transforms
5//! them to screen coordinates using a transformation matrix.
6//!
7//! This means pan/zoom only updates a small uniform buffer, not all point data.
8
9use crate::capability::{GpuRequirements, RenderCapability};
10use crate::transform::{DataTransform, TransformUniform};
11use crate::{Color, GraphicsContext, Viewport};
12use astrelis_core::profiling::profile_scope;
13use bytemuck::{Pod, Zeroable};
14use glam::Vec2;
15use std::sync::Arc;
16use wgpu::util::DeviceExt;
17
18/// A point for batch rendering.
19///
20/// Coordinates can be in any space - use `render_transformed()` to map
21/// them to screen coordinates.
22#[derive(Debug, Clone, Copy)]
23pub struct Point {
24    pub position: Vec2,
25    pub size: f32,
26    pub color: Color,
27}
28
29impl Point {
30    pub fn new(position: Vec2, size: f32, color: Color) -> Self {
31        Self {
32            position,
33            size,
34            color,
35        }
36    }
37}
38
39/// GPU instance data for a point.
40#[repr(C)]
41#[derive(Debug, Clone, Copy, Pod, Zeroable)]
42struct PointInstance {
43    position: [f32; 2],
44    size: f32,
45    color: [f32; 4],
46    _padding: f32,
47}
48
49impl PointInstance {
50    fn new(point: &Point) -> Self {
51        Self {
52            position: [point.position.x, point.position.y],
53            size: point.size,
54            color: [point.color.r, point.color.g, point.color.b, point.color.a],
55            _padding: 0.0,
56        }
57    }
58}
59
60impl RenderCapability for PointRenderer {
61    fn requirements() -> GpuRequirements {
62        GpuRequirements::none()
63    }
64
65    fn name() -> &'static str {
66        "PointRenderer"
67    }
68}
69
70/// Fast batched point renderer using GPU instancing.
71///
72/// Optimized for scatter charts with large datasets. Key features:
73/// - Points stored in data coordinates
74/// - GPU transforms data → screen (pan/zoom is cheap)
75/// - Only rebuild instance buffer when data actually changes
76pub struct PointRenderer {
77    context: Arc<GraphicsContext>,
78    pipeline: wgpu::RenderPipeline,
79    vertex_buffer: wgpu::Buffer,
80    transform_buffer: wgpu::Buffer,
81    transform_bind_group: wgpu::BindGroup,
82    instance_buffer: Option<wgpu::Buffer>,
83    instance_count: u32,
84    /// Pending points
85    pending_points: Vec<Point>,
86    /// Whether points need to be re-uploaded
87    data_dirty: bool,
88}
89
90impl PointRenderer {
91    /// Create a new point renderer with the given target texture format.
92    ///
93    /// The `target_format` must match the render target this renderer will draw into.
94    /// For window surfaces, use the format from `WindowContext::format()`.
95    pub fn new(context: Arc<GraphicsContext>, target_format: wgpu::TextureFormat) -> Self {
96        // Create transform uniform buffer
97        let transform_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
98            label: Some("Point Renderer Transform Buffer"),
99            size: std::mem::size_of::<TransformUniform>() as u64,
100            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
101            mapped_at_creation: false,
102        });
103
104        // Bind group layout
105        let bind_group_layout =
106            context
107                .device()
108                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
109                    label: Some("Point Renderer Bind Group Layout"),
110                    entries: &[wgpu::BindGroupLayoutEntry {
111                        binding: 0,
112                        visibility: wgpu::ShaderStages::VERTEX,
113                        ty: wgpu::BindingType::Buffer {
114                            ty: wgpu::BufferBindingType::Uniform,
115                            has_dynamic_offset: false,
116                            min_binding_size: None,
117                        },
118                        count: None,
119                    }],
120                });
121
122        let transform_bind_group = context
123            .device()
124            .create_bind_group(&wgpu::BindGroupDescriptor {
125                label: Some("Point Renderer Transform Bind Group"),
126                layout: &bind_group_layout,
127                entries: &[wgpu::BindGroupEntry {
128                    binding: 0,
129                    resource: transform_buffer.as_entire_binding(),
130                }],
131            });
132
133        // Shader
134        let shader = context
135            .device()
136            .create_shader_module(wgpu::ShaderModuleDescriptor {
137                label: Some("Point Renderer Shader"),
138                source: wgpu::ShaderSource::Wgsl(POINT_SHADER.into()),
139            });
140
141        // Pipeline
142        let pipeline_layout =
143            context
144                .device()
145                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
146                    label: Some("Point Renderer Pipeline Layout"),
147                    bind_group_layouts: &[&bind_group_layout],
148                    push_constant_ranges: &[],
149                });
150
151        let pipeline = context
152            .device()
153            .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
154                label: Some("Point Renderer Pipeline"),
155                layout: Some(&pipeline_layout),
156                vertex: wgpu::VertexState {
157                    module: &shader,
158                    entry_point: Some("vs_main"),
159                    buffers: &[
160                        // Unit quad vertices
161                        wgpu::VertexBufferLayout {
162                            array_stride: 8,
163                            step_mode: wgpu::VertexStepMode::Vertex,
164                            attributes: &[wgpu::VertexAttribute {
165                                format: wgpu::VertexFormat::Float32x2,
166                                offset: 0,
167                                shader_location: 0,
168                            }],
169                        },
170                        // Point instances
171                        wgpu::VertexBufferLayout {
172                            array_stride: std::mem::size_of::<PointInstance>() as u64,
173                            step_mode: wgpu::VertexStepMode::Instance,
174                            attributes: &[
175                                wgpu::VertexAttribute {
176                                    format: wgpu::VertexFormat::Float32x2,
177                                    offset: 0,
178                                    shader_location: 1,
179                                },
180                                wgpu::VertexAttribute {
181                                    format: wgpu::VertexFormat::Float32,
182                                    offset: 8,
183                                    shader_location: 2,
184                                },
185                                wgpu::VertexAttribute {
186                                    format: wgpu::VertexFormat::Float32x4,
187                                    offset: 12,
188                                    shader_location: 3,
189                                },
190                            ],
191                        },
192                    ],
193                    compilation_options: wgpu::PipelineCompilationOptions::default(),
194                },
195                fragment: Some(wgpu::FragmentState {
196                    module: &shader,
197                    entry_point: Some("fs_main"),
198                    targets: &[Some(wgpu::ColorTargetState {
199                        format: target_format,
200                        blend: Some(wgpu::BlendState::ALPHA_BLENDING),
201                        write_mask: wgpu::ColorWrites::ALL,
202                    })],
203                    compilation_options: wgpu::PipelineCompilationOptions::default(),
204                }),
205                primitive: wgpu::PrimitiveState {
206                    topology: wgpu::PrimitiveTopology::TriangleStrip,
207                    cull_mode: None,
208                    ..Default::default()
209                },
210                depth_stencil: None,
211                multisample: wgpu::MultisampleState::default(),
212                multiview: None,
213                cache: None,
214            });
215
216        // Unit quad (for rendering circles as billboards)
217        let quad_vertices: [[f32; 2]; 4] = [[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]];
218
219        let vertex_buffer =
220            context
221                .device()
222                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
223                    label: Some("Point Renderer Vertex Buffer"),
224                    contents: bytemuck::cast_slice(&quad_vertices),
225                    usage: wgpu::BufferUsages::VERTEX,
226                });
227
228        Self {
229            context,
230            pipeline,
231            vertex_buffer,
232            transform_buffer,
233            transform_bind_group,
234            instance_buffer: None,
235            instance_count: 0,
236            pending_points: Vec::with_capacity(1024),
237            data_dirty: false,
238        }
239    }
240
241    /// Clear all points. Call this when data changes.
242    pub fn clear(&mut self) {
243        self.pending_points.clear();
244        self.data_dirty = true;
245    }
246
247    /// Add a point.
248    #[inline]
249    pub fn add_point(&mut self, position: Vec2, size: f32, color: Color) {
250        self.pending_points.push(Point::new(position, size, color));
251        self.data_dirty = true;
252    }
253
254    /// Add a point.
255    #[inline]
256    pub fn add(&mut self, point: Point) {
257        self.pending_points.push(point);
258        self.data_dirty = true;
259    }
260
261    /// Get the number of points.
262    pub fn point_count(&self) -> usize {
263        self.pending_points.len()
264    }
265
266    /// Prepare GPU buffers. Only uploads data if it changed.
267    pub fn prepare(&mut self) {
268        profile_scope!("point_renderer_prepare");
269
270        if !self.data_dirty {
271            return; // No data change, skip upload
272        }
273
274        if self.pending_points.is_empty() {
275            self.instance_buffer = None;
276            self.instance_count = 0;
277            self.data_dirty = false;
278            return;
279        }
280
281        tracing::trace!("Uploading {} points to GPU", self.pending_points.len());
282
283        // Convert to GPU format
284        let instances: Vec<PointInstance> = {
285            profile_scope!("convert_instances");
286            self.pending_points.iter().map(PointInstance::new).collect()
287        };
288
289        // Create buffer
290        {
291            profile_scope!("create_instance_buffer");
292            self.instance_buffer = Some(self.context.device().create_buffer_init(
293                &wgpu::util::BufferInitDescriptor {
294                    label: Some("Point Renderer Instance Buffer"),
295                    contents: bytemuck::cast_slice(&instances),
296                    usage: wgpu::BufferUsages::VERTEX,
297                },
298            ));
299        }
300
301        self.instance_count = self.pending_points.len() as u32;
302        self.data_dirty = false;
303    }
304
305    /// Render points with identity transform (data coords = screen coords).
306    pub fn render(&self, pass: &mut wgpu::RenderPass, viewport: Viewport) {
307        let transform = DataTransform::identity(viewport);
308        self.render_transformed(pass, &transform);
309    }
310
311    /// Render points with a [`DataTransform`].
312    ///
313    /// This is the preferred method for rendering with data-to-screen mapping.
314    /// The transform is cheap to update (32 bytes), so pan/zoom only updates
315    /// the transform, not the point data.
316    ///
317    /// # Example
318    ///
319    /// ```ignore
320    /// let transform = DataTransform::from_data_range(viewport, DataRangeParams {
321    ///     plot_x: 80.0, plot_y: 20.0,
322    ///     plot_width: 600.0, plot_height: 400.0,
323    ///     data_x_min: 0.0, data_x_max: 100.0,
324    ///     data_y_min: 0.0, data_y_max: 50.0,
325    /// });
326    /// point_renderer.render_transformed(pass, &transform);
327    /// ```
328    pub fn render_transformed(&self, pass: &mut wgpu::RenderPass, transform: &DataTransform) {
329        self.render_with_uniform(pass, transform.uniform());
330    }
331
332    /// Render with a specific transform uniform.
333    fn render_with_uniform(&self, pass: &mut wgpu::RenderPass, transform: &TransformUniform) {
334        profile_scope!("point_renderer_render");
335
336        if self.instance_count == 0 {
337            return;
338        }
339
340        let Some(instance_buffer) = &self.instance_buffer else {
341            return;
342        };
343
344        // Upload transform
345        self.context.queue().write_buffer(
346            &self.transform_buffer,
347            0,
348            bytemuck::cast_slice(&[*transform]),
349        );
350
351        // Draw
352        pass.push_debug_group("PointRenderer::render");
353        pass.set_pipeline(&self.pipeline);
354        pass.set_bind_group(0, &self.transform_bind_group, &[]);
355        pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
356        pass.set_vertex_buffer(1, instance_buffer.slice(..));
357        pass.draw(0..4, 0..self.instance_count);
358        pass.pop_debug_group();
359    }
360}
361
362/// WGSL shader for points with circle rendering and data coordinate transformation.
363const POINT_SHADER: &str = r#"
364struct Transform {
365    projection: mat4x4<f32>,
366    scale: vec2<f32>,
367    offset: vec2<f32>,
368}
369
370@group(0) @binding(0)
371var<uniform> transform: Transform;
372
373struct VertexInput {
374    @location(0) quad_pos: vec2<f32>,
375    @location(1) point_position: vec2<f32>,
376    @location(2) point_size: f32,
377    @location(3) color: vec4<f32>,
378}
379
380struct VertexOutput {
381    @builtin(position) position: vec4<f32>,
382    @location(0) color: vec4<f32>,
383    @location(1) uv: vec2<f32>,
384}
385
386@vertex
387fn vs_main(input: VertexInput) -> VertexOutput {
388    var output: VertexOutput;
389
390    // Transform data coordinates to screen coordinates
391    let screen_pos = input.point_position * transform.scale + transform.offset;
392
393    // Offset quad by point size (in screen pixels)
394    let world_pos = screen_pos + input.quad_pos * input.point_size;
395
396    output.position = transform.projection * vec4<f32>(world_pos, 0.0, 1.0);
397    output.color = input.color;
398    output.uv = input.quad_pos + 0.5; // UV from 0 to 1
399
400    return output;
401}
402
403@fragment
404fn fs_main(input: VertexOutput) -> @location(0) vec4<f32> {
405    // Render as circle: distance from center
406    let center = vec2<f32>(0.5, 0.5);
407    let dist = distance(input.uv, center);
408
409    // Smooth edge for anti-aliasing
410    let alpha = 1.0 - smoothstep(0.4, 0.5, dist);
411
412    if alpha < 0.01 {
413        discard;
414    }
415
416    return vec4<f32>(input.color.rgb, input.color.a * alpha);
417}
418"#;