1use crate::error::ChartError;
4use crate::{
5 DEFAULT_COLOR, DEFAULT_HEIGHT, DEFAULT_PADDING_FRACTION, DEFAULT_TITLE_FONT_SIZE,
6 DEFAULT_WIDTH, TITLE_AREA_HEIGHT, ScaleType, extent_padded, validate_data_array,
7 validate_data_length, validate_dimensions, validate_positive,
8};
9use d3rs::color::D3Color;
10use d3rs::scale::{LinearScale, LogScale};
11use d3rs::shape::{BarConfig, BarDatum, render_bars};
12use d3rs::text::{VectorFontConfig, render_vector_text};
13use gpui::prelude::*;
14use gpui::*;
15
16#[derive(Debug, Clone)]
18pub struct BarChart {
19 categories: Vec<String>,
20 values: Vec<f64>,
21 title: Option<String>,
22 color: u32,
23 opacity: f32,
24 bar_gap: f32,
25 border_radius: f32,
26 width: f32,
27 height: f32,
28 y_scale_type: ScaleType,
29}
30
31impl BarChart {
32 pub fn title(mut self, title: impl Into<String>) -> Self {
34 self.title = Some(title.into());
35 self
36 }
37
38 pub fn color(mut self, hex: u32) -> Self {
48 self.color = hex;
49 self
50 }
51
52 pub fn opacity(mut self, opacity: f32) -> Self {
54 self.opacity = opacity.clamp(0.0, 1.0);
55 self
56 }
57
58 pub fn bar_gap(mut self, gap: f32) -> Self {
60 self.bar_gap = gap;
61 self
62 }
63
64 pub fn border_radius(mut self, radius: f32) -> Self {
66 self.border_radius = radius;
67 self
68 }
69
70 pub fn size(mut self, width: f32, height: f32) -> Self {
72 self.width = width;
73 self.height = height;
74 self
75 }
76
77 pub fn y_scale(mut self, scale: ScaleType) -> Self {
87 self.y_scale_type = scale;
88 self
89 }
90
91 pub fn build(self) -> Result<impl IntoElement, ChartError> {
93 if self.categories.is_empty() {
95 return Err(ChartError::EmptyData {
96 field: "categories",
97 });
98 }
99 validate_data_array(&self.values, "values")?;
100 validate_data_length(
101 self.categories.len(),
102 self.values.len(),
103 "categories",
104 "values",
105 )?;
106 validate_dimensions(self.width, self.height)?;
107
108 if self.y_scale_type == ScaleType::Log {
110 validate_positive(&self.values, "values")?;
111 }
112
113 let title_height = if self.title.is_some() {
115 TITLE_AREA_HEIGHT
116 } else {
117 0.0
118 };
119 let plot_height = self.height - title_height;
120
121 let (mut y_min, mut y_max) = extent_padded(&self.values, DEFAULT_PADDING_FRACTION);
123
124 if self.y_scale_type == ScaleType::Linear {
127 y_min = y_min.min(0.0);
128 y_max = y_max.max(0.0);
129 }
130
131 let x_scale = LinearScale::new()
133 .domain(0.0, self.categories.len() as f64)
134 .range(0.0, self.width as f64);
135
136 let data: Vec<BarDatum> = self
138 .categories
139 .iter()
140 .zip(self.values.iter())
141 .map(|(cat, &val)| BarDatum::new(cat.clone(), val))
142 .collect();
143
144 let config = BarConfig::new()
146 .fill_color(D3Color::from_hex(self.color))
147 .opacity(self.opacity)
148 .bar_gap(self.bar_gap)
149 .border_radius(self.border_radius);
150
151 let bar_element: AnyElement = match self.y_scale_type {
153 ScaleType::Linear => {
154 let y_scale = LinearScale::new()
155 .domain(y_min, y_max)
156 .range(plot_height as f64, 0.0);
157 render_bars(&x_scale, &y_scale, &data, self.width, plot_height, &config)
158 .into_any_element()
159 }
160 ScaleType::Log => {
161 let y_scale = LogScale::new()
162 .domain(y_min.max(1e-10), y_max)
163 .range(plot_height as f64, 0.0);
164 render_bars(&x_scale, &y_scale, &data, self.width, plot_height, &config)
165 .into_any_element()
166 }
167 };
168
169 let mut container = div()
171 .w(px(self.width))
172 .h(px(self.height))
173 .relative()
174 .flex()
175 .flex_col();
176
177 if let Some(title) = &self.title {
179 let font_config =
180 VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
181 container = container.child(
182 div()
183 .w_full()
184 .h(px(title_height))
185 .flex()
186 .justify_center()
187 .items_center()
188 .child(render_vector_text(title, &font_config)),
189 );
190 }
191
192 container = container.child(
194 div()
195 .w(px(self.width))
196 .h(px(plot_height))
197 .relative()
198 .child(bar_element),
199 );
200
201 Ok(container)
202 }
203}
204
205pub fn bar<S: AsRef<str>>(categories: &[S], values: &[f64]) -> BarChart {
222 BarChart {
223 categories: categories.iter().map(|s| s.as_ref().to_string()).collect(),
224 values: values.to_vec(),
225 title: None,
226 color: DEFAULT_COLOR,
227 opacity: 0.8,
228 bar_gap: 2.0,
229 border_radius: 2.0,
230 width: DEFAULT_WIDTH,
231 height: DEFAULT_HEIGHT,
232 y_scale_type: ScaleType::Linear,
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_bar_empty_categories() {
242 let empty_categories: Vec<&str> = vec![];
243 let result = bar(&empty_categories, &[1.0, 2.0, 3.0]).build();
244 assert!(matches!(
245 result,
246 Err(ChartError::EmptyData {
247 field: "categories"
248 })
249 ));
250 }
251
252 #[test]
253 fn test_bar_empty_values() {
254 let result = bar(&["A", "B", "C"], &[]).build();
255 assert!(matches!(
256 result,
257 Err(ChartError::EmptyData { field: "values" })
258 ));
259 }
260
261 #[test]
262 fn test_bar_data_length_mismatch() {
263 let result = bar(&["A", "B"], &[1.0, 2.0, 3.0]).build();
264 assert!(matches!(
265 result,
266 Err(ChartError::DataLengthMismatch {
267 x_field: "categories",
268 y_field: "values",
269 x_len: 2,
270 y_len: 3,
271 })
272 ));
273 }
274
275 #[test]
276 fn test_bar_invalid_value_nan() {
277 let result = bar(&["A", "B", "C"], &[1.0, f64::NAN, 3.0]).build();
278 assert!(matches!(
279 result,
280 Err(ChartError::InvalidData {
281 field: "values",
282 reason: "contains NaN or Infinity"
283 })
284 ));
285 }
286
287 #[test]
288 fn test_bar_successful_build() {
289 let categories = vec!["A", "B", "C", "D"];
290 let values = vec![10.0, 25.0, 15.0, 30.0];
291 let result = bar(&categories, &values)
292 .title("Test Bar Chart")
293 .color(0x2ca02c)
294 .build();
295 assert!(result.is_ok());
296 }
297
298 #[test]
299 fn test_bar_negative_values() {
300 let categories = vec!["A", "B", "C"];
301 let values = vec![-5.0, 10.0, -3.0];
302 let result = bar(&categories, &values).build();
303 assert!(result.is_ok());
304 }
305
306 #[test]
307 fn test_bar_builder_chain() {
308 let result = bar(&["X", "Y", "Z"], &[1.0, 2.0, 3.0])
309 .title("My Bar Chart")
310 .color(0xff0000)
311 .opacity(0.9)
312 .bar_gap(5.0)
313 .border_radius(4.0)
314 .size(800.0, 600.0)
315 .build();
316 assert!(result.is_ok());
317 }
318
319 #[test]
320 fn test_bar_log_y_scale() {
321 let categories = vec!["A", "B", "C", "D"];
322 let values = vec![10.0, 100.0, 1000.0, 10000.0];
323 let result = bar(&categories, &values)
324 .y_scale(ScaleType::Log)
325 .build();
326 assert!(result.is_ok());
327 }
328
329 #[test]
330 fn test_bar_log_y_scale_zero_value() {
331 let categories = vec!["A", "B", "C"];
332 let values = vec![0.0, 10.0, 100.0];
333 let result = bar(&categories, &values)
334 .y_scale(ScaleType::Log)
335 .build();
336 assert!(matches!(
337 result,
338 Err(ChartError::InvalidData {
339 field: "values",
340 reason: "contains non-positive values for log scale"
341 })
342 ));
343 }
344
345 #[test]
346 fn test_bar_log_y_scale_negative_value() {
347 let categories = vec!["A", "B", "C"];
348 let values = vec![-5.0, 10.0, 100.0];
349 let result = bar(&categories, &values)
350 .y_scale(ScaleType::Log)
351 .build();
352 assert!(matches!(
353 result,
354 Err(ChartError::InvalidData {
355 field: "values",
356 reason: "contains non-positive values for log scale"
357 })
358 ));
359 }
360
361 #[test]
362 fn test_bar_log_scale_with_title() {
363 let categories = vec!["Low", "Medium", "High"];
364 let values = vec![10.0, 100.0, 1000.0];
365 let result = bar(&categories, &values)
366 .title("Log Scale Bar Chart")
367 .y_scale(ScaleType::Log)
368 .color(0x2ca02c)
369 .build();
370 assert!(result.is_ok());
371 }
372}