Skip to main content

charts_rs/charts/
funnel_chart.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use super::Canvas;
14use super::canvas;
15use super::color::*;
16use super::common::*;
17use super::component::*;
18use super::params::*;
19use super::theme::{DEFAULT_Y_AXIS_WIDTH, Theme, get_default_theme_name, get_theme};
20use super::util::*;
21use crate::charts::measure_text_width_family;
22use charts_rs_derive::Chart;
23use std::sync::Arc;
24
25#[charts_rs_derive::chart_common_fields]
26#[derive(Clone, Debug, Default, Chart)]
27pub struct FunnelChart {
28    // x axis (required by derive – not rendered)
29    pub x_axis_data: Vec<String>,
30    pub x_axis_height: f32,
31    pub x_axis_stroke_color: Color,
32    pub x_axis_font_size: f32,
33    pub x_axis_font_color: Color,
34    pub x_axis_font_weight: Option<String>,
35    pub x_axis_name_gap: f32,
36    pub x_axis_name_rotate: f32,
37    pub x_axis_margin: Option<Box>,
38    pub x_boundary_gap: Option<bool>,
39
40    // y axis (required by derive)
41    y_axis_configs: Vec<YAxisConfig>,
42
43    // grid (required by derive)
44    grid_stroke_color: Color,
45    grid_stroke_width: f32,
46
47    // series (required by derive)
48    pub series_stroke_width: f32,
49    pub series_label_font_color: Color,
50    pub series_label_font_size: f32,
51    pub series_label_font_weight: Option<String>,
52    pub series_label_formatter: String,
53    /// Label position: `"inside"`, `"left"`, or `"right"` (default).
54    pub series_label_position: Option<String>,
55    pub series_colors: Vec<Color>,
56    pub series_symbol: Option<Symbol>,
57    pub series_smooth: bool,
58    pub series_fill: bool,
59
60    // ── Funnel-specific fields ────────────────────────────────────────────────
61    /// Vertical gap between trapezoids in pixels (default: 2).
62    pub funnel_gap: f32,
63
64    /// Horizontal alignment of all trapezoids (default: `Center`).
65    pub funnel_align: Align,
66
67    /// If `true`, smallest value at top; if `false` (default), largest at top.
68    pub sort_ascending: bool,
69
70    /// Minimum trapezoid width for the narrowest end, in pixels (default: 20).
71    pub min_width: f32,
72
73    /// Optional fade-in animation for the trapezoids and their labels. The
74    /// `delay` field is not used (all stages fade in together).
75    pub animation: Option<AnimationConfig>,
76}
77
78impl FunnelChart {
79    fn fill_default(&mut self) {
80        if self.funnel_gap <= 0.0 {
81            self.funnel_gap = 2.0;
82        }
83        if self.min_width <= 0.0 {
84            self.min_width = 20.0;
85        }
86        // default label position
87        if self.series_label_position.is_none() {
88            self.series_label_position = Some("right".to_string());
89        }
90    }
91
92    /// Creates a funnel chart with default theme.
93    pub fn new(series_list: Vec<Series>) -> FunnelChart {
94        FunnelChart::new_with_theme(series_list, &get_default_theme_name())
95    }
96
97    /// Creates a funnel chart with a custom theme.
98    pub fn new_with_theme(series_list: Vec<Series>, theme: &str) -> FunnelChart {
99        let mut c = FunnelChart {
100            series_list,
101            ..Default::default()
102        };
103        c.fill_theme(get_theme(theme));
104        c.fill_default();
105        c
106    }
107
108    /// Creates a funnel chart from a JSON string.
109    pub fn from_json(json: &str) -> canvas::Result<FunnelChart> {
110        let mut c = FunnelChart {
111            ..Default::default()
112        };
113        let value = c.fill_option(json)?;
114        if let Some(v) = get_f32_from_value(&value, "funnel_gap") {
115            c.funnel_gap = v;
116        }
117        if let Some(v) = get_f32_from_value(&value, "min_width") {
118            c.min_width = v;
119        }
120        if let Some(b) = get_bool_from_value(&value, "sort_ascending") {
121            c.sort_ascending = b;
122        }
123        if let Some(s) = get_string_from_value(&value, "series_label_position") {
124            c.series_label_position = Some(s);
125        }
126        if let Some(a) = get_align_from_value(&value, "funnel_align") {
127            c.funnel_align = a;
128        }
129        if let Some(anim) = value.get("animation")
130            && !anim.is_null()
131        {
132            let mut config = AnimationConfig::default();
133            if let Some(d) = get_usize_from_value(anim, "duration") {
134                config.duration = d as u32;
135            }
136            if let Some(e) = get_string_from_value(anim, "easing") {
137                config.easing = e;
138            }
139            if let Some(d) = get_usize_from_value(anim, "delay") {
140                config.delay = d as u32;
141            }
142            c.animation = Some(config);
143        }
144        c.fill_default();
145        Ok(c)
146    }
147
148    /// Renders the funnel chart to an SVG string.
149    pub fn svg(&self) -> canvas::Result<String> {
150        if self.series_list.is_empty() {
151            return Err(canvas::Error::Params {
152                message: "series_list is empty".to_string(),
153            });
154        }
155
156        let mut c = Canvas::new_width_xy(self.width, self.height, self.x, self.y);
157        self.render_background(c.child(Box::default()));
158        c.margin = self.margin.clone();
159
160        let title_height = self.render_title(c.child(Box::default()));
161        let legend_height = self.render_legend(c.child(Box::default()));
162        let axis_top = title_height.max(legend_height);
163
164        if axis_top > 0.0 {
165            c = c.child(Box {
166                top: axis_top,
167                ..Default::default()
168            });
169        }
170
171        let funnel_width = c.width();
172        let funnel_height = c.height();
173
174        // ── Collect & sort series ─────────────────────────────────────────────
175        // Tuple: (color_index, value, name).
176        let mut stages: Vec<(usize, f32, String)> = self
177            .series_list
178            .iter()
179            .enumerate()
180            .map(|(i, s)| {
181                let val: f32 = s.data_values().iter().copied().sum();
182                (s.index.unwrap_or(i), val, s.name.clone())
183            })
184            .collect();
185
186        if self.sort_ascending {
187            stages.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
188        } else {
189            stages.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
190        }
191
192        let max_val = stages
193            .iter()
194            .map(|(_, v, _)| *v)
195            .fold(f32::NEG_INFINITY, f32::max);
196        if max_val <= 0.0 {
197            return c.svg();
198        }
199        let total: f32 = stages.iter().map(|(_, v, _)| *v).sum();
200
201        let n = stages.len();
202        let gap = self.funnel_gap;
203        let stage_h = (funnel_height - (n as f32 - 1.0) * gap) / n as f32;
204
205        let label_pos = self.series_label_position.as_deref().unwrap_or("right");
206        let label_font_size = self.series_label_font_size;
207        let label_color = self.series_label_font_color;
208        let mut formatter = self.series_label_formatter.clone();
209        if formatter.is_empty() {
210            formatter = "{a}: {c}".to_string();
211        }
212        let anim_class = self.animation.as_ref().map(|_| "funnel-anim".to_string());
213
214        for (stage_idx, (color_idx, val, name)) in stages.iter().enumerate() {
215            let top_w = (val / max_val) * funnel_width;
216            // bottom width = next stage's width (or min_width for last stage)
217            let bot_w = if stage_idx + 1 < n {
218                let next_val = stages[stage_idx + 1].1;
219                ((next_val / max_val) * funnel_width).max(self.min_width)
220            } else {
221                self.min_width
222            };
223
224            let y_top = stage_idx as f32 * (stage_h + gap);
225            let y_bot = y_top + stage_h;
226
227            // horizontal offset per alignment
228            let (x_left_top, x_left_bot) = match self.funnel_align {
229                Align::Left => (0.0, 0.0),
230                Align::Right => (funnel_width - top_w, funnel_width - bot_w),
231                _ => ((funnel_width - top_w) / 2.0, (funnel_width - bot_w) / 2.0),
232            };
233            let x_right_top = x_left_top + top_w;
234            let x_right_bot = x_left_bot + bot_w;
235
236            let color = get_color(&self.series_colors, *color_idx);
237
238            c.polygon(Polygon {
239                color: Some(color),
240                fill: Some(color),
241                points: vec![
242                    (x_left_top, y_top).into(),
243                    (x_right_top, y_top).into(),
244                    (x_right_bot, y_bot).into(),
245                    (x_left_bot, y_bot).into(),
246                ],
247                class: anim_class.clone(),
248                ..Default::default()
249            });
250
251            let label_option = LabelOption {
252                series_name: name.clone(),
253                value: *val,
254                percentage: if total > 0.0 { val / total } else { 0.0 },
255                formatter: formatter.clone(),
256                ..Default::default()
257            };
258            let label_text = label_option.format();
259
260            let mid_y = (y_top + y_bot) / 2.0;
261
262            match label_pos {
263                "inside" => {
264                    // Center text inside the trapezoid
265                    let mid_x = (x_left_top + x_right_top) / 2.0;
266                    let mut text_x = mid_x;
267                    if let Ok(b) =
268                        measure_text_width_family(&self.font_family, label_font_size, &label_text)
269                    {
270                        text_x -= b.width() / 2.0;
271                    }
272                    c.text(Text {
273                        text: label_text,
274                        font_family: Some(self.font_family.clone()),
275                        font_color: Some(label_color),
276                        font_size: Some(label_font_size),
277                        font_weight: self.series_label_font_weight.clone(),
278                        dominant_baseline: Some("central".to_string()),
279                        x: Some(text_x),
280                        y: Some(mid_y),
281                        class: anim_class.clone(),
282                        ..Default::default()
283                    });
284                }
285                "left" => {
286                    let x_edge = x_left_top.min(x_left_bot) - 5.0;
287                    let mut text_x = x_edge;
288                    if let Ok(b) =
289                        measure_text_width_family(&self.font_family, label_font_size, &label_text)
290                    {
291                        text_x -= b.width();
292                    }
293                    c.text(Text {
294                        text: label_text,
295                        font_family: Some(self.font_family.clone()),
296                        font_color: Some(label_color),
297                        font_size: Some(label_font_size),
298                        font_weight: self.series_label_font_weight.clone(),
299                        dominant_baseline: Some("central".to_string()),
300                        x: Some(text_x.max(0.0)),
301                        y: Some(mid_y),
302                        class: anim_class.clone(),
303                        ..Default::default()
304                    });
305                }
306                _ => {
307                    // "right" (default): to the right of the widest edge
308                    let x_edge = x_right_top.max(x_right_bot) + 5.0;
309                    c.text(Text {
310                        text: label_text,
311                        font_family: Some(self.font_family.clone()),
312                        font_color: Some(label_color),
313                        font_size: Some(label_font_size),
314                        font_weight: self.series_label_font_weight.clone(),
315                        dominant_baseline: Some("central".to_string()),
316                        x: Some(x_edge),
317                        y: Some(mid_y),
318                        class: anim_class.clone(),
319                        ..Default::default()
320                    });
321                }
322            }
323        }
324
325        if let Some(ref anim) = self.animation {
326            let css = format!(
327                "@keyframes funnel-fade{{from{{opacity:0}}to{{opacity:1}}}} \
328                 .funnel-anim{{animation:funnel-fade {}ms {} both}}",
329                anim.duration, anim.easing
330            );
331            c.svg_with_style(&css)
332        } else {
333            c.svg()
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::FunnelChart;
341    use crate::Series;
342    use pretty_assertions::assert_eq;
343
344    fn make_series() -> Vec<Series> {
345        vec![
346            ("Impression", vec![60000.0]).into(),
347            ("Click", vec![40000.0]).into(),
348            ("Inquiry", vec![20000.0]).into(),
349            ("Order", vec![8000.0]).into(),
350            ("Re-order", vec![2000.0]).into(),
351        ]
352    }
353
354    #[test]
355    fn funnel_chart_basic() {
356        let chart = FunnelChart::new(make_series());
357        assert_eq!(
358            include_str!("../../asset/funnel_chart/basic.svg"),
359            chart.svg().unwrap()
360        );
361    }
362
363    #[test]
364    fn funnel_chart_inside_label() {
365        let mut chart = FunnelChart::new(make_series());
366        chart.title_text = "Conversion Funnel".to_string();
367        chart.series_label_position = Some("inside".to_string());
368        assert_eq!(
369            include_str!("../../asset/funnel_chart/inside_label.svg"),
370            chart.svg().unwrap()
371        );
372    }
373
374    #[test]
375    fn funnel_chart_basic_json() {
376        let chart = FunnelChart::from_json(
377            r##"{
378                "title_text": "Funnel Chart",
379                "series_label_position": "inside",
380                "funnel_gap": 4,
381                "series_list": [
382                    {"name": "Impression", "data": [60000]},
383                    {"name": "Click",      "data": [40000]},
384                    {"name": "Inquiry",    "data": [20000]},
385                    {"name": "Order",      "data": [8000]},
386                    {"name": "Re-order",   "data": [2000]}
387                ]
388            }"##,
389        )
390        .unwrap();
391        assert_eq!(
392            include_str!("../../asset/funnel_chart/basic_json.svg"),
393            chart.svg().unwrap()
394        );
395    }
396
397    #[test]
398    fn funnel_chart_animation() {
399        let mut chart = FunnelChart::new(make_series());
400        chart.animation = Some(super::AnimationConfig {
401            duration: 800,
402            easing: "ease-in".to_string(),
403            delay: 0,
404        });
405        let svg = chart.svg().unwrap();
406        assert!(
407            svg.contains("funnel-fade"),
408            "missing @keyframes funnel-fade"
409        );
410        assert!(
411            svg.contains(r#"class="funnel-anim""#),
412            "missing class on trapezoid"
413        );
414        assert!(svg.contains("800ms ease-in"), "missing duration/easing");
415    }
416}