1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use polars_core::prelude::*;
use polars_lazy::prelude::*;
pub fn cut(
s: Series,
bins: Vec<f32>,
labels: Option<Vec<&str>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
) -> PolarsResult<DataFrame> {
let var_name = s.name();
let breakpoint_str = if let Some(label) = break_point_label {
label
} else {
&"break_point"
};
let category_str = if let Some(label) = category_label {
label
} else {
&"category"
};
let cuts_df = df![
breakpoint_str => Series::new(breakpoint_str, &bins)
.extend_constant(AnyValue::Float64(f64::INFINITY), 1)?
]?;
let cuts_df = if let Some(labels) = labels {
if labels.len() != (bins.len() + 1) {
return Err(PolarsError::ShapeMisMatch(
"Labels count must equal bins count".into(),
));
}
cuts_df
.lazy()
.with_column(lit(Series::new(category_str, labels)))
} else {
cuts_df.lazy().with_column(
format_str(
"({}, {}]",
[
col(breakpoint_str).shift_and_fill(1, lit(f64::NEG_INFINITY)),
col(breakpoint_str),
],
)?
.alias(category_str),
)
}
.collect()?;
let cuts = cuts_df
.lazy()
.with_columns([col(category_str).cast(DataType::Categorical(None))])
.collect()?;
s.cast(&DataType::Float64)?
.sort(false)
.into_frame()
.join_asof(
&cuts,
var_name,
breakpoint_str,
AsofStrategy::Forward,
None,
None,
)
}
#[test]
fn test_cut() -> PolarsResult<()> {
let samples: Vec<f32> = (0..12).map(|i| -3.0 + i as f32 * 0.5).collect();
let series = Series::new("a", samples);
let out = cut(series, vec![-1.0, 1.0], None, None, None)?;
let expected = df!(
"a" => [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5],
"break_point" => [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, f64::INFINITY, f64::INFINITY, f64::INFINITY],
"category" => [
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(1.0, inf]",
"(1.0, inf]",
"(1.0, inf]"
]
)?;
assert!(out.frame_equal_missing(&expected));
Ok(())
}