woocraft 0.4.5

GPUI components lib for Woocraft design system.
Documentation
use std::rc::Rc;

use gpui::{
  App, Bounds, IntoElement, Pixels, RenderOnce, SharedString, Styled, TextAlign, Window, canvas, px,
};
use num_traits::{Num, ToPrimitive};

use crate::{
  AXIS_GAP, ActiveTheme, AxisText, Grid, PixelsExt, Plot, PlotAxis,
  label::Text,
  scale::{Scale, ScaleBand, ScaleLinear, Sealed},
  shape::Bar,
};

type XAccessor<T, X> = Rc<dyn Fn(&T) -> X>;
type YAccessor<T, Y> = Rc<dyn Fn(&T) -> Y>;
type FillAccessor<T> = Rc<dyn Fn(&T) -> gpui::Hsla>;
type LabelAccessor<T> = Rc<dyn Fn(&T) -> SharedString>;

#[derive(IntoElement)]
pub struct BarChart<T, X, Y>
where
  T: 'static,
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static, {
  data: Vec<T>,
  x: Option<XAccessor<T, X>>,
  y: Option<YAccessor<T, Y>>,
  fill: Option<FillAccessor<T>>,
  tick_margin: usize,
  label: Option<LabelAccessor<T>>,
}

impl<T, X, Y> BarChart<T, X, Y>
where
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
  pub fn new<I>(data: I) -> Self
  where
    I: IntoIterator<Item = T>, {
    Self {
      data: data.into_iter().collect(),
      x: None,
      y: None,
      fill: None,
      tick_margin: 1,
      label: None,
    }
  }

  pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
    self.x = Some(Rc::new(x));
    self
  }

  pub fn y(mut self, y: impl Fn(&T) -> Y + 'static) -> Self {
    self.y = Some(Rc::new(y));
    self
  }

  pub fn fill<H>(mut self, fill: impl Fn(&T) -> H + 'static) -> Self
  where
    H: Into<gpui::Hsla> + 'static, {
    self.fill = Some(Rc::new(move |t| fill(t).into()));
    self
  }

  pub fn tick_margin(mut self, tick_margin: usize) -> Self {
    self.tick_margin = tick_margin.max(1);
    self
  }

  pub fn label<S>(mut self, label: impl Fn(&T) -> S + 'static) -> Self
  where
    S: Into<SharedString> + 'static, {
    self.label = Some(Rc::new(move |t| label(t).into()));
    self
  }
}

impl<T, X, Y> Plot for BarChart<T, X, Y>
where
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
  fn paint(&self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
    let (Some(x_fn), Some(y_fn)) = (self.x.as_ref(), self.y.as_ref()) else {
      return;
    };

    let width = bounds.size.width.as_f32();
    let height = bounds.size.height.as_f32() - AXIS_GAP;

    let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
      .padding_inner(0.4)
      .padding_outer(0.2);
    let band_width = x.band_width();

    let y = ScaleLinear::new(
      self
        .data
        .iter()
        .map(|v| y_fn(v))
        .chain(Some(Y::zero()))
        .collect(),
      vec![height, 10.],
    );

    let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
      if (i + 1) % self.tick_margin == 0 {
        x.tick(&x_fn(d)).map(|x_tick| {
          AxisText::new(
            x_fn(d).into(),
            x_tick + band_width / 2.,
            cx.theme().muted_foreground,
          )
          .align(TextAlign::Center)
        })
      } else {
        None
      }
    });

    PlotAxis::new()
      .x(height)
      .x_label(x_label)
      .stroke(cx.theme().border)
      .paint(&bounds, window, cx);

    Grid::new()
      .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
      .stroke(cx.theme().border)
      .dash_array(&[px(4.), px(2.)])
      .paint(&bounds, window);

    let x_fn = x_fn.clone();
    let y_fn = y_fn.clone();
    let default_fill = cx.theme().primary;
    let fill = self.fill.clone();
    let label_color = cx.theme().foreground;
    let mut bar = Bar::new()
      .data(&self.data)
      .band_width(band_width)
      .x(move |d| x.tick(&x_fn(d)))
      .y0(move |_| height)
      .y1(move |d| y.tick(&y_fn(d)))
      .fill(move |d| fill.as_ref().map(|f| f(d)).unwrap_or(default_fill));

    if let Some(label) = self.label.as_ref() {
      let label = label.clone();
      bar =
        bar.label(move |d, p| vec![Text::new(label(d), p, label_color).align(TextAlign::Center)]);
    }

    bar.paint(&bounds, window, cx);
  }
}

impl<T, X, Y> RenderOnce for BarChart<T, X, Y>
where
  T: 'static,
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
  fn render(self, _: &mut Window, _: &mut App) -> impl IntoElement {
    canvas(
      move |_, _, _| {},
      move |bounds, _, window, cx| self.paint(bounds, window, cx),
    )
    .size_full()
  }
}