gpui_component/chart/
bar_chart.rs

1use std::rc::Rc;
2
3use gpui::{px, App, Bounds, Hsla, Pixels, SharedString, TextAlign, Window};
4use gpui_component_macros::IntoPlot;
5use num_traits::{Num, ToPrimitive};
6
7use crate::{
8    plot::{
9        label::Text,
10        scale::{Scale, ScaleBand, ScaleLinear, Sealed},
11        shape::Bar,
12        Axis, AxisText, Grid, Plot, AXIS_GAP,
13    },
14    ActiveTheme, PixelsExt,
15};
16
17#[derive(IntoPlot)]
18pub struct BarChart<T, X, Y>
19where
20    T: 'static,
21    X: PartialEq + Into<SharedString> + 'static,
22    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
23{
24    data: Vec<T>,
25    x: Option<Rc<dyn Fn(&T) -> X>>,
26    y: Option<Rc<dyn Fn(&T) -> Y>>,
27    fill: Option<Rc<dyn Fn(&T) -> Hsla>>,
28    tick_margin: usize,
29    label: Option<Rc<dyn Fn(&T) -> SharedString>>,
30}
31
32impl<T, X, Y> BarChart<T, X, Y>
33where
34    X: PartialEq + Into<SharedString> + 'static,
35    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
36{
37    pub fn new<I>(data: I) -> Self
38    where
39        I: IntoIterator<Item = T>,
40    {
41        Self {
42            data: data.into_iter().collect(),
43            x: None,
44            y: None,
45            fill: None,
46            tick_margin: 1,
47            label: None,
48        }
49    }
50
51    pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
52        self.x = Some(Rc::new(x));
53        self
54    }
55
56    pub fn y(mut self, y: impl Fn(&T) -> Y + 'static) -> Self {
57        self.y = Some(Rc::new(y));
58        self
59    }
60
61    pub fn fill<H>(mut self, fill: impl Fn(&T) -> H + 'static) -> Self
62    where
63        H: Into<Hsla> + 'static,
64    {
65        self.fill = Some(Rc::new(move |t| fill(t).into()));
66        self
67    }
68
69    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
70        self.tick_margin = tick_margin;
71        self
72    }
73
74    pub fn label<S>(mut self, label: impl Fn(&T) -> S + 'static) -> Self
75    where
76        S: Into<SharedString> + 'static,
77    {
78        self.label = Some(Rc::new(move |t| label(t).into()));
79        self
80    }
81}
82
83impl<T, X, Y> Plot for BarChart<T, X, Y>
84where
85    X: PartialEq + Into<SharedString> + 'static,
86    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
87{
88    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
89        let (Some(x_fn), Some(y_fn)) = (self.x.as_ref(), self.y.as_ref()) else {
90            return;
91        };
92
93        let width = bounds.size.width.as_f32();
94        let height = bounds.size.height.as_f32() - AXIS_GAP;
95
96        // X scale
97        let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
98            .padding_inner(0.4)
99            .padding_outer(0.2);
100        let band_width = x.band_width();
101
102        // Y scale, ensure start from 0.
103        let y = ScaleLinear::new(
104            self.data
105                .iter()
106                .map(|v| y_fn(v))
107                .chain(Some(Y::zero()))
108                .collect(),
109            vec![height, 10.],
110        );
111
112        // Draw X axis
113        let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
114            if (i + 1) % self.tick_margin == 0 {
115                x.tick(&x_fn(d)).map(|x_tick| {
116                    AxisText::new(
117                        x_fn(d).into(),
118                        x_tick + band_width / 2.,
119                        cx.theme().muted_foreground,
120                    )
121                    .align(TextAlign::Center)
122                })
123            } else {
124                None
125            }
126        });
127
128        Axis::new()
129            .x(height)
130            .x_label(x_label)
131            .stroke(cx.theme().border)
132            .paint(&bounds, window, cx);
133
134        // Draw grid
135        Grid::new()
136            .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
137            .stroke(cx.theme().border)
138            .dash_array(&[px(4.), px(2.)])
139            .paint(&bounds, window);
140
141        // Draw bars
142        let x_fn = x_fn.clone();
143        let y_fn = y_fn.clone();
144        let default_fill = cx.theme().chart_2;
145        let fill = self.fill.clone();
146        let label_color = cx.theme().foreground;
147        let mut bar = Bar::new()
148            .data(&self.data)
149            .band_width(band_width)
150            .x(move |d| x.tick(&x_fn(d)))
151            .y0(height)
152            .y1(move |d| y.tick(&y_fn(d)))
153            .fill(move |d| fill.as_ref().map(|f| f(d)).unwrap_or(default_fill));
154
155        if let Some(label) = self.label.as_ref() {
156            let label = label.clone();
157            bar = bar.label(move |d, p| vec![Text::new(label(d), p, label_color)]);
158        }
159
160        bar.paint(&bounds, window, cx);
161    }
162}