Skip to main content

plotlars_core/
data.rs

1use polars::{
2    frame::DataFrame,
3    prelude::{col, lit, DataType, IntoLazy},
4};
5
6use crate::io::PlotlarsError;
7
8#[doc(hidden)]
9pub fn get_unique_groups(
10    data: &DataFrame,
11    group_col: &str,
12    sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
13) -> Vec<String> {
14    let unique_groups = data
15        .column(group_col)
16        .unwrap()
17        .unique()
18        .unwrap()
19        .cast(&DataType::String)
20        .unwrap();
21
22    let mut groups: Vec<String> = unique_groups
23        .str()
24        .unwrap()
25        .iter()
26        .map(|x| x.unwrap().to_string())
27        .collect();
28
29    if let Some(sort_fn) = sort_groups_by {
30        groups.sort_by(|a, b| sort_fn(a, b));
31    } else {
32        groups.sort();
33    }
34
35    groups
36}
37
38pub(crate) fn filter_data_by_group(
39    data: &DataFrame,
40    group_col: &str,
41    group_name: &str,
42) -> DataFrame {
43    data.clone()
44        .lazy()
45        .filter(col(group_col).cast(DataType::String).eq(lit(group_name)))
46        .collect()
47        .unwrap()
48}
49
50pub(crate) fn get_numeric_column(data: &DataFrame, column_name: &str) -> Vec<Option<f32>> {
51    try_get_numeric_column(data, column_name).unwrap()
52}
53
54pub(crate) fn get_string_column(data: &DataFrame, column_name: &str) -> Vec<Option<String>> {
55    try_get_string_column(data, column_name).unwrap()
56}
57
58pub(crate) fn try_get_numeric_column(
59    data: &DataFrame,
60    column_name: &str,
61) -> Result<Vec<Option<f32>>, PlotlarsError> {
62    let column = data
63        .column(column_name)
64        .map_err(|_| PlotlarsError::ColumnNotFound {
65            column: column_name.to_string(),
66            available: data
67                .get_column_names()
68                .iter()
69                .map(|s| s.to_string())
70                .collect(),
71        })?;
72
73    let casted =
74        column
75            .clone()
76            .cast(&DataType::Float32)
77            .map_err(|_| PlotlarsError::TypeMismatch {
78                column: column_name.to_string(),
79                expected: "numeric".to_string(),
80                actual: column.dtype().to_string(),
81            })?;
82
83    Ok(casted.f32().unwrap().to_vec())
84}
85
86pub(crate) fn try_get_string_column(
87    data: &DataFrame,
88    column_name: &str,
89) -> Result<Vec<Option<String>>, PlotlarsError> {
90    let column = data
91        .column(column_name)
92        .map_err(|_| PlotlarsError::ColumnNotFound {
93            column: column_name.to_string(),
94            available: data
95                .get_column_names()
96                .iter()
97                .map(|s| s.to_string())
98                .collect(),
99        })?;
100
101    let casted =
102        column
103            .clone()
104            .cast(&DataType::String)
105            .map_err(|_| PlotlarsError::TypeMismatch {
106                column: column_name.to_string(),
107                expected: "string-castable".to_string(),
108                actual: column.dtype().to_string(),
109            })?;
110
111    Ok(casted
112        .str()
113        .unwrap()
114        .iter()
115        .map(|x| x.map(|s| s.to_string()))
116        .collect())
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use polars::prelude::*;
123
124    #[test]
125    fn test_get_unique_groups_sorted() {
126        let df = df!["g" => ["b", "a", "c", "a"]].unwrap();
127        let result = get_unique_groups(&df, "g", None);
128        assert_eq!(result, vec!["a", "b", "c"]);
129    }
130
131    #[test]
132    fn test_get_unique_groups_custom_sort() {
133        let df = df!["g" => ["b", "a", "c"]].unwrap();
134        let result = get_unique_groups(&df, "g", Some(|a: &str, b: &str| b.cmp(a)));
135        assert_eq!(result, vec!["c", "b", "a"]);
136    }
137
138    #[test]
139    fn test_get_unique_groups_single_value() {
140        let df = df!["g" => ["x", "x", "x"]].unwrap();
141        let result = get_unique_groups(&df, "g", None);
142        assert_eq!(result, vec!["x"]);
143    }
144
145    #[test]
146    fn test_get_unique_groups_numeric_cast() {
147        let df = df!["g" => [1i32, 2, 1]].unwrap();
148        let result = get_unique_groups(&df, "g", None);
149        assert_eq!(result, vec!["1", "2"]);
150    }
151
152    #[test]
153    fn test_filter_matching() {
154        let df = df!["g" => ["a", "b", "a"], "v" => [1, 2, 3]].unwrap();
155        let filtered = filter_data_by_group(&df, "g", "a");
156        assert_eq!(filtered.height(), 2);
157    }
158
159    #[test]
160    fn test_filter_no_match() {
161        let df = df!["g" => ["a", "b"], "v" => [1, 2]].unwrap();
162        let filtered = filter_data_by_group(&df, "g", "z");
163        assert_eq!(filtered.height(), 0);
164    }
165
166    #[test]
167    fn test_filter_numeric_cast() {
168        let df = df!["g" => [1i32, 2, 1], "v" => [10, 20, 30]].unwrap();
169        let filtered = filter_data_by_group(&df, "g", "1");
170        assert_eq!(filtered.height(), 2);
171    }
172
173    #[test]
174    fn test_get_numeric_integers() {
175        let df = df!["x" => [1i32, 2, 3]].unwrap();
176        let result = get_numeric_column(&df, "x");
177        assert_eq!(result, vec![Some(1.0f32), Some(2.0), Some(3.0)]);
178    }
179
180    #[test]
181    fn test_get_numeric_with_nulls() {
182        let s = Series::new("x".into(), &[Some(1.0f64), None, Some(3.0)]);
183        let df = DataFrame::new(3, vec![s.into()]).unwrap();
184        let result = get_numeric_column(&df, "x");
185        assert_eq!(result.len(), 3);
186        assert!(result[0].is_some());
187        assert!(result[1].is_none());
188        assert!(result[2].is_some());
189    }
190
191    #[test]
192    fn test_get_numeric_floats() {
193        let df = df!["x" => [1.5f64, 2.5]].unwrap();
194        let result = get_numeric_column(&df, "x");
195        assert!((result[0].unwrap() - 1.5).abs() < 0.01);
196        assert!((result[1].unwrap() - 2.5).abs() < 0.01);
197    }
198
199    #[test]
200    fn test_get_string_basic() {
201        let df = df!["s" => ["a", "b"]].unwrap();
202        let result = get_string_column(&df, "s");
203        assert_eq!(result, vec![Some("a".to_string()), Some("b".to_string())]);
204    }
205
206    #[test]
207    fn test_get_string_with_nulls() {
208        let s = Series::new("s".into(), &[Some("a"), None::<&str>]);
209        let df = DataFrame::new(2, vec![s.into()]).unwrap();
210        let result = get_string_column(&df, "s");
211        assert_eq!(result.len(), 2);
212        assert_eq!(result[0], Some("a".to_string()));
213        assert!(result[1].is_none());
214    }
215}