Skip to main content

rgpui_component/chart/
bar_chart.rs

1use std::{ops::RangeInclusive, rc::Rc};
2
3use num_traits::{Num, ToPrimitive};
4use rgpui::{
5    App, Background, Bounds, Corners, Hsla, LinearColorStop, Pixels, Point, SharedString, Size,
6    TextAlign, Window, linear_gradient, px,
7};
8use rgpui_component_macros::IntoPlot;
9
10use crate::{
11    ActiveTheme,
12    plot::{
13        AXIS_GAP, AxisLabelSide, Grid, Plot, PlotAxis,
14        label::{TEXT_GAP, TEXT_SIZE, Text, measure_text_width},
15        scale::{Scale, ScaleBand, ScaleLinear, Sealed},
16        shape::{Bar, BarAlignment},
17    },
18};
19
20use super::build_band_labels;
21
22#[derive(IntoPlot)]
23pub struct BarChart<T, B, V>
24where
25    T: 'static,
26    B: PartialEq + Into<SharedString> + 'static,
27    V: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
28{
29    data: Vec<T>,
30    band: Option<Rc<dyn Fn(&T) -> B>>,
31    value: Option<Rc<dyn Fn(&T) -> V>>,
32    fill: Option<Rc<dyn Fn(&T, Bounds<f32>, Bounds<f32>, BarAlignment) -> Background>>,
33    #[allow(clippy::type_complexity)]
34    fill_gradient:
35        Option<Rc<dyn Fn(&T, RangeInclusive<f32>, &dyn Fn(f32) -> f32) -> [LinearColorStop; 2]>>,
36    tick_margin: usize,
37    label: Option<Rc<dyn Fn(&T) -> SharedString>>,
38    label_axis: bool,
39    grid: bool,
40    alignment: BarAlignment,
41    corner_radii: Corners<Pixels>,
42}
43
44impl<T, B, V> BarChart<T, B, V>
45where
46    B: PartialEq + Into<SharedString> + 'static,
47    V: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
48{
49    pub fn new<I>(data: I) -> Self
50    where
51        I: IntoIterator<Item = T>,
52    {
53        Self {
54            data: data.into_iter().collect(),
55            band: None,
56            value: None,
57            fill: None,
58            fill_gradient: None,
59            tick_margin: 1,
60            label: None,
61            label_axis: true,
62            grid: true,
63            alignment: BarAlignment::default(),
64            corner_radii: Corners::all(px(0.)),
65        }
66    }
67
68    /// Map each datum to its band-axis value (the categorical/ordinal axis).
69    pub fn band(mut self, band: impl Fn(&T) -> B + 'static) -> Self {
70        self.band = Some(Rc::new(band));
71        self
72    }
73
74    /// Map each datum to its numeric value along the value axis.
75    pub fn value(mut self, value: impl Fn(&T) -> V + 'static) -> Self {
76        self.value = Some(Rc::new(value));
77        self
78    }
79
80    /// Set a per-datum verbatim fill.
81    ///
82    /// The closure receives:
83    ///
84    /// 1. the datum,
85    /// 2. the **bar's bounds** in pixel space, expressed relative to the
86    ///    chart's origin (i.e. the bar's painted rectangle within the chart),
87    /// 3. the **chart's bounds** in pixel space with origin `(0, 0)` and size
88    ///    equal to the full chart extent, and
89    /// 4. the bar's [`BarAlignment`] (so callers can branch on orientation,
90    ///    e.g. flip a gradient angle).
91    ///
92    /// Both rectangles share the same coordinate system, so callers can
93    /// implement arbitrary chart-aware backgrounds — bar-local gradients,
94    /// chart-wide gradients, patterns, sampled colormaps, etc. — without any
95    /// help from the library.
96    ///
97    /// Accepts any type convertible to [`Background`]. Setting this clears any
98    /// previously set [`BarChart::fill_gradient`].
99    pub fn fill<Bg>(
100        mut self,
101        fill: impl Fn(&T, Bounds<f32>, Bounds<f32>, BarAlignment) -> Bg + 'static,
102    ) -> Self
103    where
104        Bg: Into<Background> + 'static,
105    {
106        self.fill = Some(Rc::new(move |t, bar_bounds, chart_bounds, alignment| {
107            fill(t, bar_bounds, chart_bounds, alignment).into()
108        }));
109        self.fill_gradient = None;
110        self
111    }
112
113    /// Set a per-datum auto-oriented linear gradient fill.
114    ///
115    /// The closure receives the datum, the chart's full data range
116    /// (`chart_range`, derived from all data values), and a `chart_to_bar`
117    /// remap helper that maps a chart-value coordinate to a bar-local
118    /// gradient position (where `0.0` is the bar's base and `1.0` is its tip).
119    ///
120    /// Use bar-local positions directly for per-bar gradients (every bar
121    /// looks the same regardless of its value):
122    ///
123    /// ```ignore
124    /// .fill_gradient(|_, _, _| [
125    ///     linear_color_stop(c.opacity(0.3), 0.0),
126    ///     linear_color_stop(c, 1.0),
127    /// ])
128    /// ```
129    ///
130    /// Or use `chart_to_bar` to position stops at chart-relative values, so
131    /// each bar shows the slice of a chart-wide gradient corresponding to
132    /// its own `[base, value]` span:
133    ///
134    /// ```ignore
135    /// .fill_gradient(|_, chart_range, chart_to_bar| [
136    ///     linear_color_stop(c.opacity(0.3), chart_to_bar(*chart_range.start())),
137    ///     linear_color_stop(c,              chart_to_bar(*chart_range.end())),
138    /// ])
139    /// ```
140    ///
141    /// Stop positions returned outside `[0, 1]` are clipped to the bar; the
142    /// library interpolates colors at the clip points so the on-bar gradient
143    /// still matches the chart-wide one.
144    ///
145    /// The gradient angle is derived from [`BarAlignment`] so stop-0 is at the
146    /// base and stop-1 at the tip. Setting this clears any previously set
147    /// [`BarChart::fill`].
148    pub fn fill_gradient(
149        mut self,
150        fill: impl Fn(&T, RangeInclusive<f32>, &dyn Fn(f32) -> f32) -> [LinearColorStop; 2] + 'static,
151    ) -> Self {
152        self.fill_gradient = Some(Rc::new(fill));
153        self.fill = None;
154        self
155    }
156
157    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
158        self.tick_margin = tick_margin;
159        self
160    }
161
162    pub fn label<S>(mut self, label: impl Fn(&T) -> S + 'static) -> Self
163    where
164        S: Into<SharedString> + 'static,
165    {
166        self.label = Some(Rc::new(move |t| label(t).into()));
167        self
168    }
169
170    /// Show or hide the band-axis line and labels.
171    ///
172    /// Default is true.
173    pub fn label_axis(mut self, label_axis: bool) -> Self {
174        self.label_axis = label_axis;
175        self
176    }
177
178    pub fn grid(mut self, grid: bool) -> Self {
179        self.grid = grid;
180        self
181    }
182
183    /// Set the bar alignment.
184    ///
185    /// Default is [`BarAlignment::Bottom`].
186    pub fn alignment(mut self, alignment: BarAlignment) -> Self {
187        self.alignment = alignment;
188        self
189    }
190
191    /// Set the corner radii applied to every bar rectangle.
192    ///
193    /// Use [`Corners::all`] for uniform rounding, or construct [`Corners`] manually
194    /// to round only specific corners (e.g. just the tip end of each bar).
195    pub fn corner_radii(mut self, corner_radii: impl Into<Corners<Pixels>>) -> Self {
196        self.corner_radii = corner_radii.into();
197        self
198    }
199}
200
201impl<T, B, V> Plot for BarChart<T, B, V>
202where
203    B: PartialEq + Into<SharedString> + 'static,
204    V: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
205{
206    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
207        let (Some(band_fn), Some(value_fn)) = (self.band.as_ref(), self.value.as_ref()) else {
208            return;
209        };
210
211        let total_width = bounds.size.width.as_f32();
212        let total_height = bounds.size.height.as_f32();
213        let axis_gap = if self.label_axis { AXIS_GAP } else { 0. };
214        let alignment = self.alignment;
215        let is_horizontal = alignment.is_horizontal();
216
217        // Band scale spans the full extent perpendicular to the value axis.
218        let band_extent = if is_horizontal {
219            total_height
220        } else {
221            total_width
222        };
223        let band_scale = ScaleBand::new(
224            self.data.iter().map(|v| band_fn(v)).collect(),
225            vec![0., band_extent],
226        )
227        .padding_inner(0.4)
228        .padding_outer(0.2);
229        let band_width = band_scale.band_width();
230
231        let value_dim = if is_horizontal {
232            total_width
233        } else {
234            total_height
235        };
236        // For horizontal charts the band labels (category names) are rendered
237        // along the value axis and can be arbitrarily wide, so we measure the
238        // actual maximum label width instead of using a fixed constant.
239        // Similarly, value labels (numbers) at the bar ends are measured so the
240        // scale range is always shrunk by exactly the right amount.
241        let (band_gap, value_end_gap) = if is_horizontal {
242            let font_size = px(TEXT_SIZE);
243            let band_gap = if self.label_axis {
244                let max_w = self
245                    .data
246                    .iter()
247                    .map(|v| {
248                        let s: SharedString = band_fn(v).into();
249                        measure_text_width(&s, font_size, window)
250                    })
251                    .fold(0f32, f32::max);
252                // TEXT_GAP: space between axis line and label start/end.
253                max_w + TEXT_GAP * 2.
254            } else {
255                0.
256            };
257            let value_end_gap = if let Some(label_fn) = self.label.as_ref() {
258                let max_w = self
259                    .data
260                    .iter()
261                    .map(|v| measure_text_width(&label_fn(v), font_size, window))
262                    .fold(0f32, f32::max);
263                max_w + TEXT_GAP * 2.
264            } else {
265                TEXT_GAP * 4.
266            };
267            (band_gap, value_end_gap)
268        } else {
269            (axis_gap, 10.)
270        };
271        let (range, baseline) = match alignment {
272            BarAlignment::Bottom => {
273                let baseline = value_dim - axis_gap;
274                (vec![baseline, 10.], baseline)
275            }
276            BarAlignment::Top => {
277                let baseline = axis_gap;
278                (vec![baseline, value_dim - 10.], baseline)
279            }
280            BarAlignment::Left => {
281                let baseline = band_gap;
282                (vec![baseline, value_dim - value_end_gap], baseline)
283            }
284            BarAlignment::Right => {
285                let baseline = value_dim - band_gap;
286                (vec![baseline, value_end_gap], baseline)
287            }
288        };
289        let value_scale = ScaleLinear::new(
290            self.data
291                .iter()
292                .map(|v| value_fn(v))
293                .chain(Some(V::zero()))
294                .collect(),
295            range,
296        );
297
298        // Draw band axis (with categorical labels).
299        let mut axis = PlotAxis::new().stroke(cx.theme().border);
300        if self.label_axis {
301            let labels = build_band_labels(
302                &self.data,
303                band_fn.as_ref(),
304                &band_scale,
305                band_width,
306                self.tick_margin,
307                cx.theme().muted_foreground,
308            );
309            axis = match alignment {
310                BarAlignment::Bottom => axis.x(baseline).x_label(labels),
311                BarAlignment::Top => axis
312                    .x(baseline)
313                    .x_label_side(AxisLabelSide::Start)
314                    .x_label(labels),
315                BarAlignment::Left => axis
316                    .y(baseline)
317                    .y_label_side(AxisLabelSide::Start)
318                    .y_label(labels.into_iter().map(|t| t.align(TextAlign::Right))),
319                BarAlignment::Right => axis
320                    .y(baseline)
321                    .y_label(labels.into_iter().map(|t| t.align(TextAlign::Left))),
322            };
323        }
324        axis.paint(&bounds, window, cx);
325
326        // Far edge of the value axis in pixel space (opposite the baseline).
327        let far = match alignment {
328            BarAlignment::Bottom => 10.,
329            BarAlignment::Top => value_dim - 10.,
330            BarAlignment::Left => value_dim - value_end_gap,
331            BarAlignment::Right => value_end_gap,
332        };
333
334        // Draw grid: lines perpendicular to the value axis, evenly spaced
335        // across the value range and excluding the line at the baseline.
336        if self.grid {
337            let grid_steps: Vec<f32> = (0..4)
338                .map(|i| far + (baseline - far) * i as f32 / 4.0)
339                .collect();
340            let grid = Grid::new()
341                .stroke(cx.theme().border)
342                .dash_array(&[px(4.), px(2.)]);
343            let grid = if is_horizontal {
344                grid.x(grid_steps)
345            } else {
346                grid.y(grid_steps)
347            };
348            grid.paint(&bounds, window);
349        }
350
351        // Draw bars.
352        let band_fn_cloned = band_fn.clone();
353        let value_fn_cloned = value_fn.clone();
354        let default_fill: Background = cx.theme().chart_2.into();
355        let fill = self.fill.clone();
356        let fill_gradient = self.fill_gradient.clone();
357        let label_color = cx.theme().foreground;
358
359        // Chart bounds in pixel space, with origin (0, 0) and size equal to
360        // the full chart extent. Passed to user `fill` closures so they can
361        // position chart-wide backgrounds (gradients, patterns, etc.).
362        let chart_bounds: Bounds<f32> = Bounds {
363            origin: Point::new(0., 0.),
364            size: Size::new(total_width, total_height),
365        };
366
367        // Chart data range in f32 — passed to `fill_gradient` callers and used
368        // by the `chart_to_bar` remap helper.
369        let chart_range = {
370            let mut lo = 0.0_f32;
371            let mut hi = 0.0_f32;
372            for v in &self.data {
373                if let Some(f) = value_fn(v).to_f32() {
374                    lo = lo.min(f);
375                    hi = hi.max(f);
376                }
377            }
378            lo..=hi
379        };
380
381        let mut bar = Bar::new()
382            .data(&self.data)
383            .alignment(alignment)
384            .band_width(band_width)
385            .cross(move |d| band_scale.tick(&band_fn_cloned(d)))
386            .base(move |_| baseline)
387            .value(move |d| value_scale.tick(&value_fn_cloned(d)))
388            .corner_radii(self.corner_radii);
389
390        bar = match (fill, fill_gradient) {
391            (_, Some(fg)) => {
392                let value_fn_for_grad = value_fn.clone();
393                bar.fill(move |d, _frame, alignment| {
394                    let v = value_fn_for_grad(d).to_f32().unwrap_or(0.);
395                    let base_v = 0.0_f32;
396                    let bar_lo = base_v.min(v);
397                    let bar_hi = base_v.max(v);
398                    let bar_span = (bar_hi - bar_lo).max(f32::EPSILON);
399                    let chart_to_bar = |chart_value: f32| (chart_value - bar_lo) / bar_span;
400                    let stops = fg(d, chart_range.clone(), &chart_to_bar);
401                    let [s0, s1] = clip_stops_to_bar(stops);
402                    let bg: Background = linear_gradient(alignment.gradient_angle(), s0, s1);
403                    bg
404                })
405            }
406            (Some(f), _) => {
407                bar.fill(move |d, frame, alignment| f(d, frame, chart_bounds, alignment))
408            }
409            _ => bar.fill(move |_, _, _| default_fill),
410        };
411
412        if let Some(label) = self.label.as_ref() {
413            let label = label.clone();
414            let text_align = match alignment {
415                BarAlignment::Bottom | BarAlignment::Top => TextAlign::Center,
416                BarAlignment::Left => TextAlign::Left,
417                BarAlignment::Right => TextAlign::Right,
418            };
419            bar =
420                bar.label(move |d, p| vec![Text::new(label(d), p, label_color).align(text_align)]);
421        }
422
423        bar.paint(&bounds, window, cx);
424    }
425}
426
427/// Clip a two-stop gradient to bar-local `[0, 1]`, interpolating colors at the
428/// clip points so the on-bar gradient matches the (possibly broader) gradient
429/// the caller defined.
430///
431/// When a stop position falls outside `[0, 1]` (e.g. because `chart_to_bar`
432/// returned a value past the bar's edge for a chart-relative gradient),
433/// gpui's renderer would clamp the position and lose the gradient effect.
434/// This function instead replaces such a stop with the color sampled along
435/// the line through both stops at position `0.0` or `1.0`, preserving the
436/// visual slice.
437fn clip_stops_to_bar(stops: [LinearColorStop; 2]) -> [LinearColorStop; 2] {
438    let [a, b] = stops;
439    let p0 = a.percentage;
440    let p1 = b.percentage;
441    let lerp = |t: f32| -> Hsla {
442        Hsla {
443            h: a.color.h + (b.color.h - a.color.h) * t,
444            s: a.color.s + (b.color.s - a.color.s) * t,
445            l: a.color.l + (b.color.l - a.color.l) * t,
446            a: a.color.a + (b.color.a - a.color.a) * t,
447        }
448    };
449    let span = p1 - p0;
450    let sample = |target: f32| -> Hsla {
451        if span.abs() < f32::EPSILON {
452            a.color
453        } else {
454            lerp((target - p0) / span)
455        }
456    };
457    let new_a = if (0. ..=1.).contains(&p0) {
458        a
459    } else {
460        LinearColorStop {
461            color: sample(p0.clamp(0., 1.)),
462            percentage: p0.clamp(0., 1.),
463        }
464    };
465    let new_b = if (0. ..=1.).contains(&p1) {
466        b
467    } else {
468        LinearColorStop {
469            color: sample(p1.clamp(0., 1.)),
470            percentage: p1.clamp(0., 1.),
471        }
472    };
473    [new_a, new_b]
474}