Skip to main content

nexcore_dataframe/
sort.rs

1//! Sorting operations on DataFrames.
2
3use crate::column::Column;
4use crate::dataframe::DataFrame;
5use crate::error::DataFrameError;
6use crate::scalar::Scalar;
7
8impl DataFrame {
9    /// Sort by a single column. Nulls sort last.
10    pub fn sort(&self, by: &str, descending: bool) -> Result<Self, DataFrameError> {
11        let col = self.column(by)?;
12        let mut indices: Vec<usize> = (0..self.height()).collect();
13
14        indices.sort_by(|&a, &b| {
15            let va = col.get(a).unwrap_or(Scalar::Null);
16            let vb = col.get(b).unwrap_or(Scalar::Null);
17            let ord = va.compare(&vb);
18            if descending { ord.reverse() } else { ord }
19        });
20
21        let columns: Vec<Column> = self.columns().iter().map(|c| c.take(&indices)).collect();
22        Ok(Self::from_columns_unchecked(columns))
23    }
24
25    /// Take the first n rows.
26    #[must_use]
27    pub fn head(&self, n: usize) -> Self {
28        let take = n.min(self.height());
29        let indices: Vec<usize> = (0..take).collect();
30        let columns: Vec<Column> = self.columns().iter().map(|c| c.take(&indices)).collect();
31        Self::from_columns_unchecked(columns)
32    }
33
34    /// Take the last n rows.
35    #[must_use]
36    pub fn tail(&self, n: usize) -> Self {
37        let take = n.min(self.height());
38        let start = self.height().saturating_sub(take);
39        let indices: Vec<usize> = (start..self.height()).collect();
40        let columns: Vec<Column> = self.columns().iter().map(|c| c.take(&indices)).collect();
41        Self::from_columns_unchecked(columns)
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48
49    #[test]
50    fn sort_ascending() {
51        let df = DataFrame::new(vec![
52            Column::from_strs("name", &["c", "a", "b"]),
53            Column::from_i64s("val", vec![3, 1, 2]),
54        ])
55        .unwrap_or_else(|_| unreachable!());
56
57        let sorted = df.sort("val", false).unwrap_or_else(|_| unreachable!());
58        assert_eq!(
59            sorted
60                .column("name")
61                .unwrap_or_else(|_| unreachable!())
62                .get(0),
63            Some(Scalar::String("a".into()))
64        );
65        assert_eq!(
66            sorted
67                .column("name")
68                .unwrap_or_else(|_| unreachable!())
69                .get(2),
70            Some(Scalar::String("c".into()))
71        );
72    }
73
74    #[test]
75    fn sort_descending() {
76        let df = DataFrame::new(vec![Column::from_i64s("x", vec![1, 3, 2])])
77            .unwrap_or_else(|_| unreachable!());
78        let sorted = df.sort("x", true).unwrap_or_else(|_| unreachable!());
79        assert_eq!(
80            sorted.column("x").unwrap_or_else(|_| unreachable!()).get(0),
81            Some(Scalar::Int64(3))
82        );
83        assert_eq!(
84            sorted.column("x").unwrap_or_else(|_| unreachable!()).get(2),
85            Some(Scalar::Int64(1))
86        );
87    }
88
89    #[test]
90    fn sort_with_nulls() {
91        let df = DataFrame::new(vec![Column::new_i64("x", vec![Some(2), None, Some(1)])])
92            .unwrap_or_else(|_| unreachable!());
93        let sorted = df.sort("x", false).unwrap_or_else(|_| unreachable!());
94        assert_eq!(
95            sorted.column("x").unwrap_or_else(|_| unreachable!()).get(0),
96            Some(Scalar::Int64(1))
97        );
98        assert_eq!(
99            sorted.column("x").unwrap_or_else(|_| unreachable!()).get(1),
100            Some(Scalar::Int64(2))
101        );
102        assert_eq!(
103            sorted.column("x").unwrap_or_else(|_| unreachable!()).get(2),
104            Some(Scalar::Null)
105        );
106    }
107
108    #[test]
109    fn head_and_tail() {
110        let df = DataFrame::new(vec![Column::from_i64s("x", vec![1, 2, 3, 4, 5])])
111            .unwrap_or_else(|_| unreachable!());
112        let h = df.head(3);
113        assert_eq!(h.height(), 3);
114        assert_eq!(
115            h.column("x").unwrap_or_else(|_| unreachable!()).get(0),
116            Some(Scalar::Int64(1))
117        );
118
119        let t = df.tail(2);
120        assert_eq!(t.height(), 2);
121        assert_eq!(
122            t.column("x").unwrap_or_else(|_| unreachable!()).get(0),
123            Some(Scalar::Int64(4))
124        );
125    }
126
127    #[test]
128    fn head_exceeds_length() {
129        let df = DataFrame::new(vec![Column::from_i64s("x", vec![1, 2])])
130            .unwrap_or_else(|_| unreachable!());
131        let h = df.head(100);
132        assert_eq!(h.height(), 2);
133    }
134}