Skip to main content

charts_rs/charts/
waterfall_chart.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use super::Canvas;
14use super::canvas;
15use super::color::*;
16use super::common::*;
17use super::component::*;
18use super::params::*;
19use super::theme::{DEFAULT_Y_AXIS_WIDTH, Theme, get_default_theme_name, get_theme};
20use super::util::*;
21use crate::charts::measure_text_width_family;
22use charts_rs_derive::Chart;
23use std::sync::Arc;
24
25// ── Data types ────────────────────────────────────────────────────────────────
26
27/// A single bar in the waterfall chart.
28#[derive(Clone, Debug, Default)]
29pub struct WaterfallData {
30    /// The numeric value for this bar (positive = increase, negative = decrease).
31    /// For `is_total = true` bars this is the cumulative value to display (usually
32    /// computed automatically, but can be set explicitly if you want to override it).
33    pub value: f32,
34    /// When `true` this bar is rendered as a "total" bar that starts at `0` and
35    /// spans to the current running sum.  Colors it with `total_color`.
36    pub is_total: bool,
37}
38
39impl From<f32> for WaterfallData {
40    fn from(value: f32) -> Self {
41        WaterfallData {
42            value,
43            is_total: false,
44        }
45    }
46}
47
48impl From<(f32, bool)> for WaterfallData {
49    fn from(v: (f32, bool)) -> Self {
50        WaterfallData {
51            value: v.0,
52            is_total: v.1,
53        }
54    }
55}
56
57// ── WaterfallChart ────────────────────────────────────────────────────────────
58
59#[derive(Clone, Debug, Default, Chart)]
60pub struct WaterfallChart {
61    pub width: f32,
62    pub height: f32,
63    pub x: f32,
64    pub y: f32,
65    pub margin: Box,
66    // dummy – required by #[derive(Chart)]
67    pub series_list: Vec<Series>,
68    pub font_family: String,
69    pub background_color: Color,
70    pub is_light: bool,
71
72    // title
73    pub title_text: String,
74    pub title_font_size: f32,
75    pub title_font_color: Color,
76    pub title_font_weight: Option<String>,
77    pub title_margin: Option<Box>,
78    pub title_align: Align,
79    pub title_height: f32,
80
81    // sub title
82    pub sub_title_text: String,
83    pub sub_title_font_size: f32,
84    pub sub_title_font_color: Color,
85    pub sub_title_font_weight: Option<String>,
86    pub sub_title_margin: Option<Box>,
87    pub sub_title_align: Align,
88    pub sub_title_height: f32,
89
90    // legend (required by derive – not shown by default)
91    pub legend_font_size: f32,
92    pub legend_font_color: Color,
93    pub legend_font_weight: Option<String>,
94    pub legend_align: Align,
95    pub legend_margin: Option<Box>,
96    pub legend_category: LegendCategory,
97    pub legend_show: Option<bool>,
98
99    // x axis
100    pub x_axis_data: Vec<String>,
101    pub x_axis_height: f32,
102    pub x_axis_stroke_color: Color,
103    pub x_axis_font_size: f32,
104    pub x_axis_font_color: Color,
105    pub x_axis_font_weight: Option<String>,
106    pub x_axis_name_gap: f32,
107    pub x_axis_name_rotate: f32,
108    pub x_axis_margin: Option<Box>,
109    pub x_axis_hidden: bool,
110    pub x_boundary_gap: Option<bool>,
111
112    // y axis
113    pub y_axis_hidden: bool,
114    y_axis_configs: Vec<YAxisConfig>,
115
116    // grid
117    grid_stroke_color: Color,
118    grid_stroke_width: f32,
119
120    // series (required by derive)
121    pub series_stroke_width: f32,
122    pub series_label_font_color: Color,
123    pub series_label_font_size: f32,
124    pub series_label_font_weight: Option<String>,
125    pub series_label_formatter: String,
126    pub series_colors: Vec<Color>,
127    pub series_symbol: Option<Symbol>,
128    pub series_smooth: bool,
129    pub series_fill: bool,
130
131    // ── Waterfall-specific fields ─────────────────────────────────────────────
132    /// The data points.  Each value is an increment/decrement, except for
133    /// entries where `is_total = true` which reset to 0 and show the running sum.
134    pub data: Vec<WaterfallData>,
135
136    /// Bar color for positive increments.  Defaults to the first `series_colors` entry.
137    pub increase_color: Color,
138
139    /// Bar color for negative increments.  Defaults to a warm red.
140    pub decrease_color: Color,
141
142    /// Bar color for "total" bars.  Defaults to the second `series_colors` entry.
143    pub total_color: Color,
144
145    /// Whether to show value labels above/below each bar (default: true).
146    pub label_show: bool,
147
148    /// Whether to draw a dashed connector line between adjacent bars (default: true).
149    pub connector_line_show: bool,
150
151    /// Fraction of each x-unit occupied by a bar (0..1, default: 0.6).
152    pub bar_width_ratio: f32,
153}
154
155impl WaterfallChart {
156    fn fill_default(&mut self) {
157        // legend hidden by default (no series names in the usual sense)
158        if self.legend_show.is_none() {
159            self.legend_show = Some(false);
160        }
161        if self.bar_width_ratio <= 0.0 {
162            self.bar_width_ratio = 0.6;
163        }
164        if self.increase_color.is_zero() {
165            self.increase_color = get_color(&self.series_colors, 0);
166        }
167        if self.total_color.is_zero() {
168            self.total_color = get_color(&self.series_colors, 1);
169        }
170        if self.decrease_color.is_zero() {
171            // warm red not in the default palette – hard-coded
172            self.decrease_color = (238, 102, 102).into(); // #EE6666
173        }
174    }
175
176    /// Creates a waterfall chart with default theme.
177    pub fn new(data: Vec<WaterfallData>, x_axis_data: Vec<String>) -> WaterfallChart {
178        WaterfallChart::new_with_theme(data, x_axis_data, &get_default_theme_name())
179    }
180
181    /// Creates a waterfall chart with a custom theme.
182    pub fn new_with_theme(
183        data: Vec<WaterfallData>,
184        x_axis_data: Vec<String>,
185        theme: &str,
186    ) -> WaterfallChart {
187        let mut c = WaterfallChart {
188            data,
189            x_axis_data,
190            label_show: true,
191            connector_line_show: true,
192            ..Default::default()
193        };
194        c.fill_theme(get_theme(theme));
195        c.fill_default();
196        c
197    }
198
199    /// Creates a waterfall chart from a JSON string.
200    pub fn from_json(json: &str) -> canvas::Result<WaterfallChart> {
201        let mut c = WaterfallChart {
202            label_show: true,
203            connector_line_show: true,
204            ..Default::default()
205        };
206        let value = c.fill_option(json)?;
207
208        if let Some(b) = get_bool_from_value(&value, "x_axis_hidden") {
209            c.x_axis_hidden = b;
210        }
211        if let Some(b) = get_bool_from_value(&value, "y_axis_hidden") {
212            c.y_axis_hidden = b;
213        }
214        if let Some(b) = get_bool_from_value(&value, "label_show") {
215            c.label_show = b;
216        }
217        if let Some(b) = get_bool_from_value(&value, "connector_line_show") {
218            c.connector_line_show = b;
219        }
220        if let Some(v) = get_f32_from_value(&value, "bar_width_ratio") {
221            c.bar_width_ratio = v;
222        }
223        if let Some(col) = get_color_from_value(&value, "increase_color") {
224            c.increase_color = col;
225        }
226        if let Some(col) = get_color_from_value(&value, "decrease_color") {
227            c.decrease_color = col;
228        }
229        if let Some(col) = get_color_from_value(&value, "total_color") {
230            c.total_color = col;
231        }
232
233        // parse data: array of either [value, is_total] or bare numbers
234        if let Some(arr) = value.get("data").and_then(|v| v.as_array()) {
235            let mut items = vec![];
236            for item in arr {
237                if let Some(pair) = item.as_array() {
238                    let val = pair.first().and_then(|v| v.as_f64()).unwrap_or(0.0) as f32;
239                    let is_total = pair.get(1).and_then(|v| v.as_bool()).unwrap_or(false);
240                    items.push(WaterfallData {
241                        value: val,
242                        is_total,
243                    });
244                } else if let Some(v) = item.as_f64() {
245                    items.push(WaterfallData {
246                        value: v as f32,
247                        is_total: false,
248                    });
249                }
250            }
251            c.data = items;
252        }
253        if let Some(x) = get_string_slice_from_value(&value, "x_axis_data") {
254            c.x_axis_data = x;
255        }
256
257        c.fill_default();
258        Ok(c)
259    }
260
261    /// Computes the cumulative running sum at each bar position.
262    ///
263    /// For a normal bar, the cumulative *before* is the sum of all previous
264    /// non-total bars.  For a total bar, the running sum is the value to display
265    /// (we auto-compute it unless the user supplied a non-zero explicit value).
266    fn compute_cumulative(&self) -> Vec<(f32, f32)> {
267        // Returns (bar_bottom, bar_top) in data coordinates for every bar.
268        // (bar_bottom, bar_top) with bar_top >= bar_bottom means visually the
269        // filled region runs from bar_bottom upward to bar_top.
270        let mut cum: f32 = 0.0;
271        let mut result = Vec::with_capacity(self.data.len());
272
273        for item in &self.data {
274            if item.is_total {
275                // Total bar: visual range is [0, cumulative_sum].
276                // If the user set an explicit value, use it; otherwise auto-compute.
277                let display = if item.value != 0.0 { item.value } else { cum };
278                result.push((0.0_f32, display));
279                // Totals reset-accumulate to match the display value for subsequent deltas
280                cum = display;
281            } else {
282                let bottom = cum;
283                let top = cum + item.value;
284                result.push((bottom, top));
285                cum = top;
286            }
287        }
288        result
289    }
290
291    /// Renders the waterfall chart to an SVG string.
292    pub fn svg(&self) -> canvas::Result<String> {
293        if self.data.is_empty() {
294            return Err(canvas::Error::Params {
295                message: "data is empty".to_string(),
296            });
297        }
298
299        let mut c = Canvas::new_width_xy(self.width, self.height, self.x, self.y);
300        self.render_background(c.child(Box::default()));
301
302        let mut x_axis_height = self.x_axis_height;
303        if self.x_axis_hidden {
304            x_axis_height = 0.0;
305        }
306        c.margin = self.margin.clone();
307
308        let title_height = self.render_title(c.child(Box::default()));
309        let legend_height = self.render_legend(c.child(Box::default()));
310        let axis_top = title_height.max(legend_height);
311
312        // ── Compute axis values ───────────────────────────────────────────────
313        let cum = self.compute_cumulative();
314        // Collect all boundary values for the y-axis range
315        let all_vals: Vec<f32> = cum.iter().flat_map(|(b, t)| [*b, *t]).collect();
316
317        let y_axis_config = &self.y_axis_configs[0];
318        let y_axis_values = get_axis_values(AxisValueParams {
319            data_list: all_vals,
320            split_number: y_axis_config.axis_split_number,
321            reverse: Some(true),
322            min: y_axis_config.axis_min,
323            max: y_axis_config.axis_max,
324            thousands_format: y_axis_config
325                .axis_formatter
326                .as_deref()
327                .unwrap_or("")
328                .contains(THOUSANDS_FORMAT_LABEL),
329            scale: y_axis_config.axis_scale.clone(),
330        });
331
332        let mut y_axis_width = if self.y_axis_hidden {
333            0.0
334        } else if let Some(w) = y_axis_config.axis_width {
335            w
336        } else {
337            let formatter = y_axis_config.axis_formatter.clone().unwrap_or_default();
338            let longest = y_axis_values
339                .data
340                .iter()
341                .max_by_key(|s| s.len())
342                .map(|s| s.as_str())
343                .unwrap_or("");
344            let label = format_string(longest, &formatter);
345            measure_text_width_family(&self.font_family, y_axis_config.axis_font_size, &label)
346                .map(|b| b.width() + 5.0)
347                .unwrap_or(DEFAULT_Y_AXIS_WIDTH)
348        };
349        if self.y_axis_hidden {
350            y_axis_width = 0.0;
351        }
352
353        let axis_height = c.height() - x_axis_height - axis_top;
354        let axis_width = c.width() - y_axis_width;
355
356        if axis_top > 0.0 {
357            c = c.child(Box {
358                top: axis_top,
359                ..Default::default()
360            });
361        }
362
363        // ── Render grid / axes ────────────────────────────────────────────────
364        self.render_grid(
365            c.child(Box {
366                left: y_axis_width,
367                ..Default::default()
368            }),
369            axis_width,
370            axis_height,
371        );
372
373        if y_axis_width > 0.0 {
374            self.render_y_axis(
375                c.child(Box::default()),
376                y_axis_values.data.clone(),
377                axis_height,
378                y_axis_width,
379                0,
380            );
381        }
382
383        if !self.x_axis_hidden {
384            self.render_x_axis(
385                c.child(Box {
386                    top: c.height() - x_axis_height,
387                    left: y_axis_width,
388                    ..Default::default()
389                }),
390                self.x_axis_data.clone(),
391                axis_width,
392            );
393        }
394
395        // ── Render bars ───────────────────────────────────────────────────────
396        let n = self.data.len();
397        let max_height = c.height() - x_axis_height;
398        let unit_w = axis_width / n as f32;
399        let bar_w = unit_w * self.bar_width_ratio;
400        let bar_margin = (unit_w - bar_w) / 2.0;
401
402        // Label format – {c} = value, {a} = series name (empty here)
403        let formatter = if self.series_label_formatter.is_empty() {
404            "{c}".to_string()
405        } else {
406            self.series_label_formatter.clone()
407        };
408
409        let mut draw_c = c.child(Box {
410            left: y_axis_width,
411            ..Default::default()
412        });
413
414        let zero_y = y_axis_values.get_offset_height(0.0, max_height);
415
416        for (i, item) in self.data.iter().enumerate() {
417            let (bar_bot_val, bar_top_val) = cum[i];
418
419            // Visual top is the larger value (smaller y-pixel)
420            let high_val = bar_bot_val.max(bar_top_val);
421            let low_val = bar_bot_val.min(bar_top_val);
422
423            let y_high = y_axis_values.get_offset_height(high_val, max_height);
424            let y_low = y_axis_values.get_offset_height(low_val, max_height);
425            let bar_h = (y_low - y_high).max(1.0);
426
427            let x_left = i as f32 * unit_w + bar_margin;
428
429            let color = if item.is_total {
430                self.total_color
431            } else if item.value >= 0.0 {
432                self.increase_color
433            } else {
434                self.decrease_color
435            };
436
437            draw_c.rect(Rect {
438                color: Some(color),
439                fill: Some(color.into()),
440                left: x_left,
441                top: y_high,
442                width: bar_w,
443                height: bar_h,
444                rx: Some(2.0),
445                ry: Some(2.0),
446                ..Default::default()
447            });
448
449            // ── Value label ───────────────────────────────────────────────────
450            if self.label_show {
451                let label_opt = LabelOption {
452                    value: item.value.abs(),
453                    formatter: formatter.clone(),
454                    ..Default::default()
455                };
456                let label_text = label_opt.format();
457                let label_y = if item.value >= 0.0 || item.is_total {
458                    y_high - 4.0 // above bar
459                } else {
460                    y_low + self.series_label_font_size + 2.0 // below bar
461                };
462                let mut label_x = x_left + bar_w / 2.0;
463                if let Ok(b) = measure_text_width_family(
464                    &self.font_family,
465                    self.series_label_font_size,
466                    &label_text,
467                ) {
468                    label_x -= b.width() / 2.0;
469                }
470                draw_c.text(Text {
471                    text: label_text,
472                    font_family: Some(self.font_family.clone()),
473                    font_color: Some(self.series_label_font_color),
474                    font_size: Some(self.series_label_font_size),
475                    font_weight: self.series_label_font_weight.clone(),
476                    x: Some(label_x),
477                    y: Some(label_y),
478                    ..Default::default()
479                });
480            }
481
482            // ── Connector line to next bar ────────────────────────────────────
483            if self.connector_line_show && i + 1 < n {
484                // Connect at the top of the running cumulative after this bar
485                let connector_y = y_axis_values.get_offset_height(bar_top_val, max_height);
486                let x_right = x_left + bar_w;
487                let next_x_left = (i + 1) as f32 * unit_w + bar_margin;
488
489                draw_c.line(Line {
490                    color: Some(self.grid_stroke_color),
491                    stroke_width: 1.0,
492                    stroke_dash_array: Some("4,4".to_string()),
493                    left: x_right,
494                    top: connector_y,
495                    right: next_x_left,
496                    bottom: connector_y,
497                });
498            }
499        }
500
501        // Zero baseline (solid line at y=0 when there are negative values)
502        let has_negative = self.data.iter().any(|d| d.value < 0.0);
503        if has_negative {
504            draw_c.line(Line {
505                color: Some(self.x_axis_stroke_color),
506                stroke_width: 1.0,
507                left: 0.0,
508                top: zero_y,
509                right: axis_width,
510                bottom: zero_y,
511                ..Default::default()
512            });
513        }
514
515        c.svg()
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::{WaterfallChart, WaterfallData};
522    use pretty_assertions::assert_eq;
523
524    fn make_data() -> (Vec<WaterfallData>, Vec<String>) {
525        let data = vec![
526            (900.0, false).into(),
527            (345.0, false).into(),
528            (393.0, false).into(),
529            (-108.0, false).into(),
530            (-154.0, false).into(),
531            (135.0, false).into(),
532            (-333.0, false).into(),
533            (548.0, false).into(),
534            (0.0, true).into(), // auto-total
535        ];
536        let labels = vec![
537            "Initial".to_string(),
538            "Product Revenue".to_string(),
539            "Service Revenue".to_string(),
540            "Purchases".to_string(),
541            "Marketing".to_string(),
542            "Other Income".to_string(),
543            "Payroll".to_string(),
544            "Other Expenses".to_string(),
545            "Profit".to_string(),
546        ];
547        (data, labels)
548    }
549
550    #[test]
551    fn waterfall_chart_basic() {
552        let (data, labels) = make_data();
553        let chart = WaterfallChart::new(data, labels);
554        assert_eq!(
555            include_str!("../../asset/waterfall_chart/basic.svg"),
556            chart.svg().unwrap()
557        );
558    }
559
560    #[test]
561    fn waterfall_chart_basic_json() {
562        let chart = WaterfallChart::from_json(
563            r#"{
564                "title_text": "Waterfall Chart",
565                "x_axis_data": ["Initial","Revenue","Services","Purchases","Marketing","Profit"],
566                "data": [
567                    [900, false],
568                    [345, false],
569                    [393, false],
570                    [-108, false],
571                    [-154, false],
572                    [0, true]
573                ]
574            }"#,
575        )
576        .unwrap();
577        assert_eq!(
578            include_str!("../../asset/waterfall_chart/basic_json.svg"),
579            chart.svg().unwrap()
580        );
581    }
582}