Skip to main content

astrelis_render/
point_renderer.rs

1//! Fast instanced point renderer with GPU-based coordinate transformation.
2//!
3//! Renders thousands of points efficiently using GPU instancing.
4//! Points are stored in data coordinates, and the GPU transforms
5//! them to screen coordinates using a transformation matrix.
6//!
7//! This means pan/zoom only updates a small uniform buffer, not all point data.
8
9use crate::capability::{GpuRequirements, RenderCapability};
10use crate::transform::{DataTransform, TransformUniform};
11use crate::{Color, GraphicsContext, Viewport};
12use astrelis_core::profiling::profile_scope;
13use bytemuck::{Pod, Zeroable};
14use glam::Vec2;
15use std::sync::Arc;
16use wgpu::util::DeviceExt;
17
18/// A point for batch rendering.
19///
20/// Coordinates can be in any space - use `render_with_data_transform()` to map
21/// them to screen coordinates.
22#[derive(Debug, Clone, Copy)]
23pub struct Point {
24    pub position: Vec2,
25    pub size: f32,
26    pub color: Color,
27}
28
29impl Point {
30    pub fn new(position: Vec2, size: f32, color: Color) -> Self {
31        Self {
32            position,
33            size,
34            color,
35        }
36    }
37}
38
39/// GPU instance data for a point.
40#[repr(C)]
41#[derive(Debug, Clone, Copy, Pod, Zeroable)]
42struct PointInstance {
43    position: [f32; 2],
44    size: f32,
45    color: [f32; 4],
46    _padding: f32,
47}
48
49impl PointInstance {
50    fn new(point: &Point) -> Self {
51        Self {
52            position: [point.position.x, point.position.y],
53            size: point.size,
54            color: [point.color.r, point.color.g, point.color.b, point.color.a],
55            _padding: 0.0,
56        }
57    }
58}
59
60impl RenderCapability for PointRenderer {
61    fn requirements() -> GpuRequirements {
62        GpuRequirements::none()
63    }
64
65    fn name() -> &'static str {
66        "PointRenderer"
67    }
68}
69
70/// Fast batched point renderer using GPU instancing.
71///
72/// Optimized for scatter charts with large datasets. Key features:
73/// - Points stored in data coordinates
74/// - GPU transforms data → screen (pan/zoom is cheap)
75/// - Only rebuild instance buffer when data actually changes
76pub struct PointRenderer {
77    context: Arc<GraphicsContext>,
78    pipeline: wgpu::RenderPipeline,
79    vertex_buffer: wgpu::Buffer,
80    transform_buffer: wgpu::Buffer,
81    transform_bind_group: wgpu::BindGroup,
82    instance_buffer: Option<wgpu::Buffer>,
83    instance_count: u32,
84    /// Pending points
85    pending_points: Vec<Point>,
86    /// Whether points need to be re-uploaded
87    data_dirty: bool,
88}
89
90impl PointRenderer {
91    /// Create a new point renderer with the given target texture format.
92    ///
93    /// The `target_format` must match the render target this renderer will draw into.
94    /// For window surfaces, use the format from `WindowContext::format()`.
95    pub fn new(context: Arc<GraphicsContext>, target_format: wgpu::TextureFormat) -> Self {
96        // Create transform uniform buffer
97        let transform_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
98            label: Some("Point Renderer Transform Buffer"),
99            size: std::mem::size_of::<TransformUniform>() as u64,
100            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
101            mapped_at_creation: false,
102        });
103
104        // Bind group layout
105        let bind_group_layout =
106            context
107                .device()
108                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
109                    label: Some("Point Renderer Bind Group Layout"),
110                    entries: &[wgpu::BindGroupLayoutEntry {
111                        binding: 0,
112                        visibility: wgpu::ShaderStages::VERTEX,
113                        ty: wgpu::BindingType::Buffer {
114                            ty: wgpu::BufferBindingType::Uniform,
115                            has_dynamic_offset: false,
116                            min_binding_size: None,
117                        },
118                        count: None,
119                    }],
120                });
121
122        let transform_bind_group = context
123            .device()
124            .create_bind_group(&wgpu::BindGroupDescriptor {
125                label: Some("Point Renderer Transform Bind Group"),
126                layout: &bind_group_layout,
127                entries: &[wgpu::BindGroupEntry {
128                    binding: 0,
129                    resource: transform_buffer.as_entire_binding(),
130                }],
131            });
132
133        // Shader
134        let shader = context
135            .device()
136            .create_shader_module(wgpu::ShaderModuleDescriptor {
137                label: Some("Point Renderer Shader"),
138                source: wgpu::ShaderSource::Wgsl(POINT_SHADER.into()),
139            });
140
141        // Pipeline
142        let pipeline_layout =
143            context
144                .device()
145                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
146                    label: Some("Point Renderer Pipeline Layout"),
147                    bind_group_layouts: &[&bind_group_layout],
148                    push_constant_ranges: &[],
149                });
150
151        let pipeline = context
152            .device()
153            .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
154                label: Some("Point Renderer Pipeline"),
155                layout: Some(&pipeline_layout),
156                vertex: wgpu::VertexState {
157                    module: &shader,
158                    entry_point: Some("vs_main"),
159                    buffers: &[
160                        // Unit quad vertices
161                        wgpu::VertexBufferLayout {
162                            array_stride: 8,
163                            step_mode: wgpu::VertexStepMode::Vertex,
164                            attributes: &[wgpu::VertexAttribute {
165                                format: wgpu::VertexFormat::Float32x2,
166                                offset: 0,
167                                shader_location: 0,
168                            }],
169                        },
170                        // Point instances
171                        wgpu::VertexBufferLayout {
172                            array_stride: std::mem::size_of::<PointInstance>() as u64,
173                            step_mode: wgpu::VertexStepMode::Instance,
174                            attributes: &[
175                                wgpu::VertexAttribute {
176                                    format: wgpu::VertexFormat::Float32x2,
177                                    offset: 0,
178                                    shader_location: 1,
179                                },
180                                wgpu::VertexAttribute {
181                                    format: wgpu::VertexFormat::Float32,
182                                    offset: 8,
183                                    shader_location: 2,
184                                },
185                                wgpu::VertexAttribute {
186                                    format: wgpu::VertexFormat::Float32x4,
187                                    offset: 12,
188                                    shader_location: 3,
189                                },
190                            ],
191                        },
192                    ],
193                    compilation_options: wgpu::PipelineCompilationOptions::default(),
194                },
195                fragment: Some(wgpu::FragmentState {
196                    module: &shader,
197                    entry_point: Some("fs_main"),
198                    targets: &[Some(wgpu::ColorTargetState {
199                        format: target_format,
200                        blend: Some(wgpu::BlendState::ALPHA_BLENDING),
201                        write_mask: wgpu::ColorWrites::ALL,
202                    })],
203                    compilation_options: wgpu::PipelineCompilationOptions::default(),
204                }),
205                primitive: wgpu::PrimitiveState {
206                    topology: wgpu::PrimitiveTopology::TriangleStrip,
207                    cull_mode: None,
208                    ..Default::default()
209                },
210                depth_stencil: None,
211                multisample: wgpu::MultisampleState::default(),
212                multiview: None,
213                cache: None,
214            });
215
216        // Unit quad (for rendering circles as billboards)
217        let quad_vertices: [[f32; 2]; 4] = [[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]];
218
219        let vertex_buffer =
220            context
221                .device()
222                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
223                    label: Some("Point Renderer Vertex Buffer"),
224                    contents: bytemuck::cast_slice(&quad_vertices),
225                    usage: wgpu::BufferUsages::VERTEX,
226                });
227
228        Self {
229            context,
230            pipeline,
231            vertex_buffer,
232            transform_buffer,
233            transform_bind_group,
234            instance_buffer: None,
235            instance_count: 0,
236            pending_points: Vec::with_capacity(1024),
237            data_dirty: false,
238        }
239    }
240
241    /// Clear all points. Call this when data changes.
242    pub fn clear(&mut self) {
243        self.pending_points.clear();
244        self.data_dirty = true;
245    }
246
247    /// Add a point.
248    #[inline]
249    pub fn add_point(&mut self, position: Vec2, size: f32, color: Color) {
250        self.pending_points.push(Point::new(position, size, color));
251        self.data_dirty = true;
252    }
253
254    /// Add a point.
255    #[inline]
256    pub fn add(&mut self, point: Point) {
257        self.pending_points.push(point);
258        self.data_dirty = true;
259    }
260
261    /// Get the number of points.
262    pub fn point_count(&self) -> usize {
263        self.pending_points.len()
264    }
265
266    /// Prepare GPU buffers. Only uploads data if it changed.
267    pub fn prepare(&mut self) {
268        profile_scope!("point_renderer_prepare");
269
270        if !self.data_dirty {
271            return; // No data change, skip upload
272        }
273
274        if self.pending_points.is_empty() {
275            self.instance_buffer = None;
276            self.instance_count = 0;
277            self.data_dirty = false;
278            return;
279        }
280
281        tracing::trace!("Uploading {} points to GPU", self.pending_points.len());
282
283        // Convert to GPU format
284        let instances: Vec<PointInstance> = {
285            profile_scope!("convert_instances");
286            self.pending_points.iter().map(PointInstance::new).collect()
287        };
288
289        // Create buffer
290        {
291            profile_scope!("create_instance_buffer");
292            self.instance_buffer = Some(self.context.device().create_buffer_init(
293                &wgpu::util::BufferInitDescriptor {
294                    label: Some("Point Renderer Instance Buffer"),
295                    contents: bytemuck::cast_slice(&instances),
296                    usage: wgpu::BufferUsages::VERTEX,
297                },
298            ));
299        }
300
301        self.instance_count = self.pending_points.len() as u32;
302        self.data_dirty = false;
303    }
304
305    /// Render points with identity transform (data coords = screen coords).
306    pub fn render(&self, pass: &mut wgpu::RenderPass, viewport: Viewport) {
307        let transform = DataTransform::identity(viewport);
308        self.render_transformed(pass, &transform);
309    }
310
311    /// Render points with a [`DataTransform`].
312    ///
313    /// This is the preferred method for rendering with data-to-screen mapping.
314    /// The transform is cheap to update (32 bytes), so pan/zoom only updates
315    /// the transform, not the point data.
316    ///
317    /// # Example
318    ///
319    /// ```ignore
320    /// let transform = DataTransform::from_data_range(viewport, DataRangeParams {
321    ///     plot_x: 80.0, plot_y: 20.0,
322    ///     plot_width: 600.0, plot_height: 400.0,
323    ///     data_x_min: 0.0, data_x_max: 100.0,
324    ///     data_y_min: 0.0, data_y_max: 50.0,
325    /// });
326    /// point_renderer.render_transformed(pass, &transform);
327    /// ```
328    pub fn render_transformed(&self, pass: &mut wgpu::RenderPass, transform: &DataTransform) {
329        self.render_with_uniform(pass, transform.uniform());
330    }
331
332    /// Render points with a data-to-screen transformation.
333    ///
334    /// **Deprecated:** Use [`render_transformed`](Self::render_transformed) with a
335    /// [`DataTransform`] instead for a cleaner API.
336    ///
337    /// This is the fast path for charts: data doesn't change on pan/zoom,
338    /// only the transform does.
339    pub fn render_with_data_transform(
340        &self,
341        pass: &mut wgpu::RenderPass,
342        viewport: Viewport,
343        plot_x: f32,
344        plot_y: f32,
345        plot_width: f32,
346        plot_height: f32,
347        data_x_min: f64,
348        data_x_max: f64,
349        data_y_min: f64,
350        data_y_max: f64,
351    ) {
352        let transform = DataTransform::from_data_range(
353            viewport,
354            crate::transform::DataRangeParams::new(
355                plot_x,
356                plot_y,
357                plot_width,
358                plot_height,
359                data_x_min,
360                data_x_max,
361                data_y_min,
362                data_y_max,
363            ),
364        );
365        self.render_transformed(pass, &transform);
366    }
367
368    /// Render with a specific transform uniform.
369    fn render_with_uniform(&self, pass: &mut wgpu::RenderPass, transform: &TransformUniform) {
370        profile_scope!("point_renderer_render");
371
372        if self.instance_count == 0 {
373            return;
374        }
375
376        let Some(instance_buffer) = &self.instance_buffer else {
377            return;
378        };
379
380        // Upload transform
381        self.context.queue().write_buffer(
382            &self.transform_buffer,
383            0,
384            bytemuck::cast_slice(&[*transform]),
385        );
386
387        // Draw
388        pass.push_debug_group("PointRenderer::render");
389        pass.set_pipeline(&self.pipeline);
390        pass.set_bind_group(0, &self.transform_bind_group, &[]);
391        pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
392        pass.set_vertex_buffer(1, instance_buffer.slice(..));
393        pass.draw(0..4, 0..self.instance_count);
394        pass.pop_debug_group();
395    }
396}
397
398/// WGSL shader for points with circle rendering and data coordinate transformation.
399const POINT_SHADER: &str = r#"
400struct Transform {
401    projection: mat4x4<f32>,
402    scale: vec2<f32>,
403    offset: vec2<f32>,
404}
405
406@group(0) @binding(0)
407var<uniform> transform: Transform;
408
409struct VertexInput {
410    @location(0) quad_pos: vec2<f32>,
411    @location(1) point_position: vec2<f32>,
412    @location(2) point_size: f32,
413    @location(3) color: vec4<f32>,
414}
415
416struct VertexOutput {
417    @builtin(position) position: vec4<f32>,
418    @location(0) color: vec4<f32>,
419    @location(1) uv: vec2<f32>,
420}
421
422@vertex
423fn vs_main(input: VertexInput) -> VertexOutput {
424    var output: VertexOutput;
425
426    // Transform data coordinates to screen coordinates
427    let screen_pos = input.point_position * transform.scale + transform.offset;
428
429    // Offset quad by point size (in screen pixels)
430    let world_pos = screen_pos + input.quad_pos * input.point_size;
431
432    output.position = transform.projection * vec4<f32>(world_pos, 0.0, 1.0);
433    output.color = input.color;
434    output.uv = input.quad_pos + 0.5; // UV from 0 to 1
435
436    return output;
437}
438
439@fragment
440fn fs_main(input: VertexOutput) -> @location(0) vec4<f32> {
441    // Render as circle: distance from center
442    let center = vec2<f32>(0.5, 0.5);
443    let dist = distance(input.uv, center);
444
445    // Smooth edge for anti-aliasing
446    let alpha = 1.0 - smoothstep(0.4, 0.5, dist);
447
448    if alpha < 0.01 {
449        discard;
450    }
451
452    return vec4<f32>(input.color.rgb, input.color.a * alpha);
453}
454"#;