1use std::{ops::RangeInclusive, rc::Rc};
2
3use num_traits::{Num, ToPrimitive};
4use rgpui::{
5 App, Background, Bounds, Corners, Hsla, LinearColorStop, Pixels, Point, SharedString, Size,
6 TextAlign, Window, linear_gradient, px,
7};
8use rgpui_component_macros::IntoPlot;
9
10use crate::{
11 ActiveTheme,
12 plot::{
13 AXIS_GAP, AxisLabelSide, Grid, Plot, PlotAxis,
14 label::{TEXT_GAP, TEXT_SIZE, Text, measure_text_width},
15 scale::{Scale, ScaleBand, ScaleLinear, Sealed},
16 shape::{Bar, BarAlignment},
17 },
18};
19
20use super::build_band_labels;
21
22#[derive(IntoPlot)]
23pub struct BarChart<T, B, V>
24where
25 T: 'static,
26 B: PartialEq + Into<SharedString> + 'static,
27 V: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
28{
29 data: Vec<T>,
30 band: Option<Rc<dyn Fn(&T) -> B>>,
31 value: Option<Rc<dyn Fn(&T) -> V>>,
32 fill: Option<Rc<dyn Fn(&T, Bounds<f32>, Bounds<f32>, BarAlignment) -> Background>>,
33 #[allow(clippy::type_complexity)]
34 fill_gradient:
35 Option<Rc<dyn Fn(&T, RangeInclusive<f32>, &dyn Fn(f32) -> f32) -> [LinearColorStop; 2]>>,
36 tick_margin: usize,
37 label: Option<Rc<dyn Fn(&T) -> SharedString>>,
38 label_axis: bool,
39 grid: bool,
40 alignment: BarAlignment,
41 corner_radii: Corners<Pixels>,
42}
43
44impl<T, B, V> BarChart<T, B, V>
45where
46 B: PartialEq + Into<SharedString> + 'static,
47 V: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
48{
49 pub fn new<I>(data: I) -> Self
50 where
51 I: IntoIterator<Item = T>,
52 {
53 Self {
54 data: data.into_iter().collect(),
55 band: None,
56 value: None,
57 fill: None,
58 fill_gradient: None,
59 tick_margin: 1,
60 label: None,
61 label_axis: true,
62 grid: true,
63 alignment: BarAlignment::default(),
64 corner_radii: Corners::all(px(0.)),
65 }
66 }
67
68 pub fn band(mut self, band: impl Fn(&T) -> B + 'static) -> Self {
70 self.band = Some(Rc::new(band));
71 self
72 }
73
74 pub fn value(mut self, value: impl Fn(&T) -> V + 'static) -> Self {
76 self.value = Some(Rc::new(value));
77 self
78 }
79
80 pub fn fill<Bg>(
100 mut self,
101 fill: impl Fn(&T, Bounds<f32>, Bounds<f32>, BarAlignment) -> Bg + 'static,
102 ) -> Self
103 where
104 Bg: Into<Background> + 'static,
105 {
106 self.fill = Some(Rc::new(move |t, bar_bounds, chart_bounds, alignment| {
107 fill(t, bar_bounds, chart_bounds, alignment).into()
108 }));
109 self.fill_gradient = None;
110 self
111 }
112
113 pub fn fill_gradient(
149 mut self,
150 fill: impl Fn(&T, RangeInclusive<f32>, &dyn Fn(f32) -> f32) -> [LinearColorStop; 2] + 'static,
151 ) -> Self {
152 self.fill_gradient = Some(Rc::new(fill));
153 self.fill = None;
154 self
155 }
156
157 pub fn tick_margin(mut self, tick_margin: usize) -> Self {
158 self.tick_margin = tick_margin;
159 self
160 }
161
162 pub fn label<S>(mut self, label: impl Fn(&T) -> S + 'static) -> Self
163 where
164 S: Into<SharedString> + 'static,
165 {
166 self.label = Some(Rc::new(move |t| label(t).into()));
167 self
168 }
169
170 pub fn label_axis(mut self, label_axis: bool) -> Self {
174 self.label_axis = label_axis;
175 self
176 }
177
178 pub fn grid(mut self, grid: bool) -> Self {
179 self.grid = grid;
180 self
181 }
182
183 pub fn alignment(mut self, alignment: BarAlignment) -> Self {
187 self.alignment = alignment;
188 self
189 }
190
191 pub fn corner_radii(mut self, corner_radii: impl Into<Corners<Pixels>>) -> Self {
196 self.corner_radii = corner_radii.into();
197 self
198 }
199}
200
201impl<T, B, V> Plot for BarChart<T, B, V>
202where
203 B: PartialEq + Into<SharedString> + 'static,
204 V: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
205{
206 fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
207 let (Some(band_fn), Some(value_fn)) = (self.band.as_ref(), self.value.as_ref()) else {
208 return;
209 };
210
211 let total_width = bounds.size.width.as_f32();
212 let total_height = bounds.size.height.as_f32();
213 let axis_gap = if self.label_axis { AXIS_GAP } else { 0. };
214 let alignment = self.alignment;
215 let is_horizontal = alignment.is_horizontal();
216
217 let band_extent = if is_horizontal {
219 total_height
220 } else {
221 total_width
222 };
223 let band_scale = ScaleBand::new(
224 self.data.iter().map(|v| band_fn(v)).collect(),
225 vec![0., band_extent],
226 )
227 .padding_inner(0.4)
228 .padding_outer(0.2);
229 let band_width = band_scale.band_width();
230
231 let value_dim = if is_horizontal {
232 total_width
233 } else {
234 total_height
235 };
236 let (band_gap, value_end_gap) = if is_horizontal {
242 let font_size = px(TEXT_SIZE);
243 let band_gap = if self.label_axis {
244 let max_w = self
245 .data
246 .iter()
247 .map(|v| {
248 let s: SharedString = band_fn(v).into();
249 measure_text_width(&s, font_size, window)
250 })
251 .fold(0f32, f32::max);
252 max_w + TEXT_GAP * 2.
254 } else {
255 0.
256 };
257 let value_end_gap = if let Some(label_fn) = self.label.as_ref() {
258 let max_w = self
259 .data
260 .iter()
261 .map(|v| measure_text_width(&label_fn(v), font_size, window))
262 .fold(0f32, f32::max);
263 max_w + TEXT_GAP * 2.
264 } else {
265 TEXT_GAP * 4.
266 };
267 (band_gap, value_end_gap)
268 } else {
269 (axis_gap, 10.)
270 };
271 let (range, baseline) = match alignment {
272 BarAlignment::Bottom => {
273 let baseline = value_dim - axis_gap;
274 (vec![baseline, 10.], baseline)
275 }
276 BarAlignment::Top => {
277 let baseline = axis_gap;
278 (vec![baseline, value_dim - 10.], baseline)
279 }
280 BarAlignment::Left => {
281 let baseline = band_gap;
282 (vec![baseline, value_dim - value_end_gap], baseline)
283 }
284 BarAlignment::Right => {
285 let baseline = value_dim - band_gap;
286 (vec![baseline, value_end_gap], baseline)
287 }
288 };
289 let value_scale = ScaleLinear::new(
290 self.data
291 .iter()
292 .map(|v| value_fn(v))
293 .chain(Some(V::zero()))
294 .collect(),
295 range,
296 );
297
298 let mut axis = PlotAxis::new().stroke(cx.theme().border);
300 if self.label_axis {
301 let labels = build_band_labels(
302 &self.data,
303 band_fn.as_ref(),
304 &band_scale,
305 band_width,
306 self.tick_margin,
307 cx.theme().muted_foreground,
308 );
309 axis = match alignment {
310 BarAlignment::Bottom => axis.x(baseline).x_label(labels),
311 BarAlignment::Top => axis
312 .x(baseline)
313 .x_label_side(AxisLabelSide::Start)
314 .x_label(labels),
315 BarAlignment::Left => axis
316 .y(baseline)
317 .y_label_side(AxisLabelSide::Start)
318 .y_label(labels.into_iter().map(|t| t.align(TextAlign::Right))),
319 BarAlignment::Right => axis
320 .y(baseline)
321 .y_label(labels.into_iter().map(|t| t.align(TextAlign::Left))),
322 };
323 }
324 axis.paint(&bounds, window, cx);
325
326 let far = match alignment {
328 BarAlignment::Bottom => 10.,
329 BarAlignment::Top => value_dim - 10.,
330 BarAlignment::Left => value_dim - value_end_gap,
331 BarAlignment::Right => value_end_gap,
332 };
333
334 if self.grid {
337 let grid_steps: Vec<f32> = (0..4)
338 .map(|i| far + (baseline - far) * i as f32 / 4.0)
339 .collect();
340 let grid = Grid::new()
341 .stroke(cx.theme().border)
342 .dash_array(&[px(4.), px(2.)]);
343 let grid = if is_horizontal {
344 grid.x(grid_steps)
345 } else {
346 grid.y(grid_steps)
347 };
348 grid.paint(&bounds, window);
349 }
350
351 let band_fn_cloned = band_fn.clone();
353 let value_fn_cloned = value_fn.clone();
354 let default_fill: Background = cx.theme().chart_2.into();
355 let fill = self.fill.clone();
356 let fill_gradient = self.fill_gradient.clone();
357 let label_color = cx.theme().foreground;
358
359 let chart_bounds: Bounds<f32> = Bounds {
363 origin: Point::new(0., 0.),
364 size: Size::new(total_width, total_height),
365 };
366
367 let chart_range = {
370 let mut lo = 0.0_f32;
371 let mut hi = 0.0_f32;
372 for v in &self.data {
373 if let Some(f) = value_fn(v).to_f32() {
374 lo = lo.min(f);
375 hi = hi.max(f);
376 }
377 }
378 lo..=hi
379 };
380
381 let mut bar = Bar::new()
382 .data(&self.data)
383 .alignment(alignment)
384 .band_width(band_width)
385 .cross(move |d| band_scale.tick(&band_fn_cloned(d)))
386 .base(move |_| baseline)
387 .value(move |d| value_scale.tick(&value_fn_cloned(d)))
388 .corner_radii(self.corner_radii);
389
390 bar = match (fill, fill_gradient) {
391 (_, Some(fg)) => {
392 let value_fn_for_grad = value_fn.clone();
393 bar.fill(move |d, _frame, alignment| {
394 let v = value_fn_for_grad(d).to_f32().unwrap_or(0.);
395 let base_v = 0.0_f32;
396 let bar_lo = base_v.min(v);
397 let bar_hi = base_v.max(v);
398 let bar_span = (bar_hi - bar_lo).max(f32::EPSILON);
399 let chart_to_bar = |chart_value: f32| (chart_value - bar_lo) / bar_span;
400 let stops = fg(d, chart_range.clone(), &chart_to_bar);
401 let [s0, s1] = clip_stops_to_bar(stops);
402 let bg: Background = linear_gradient(alignment.gradient_angle(), s0, s1);
403 bg
404 })
405 }
406 (Some(f), _) => {
407 bar.fill(move |d, frame, alignment| f(d, frame, chart_bounds, alignment))
408 }
409 _ => bar.fill(move |_, _, _| default_fill),
410 };
411
412 if let Some(label) = self.label.as_ref() {
413 let label = label.clone();
414 let text_align = match alignment {
415 BarAlignment::Bottom | BarAlignment::Top => TextAlign::Center,
416 BarAlignment::Left => TextAlign::Left,
417 BarAlignment::Right => TextAlign::Right,
418 };
419 bar =
420 bar.label(move |d, p| vec![Text::new(label(d), p, label_color).align(text_align)]);
421 }
422
423 bar.paint(&bounds, window, cx);
424 }
425}
426
427fn clip_stops_to_bar(stops: [LinearColorStop; 2]) -> [LinearColorStop; 2] {
438 let [a, b] = stops;
439 let p0 = a.percentage;
440 let p1 = b.percentage;
441 let lerp = |t: f32| -> Hsla {
442 Hsla {
443 h: a.color.h + (b.color.h - a.color.h) * t,
444 s: a.color.s + (b.color.s - a.color.s) * t,
445 l: a.color.l + (b.color.l - a.color.l) * t,
446 a: a.color.a + (b.color.a - a.color.a) * t,
447 }
448 };
449 let span = p1 - p0;
450 let sample = |target: f32| -> Hsla {
451 if span.abs() < f32::EPSILON {
452 a.color
453 } else {
454 lerp((target - p0) / span)
455 }
456 };
457 let new_a = if (0. ..=1.).contains(&p0) {
458 a
459 } else {
460 LinearColorStop {
461 color: sample(p0.clamp(0., 1.)),
462 percentage: p0.clamp(0., 1.),
463 }
464 };
465 let new_b = if (0. ..=1.).contains(&p1) {
466 b
467 } else {
468 LinearColorStop {
469 color: sample(p1.clamp(0., 1.)),
470 percentage: p1.clamp(0., 1.),
471 }
472 };
473 [new_a, new_b]
474}