1use std::rc::Rc;
2
3use gpui::{App, Bounds, Hsla, PathBuilder, Pixels, SharedString, TextAlign, Window, fill, px};
4use gpui_component_macros::IntoPlot;
5use num_traits::{Num, ToPrimitive};
6
7use crate::{
8 ActiveTheme, PixelsExt,
9 plot::{
10 AXIS_GAP, AxisText, Grid, Plot, PlotAxis, origin_point,
11 scale::{Scale, ScaleBand, ScaleLinear, Sealed},
12 },
13};
14
15#[derive(IntoPlot)]
16pub struct CandlestickChart<T, X, Y>
17where
18 T: 'static,
19 X: PartialEq + Into<SharedString> + 'static,
20 Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
21{
22 data: Vec<T>,
23 x: Option<Rc<dyn Fn(&T) -> X>>,
24 open: Option<Rc<dyn Fn(&T) -> Y>>,
25 high: Option<Rc<dyn Fn(&T) -> Y>>,
26 low: Option<Rc<dyn Fn(&T) -> Y>>,
27 close: Option<Rc<dyn Fn(&T) -> Y>>,
28 tick_margin: usize,
29 body_width_ratio: f32,
30}
31
32impl<T, X, Y> CandlestickChart<T, X, Y>
33where
34 X: PartialEq + Into<SharedString> + 'static,
35 Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
36{
37 pub fn new<I>(data: I) -> Self
38 where
39 I: IntoIterator<Item = T>,
40 {
41 Self {
42 data: data.into_iter().collect(),
43 x: None,
44 open: None,
45 high: None,
46 low: None,
47 close: None,
48 tick_margin: 1,
49 body_width_ratio: 0.8,
50 }
51 }
52
53 pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
54 self.x = Some(Rc::new(x));
55 self
56 }
57
58 pub fn open(mut self, open: impl Fn(&T) -> Y + 'static) -> Self {
59 self.open = Some(Rc::new(open));
60 self
61 }
62
63 pub fn high(mut self, high: impl Fn(&T) -> Y + 'static) -> Self {
64 self.high = Some(Rc::new(high));
65 self
66 }
67
68 pub fn low(mut self, low: impl Fn(&T) -> Y + 'static) -> Self {
69 self.low = Some(Rc::new(low));
70 self
71 }
72
73 pub fn close(mut self, close: impl Fn(&T) -> Y + 'static) -> Self {
74 self.close = Some(Rc::new(close));
75 self
76 }
77
78 pub fn tick_margin(mut self, tick_margin: usize) -> Self {
79 self.tick_margin = tick_margin;
80 self
81 }
82
83 pub fn body_width_ratio(mut self, ratio: f32) -> Self {
84 self.body_width_ratio = ratio;
85 self
86 }
87}
88
89impl<T, X, Y> Plot for CandlestickChart<T, X, Y>
90where
91 X: PartialEq + Into<SharedString> + 'static,
92 Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
93{
94 fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
95 let (Some(x_fn), Some(open_fn), Some(high_fn), Some(low_fn), Some(close_fn)) = (
96 self.x.as_ref(),
97 self.open.as_ref(),
98 self.high.as_ref(),
99 self.low.as_ref(),
100 self.close.as_ref(),
101 ) else {
102 return;
103 };
104
105 let width = bounds.size.width.as_f32();
106 let height = bounds.size.height.as_f32() - AXIS_GAP;
107
108 let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
110 .padding_inner(0.4)
111 .padding_outer(0.2);
112 let band_width = x.band_width();
113
114 let all_values: Vec<Y> = self
116 .data
117 .iter()
118 .flat_map(|d| vec![high_fn(d), low_fn(d), open_fn(d), close_fn(d)])
119 .collect();
120 let y = ScaleLinear::new(all_values, vec![height, 10.]);
121
122 let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
124 if (i + 1) % self.tick_margin == 0 {
125 x.tick(&x_fn(d)).map(|x_tick| {
126 AxisText::new(
127 x_fn(d).into(),
128 x_tick + band_width / 2.,
129 cx.theme().muted_foreground,
130 )
131 .align(TextAlign::Center)
132 })
133 } else {
134 None
135 }
136 });
137
138 PlotAxis::new()
139 .x(height)
140 .x_label(x_label)
141 .stroke(cx.theme().border)
142 .paint(&bounds, window, cx);
143
144 Grid::new()
146 .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
147 .stroke(cx.theme().border)
148 .dash_array(&[px(4.), px(2.)])
149 .paint(&bounds, window);
150
151 let origin = bounds.origin;
153 let x_fn = x_fn.clone();
154 let open_fn = open_fn.clone();
155 let high_fn = high_fn.clone();
156 let low_fn = low_fn.clone();
157 let close_fn = close_fn.clone();
158
159 for d in &self.data {
160 let x_tick = x.tick(&x_fn(d));
161 let Some(x_tick) = x_tick else {
162 continue;
163 };
164
165 let open = open_fn(d);
167 let high = high_fn(d);
168 let low = low_fn(d);
169 let close = close_fn(d);
170
171 let open_y = y.tick(&open);
173 let high_y = y.tick(&high);
174 let low_y = y.tick(&low);
175 let close_y = y.tick(&close);
176
177 let (Some(open_y), Some(high_y), Some(low_y), Some(close_y)) =
178 (open_y, high_y, low_y, close_y)
179 else {
180 continue;
181 };
182
183 let is_bullish = close > open;
185 let color: Hsla = if is_bullish {
186 cx.theme().bullish
187 } else {
188 cx.theme().bearish
189 };
190
191 let center_x = x_tick + band_width / 2.;
193 let body_width = band_width * self.body_width_ratio;
194 let body_left = center_x - body_width / 2.;
195 let body_right = center_x + body_width / 2.;
196
197 let mut wick_builder = PathBuilder::stroke(px(1.));
199 wick_builder.move_to(origin_point(px(center_x), px(high_y), origin));
200 wick_builder.line_to(origin_point(px(center_x), px(low_y), origin));
201
202 if let Ok(path) = wick_builder.build() {
203 window.paint_path(path, color);
204 }
205
206 let (top, bottom) = if is_bullish {
210 (close_y, open_y)
211 } else {
212 (open_y, close_y)
213 };
214
215 let body_bounds = Bounds::from_corners(
216 origin_point(px(body_left), px(top), origin),
217 origin_point(px(body_right), px(bottom), origin),
218 );
219
220 window.paint_quad(fill(body_bounds, color));
221 }
222 }
223}