1use polars::{
2 frame::DataFrame,
3 prelude::{col, lit, DataType, IntoLazy},
4};
5
6use crate::io::PlotlarsError;
7
8#[doc(hidden)]
9pub fn get_unique_groups(
10 data: &DataFrame,
11 group_col: &str,
12 sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
13) -> Vec<String> {
14 let unique_groups = data
15 .column(group_col)
16 .unwrap()
17 .unique()
18 .unwrap()
19 .cast(&DataType::String)
20 .unwrap();
21
22 let mut groups: Vec<String> = unique_groups
23 .str()
24 .unwrap()
25 .iter()
26 .map(|x| x.unwrap().to_string())
27 .collect();
28
29 if let Some(sort_fn) = sort_groups_by {
30 groups.sort_by(|a, b| sort_fn(a, b));
31 } else {
32 groups.sort();
33 }
34
35 groups
36}
37
38pub(crate) fn filter_data_by_group(
39 data: &DataFrame,
40 group_col: &str,
41 group_name: &str,
42) -> DataFrame {
43 data.clone()
44 .lazy()
45 .filter(col(group_col).cast(DataType::String).eq(lit(group_name)))
46 .collect()
47 .unwrap()
48}
49
50pub(crate) fn get_numeric_column(data: &DataFrame, column_name: &str) -> Vec<Option<f32>> {
51 try_get_numeric_column(data, column_name).unwrap()
52}
53
54pub(crate) fn get_string_column(data: &DataFrame, column_name: &str) -> Vec<Option<String>> {
55 try_get_string_column(data, column_name).unwrap()
56}
57
58pub(crate) fn try_get_numeric_column(
59 data: &DataFrame,
60 column_name: &str,
61) -> Result<Vec<Option<f32>>, PlotlarsError> {
62 let column = data
63 .column(column_name)
64 .map_err(|_| PlotlarsError::ColumnNotFound {
65 column: column_name.to_string(),
66 available: data
67 .get_column_names()
68 .iter()
69 .map(|s| s.to_string())
70 .collect(),
71 })?;
72
73 let casted =
74 column
75 .clone()
76 .cast(&DataType::Float32)
77 .map_err(|_| PlotlarsError::TypeMismatch {
78 column: column_name.to_string(),
79 expected: "numeric".to_string(),
80 actual: column.dtype().to_string(),
81 })?;
82
83 Ok(casted.f32().unwrap().to_vec())
84}
85
86pub(crate) fn try_get_string_column(
87 data: &DataFrame,
88 column_name: &str,
89) -> Result<Vec<Option<String>>, PlotlarsError> {
90 let column = data
91 .column(column_name)
92 .map_err(|_| PlotlarsError::ColumnNotFound {
93 column: column_name.to_string(),
94 available: data
95 .get_column_names()
96 .iter()
97 .map(|s| s.to_string())
98 .collect(),
99 })?;
100
101 let casted =
102 column
103 .clone()
104 .cast(&DataType::String)
105 .map_err(|_| PlotlarsError::TypeMismatch {
106 column: column_name.to_string(),
107 expected: "string-castable".to_string(),
108 actual: column.dtype().to_string(),
109 })?;
110
111 Ok(casted
112 .str()
113 .unwrap()
114 .iter()
115 .map(|x| x.map(|s| s.to_string()))
116 .collect())
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use polars::prelude::*;
123
124 #[test]
125 fn test_get_unique_groups_sorted() {
126 let df = df!["g" => ["b", "a", "c", "a"]].unwrap();
127 let result = get_unique_groups(&df, "g", None);
128 assert_eq!(result, vec!["a", "b", "c"]);
129 }
130
131 #[test]
132 fn test_get_unique_groups_custom_sort() {
133 let df = df!["g" => ["b", "a", "c"]].unwrap();
134 let result = get_unique_groups(&df, "g", Some(|a: &str, b: &str| b.cmp(a)));
135 assert_eq!(result, vec!["c", "b", "a"]);
136 }
137
138 #[test]
139 fn test_get_unique_groups_single_value() {
140 let df = df!["g" => ["x", "x", "x"]].unwrap();
141 let result = get_unique_groups(&df, "g", None);
142 assert_eq!(result, vec!["x"]);
143 }
144
145 #[test]
146 fn test_get_unique_groups_numeric_cast() {
147 let df = df!["g" => [1i32, 2, 1]].unwrap();
148 let result = get_unique_groups(&df, "g", None);
149 assert_eq!(result, vec!["1", "2"]);
150 }
151
152 #[test]
153 fn test_filter_matching() {
154 let df = df!["g" => ["a", "b", "a"], "v" => [1, 2, 3]].unwrap();
155 let filtered = filter_data_by_group(&df, "g", "a");
156 assert_eq!(filtered.height(), 2);
157 }
158
159 #[test]
160 fn test_filter_no_match() {
161 let df = df!["g" => ["a", "b"], "v" => [1, 2]].unwrap();
162 let filtered = filter_data_by_group(&df, "g", "z");
163 assert_eq!(filtered.height(), 0);
164 }
165
166 #[test]
167 fn test_filter_numeric_cast() {
168 let df = df!["g" => [1i32, 2, 1], "v" => [10, 20, 30]].unwrap();
169 let filtered = filter_data_by_group(&df, "g", "1");
170 assert_eq!(filtered.height(), 2);
171 }
172
173 #[test]
174 fn test_get_numeric_integers() {
175 let df = df!["x" => [1i32, 2, 3]].unwrap();
176 let result = get_numeric_column(&df, "x");
177 assert_eq!(result, vec![Some(1.0f32), Some(2.0), Some(3.0)]);
178 }
179
180 #[test]
181 fn test_get_numeric_with_nulls() {
182 let s = Series::new("x".into(), &[Some(1.0f64), None, Some(3.0)]);
183 let df = DataFrame::new(3, vec![s.into()]).unwrap();
184 let result = get_numeric_column(&df, "x");
185 assert_eq!(result.len(), 3);
186 assert!(result[0].is_some());
187 assert!(result[1].is_none());
188 assert!(result[2].is_some());
189 }
190
191 #[test]
192 fn test_get_numeric_floats() {
193 let df = df!["x" => [1.5f64, 2.5]].unwrap();
194 let result = get_numeric_column(&df, "x");
195 assert!((result[0].unwrap() - 1.5).abs() < 0.01);
196 assert!((result[1].unwrap() - 2.5).abs() < 0.01);
197 }
198
199 #[test]
200 fn test_get_string_basic() {
201 let df = df!["s" => ["a", "b"]].unwrap();
202 let result = get_string_column(&df, "s");
203 assert_eq!(result, vec![Some("a".to_string()), Some("b".to_string())]);
204 }
205
206 #[test]
207 fn test_get_string_with_nulls() {
208 let s = Series::new("s".into(), &[Some("a"), None::<&str>]);
209 let df = DataFrame::new(2, vec![s.into()]).unwrap();
210 let result = get_string_column(&df, "s");
211 assert_eq!(result.len(), 2);
212 assert_eq!(result[0], Some("a".to_string()));
213 assert!(result[1].is_none());
214 }
215}