Skip to main content

astrelis_render/
point_renderer.rs

1//! Fast instanced point renderer with GPU-based coordinate transformation.
2//!
3//! Renders thousands of points efficiently using GPU instancing.
4//! Points are stored in data coordinates, and the GPU transforms
5//! them to screen coordinates using a transformation matrix.
6//!
7//! This means pan/zoom only updates a small uniform buffer, not all point data.
8
9use astrelis_core::profiling::profile_scope;
10use crate::capability::{GpuRequirements, RenderCapability};
11use crate::transform::{DataTransform, TransformUniform};
12use crate::{Color, GraphicsContext, Viewport};
13use bytemuck::{Pod, Zeroable};
14use glam::Vec2;
15use std::sync::Arc;
16use wgpu::util::DeviceExt;
17
18/// A point for batch rendering.
19///
20/// Coordinates can be in any space - use `render_with_data_transform()` to map
21/// them to screen coordinates.
22#[derive(Debug, Clone, Copy)]
23pub struct Point {
24    pub position: Vec2,
25    pub size: f32,
26    pub color: Color,
27}
28
29impl Point {
30    pub fn new(position: Vec2, size: f32, color: Color) -> Self {
31        Self { position, size, color }
32    }
33}
34
35/// GPU instance data for a point.
36#[repr(C)]
37#[derive(Debug, Clone, Copy, Pod, Zeroable)]
38struct PointInstance {
39    position: [f32; 2],
40    size: f32,
41    color: [f32; 4],
42    _padding: f32,
43}
44
45impl PointInstance {
46    fn new(point: &Point) -> Self {
47        Self {
48            position: [point.position.x, point.position.y],
49            size: point.size,
50            color: [point.color.r, point.color.g, point.color.b, point.color.a],
51            _padding: 0.0,
52        }
53    }
54}
55
56impl RenderCapability for PointRenderer {
57    fn requirements() -> GpuRequirements {
58        GpuRequirements::none()
59    }
60
61    fn name() -> &'static str {
62        "PointRenderer"
63    }
64}
65
66/// Fast batched point renderer using GPU instancing.
67///
68/// Optimized for scatter charts with large datasets. Key features:
69/// - Points stored in data coordinates
70/// - GPU transforms data → screen (pan/zoom is cheap)
71/// - Only rebuild instance buffer when data actually changes
72pub struct PointRenderer {
73    context: Arc<GraphicsContext>,
74    pipeline: wgpu::RenderPipeline,
75    vertex_buffer: wgpu::Buffer,
76    transform_buffer: wgpu::Buffer,
77    transform_bind_group: wgpu::BindGroup,
78    instance_buffer: Option<wgpu::Buffer>,
79    instance_count: u32,
80    /// Pending points
81    pending_points: Vec<Point>,
82    /// Whether points need to be re-uploaded
83    data_dirty: bool,
84}
85
86impl PointRenderer {
87    /// Create a new point renderer with the given target texture format.
88    ///
89    /// The `target_format` must match the render target this renderer will draw into.
90    /// For window surfaces, use the format from `WindowContext::format()`.
91    pub fn new(context: Arc<GraphicsContext>, target_format: wgpu::TextureFormat) -> Self {
92        // Create transform uniform buffer
93        let transform_buffer = context.device().create_buffer(&wgpu::BufferDescriptor {
94            label: Some("Point Renderer Transform Buffer"),
95            size: std::mem::size_of::<TransformUniform>() as u64,
96            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
97            mapped_at_creation: false,
98        });
99
100        // Bind group layout
101        let bind_group_layout =
102            context
103                .device()
104                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
105                    label: Some("Point Renderer Bind Group Layout"),
106                    entries: &[wgpu::BindGroupLayoutEntry {
107                        binding: 0,
108                        visibility: wgpu::ShaderStages::VERTEX,
109                        ty: wgpu::BindingType::Buffer {
110                            ty: wgpu::BufferBindingType::Uniform,
111                            has_dynamic_offset: false,
112                            min_binding_size: None,
113                        },
114                        count: None,
115                    }],
116                });
117
118        let transform_bind_group = context.device().create_bind_group(&wgpu::BindGroupDescriptor {
119            label: Some("Point Renderer Transform Bind Group"),
120            layout: &bind_group_layout,
121            entries: &[wgpu::BindGroupEntry {
122                binding: 0,
123                resource: transform_buffer.as_entire_binding(),
124            }],
125        });
126
127        // Shader
128        let shader = context
129            .device()
130            .create_shader_module(wgpu::ShaderModuleDescriptor {
131                label: Some("Point Renderer Shader"),
132                source: wgpu::ShaderSource::Wgsl(POINT_SHADER.into()),
133            });
134
135        // Pipeline
136        let pipeline_layout =
137            context
138                .device()
139                .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
140                    label: Some("Point Renderer Pipeline Layout"),
141                    bind_group_layouts: &[&bind_group_layout],
142                    push_constant_ranges: &[],
143                });
144
145        let pipeline = context
146            .device()
147            .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
148                label: Some("Point Renderer Pipeline"),
149                layout: Some(&pipeline_layout),
150                vertex: wgpu::VertexState {
151                    module: &shader,
152                    entry_point: Some("vs_main"),
153                    buffers: &[
154                        // Unit quad vertices
155                        wgpu::VertexBufferLayout {
156                            array_stride: 8,
157                            step_mode: wgpu::VertexStepMode::Vertex,
158                            attributes: &[wgpu::VertexAttribute {
159                                format: wgpu::VertexFormat::Float32x2,
160                                offset: 0,
161                                shader_location: 0,
162                            }],
163                        },
164                        // Point instances
165                        wgpu::VertexBufferLayout {
166                            array_stride: std::mem::size_of::<PointInstance>() as u64,
167                            step_mode: wgpu::VertexStepMode::Instance,
168                            attributes: &[
169                                wgpu::VertexAttribute {
170                                    format: wgpu::VertexFormat::Float32x2,
171                                    offset: 0,
172                                    shader_location: 1,
173                                },
174                                wgpu::VertexAttribute {
175                                    format: wgpu::VertexFormat::Float32,
176                                    offset: 8,
177                                    shader_location: 2,
178                                },
179                                wgpu::VertexAttribute {
180                                    format: wgpu::VertexFormat::Float32x4,
181                                    offset: 12,
182                                    shader_location: 3,
183                                },
184                            ],
185                        },
186                    ],
187                    compilation_options: wgpu::PipelineCompilationOptions::default(),
188                },
189                fragment: Some(wgpu::FragmentState {
190                    module: &shader,
191                    entry_point: Some("fs_main"),
192                    targets: &[Some(wgpu::ColorTargetState {
193                        format: target_format,
194                        blend: Some(wgpu::BlendState::ALPHA_BLENDING),
195                        write_mask: wgpu::ColorWrites::ALL,
196                    })],
197                    compilation_options: wgpu::PipelineCompilationOptions::default(),
198                }),
199                primitive: wgpu::PrimitiveState {
200                    topology: wgpu::PrimitiveTopology::TriangleStrip,
201                    cull_mode: None,
202                    ..Default::default()
203                },
204                depth_stencil: None,
205                multisample: wgpu::MultisampleState::default(),
206                multiview: None,
207                cache: None,
208            });
209
210        // Unit quad (for rendering circles as billboards)
211        let quad_vertices: [[f32; 2]; 4] = [
212            [-0.5, -0.5],
213            [0.5, -0.5],
214            [-0.5, 0.5],
215            [0.5, 0.5],
216        ];
217
218        let vertex_buffer = context
219            .device()
220            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
221                label: Some("Point Renderer Vertex Buffer"),
222                contents: bytemuck::cast_slice(&quad_vertices),
223                usage: wgpu::BufferUsages::VERTEX,
224            });
225
226        Self {
227            context,
228            pipeline,
229            vertex_buffer,
230            transform_buffer,
231            transform_bind_group,
232            instance_buffer: None,
233            instance_count: 0,
234            pending_points: Vec::with_capacity(1024),
235            data_dirty: false,
236        }
237    }
238
239    /// Clear all points. Call this when data changes.
240    pub fn clear(&mut self) {
241        self.pending_points.clear();
242        self.data_dirty = true;
243    }
244
245    /// Add a point.
246    #[inline]
247    pub fn add_point(&mut self, position: Vec2, size: f32, color: Color) {
248        self.pending_points.push(Point::new(position, size, color));
249        self.data_dirty = true;
250    }
251
252    /// Add a point.
253    #[inline]
254    pub fn add(&mut self, point: Point) {
255        self.pending_points.push(point);
256        self.data_dirty = true;
257    }
258
259    /// Get the number of points.
260    pub fn point_count(&self) -> usize {
261        self.pending_points.len()
262    }
263
264    /// Prepare GPU buffers. Only uploads data if it changed.
265    pub fn prepare(&mut self) {
266        profile_scope!("point_renderer_prepare");
267
268        if !self.data_dirty {
269            return; // No data change, skip upload
270        }
271
272        if self.pending_points.is_empty() {
273            self.instance_buffer = None;
274            self.instance_count = 0;
275            self.data_dirty = false;
276            return;
277        }
278
279        tracing::trace!("Uploading {} points to GPU", self.pending_points.len());
280
281        // Convert to GPU format
282        let instances: Vec<PointInstance> = {
283            profile_scope!("convert_instances");
284            self.pending_points.iter().map(PointInstance::new).collect()
285        };
286
287        // Create buffer
288        {
289            profile_scope!("create_instance_buffer");
290            self.instance_buffer = Some(
291                self.context
292                    .device()
293                    .create_buffer_init(&wgpu::util::BufferInitDescriptor {
294                        label: Some("Point Renderer Instance Buffer"),
295                        contents: bytemuck::cast_slice(&instances),
296                        usage: wgpu::BufferUsages::VERTEX,
297                    }),
298            );
299        }
300
301        self.instance_count = self.pending_points.len() as u32;
302        self.data_dirty = false;
303    }
304
305    /// Render points with identity transform (data coords = screen coords).
306    pub fn render(&self, pass: &mut wgpu::RenderPass, viewport: Viewport) {
307        let transform = DataTransform::identity(viewport);
308        self.render_transformed(pass, &transform);
309    }
310
311    /// Render points with a [`DataTransform`].
312    ///
313    /// This is the preferred method for rendering with data-to-screen mapping.
314    /// The transform is cheap to update (32 bytes), so pan/zoom only updates
315    /// the transform, not the point data.
316    ///
317    /// # Example
318    ///
319    /// ```ignore
320    /// let transform = DataTransform::from_data_range(viewport, DataRangeParams {
321    ///     plot_x: 80.0, plot_y: 20.0,
322    ///     plot_width: 600.0, plot_height: 400.0,
323    ///     data_x_min: 0.0, data_x_max: 100.0,
324    ///     data_y_min: 0.0, data_y_max: 50.0,
325    /// });
326    /// point_renderer.render_transformed(pass, &transform);
327    /// ```
328    pub fn render_transformed(&self, pass: &mut wgpu::RenderPass, transform: &DataTransform) {
329        self.render_with_uniform(pass, transform.uniform());
330    }
331
332    /// Render points with a data-to-screen transformation.
333    ///
334    /// **Deprecated:** Use [`render_transformed`](Self::render_transformed) with a
335    /// [`DataTransform`] instead for a cleaner API.
336    ///
337    /// This is the fast path for charts: data doesn't change on pan/zoom,
338    /// only the transform does.
339    pub fn render_with_data_transform(
340        &self,
341        pass: &mut wgpu::RenderPass,
342        viewport: Viewport,
343        plot_x: f32,
344        plot_y: f32,
345        plot_width: f32,
346        plot_height: f32,
347        data_x_min: f64,
348        data_x_max: f64,
349        data_y_min: f64,
350        data_y_max: f64,
351    ) {
352        let transform = DataTransform::from_data_range(
353            viewport,
354            crate::transform::DataRangeParams::new(
355                plot_x,
356                plot_y,
357                plot_width,
358                plot_height,
359                data_x_min,
360                data_x_max,
361                data_y_min,
362                data_y_max,
363            ),
364        );
365        self.render_transformed(pass, &transform);
366    }
367
368    /// Render with a specific transform uniform.
369    fn render_with_uniform(&self, pass: &mut wgpu::RenderPass, transform: &TransformUniform) {
370        profile_scope!("point_renderer_render");
371
372        if self.instance_count == 0 {
373            return;
374        }
375
376        let Some(instance_buffer) = &self.instance_buffer else {
377            return;
378        };
379
380        // Upload transform
381        self.context.queue().write_buffer(
382            &self.transform_buffer,
383            0,
384            bytemuck::cast_slice(&[*transform]),
385        );
386
387        // Draw
388        pass.push_debug_group("PointRenderer::render");
389        pass.set_pipeline(&self.pipeline);
390        pass.set_bind_group(0, &self.transform_bind_group, &[]);
391        pass.set_vertex_buffer(0, self.vertex_buffer.slice(..));
392        pass.set_vertex_buffer(1, instance_buffer.slice(..));
393        pass.draw(0..4, 0..self.instance_count);
394        pass.pop_debug_group();
395    }
396}
397
398/// WGSL shader for points with circle rendering and data coordinate transformation.
399const POINT_SHADER: &str = r#"
400struct Transform {
401    projection: mat4x4<f32>,
402    scale: vec2<f32>,
403    offset: vec2<f32>,
404}
405
406@group(0) @binding(0)
407var<uniform> transform: Transform;
408
409struct VertexInput {
410    @location(0) quad_pos: vec2<f32>,
411    @location(1) point_position: vec2<f32>,
412    @location(2) point_size: f32,
413    @location(3) color: vec4<f32>,
414}
415
416struct VertexOutput {
417    @builtin(position) position: vec4<f32>,
418    @location(0) color: vec4<f32>,
419    @location(1) uv: vec2<f32>,
420}
421
422@vertex
423fn vs_main(input: VertexInput) -> VertexOutput {
424    var output: VertexOutput;
425
426    // Transform data coordinates to screen coordinates
427    let screen_pos = input.point_position * transform.scale + transform.offset;
428
429    // Offset quad by point size (in screen pixels)
430    let world_pos = screen_pos + input.quad_pos * input.point_size;
431
432    output.position = transform.projection * vec4<f32>(world_pos, 0.0, 1.0);
433    output.color = input.color;
434    output.uv = input.quad_pos + 0.5; // UV from 0 to 1
435
436    return output;
437}
438
439@fragment
440fn fs_main(input: VertexOutput) -> @location(0) vec4<f32> {
441    // Render as circle: distance from center
442    let center = vec2<f32>(0.5, 0.5);
443    let dist = distance(input.uv, center);
444
445    // Smooth edge for anti-aliasing
446    let alpha = 1.0 - smoothstep(0.4, 0.5, dist);
447
448    if alpha < 0.01 {
449        discard;
450    }
451
452    return vec4<f32>(input.color.rgb, input.color.a * alpha);
453}
454"#;