1use std::rc::Rc;
2
3use num_traits::{Num, ToPrimitive};
4use rgpui::{App, Bounds, Hsla, PathBuilder, Pixels, SharedString, Window, fill, px};
5use rgpui_component_macros::IntoPlot;
6
7use crate::{
8 ActiveTheme,
9 plot::{
10 AXIS_GAP, Grid, Plot, PlotAxis, origin_point,
11 scale::{Scale, ScaleBand, ScaleLinear, Sealed},
12 },
13};
14
15use super::build_band_labels;
16
17#[derive(IntoPlot)]
18pub struct CandlestickChart<T, X, Y>
19where
20 T: 'static,
21 X: PartialEq + Into<SharedString> + 'static,
22 Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
23{
24 data: Vec<T>,
25 x: Option<Rc<dyn Fn(&T) -> X>>,
26 open: Option<Rc<dyn Fn(&T) -> Y>>,
27 high: Option<Rc<dyn Fn(&T) -> Y>>,
28 low: Option<Rc<dyn Fn(&T) -> Y>>,
29 close: Option<Rc<dyn Fn(&T) -> Y>>,
30 tick_margin: usize,
31 body_width_ratio: f32,
32 x_axis: bool,
33 grid: bool,
34}
35
36impl<T, X, Y> CandlestickChart<T, X, Y>
37where
38 X: PartialEq + Into<SharedString> + 'static,
39 Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
40{
41 pub fn new<I>(data: I) -> Self
42 where
43 I: IntoIterator<Item = T>,
44 {
45 Self {
46 data: data.into_iter().collect(),
47 x: None,
48 open: None,
49 high: None,
50 low: None,
51 close: None,
52 tick_margin: 1,
53 body_width_ratio: 0.8,
54 x_axis: true,
55 grid: true,
56 }
57 }
58
59 pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
60 self.x = Some(Rc::new(x));
61 self
62 }
63
64 pub fn open(mut self, open: impl Fn(&T) -> Y + 'static) -> Self {
65 self.open = Some(Rc::new(open));
66 self
67 }
68
69 pub fn high(mut self, high: impl Fn(&T) -> Y + 'static) -> Self {
70 self.high = Some(Rc::new(high));
71 self
72 }
73
74 pub fn low(mut self, low: impl Fn(&T) -> Y + 'static) -> Self {
75 self.low = Some(Rc::new(low));
76 self
77 }
78
79 pub fn close(mut self, close: impl Fn(&T) -> Y + 'static) -> Self {
80 self.close = Some(Rc::new(close));
81 self
82 }
83
84 pub fn tick_margin(mut self, tick_margin: usize) -> Self {
85 self.tick_margin = tick_margin;
86 self
87 }
88
89 pub fn body_width_ratio(mut self, ratio: f32) -> Self {
90 self.body_width_ratio = ratio;
91 self
92 }
93
94 pub fn x_axis(mut self, x_axis: bool) -> Self {
98 self.x_axis = x_axis;
99 self
100 }
101
102 pub fn grid(mut self, grid: bool) -> Self {
103 self.grid = grid;
104 self
105 }
106}
107
108impl<T, X, Y> Plot for CandlestickChart<T, X, Y>
109where
110 X: PartialEq + Into<SharedString> + 'static,
111 Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
112{
113 fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
114 let (Some(x_fn), Some(open_fn), Some(high_fn), Some(low_fn), Some(close_fn)) = (
115 self.x.as_ref(),
116 self.open.as_ref(),
117 self.high.as_ref(),
118 self.low.as_ref(),
119 self.close.as_ref(),
120 ) else {
121 return;
122 };
123
124 let width = bounds.size.width.as_f32();
125 let axis_gap = if self.x_axis { AXIS_GAP } else { 0. };
126 let height = bounds.size.height.as_f32() - axis_gap;
127
128 let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
130 .padding_inner(0.4)
131 .padding_outer(0.2);
132 let band_width = x.band_width();
133
134 let all_values: Vec<Y> = self
136 .data
137 .iter()
138 .flat_map(|d| vec![high_fn(d), low_fn(d), open_fn(d), close_fn(d)])
139 .collect();
140 let y = ScaleLinear::new(all_values, vec![height, 10.]);
141
142 let mut axis = PlotAxis::new().stroke(cx.theme().border);
144 if self.x_axis {
145 let labels = build_band_labels(
146 &self.data,
147 x_fn.as_ref(),
148 &x,
149 band_width,
150 self.tick_margin,
151 cx.theme().muted_foreground,
152 );
153 axis = axis.x(height).x_label(labels);
154 }
155 axis.paint(&bounds, window, cx);
156
157 if self.grid {
159 Grid::new()
160 .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
161 .stroke(cx.theme().border)
162 .dash_array(&[px(4.), px(2.)])
163 .paint(&bounds, window);
164 }
165
166 let origin = bounds.origin;
168 let x_fn = x_fn.clone();
169 let open_fn = open_fn.clone();
170 let high_fn = high_fn.clone();
171 let low_fn = low_fn.clone();
172 let close_fn = close_fn.clone();
173
174 for d in &self.data {
175 let x_tick = x.tick(&x_fn(d));
176 let Some(x_tick) = x_tick else {
177 continue;
178 };
179
180 let open = open_fn(d);
182 let high = high_fn(d);
183 let low = low_fn(d);
184 let close = close_fn(d);
185
186 let open_y = y.tick(&open);
188 let high_y = y.tick(&high);
189 let low_y = y.tick(&low);
190 let close_y = y.tick(&close);
191
192 let (Some(open_y), Some(high_y), Some(low_y), Some(close_y)) =
193 (open_y, high_y, low_y, close_y)
194 else {
195 continue;
196 };
197
198 let is_bullish = close > open;
200 let color: Hsla = if is_bullish {
201 cx.theme().chart_bullish
202 } else {
203 cx.theme().chart_bearish
204 };
205
206 let center_x = x_tick + band_width / 2.;
208 let body_width = band_width * self.body_width_ratio;
209 let body_left = center_x - body_width / 2.;
210 let body_right = center_x + body_width / 2.;
211
212 let mut wick_builder = PathBuilder::stroke(px(1.));
214 wick_builder.move_to(origin_point(px(center_x), px(high_y), origin));
215 wick_builder.line_to(origin_point(px(center_x), px(low_y), origin));
216
217 if let Ok(path) = wick_builder.build() {
218 window.paint_path(path, color);
219 }
220
221 let (top, bottom) = if is_bullish {
225 (close_y, open_y)
226 } else {
227 (open_y, close_y)
228 };
229
230 let body_bounds = Bounds::from_corners(
231 origin_point(px(body_left), px(top), origin),
232 origin_point(px(body_right), px(bottom), origin),
233 );
234
235 window.paint_quad(fill(body_bounds, color));
236 }
237 }
238}