gpui_component/chart/
line_chart.rs

1use std::rc::Rc;
2
3use gpui::{px, App, Bounds, Hsla, Pixels, SharedString, TextAlign, Window};
4use gpui_component_macros::IntoPlot;
5use num_traits::{Num, ToPrimitive};
6
7use crate::{
8    plot::{
9        scale::{Scale, ScaleLinear, ScalePoint, Sealed},
10        shape::Line,
11        Axis, AxisText, Grid, Plot, StrokeStyle, AXIS_GAP,
12    },
13    ActiveTheme, PixelsExt,
14};
15
16#[derive(IntoPlot)]
17pub struct LineChart<T, X, Y>
18where
19    T: 'static,
20    X: PartialEq + Into<SharedString> + 'static,
21    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
22{
23    data: Vec<T>,
24    x: Option<Rc<dyn Fn(&T) -> X>>,
25    y: Option<Rc<dyn Fn(&T) -> Y>>,
26    stroke: Option<Hsla>,
27    stroke_style: StrokeStyle,
28    dot: bool,
29    tick_margin: usize,
30}
31
32impl<T, X, Y> LineChart<T, X, Y>
33where
34    X: PartialEq + Into<SharedString> + 'static,
35    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
36{
37    pub fn new<I>(data: I) -> Self
38    where
39        I: IntoIterator<Item = T>,
40    {
41        Self {
42            data: data.into_iter().collect(),
43            stroke: None,
44            stroke_style: Default::default(),
45            dot: false,
46            x: None,
47            y: None,
48            tick_margin: 1,
49        }
50    }
51
52    pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
53        self.x = Some(Rc::new(x));
54        self
55    }
56
57    pub fn y(mut self, y: impl Fn(&T) -> Y + 'static) -> Self {
58        self.y = Some(Rc::new(y));
59        self
60    }
61
62    pub fn natural(mut self) -> Self {
63        self.stroke_style = StrokeStyle::Natural;
64        self
65    }
66
67    pub fn linear(mut self) -> Self {
68        self.stroke_style = StrokeStyle::Linear;
69        self
70    }
71
72    pub fn step_after(mut self) -> Self {
73        self.stroke_style = StrokeStyle::StepAfter;
74        self
75    }
76
77    pub fn dot(mut self) -> Self {
78        self.dot = true;
79        self
80    }
81
82    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
83        self.tick_margin = tick_margin;
84        self
85    }
86}
87
88impl<T, X, Y> Plot for LineChart<T, X, Y>
89where
90    X: PartialEq + Into<SharedString> + 'static,
91    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
92{
93    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
94        let (Some(x_fn), Some(y_fn)) = (self.x.as_ref(), self.y.as_ref()) else {
95            return;
96        };
97
98        let width = bounds.size.width.as_f32();
99        let height = bounds.size.height.as_f32() - AXIS_GAP;
100
101        // X scale
102        let x = ScalePoint::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width]);
103
104        // Y scale, ensure start from 0.
105        let y = ScaleLinear::new(
106            self.data
107                .iter()
108                .map(|v| y_fn(v))
109                .chain(Some(Y::zero()))
110                .collect(),
111            vec![height, 10.],
112        );
113
114        // Draw X axis
115        let data_len = self.data.len();
116        let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
117            if (i + 1) % self.tick_margin == 0 {
118                x.tick(&x_fn(d)).map(|x_tick| {
119                    let align = match i {
120                        0 => {
121                            if data_len == 1 {
122                                TextAlign::Center
123                            } else {
124                                TextAlign::Left
125                            }
126                        }
127                        i if i == data_len - 1 => TextAlign::Right,
128                        _ => TextAlign::Center,
129                    };
130                    AxisText::new(x_fn(d).into(), x_tick, cx.theme().muted_foreground).align(align)
131                })
132            } else {
133                None
134            }
135        });
136
137        Axis::new()
138            .x(height)
139            .x_label(x_label)
140            .stroke(cx.theme().border)
141            .paint(&bounds, window, cx);
142
143        // Draw grid
144        Grid::new()
145            .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
146            .stroke(cx.theme().border)
147            .dash_array(&[px(4.), px(2.)])
148            .paint(&bounds, window);
149
150        // Draw line
151        let stroke = self.stroke.unwrap_or(cx.theme().chart_2);
152        let x_fn = x_fn.clone();
153        let y_fn = y_fn.clone();
154        let mut line = Line::new()
155            .data(&self.data)
156            .x(move |d| x.tick(&x_fn(d)))
157            .y(move |d| y.tick(&y_fn(d)))
158            .stroke(stroke)
159            .stroke_style(self.stroke_style)
160            .stroke_width(2.);
161
162        if self.dot {
163            line = line.dot().dot_size(8.).dot_fill_color(stroke);
164        }
165
166        line.paint(&bounds, window);
167    }
168}