gpui_component/chart/
candlestick_chart.rs

1use std::rc::Rc;
2
3use gpui::{App, Bounds, Hsla, PathBuilder, Pixels, SharedString, TextAlign, Window, fill, px};
4use gpui_component_macros::IntoPlot;
5use num_traits::{Num, ToPrimitive};
6
7use crate::{
8    ActiveTheme, PixelsExt,
9    plot::{
10        AXIS_GAP, AxisText, Grid, Plot, PlotAxis, origin_point,
11        scale::{Scale, ScaleBand, ScaleLinear, Sealed},
12    },
13};
14
15#[derive(IntoPlot)]
16pub struct CandlestickChart<T, X, Y>
17where
18    T: 'static,
19    X: PartialEq + Into<SharedString> + 'static,
20    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
21{
22    data: Vec<T>,
23    x: Option<Rc<dyn Fn(&T) -> X>>,
24    open: Option<Rc<dyn Fn(&T) -> Y>>,
25    high: Option<Rc<dyn Fn(&T) -> Y>>,
26    low: Option<Rc<dyn Fn(&T) -> Y>>,
27    close: Option<Rc<dyn Fn(&T) -> Y>>,
28    tick_margin: usize,
29    body_width_ratio: f32,
30}
31
32impl<T, X, Y> CandlestickChart<T, X, Y>
33where
34    X: PartialEq + Into<SharedString> + 'static,
35    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
36{
37    pub fn new<I>(data: I) -> Self
38    where
39        I: IntoIterator<Item = T>,
40    {
41        Self {
42            data: data.into_iter().collect(),
43            x: None,
44            open: None,
45            high: None,
46            low: None,
47            close: None,
48            tick_margin: 1,
49            body_width_ratio: 0.8,
50        }
51    }
52
53    pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
54        self.x = Some(Rc::new(x));
55        self
56    }
57
58    pub fn open(mut self, open: impl Fn(&T) -> Y + 'static) -> Self {
59        self.open = Some(Rc::new(open));
60        self
61    }
62
63    pub fn high(mut self, high: impl Fn(&T) -> Y + 'static) -> Self {
64        self.high = Some(Rc::new(high));
65        self
66    }
67
68    pub fn low(mut self, low: impl Fn(&T) -> Y + 'static) -> Self {
69        self.low = Some(Rc::new(low));
70        self
71    }
72
73    pub fn close(mut self, close: impl Fn(&T) -> Y + 'static) -> Self {
74        self.close = Some(Rc::new(close));
75        self
76    }
77
78    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
79        self.tick_margin = tick_margin;
80        self
81    }
82
83    pub fn body_width_ratio(mut self, ratio: f32) -> Self {
84        self.body_width_ratio = ratio;
85        self
86    }
87}
88
89impl<T, X, Y> Plot for CandlestickChart<T, X, Y>
90where
91    X: PartialEq + Into<SharedString> + 'static,
92    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
93{
94    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
95        let (Some(x_fn), Some(open_fn), Some(high_fn), Some(low_fn), Some(close_fn)) = (
96            self.x.as_ref(),
97            self.open.as_ref(),
98            self.high.as_ref(),
99            self.low.as_ref(),
100            self.close.as_ref(),
101        ) else {
102            return;
103        };
104
105        let width = bounds.size.width.as_f32();
106        let height = bounds.size.height.as_f32() - AXIS_GAP;
107
108        // X scale
109        let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
110            .padding_inner(0.4)
111            .padding_outer(0.2);
112        let band_width = x.band_width();
113
114        // Y scale
115        let all_values: Vec<Y> = self
116            .data
117            .iter()
118            .flat_map(|d| vec![high_fn(d), low_fn(d), open_fn(d), close_fn(d)])
119            .collect();
120        let y = ScaleLinear::new(all_values, vec![height, 10.]);
121
122        // Draw X axis
123        let x_label = self.data.iter().enumerate().filter_map(|(i, d)| {
124            if (i + 1) % self.tick_margin == 0 {
125                x.tick(&x_fn(d)).map(|x_tick| {
126                    AxisText::new(
127                        x_fn(d).into(),
128                        x_tick + band_width / 2.,
129                        cx.theme().muted_foreground,
130                    )
131                    .align(TextAlign::Center)
132                })
133            } else {
134                None
135            }
136        });
137
138        PlotAxis::new()
139            .x(height)
140            .x_label(x_label)
141            .stroke(cx.theme().border)
142            .paint(&bounds, window, cx);
143
144        // Draw grid
145        Grid::new()
146            .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
147            .stroke(cx.theme().border)
148            .dash_array(&[px(4.), px(2.)])
149            .paint(&bounds, window);
150
151        // Draw candlesticks
152        let origin = bounds.origin;
153        let x_fn = x_fn.clone();
154        let open_fn = open_fn.clone();
155        let high_fn = high_fn.clone();
156        let low_fn = low_fn.clone();
157        let close_fn = close_fn.clone();
158
159        for d in &self.data {
160            let x_tick = x.tick(&x_fn(d));
161            let Some(x_tick) = x_tick else {
162                continue;
163            };
164
165            // Get OHLC values for the current data point
166            let open = open_fn(d);
167            let high = high_fn(d);
168            let low = low_fn(d);
169            let close = close_fn(d);
170
171            // Convert values to pixel coordinates
172            let open_y = y.tick(&open);
173            let high_y = y.tick(&high);
174            let low_y = y.tick(&low);
175            let close_y = y.tick(&close);
176
177            let (Some(open_y), Some(high_y), Some(low_y), Some(close_y)) =
178                (open_y, high_y, low_y, close_y)
179            else {
180                continue;
181            };
182
183            // Determine if bullish (close > open) or bearish (close < open)
184            let is_bullish = close > open;
185            let color: Hsla = if is_bullish {
186                cx.theme().bullish
187            } else {
188                cx.theme().bearish
189            };
190
191            // Calculate candlestick body dimensions
192            let center_x = x_tick + band_width / 2.;
193            let body_width = band_width * self.body_width_ratio;
194            let body_left = center_x - body_width / 2.;
195            let body_right = center_x + body_width / 2.;
196
197            // Draw wick (high to low line)
198            let mut wick_builder = PathBuilder::stroke(px(1.));
199            wick_builder.move_to(origin_point(px(center_x), px(high_y), origin));
200            wick_builder.line_to(origin_point(px(center_x), px(low_y), origin));
201
202            if let Ok(path) = wick_builder.build() {
203                window.paint_path(path, color);
204            }
205
206            // Draw body (open to close rectangle)
207            // For bullish: top is close, bottom is open
208            // For bearish: top is open, bottom is close
209            let (top, bottom) = if is_bullish {
210                (close_y, open_y)
211            } else {
212                (open_y, close_y)
213            };
214
215            let body_bounds = Bounds::from_corners(
216                origin_point(px(body_left), px(top), origin),
217                origin_point(px(body_right), px(bottom), origin),
218            );
219
220            window.paint_quad(fill(body_bounds, color));
221        }
222    }
223}