Skip to main content

astrelis_geometry/chart/
gpu.rs

1//! GPU state management for chart rendering.
2//!
3//! This module provides efficient GPU buffer management for charts:
4//! - Instance buffers for line segments
5//! - Partial buffer updates for streaming data
6//! - Transform uniforms for pan/zoom
7//!
8//! # Architecture
9//!
10//! ```text
11//! ChartGpuState
12//!   ├── ChartTransform (uniform buffer)
13//!   │     └── view_matrix, data_range, etc.
14//!   ├── SeriesGpuBuffers (per series)
15//!   │     ├── line_vertices
16//!   │     ├── marker_vertices
17//!   │     └── dirty_ranges
18//!   └── grid_vertices
19//! ```
20
21use super::cache::ChartDirtyFlags;
22use super::rect::Rect;
23use super::streaming::PrepareResult;
24use super::types::{Chart, DataPoint};
25use astrelis_render::{Color, GraphicsContext, wgpu};
26use bytemuck::{Pod, Zeroable};
27use glam::{Mat4, Vec2};
28use std::sync::Arc;
29
30/// GPU vertex for line rendering.
31#[repr(C)]
32#[derive(Debug, Clone, Copy, Pod, Zeroable)]
33pub struct LineVertex {
34    /// Position (x, y)
35    pub position: [f32; 2],
36    /// Color (r, g, b, a)
37    pub color: [f32; 4],
38    /// Line direction for thickness calculation
39    pub direction: [f32; 2],
40    /// Thickness
41    pub thickness: f32,
42    /// Padding for alignment
43    _padding: f32,
44}
45
46impl LineVertex {
47    /// Create a new line vertex.
48    pub fn new(position: Vec2, color: Color, direction: Vec2, thickness: f32) -> Self {
49        Self {
50            position: position.to_array(),
51            color: [color.r, color.g, color.b, color.a],
52            direction: direction.to_array(),
53            thickness,
54            _padding: 0.0,
55        }
56    }
57}
58
59/// GPU vertex for marker rendering.
60#[repr(C)]
61#[derive(Debug, Clone, Copy, Pod, Zeroable)]
62pub struct MarkerVertex {
63    /// Position (x, y)
64    pub position: [f32; 2],
65    /// Color (r, g, b, a)
66    pub color: [f32; 4],
67    /// Size
68    pub size: f32,
69    /// Shape type (0 = circle, 1 = square, etc.)
70    pub shape: u32,
71}
72
73impl MarkerVertex {
74    /// Create a new marker vertex.
75    pub fn new(position: Vec2, color: Color, size: f32, shape: u32) -> Self {
76        Self {
77            position: position.to_array(),
78            color: [color.r, color.g, color.b, color.a],
79            size,
80            shape,
81        }
82    }
83}
84
85/// Chart transform uniform data.
86#[repr(C)]
87#[derive(Debug, Clone, Copy, Pod, Zeroable)]
88pub struct ChartTransform {
89    /// View-projection matrix
90    pub view_proj: [[f32; 4]; 4],
91    /// Data range (x_min, x_max, y_min, y_max)
92    pub data_range: [f32; 4],
93    /// Plot area in screen coords (x, y, width, height)
94    pub plot_area: [f32; 4],
95    /// Viewport size
96    pub viewport_size: [f32; 2],
97    /// Padding
98    _padding: [f32; 2],
99}
100
101impl Default for ChartTransform {
102    fn default() -> Self {
103        Self {
104            view_proj: Mat4::IDENTITY.to_cols_array_2d(),
105            data_range: [0.0, 1.0, 0.0, 1.0],
106            plot_area: [0.0, 0.0, 1.0, 1.0],
107            viewport_size: [800.0, 600.0],
108            _padding: [0.0; 2],
109        }
110    }
111}
112
113impl ChartTransform {
114    /// Create a transform from chart state.
115    pub fn from_chart(chart: &Chart, bounds: &Rect, viewport_size: Vec2) -> Self {
116        let (x_min, x_max) = chart.x_range();
117        let (y_min, y_max) = chart.y_range();
118
119        // Create orthographic projection
120        let view_proj =
121            Mat4::orthographic_rh(0.0, viewport_size.x, viewport_size.y, 0.0, -1.0, 1.0);
122
123        Self {
124            view_proj: view_proj.to_cols_array_2d(),
125            data_range: [x_min as f32, x_max as f32, y_min as f32, y_max as f32],
126            plot_area: [bounds.x, bounds.y, bounds.width, bounds.height],
127            viewport_size: viewport_size.to_array(),
128            _padding: [0.0; 2],
129        }
130    }
131
132    /// Convert data coordinates to normalized coordinates.
133    pub fn data_to_normalized(&self, x: f64, y: f64) -> Vec2 {
134        let x_norm = ((x as f32 - self.data_range[0]) / (self.data_range[1] - self.data_range[0]))
135            .clamp(0.0, 1.0);
136        let y_norm = ((y as f32 - self.data_range[2]) / (self.data_range[3] - self.data_range[2]))
137            .clamp(0.0, 1.0);
138        Vec2::new(x_norm, 1.0 - y_norm) // Flip Y for screen coords
139    }
140
141    /// Convert normalized coordinates to screen coordinates.
142    pub fn normalized_to_screen(&self, normalized: Vec2) -> Vec2 {
143        Vec2::new(
144            self.plot_area[0] + normalized.x * self.plot_area[2],
145            self.plot_area[1] + normalized.y * self.plot_area[3],
146        )
147    }
148
149    /// Convert data coordinates to screen coordinates.
150    pub fn data_to_screen(&self, x: f64, y: f64) -> Vec2 {
151        let normalized = self.data_to_normalized(x, y);
152        self.normalized_to_screen(normalized)
153    }
154}
155
156/// Range of dirty data that needs GPU upload.
157#[derive(Debug, Clone, Copy, Default)]
158pub struct DirtyRange {
159    /// Start index (inclusive)
160    pub start: usize,
161    /// End index (exclusive)
162    pub end: usize,
163}
164
165impl DirtyRange {
166    /// Create a new dirty range.
167    pub fn new(start: usize, end: usize) -> Self {
168        Self { start, end }
169    }
170
171    /// Create a range covering everything.
172    pub fn all(len: usize) -> Self {
173        Self { start: 0, end: len }
174    }
175
176    /// Check if this range is empty.
177    pub fn is_empty(&self) -> bool {
178        self.start >= self.end
179    }
180
181    /// Get the number of elements in the range.
182    pub fn len(&self) -> usize {
183        self.end.saturating_sub(self.start)
184    }
185
186    /// Merge with another range.
187    pub fn merge(&self, other: &DirtyRange) -> DirtyRange {
188        if self.is_empty() {
189            *other
190        } else if other.is_empty() {
191            *self
192        } else {
193            DirtyRange {
194                start: self.start.min(other.start),
195                end: self.end.max(other.end),
196            }
197        }
198    }
199}
200
201/// GPU buffers for a single series.
202#[derive(Debug, Default)]
203pub struct SeriesGpuBuffers {
204    /// Vertex buffer for line segments
205    pub line_buffer: Option<wgpu::Buffer>,
206    /// Current line vertex count
207    pub line_vertex_count: usize,
208    /// Vertex buffer for markers
209    pub marker_buffer: Option<wgpu::Buffer>,
210    /// Current marker vertex count
211    pub marker_vertex_count: usize,
212    /// Dirty range for partial updates
213    pub dirty_range: Option<DirtyRange>,
214    /// Data version when buffers were last updated
215    pub data_version: u64,
216}
217
218/// GPU state for chart rendering.
219///
220/// Manages all GPU resources needed for efficient chart rendering:
221/// - Transform uniform buffer
222/// - Per-series vertex buffers
223/// - Grid line buffers
224pub struct ChartGpuState {
225    /// Graphics context
226    context: Arc<GraphicsContext>,
227    /// Transform uniform buffer
228    transform_buffer: wgpu::Buffer,
229    /// Transform bind group layout
230    transform_bind_group_layout: wgpu::BindGroupLayout,
231    /// Transform bind group
232    transform_bind_group: wgpu::BindGroup,
233    /// Per-series GPU buffers
234    series_buffers: Vec<SeriesGpuBuffers>,
235    /// Current transform data
236    current_transform: ChartTransform,
237    /// Version counter for tracking changes
238    version: u64,
239}
240
241impl ChartGpuState {
242    /// Create a new GPU state.
243    pub fn new(context: Arc<GraphicsContext>) -> Self {
244        let device = context.device();
245
246        // Create transform uniform buffer
247        let transform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
248            label: Some("Chart Transform Buffer"),
249            size: std::mem::size_of::<ChartTransform>() as u64,
250            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
251            mapped_at_creation: false,
252        });
253
254        // Create bind group layout
255        let transform_bind_group_layout =
256            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
257                label: Some("Chart Transform Bind Group Layout"),
258                entries: &[wgpu::BindGroupLayoutEntry {
259                    binding: 0,
260                    visibility: wgpu::ShaderStages::VERTEX,
261                    ty: wgpu::BindingType::Buffer {
262                        ty: wgpu::BufferBindingType::Uniform,
263                        has_dynamic_offset: false,
264                        min_binding_size: None,
265                    },
266                    count: None,
267                }],
268            });
269
270        // Create bind group
271        let transform_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
272            label: Some("Chart Transform Bind Group"),
273            layout: &transform_bind_group_layout,
274            entries: &[wgpu::BindGroupEntry {
275                binding: 0,
276                resource: transform_buffer.as_entire_binding(),
277            }],
278        });
279
280        Self {
281            context,
282            transform_buffer,
283            transform_bind_group_layout,
284            transform_bind_group,
285            series_buffers: Vec::new(),
286            current_transform: ChartTransform::default(),
287            version: 0,
288        }
289    }
290
291    /// Update GPU state from chart and cache.
292    pub fn update(
293        &mut self,
294        chart: &Chart,
295        bounds: &Rect,
296        viewport_size: Vec2,
297        dirty_flags: ChartDirtyFlags,
298    ) {
299        // Update transform if view changed
300        if dirty_flags.intersects(
301            ChartDirtyFlags::VIEW_CHANGED
302                | ChartDirtyFlags::BOUNDS_CHANGED
303                | ChartDirtyFlags::AXES_CHANGED,
304        ) {
305            self.current_transform = ChartTransform::from_chart(chart, bounds, viewport_size);
306            let queue = self.context.queue();
307            queue.write_buffer(
308                &self.transform_buffer,
309                0,
310                bytemuck::bytes_of(&self.current_transform),
311            );
312        }
313
314        // Ensure we have enough series buffers
315        while self.series_buffers.len() < chart.series.len() {
316            self.series_buffers.push(SeriesGpuBuffers::default());
317        }
318
319        // Update series buffers
320        let needs_data_update =
321            dirty_flags.intersects(ChartDirtyFlags::DATA_CHANGED | ChartDirtyFlags::DATA_APPENDED);
322
323        if needs_data_update {
324            // Clone transform to avoid borrow issues
325            let transform = self.current_transform;
326            for (idx, series) in chart.series.iter().enumerate() {
327                self.update_series_buffer(idx, &series.data, &transform);
328            }
329        }
330
331        self.version = self.version.wrapping_add(1);
332    }
333
334    /// Update GPU state from streaming chart prepare result.
335    pub fn update_from_prepare_result(
336        &mut self,
337        chart: &Chart,
338        bounds: &Rect,
339        viewport_size: Vec2,
340        result: &PrepareResult,
341    ) {
342        // Always update transform
343        self.current_transform = ChartTransform::from_chart(chart, bounds, viewport_size);
344        {
345            let queue = self.context.queue();
346            queue.write_buffer(
347                &self.transform_buffer,
348                0,
349                bytemuck::bytes_of(&self.current_transform),
350            );
351        }
352
353        // Ensure we have enough series buffers
354        while self.series_buffers.len() < chart.series.len() {
355            self.series_buffers.push(SeriesGpuBuffers::default());
356        }
357
358        // Clone transform to avoid borrow issues
359        let transform = self.current_transform;
360
361        // Update only series that changed
362        for update in &result.series_updates {
363            if update.full_rebuild || update.new_points > 0 {
364                let series = &chart.series[update.index];
365                self.update_series_buffer(update.index, &series.data, &transform);
366            }
367        }
368
369        self.version = self.version.wrapping_add(1);
370    }
371
372    /// Update a single series buffer.
373    fn update_series_buffer(
374        &mut self,
375        series_idx: usize,
376        data: &[DataPoint],
377        transform: &ChartTransform,
378    ) {
379        if data.len() < 2 {
380            self.series_buffers[series_idx].line_vertex_count = 0;
381            return;
382        }
383
384        // Generate line vertices
385        let mut vertices = Vec::with_capacity(data.len() * 2);
386
387        for i in 0..data.len() - 1 {
388            let p0 = transform.data_to_screen(data[i].x, data[i].y);
389            let p1 = transform.data_to_screen(data[i + 1].x, data[i + 1].y);
390
391            let dir = (p1 - p0).normalize_or_zero();
392
393            vertices.push(LineVertex::new(p0, Color::WHITE, dir, 1.0));
394            vertices.push(LineVertex::new(p1, Color::WHITE, dir, 1.0));
395        }
396
397        let device = self.context.device();
398        let queue = self.context.queue();
399
400        let buffer_size = (vertices.len() * std::mem::size_of::<LineVertex>()) as u64;
401
402        // Check if we need to recreate the buffer
403        let needs_new_buffer = self.series_buffers[series_idx]
404            .line_buffer
405            .as_ref()
406            .is_none_or(|b| b.size() < buffer_size);
407
408        if needs_new_buffer {
409            // Create new buffer with some extra capacity
410            let capacity = (buffer_size * 3 / 2).max(4096);
411            let buffer = device.create_buffer(&wgpu::BufferDescriptor {
412                label: Some(&format!("Chart Series {} Line Buffer", series_idx)),
413                size: capacity,
414                usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
415                mapped_at_creation: false,
416            });
417            self.series_buffers[series_idx].line_buffer = Some(buffer);
418        }
419
420        // Write data
421        if let Some(buffer) = &self.series_buffers[series_idx].line_buffer {
422            queue.write_buffer(buffer, 0, bytemuck::cast_slice(&vertices));
423        }
424
425        self.series_buffers[series_idx].line_vertex_count = vertices.len();
426    }
427
428    /// Get the transform bind group.
429    pub fn transform_bind_group(&self) -> &wgpu::BindGroup {
430        &self.transform_bind_group
431    }
432
433    /// Get the transform bind group layout.
434    pub fn transform_bind_group_layout(&self) -> &wgpu::BindGroupLayout {
435        &self.transform_bind_group_layout
436    }
437
438    /// Get the current transform.
439    pub fn transform(&self) -> &ChartTransform {
440        &self.current_transform
441    }
442
443    /// Get series buffers.
444    pub fn series_buffers(&self) -> &[SeriesGpuBuffers] {
445        &self.series_buffers
446    }
447
448    /// Get the version counter.
449    pub fn version(&self) -> u64 {
450        self.version
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_chart_transform() {
460        let transform = ChartTransform {
461            data_range: [0.0, 100.0, 0.0, 100.0],
462            plot_area: [50.0, 50.0, 400.0, 300.0],
463            ..Default::default()
464        };
465
466        // Test data to normalized conversion
467        let norm = transform.data_to_normalized(50.0, 50.0);
468        assert!((norm.x - 0.5).abs() < 0.001);
469        assert!((norm.y - 0.5).abs() < 0.001);
470
471        // Test normalized to screen conversion
472        let screen = transform.normalized_to_screen(Vec2::new(0.5, 0.5));
473        assert!((screen.x - 250.0).abs() < 0.001);
474        assert!((screen.y - 200.0).abs() < 0.001);
475    }
476
477    #[test]
478    fn test_dirty_range() {
479        let r1 = DirtyRange::new(10, 20);
480        let r2 = DirtyRange::new(15, 30);
481
482        let merged = r1.merge(&r2);
483        assert_eq!(merged.start, 10);
484        assert_eq!(merged.end, 30);
485        assert_eq!(merged.len(), 20);
486    }
487}