Skip to main content

rgpui_component/chart/
line_chart.rs

1use std::rc::Rc;
2
3use num_traits::{Num, ToPrimitive};
4use rgpui::{App, Bounds, Hsla, Pixels, SharedString, Window, px};
5use rgpui_component_macros::IntoPlot;
6
7use crate::{
8    ActiveTheme,
9    plot::{
10        AXIS_GAP, Grid, Plot, PlotAxis, StrokeStyle,
11        scale::{Scale, ScaleLinear, ScalePoint, Sealed},
12        shape::Line,
13    },
14};
15
16use super::build_point_x_labels;
17
18#[derive(IntoPlot)]
19pub struct LineChart<T, X, Y>
20where
21    T: 'static,
22    X: PartialEq + Into<SharedString> + 'static,
23    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
24{
25    data: Vec<T>,
26    x: Option<Rc<dyn Fn(&T) -> X>>,
27    y: Option<Rc<dyn Fn(&T) -> Y>>,
28    stroke: Option<Hsla>,
29    stroke_style: StrokeStyle,
30    dot: bool,
31    tick_margin: usize,
32    x_axis: bool,
33    grid: bool,
34}
35
36impl<T, X, Y> LineChart<T, X, Y>
37where
38    X: PartialEq + Into<SharedString> + 'static,
39    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
40{
41    pub fn new<I>(data: I) -> Self
42    where
43        I: IntoIterator<Item = T>,
44    {
45        Self {
46            data: data.into_iter().collect(),
47            stroke: None,
48            stroke_style: Default::default(),
49            dot: false,
50            x: None,
51            y: None,
52            tick_margin: 1,
53            x_axis: true,
54            grid: true,
55        }
56    }
57
58    pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
59        self.x = Some(Rc::new(x));
60        self
61    }
62
63    pub fn y(mut self, y: impl Fn(&T) -> Y + 'static) -> Self {
64        self.y = Some(Rc::new(y));
65        self
66    }
67
68    pub fn stroke(mut self, stroke: impl Into<Hsla>) -> Self {
69        self.stroke = Some(stroke.into());
70        self
71    }
72
73    pub fn natural(mut self) -> Self {
74        self.stroke_style = StrokeStyle::Natural;
75        self
76    }
77
78    pub fn linear(mut self) -> Self {
79        self.stroke_style = StrokeStyle::Linear;
80        self
81    }
82
83    pub fn step_after(mut self) -> Self {
84        self.stroke_style = StrokeStyle::StepAfter;
85        self
86    }
87
88    pub fn dot(mut self) -> Self {
89        self.dot = true;
90        self
91    }
92
93    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
94        self.tick_margin = tick_margin;
95        self
96    }
97
98    /// Show or hide the x-axis line and labels.
99    ///
100    /// Default is true.
101    pub fn x_axis(mut self, x_axis: bool) -> Self {
102        self.x_axis = x_axis;
103        self
104    }
105
106    pub fn grid(mut self, grid: bool) -> Self {
107        self.grid = grid;
108        self
109    }
110}
111
112impl<T, X, Y> Plot for LineChart<T, X, Y>
113where
114    X: PartialEq + Into<SharedString> + 'static,
115    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
116{
117    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
118        let (Some(x_fn), Some(y_fn)) = (self.x.as_ref(), self.y.as_ref()) else {
119            return;
120        };
121
122        let width = bounds.size.width.as_f32();
123        let axis_gap = if self.x_axis { AXIS_GAP } else { 0. };
124        let height = bounds.size.height.as_f32() - axis_gap;
125
126        // X scale
127        let x = ScalePoint::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width]);
128
129        // Y scale, ensure start from 0.
130        let y = ScaleLinear::new(
131            self.data
132                .iter()
133                .map(|v| y_fn(v))
134                .chain(Some(Y::zero()))
135                .collect(),
136            vec![height, 10.],
137        );
138
139        // Draw X axis
140        let mut axis = PlotAxis::new().stroke(cx.theme().border);
141        if self.x_axis {
142            let labels = build_point_x_labels(
143                &self.data,
144                x_fn.as_ref(),
145                &x,
146                self.tick_margin,
147                cx.theme().muted_foreground,
148            );
149            axis = axis.x(height).x_label(labels);
150        }
151        axis.paint(&bounds, window, cx);
152
153        // Draw grid
154        if self.grid {
155            Grid::new()
156                .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
157                .stroke(cx.theme().border)
158                .dash_array(&[px(4.), px(2.)])
159                .paint(&bounds, window);
160        }
161
162        // Draw line
163        let stroke = self.stroke.unwrap_or(cx.theme().chart_2);
164        let x_fn = x_fn.clone();
165        let y_fn = y_fn.clone();
166        let mut line = Line::new()
167            .data(&self.data)
168            .x(move |d| x.tick(&x_fn(d)))
169            .y(move |d| y.tick(&y_fn(d)))
170            .stroke(stroke)
171            .stroke_style(self.stroke_style)
172            .stroke_width(2.);
173
174        if self.dot {
175            line = line.dot().dot_size(8.).dot_fill_color(stroke);
176        }
177
178        line.paint(&bounds, window);
179    }
180}