woocraft 0.4.5

GPUI components lib for Woocraft design system.
Documentation
use std::rc::Rc;

use gpui::{
  App, Bounds, IntoElement, Pixels, RenderOnce, SharedString, Styled, TextAlign, Window, canvas,
  fill, px,
};
use num_traits::{Num, ToPrimitive};

use crate::{
  AXIS_GAP, ActiveTheme, AxisText, Grid, PixelsExt, Plot, PlotAxis, origin_point,
  scale::{Scale, ScaleBand, ScaleLinear, Sealed},
};

type XAccessor<T, X> = Rc<dyn Fn(&T) -> X>;
type YAccessor<T, Y> = Rc<dyn Fn(&T) -> Y>;

#[derive(IntoElement)]
pub struct CandlestickChart<T, X, Y>
where
  T: 'static,
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static, {
  data: Vec<T>,
  x: Option<XAccessor<T, X>>,
  open: Option<YAccessor<T, Y>>,
  high: Option<YAccessor<T, Y>>,
  low: Option<YAccessor<T, Y>>,
  close: Option<YAccessor<T, Y>>,
  tick_margin: usize,
  body_width_ratio: f32,
}

impl<T, X, Y> CandlestickChart<T, X, Y>
where
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
  pub fn new<I>(data: I) -> Self
  where
    I: IntoIterator<Item = T>, {
    Self {
      data: data.into_iter().collect(),
      x: None,
      open: None,
      high: None,
      low: None,
      close: None,
      tick_margin: 1,
      body_width_ratio: 0.8,
    }
  }

  pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
    self.x = Some(Rc::new(x));
    self
  }

  pub fn open(mut self, open: impl Fn(&T) -> Y + 'static) -> Self {
    self.open = Some(Rc::new(open));
    self
  }

  pub fn high(mut self, high: impl Fn(&T) -> Y + 'static) -> Self {
    self.high = Some(Rc::new(high));
    self
  }

  pub fn low(mut self, low: impl Fn(&T) -> Y + 'static) -> Self {
    self.low = Some(Rc::new(low));
    self
  }

  pub fn close(mut self, close: impl Fn(&T) -> Y + 'static) -> Self {
    self.close = Some(Rc::new(close));
    self
  }

  pub fn tick_margin(mut self, tick_margin: usize) -> Self {
    self.tick_margin = tick_margin.max(1);
    self
  }

  pub fn body_width_ratio(mut self, ratio: f32) -> Self {
    self.body_width_ratio = ratio.clamp(0.1, 1.0);
    self
  }
}

impl<T, X, Y> Plot for CandlestickChart<T, X, Y>
where
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
  fn paint(&self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
    let (Some(x_fn), Some(open_fn), Some(high_fn), Some(low_fn), Some(close_fn)) = (
      self.x.as_ref(),
      self.open.as_ref(),
      self.high.as_ref(),
      self.low.as_ref(),
      self.close.as_ref(),
    ) else {
      return;
    };

    let width = bounds.size.width.as_f32();
    let height = bounds.size.height.as_f32() - AXIS_GAP;

    let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
      .padding_inner(0.4)
      .padding_outer(0.2);
    let band_width = x.band_width();

    let all_values: Vec<Y> = self
      .data
      .iter()
      .flat_map(|d| [high_fn(d), low_fn(d), open_fn(d), close_fn(d)])
      .collect();
    let y = ScaleLinear::new(all_values, vec![height, 10.]);

    let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
      if (i + 1) % self.tick_margin == 0 {
        x.tick(&x_fn(d)).map(|x_tick| {
          AxisText::new(
            x_fn(d).into(),
            x_tick + band_width / 2.,
            cx.theme().muted_foreground,
          )
          .align(TextAlign::Center)
        })
      } else {
        None
      }
    });

    PlotAxis::new()
      .x(height)
      .x_label(x_label)
      .stroke(cx.theme().border)
      .paint(&bounds, window, cx);

    Grid::new()
      .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
      .stroke(cx.theme().border)
      .dash_array(&[px(4.), px(2.)])
      .paint(&bounds, window);

    let origin = bounds.origin;

    for d in &self.data {
      let Some(x_tick) = x.tick(&x_fn(d)) else {
        continue;
      };

      let open = open_fn(d);
      let high = high_fn(d);
      let low = low_fn(d);
      let close = close_fn(d);

      let (Some(open_y), Some(high_y), Some(low_y), Some(close_y)) =
        (y.tick(&open), y.tick(&high), y.tick(&low), y.tick(&close))
      else {
        continue;
      };

      let is_bullish = close >= open;
      let color = if is_bullish {
        cx.theme().success
      } else {
        cx.theme().danger
      };

      let center_x = x_tick + band_width / 2.;
      let body_width = band_width * self.body_width_ratio;
      let body_left = center_x - body_width / 2.;
      let body_right = center_x + body_width / 2.;

      let mut wick_builder = gpui::PathBuilder::stroke(px(1.));
      wick_builder.move_to(origin_point(px(center_x), px(high_y), origin));
      wick_builder.line_to(origin_point(px(center_x), px(low_y), origin));

      if let Ok(path) = wick_builder.build() {
        window.paint_path(path, color);
      }

      let (top, bottom) = if is_bullish {
        (close_y, open_y)
      } else {
        (open_y, close_y)
      };

      let body_bounds = Bounds::from_corners(
        origin_point(px(body_left), px(top), origin),
        origin_point(px(body_right), px(bottom), origin),
      );

      window.paint_quad(fill(body_bounds, color));
    }
  }
}

impl<T, X, Y> RenderOnce for CandlestickChart<T, X, Y>
where
  T: 'static,
  X: PartialEq + Into<SharedString> + 'static,
  Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
{
  fn render(self, _: &mut Window, _: &mut App) -> impl IntoElement {
    canvas(
      move |_, _, _| {},
      move |bounds, _, window, cx| self.paint(bounds, window, cx),
    )
    .size_full()
  }
}