1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use polars_core::prelude::*;
use polars_lazy::prelude::*;

pub fn cut(
    s: Series,
    bins: Vec<f32>,
    labels: Option<Vec<&str>>,
    break_point_label: Option<&str>,
    category_label: Option<&str>,
) -> PolarsResult<DataFrame> {
    let var_name = s.name();

    let breakpoint_str = if let Some(label) = break_point_label {
        label
    } else {
        &"break_point"
    };

    let category_str = if let Some(label) = category_label {
        label
    } else {
        &"category"
    };

    let cuts_df = df![
        breakpoint_str => Series::new(breakpoint_str, &bins)
            .extend_constant(AnyValue::Float64(f64::INFINITY), 1)?
    ]?;

    let cuts_df = if let Some(labels) = labels {
        if labels.len() != (bins.len() + 1) {
            return Err(PolarsError::ShapeMisMatch(
                "Labels count must equal bins count".into(),
            ));
        }

        cuts_df
            .lazy()
            .with_column(lit(Series::new(category_str, labels)))
    } else {
        cuts_df.lazy().with_column(
            format_str(
                "({}, {}]",
                [
                    col(breakpoint_str).shift_and_fill(1, lit(f64::NEG_INFINITY)),
                    col(breakpoint_str),
                ],
            )?
            .alias(category_str),
        )
    }
    .collect()?;

    let cuts = cuts_df
        .lazy()
        .with_columns([col(category_str).cast(DataType::Categorical(None))])
        .collect()?;

    s.cast(&DataType::Float64)?
        .sort(false)
        .into_frame()
        .join_asof(
            &cuts,
            var_name,
            breakpoint_str,
            AsofStrategy::Forward,
            None,
            None,
        )
}

#[test]
fn test_cut() -> PolarsResult<()> {
    let samples: Vec<f32> = (0..12).map(|i| -3.0 + i as f32 * 0.5).collect();
    let series = Series::new("a", samples);

    let out = cut(series, vec![-1.0, 1.0], None, None, None)?;

    let expected = df!(
        "a"           => [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5],
        "break_point" => [-1.0, -1.0, -1.0, -1.0, -1.0,  1.0, 1.0, 1.0, 1.0, f64::INFINITY, f64::INFINITY, f64::INFINITY],
        "category"    => [
            "(-inf, -1.0]",
            "(-inf, -1.0]",
            "(-inf, -1.0]",
            "(-inf, -1.0]",
            "(-inf, -1.0]",
            "(-1.0, 1.0]",
            "(-1.0, 1.0]",
            "(-1.0, 1.0]",
            "(-1.0, 1.0]",
            "(1.0, inf]",
            "(1.0, inf]",
            "(1.0, inf]"
        ]
    )?;

    assert!(out.frame_equal_missing(&expected));

    Ok(())
}