Skip to main content

rgpui_component/chart/
candlestick_chart.rs

1use std::rc::Rc;
2
3use num_traits::{Num, ToPrimitive};
4use rgpui::{App, Bounds, Hsla, PathBuilder, Pixels, SharedString, Window, fill, px};
5use rgpui_component_macros::IntoPlot;
6
7use crate::{
8    ActiveTheme,
9    plot::{
10        AXIS_GAP, Grid, Plot, PlotAxis, origin_point,
11        scale::{Scale, ScaleBand, ScaleLinear, Sealed},
12    },
13};
14
15use super::build_band_labels;
16
17#[derive(IntoPlot)]
18pub struct CandlestickChart<T, X, Y>
19where
20    T: 'static,
21    X: PartialEq + Into<SharedString> + 'static,
22    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
23{
24    data: Vec<T>,
25    x: Option<Rc<dyn Fn(&T) -> X>>,
26    open: Option<Rc<dyn Fn(&T) -> Y>>,
27    high: Option<Rc<dyn Fn(&T) -> Y>>,
28    low: Option<Rc<dyn Fn(&T) -> Y>>,
29    close: Option<Rc<dyn Fn(&T) -> Y>>,
30    tick_margin: usize,
31    body_width_ratio: f32,
32    x_axis: bool,
33    grid: bool,
34}
35
36impl<T, X, Y> CandlestickChart<T, X, Y>
37where
38    X: PartialEq + Into<SharedString> + 'static,
39    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
40{
41    pub fn new<I>(data: I) -> Self
42    where
43        I: IntoIterator<Item = T>,
44    {
45        Self {
46            data: data.into_iter().collect(),
47            x: None,
48            open: None,
49            high: None,
50            low: None,
51            close: None,
52            tick_margin: 1,
53            body_width_ratio: 0.8,
54            x_axis: true,
55            grid: true,
56        }
57    }
58
59    pub fn x(mut self, x: impl Fn(&T) -> X + 'static) -> Self {
60        self.x = Some(Rc::new(x));
61        self
62    }
63
64    pub fn open(mut self, open: impl Fn(&T) -> Y + 'static) -> Self {
65        self.open = Some(Rc::new(open));
66        self
67    }
68
69    pub fn high(mut self, high: impl Fn(&T) -> Y + 'static) -> Self {
70        self.high = Some(Rc::new(high));
71        self
72    }
73
74    pub fn low(mut self, low: impl Fn(&T) -> Y + 'static) -> Self {
75        self.low = Some(Rc::new(low));
76        self
77    }
78
79    pub fn close(mut self, close: impl Fn(&T) -> Y + 'static) -> Self {
80        self.close = Some(Rc::new(close));
81        self
82    }
83
84    pub fn tick_margin(mut self, tick_margin: usize) -> Self {
85        self.tick_margin = tick_margin;
86        self
87    }
88
89    pub fn body_width_ratio(mut self, ratio: f32) -> Self {
90        self.body_width_ratio = ratio;
91        self
92    }
93
94    /// Show or hide the x-axis line and labels.
95    ///
96    /// Default is true.
97    pub fn x_axis(mut self, x_axis: bool) -> Self {
98        self.x_axis = x_axis;
99        self
100    }
101
102    pub fn grid(mut self, grid: bool) -> Self {
103        self.grid = grid;
104        self
105    }
106}
107
108impl<T, X, Y> Plot for CandlestickChart<T, X, Y>
109where
110    X: PartialEq + Into<SharedString> + 'static,
111    Y: Copy + PartialOrd + Num + ToPrimitive + Sealed + 'static,
112{
113    fn paint(&mut self, bounds: Bounds<Pixels>, window: &mut Window, cx: &mut App) {
114        let (Some(x_fn), Some(open_fn), Some(high_fn), Some(low_fn), Some(close_fn)) = (
115            self.x.as_ref(),
116            self.open.as_ref(),
117            self.high.as_ref(),
118            self.low.as_ref(),
119            self.close.as_ref(),
120        ) else {
121            return;
122        };
123
124        let width = bounds.size.width.as_f32();
125        let axis_gap = if self.x_axis { AXIS_GAP } else { 0. };
126        let height = bounds.size.height.as_f32() - axis_gap;
127
128        // X scale
129        let x = ScaleBand::new(self.data.iter().map(|v| x_fn(v)).collect(), vec![0., width])
130            .padding_inner(0.4)
131            .padding_outer(0.2);
132        let band_width = x.band_width();
133
134        // Y scale
135        let all_values: Vec<Y> = self
136            .data
137            .iter()
138            .flat_map(|d| vec![high_fn(d), low_fn(d), open_fn(d), close_fn(d)])
139            .collect();
140        let y = ScaleLinear::new(all_values, vec![height, 10.]);
141
142        // Draw X axis
143        let mut axis = PlotAxis::new().stroke(cx.theme().border);
144        if self.x_axis {
145            let labels = build_band_labels(
146                &self.data,
147                x_fn.as_ref(),
148                &x,
149                band_width,
150                self.tick_margin,
151                cx.theme().muted_foreground,
152            );
153            axis = axis.x(height).x_label(labels);
154        }
155        axis.paint(&bounds, window, cx);
156
157        // Draw grid
158        if self.grid {
159            Grid::new()
160                .y((0..=3).map(|i| height * i as f32 / 4.0).collect())
161                .stroke(cx.theme().border)
162                .dash_array(&[px(4.), px(2.)])
163                .paint(&bounds, window);
164        }
165
166        // Draw candlesticks
167        let origin = bounds.origin;
168        let x_fn = x_fn.clone();
169        let open_fn = open_fn.clone();
170        let high_fn = high_fn.clone();
171        let low_fn = low_fn.clone();
172        let close_fn = close_fn.clone();
173
174        for d in &self.data {
175            let x_tick = x.tick(&x_fn(d));
176            let Some(x_tick) = x_tick else {
177                continue;
178            };
179
180            // Get OHLC values for the current data point
181            let open = open_fn(d);
182            let high = high_fn(d);
183            let low = low_fn(d);
184            let close = close_fn(d);
185
186            // Convert values to pixel coordinates
187            let open_y = y.tick(&open);
188            let high_y = y.tick(&high);
189            let low_y = y.tick(&low);
190            let close_y = y.tick(&close);
191
192            let (Some(open_y), Some(high_y), Some(low_y), Some(close_y)) =
193                (open_y, high_y, low_y, close_y)
194            else {
195                continue;
196            };
197
198            // Determine if bullish (close > open) or bearish (close < open)
199            let is_bullish = close > open;
200            let color: Hsla = if is_bullish {
201                cx.theme().chart_bullish
202            } else {
203                cx.theme().chart_bearish
204            };
205
206            // Calculate candlestick body dimensions
207            let center_x = x_tick + band_width / 2.;
208            let body_width = band_width * self.body_width_ratio;
209            let body_left = center_x - body_width / 2.;
210            let body_right = center_x + body_width / 2.;
211
212            // Draw wick (high to low line)
213            let mut wick_builder = PathBuilder::stroke(px(1.));
214            wick_builder.move_to(origin_point(px(center_x), px(high_y), origin));
215            wick_builder.line_to(origin_point(px(center_x), px(low_y), origin));
216
217            if let Ok(path) = wick_builder.build() {
218                window.paint_path(path, color);
219            }
220
221            // Draw body (open to close rectangle)
222            // For bullish: top is close, bottom is open
223            // For bearish: top is open, bottom is close
224            let (top, bottom) = if is_bullish {
225                (close_y, open_y)
226            } else {
227                (open_y, close_y)
228            };
229
230            let body_bounds = Bounds::from_corners(
231                origin_point(px(body_left), px(top), origin),
232                origin_point(px(body_right), px(bottom), origin),
233            );
234
235            window.paint_quad(fill(body_bounds, color));
236        }
237    }
238}