polars_rows_iter/iter_from_column/
iter_from_column_series.rs

1use super::*;
2use iter_from_column_trait::IterFromColumn;
3use polars::prelude::*;
4
5impl<'a> IterFromColumn<'a> for Series {
6    type RawInner = Series;
7    fn create_iter(column: &'a Column) -> PolarsResult<Box<dyn Iterator<Item = Option<Series>> + 'a>> {
8        create_iter(column)
9    }
10
11    #[inline]
12    fn get_value(polars_value: Option<Series>, column_name: &str, _dtype: &DataType) -> PolarsResult<Self>
13    where
14        Self: Sized,
15    {
16        polars_value.ok_or_else(|| <&'a str as IterFromColumn<'a>>::unexpected_null_value_error(column_name))
17    }
18}
19
20impl<'a> IterFromColumn<'a> for Option<Series> {
21    type RawInner = Series;
22    fn create_iter(column: &'a Column) -> PolarsResult<Box<dyn Iterator<Item = Option<Series>> + 'a>> {
23        create_iter(column)
24    }
25
26    #[inline]
27    fn get_value(polars_value: Option<Series>, _column_name: &str, _dtype: &DataType) -> PolarsResult<Self>
28    where
29        Self: Sized,
30    {
31        Ok(polars_value)
32    }
33}
34
35pub fn create_iter<'a>(column: &'a Column) -> PolarsResult<Box<dyn Iterator<Item = Option<Series>> + 'a>> {
36    let iter: Box<dyn Iterator<Item = Option<Series>>> = match column.dtype() {
37        // remove Box::new(), when https://github.com/rust-lang/rust/issues/65991 is stable
38        DataType::List(_) => Box::new(column.list()?.into_iter()),
39        dtype => {
40            let column_name = column.name().as_str();
41            return Err(
42                polars_err!(SchemaMismatch: "Cannot get Series from column '{column_name}' with dtype : {dtype}"),
43            );
44        }
45    };
46
47    Ok(iter)
48}
49
50#[cfg(test)]
51mod tests {
52    use crate::*;
53    use itertools::{izip, Itertools};
54    use polars::prelude::*;
55    use rand::{rngs::StdRng, SeedableRng};
56    use shared_test_helpers::*;
57
58    const ROW_COUNT: usize = 64;
59
60    #[test]
61    fn str_test() {
62        let mut rng = StdRng::seed_from_u64(0);
63        let height = ROW_COUNT;
64        let dtype = DataType::String;
65
66        let col = create_column("col", dtype.clone(), false, height, &mut rng);
67        let col_opt = create_column("col_opt", dtype, true, height, &mut rng);
68
69        let col_values = col.str().unwrap().iter().map(|v| v.unwrap().to_owned()).collect_vec();
70        let col_opt_values = col_opt
71            .str()
72            .unwrap()
73            .iter()
74            .map(|v| v.map(|s| s.to_owned()))
75            .collect_vec();
76
77        let df = DataFrame::new(vec![col, col_opt]).unwrap();
78
79        let col_iter = col_values.iter();
80        let col_opt_iter = col_opt_values.iter();
81
82        let expected_rows = izip!(col_iter, col_opt_iter)
83            .map(|(col, col_opt)| TestRow {
84                col: col.as_ref(),
85                col_opt: col_opt.as_ref().map(|v| v.as_str()),
86            })
87            .collect_vec();
88
89        #[derive(Debug, FromDataFrameRow, PartialEq)]
90        struct TestRow<'a> {
91            col: &'a str,
92            col_opt: Option<&'a str>,
93        }
94
95        let rows = df.rows_iter::<TestRow>().unwrap().map(|v| v.unwrap()).collect_vec();
96
97        assert_eq!(rows, expected_rows)
98    }
99
100    #[cfg(feature = "dtype-categorical")]
101    #[test]
102    fn cat_test() {
103        let mut rng = StdRng::seed_from_u64(0);
104        let height = ROW_COUNT;
105
106        let cats = Categories::new(PlSmallStr::EMPTY, PlSmallStr::EMPTY, CategoricalPhysical::U32);
107        let dtype = DataType::from_categories(cats);
108
109        let col = create_column("col", dtype.clone(), false, height, &mut rng);
110        let col_opt = create_column("col_opt", dtype, true, height, &mut rng);
111
112        let col_values = col
113            .cat::<Categorical32Type>()
114            .unwrap()
115            .iter_str()
116            .map(|v| v.unwrap().to_owned())
117            .collect_vec();
118
119        let col_opt_values = col_opt
120            .cat::<Categorical32Type>()
121            .unwrap()
122            .iter_str()
123            .map(|v| v.map(|s| s.to_owned()))
124            .collect_vec();
125
126        let df = DataFrame::new(vec![col, col_opt]).unwrap();
127
128        let col_iter = col_values.iter();
129        let col_opt_iter = col_opt_values.iter();
130
131        let expected_rows = izip!(col_iter, col_opt_iter)
132            .map(|(col, col_opt)| TestRow {
133                col: col.as_ref(),
134                col_opt: col_opt.as_ref().map(|v| v.as_str()),
135            })
136            .collect_vec();
137
138        #[derive(Debug, FromDataFrameRow, PartialEq)]
139        struct TestRow<'a> {
140            col: &'a str,
141            col_opt: Option<&'a str>,
142        }
143
144        let rows = df.rows_iter::<TestRow>().unwrap().map(|v| v.unwrap()).collect_vec();
145
146        assert_eq!(rows, expected_rows)
147    }
148}