nexcore_dataframe/
sort.rs1use crate::column::Column;
4use crate::dataframe::DataFrame;
5use crate::error::DataFrameError;
6use crate::scalar::Scalar;
7
8impl DataFrame {
9 pub fn sort(&self, by: &str, descending: bool) -> Result<Self, DataFrameError> {
11 let col = self.column(by)?;
12 let mut indices: Vec<usize> = (0..self.height()).collect();
13
14 indices.sort_by(|&a, &b| {
15 let va = col.get(a).unwrap_or(Scalar::Null);
16 let vb = col.get(b).unwrap_or(Scalar::Null);
17 let ord = va.compare(&vb);
18 if descending { ord.reverse() } else { ord }
19 });
20
21 let columns: Vec<Column> = self.columns().iter().map(|c| c.take(&indices)).collect();
22 Ok(Self::from_columns_unchecked(columns))
23 }
24
25 #[must_use]
27 pub fn head(&self, n: usize) -> Self {
28 let take = n.min(self.height());
29 let indices: Vec<usize> = (0..take).collect();
30 let columns: Vec<Column> = self.columns().iter().map(|c| c.take(&indices)).collect();
31 Self::from_columns_unchecked(columns)
32 }
33
34 #[must_use]
36 pub fn tail(&self, n: usize) -> Self {
37 let take = n.min(self.height());
38 let start = self.height().saturating_sub(take);
39 let indices: Vec<usize> = (start..self.height()).collect();
40 let columns: Vec<Column> = self.columns().iter().map(|c| c.take(&indices)).collect();
41 Self::from_columns_unchecked(columns)
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48
49 #[test]
50 fn sort_ascending() {
51 let df = DataFrame::new(vec![
52 Column::from_strs("name", &["c", "a", "b"]),
53 Column::from_i64s("val", vec![3, 1, 2]),
54 ])
55 .unwrap_or_else(|_| unreachable!());
56
57 let sorted = df.sort("val", false).unwrap_or_else(|_| unreachable!());
58 assert_eq!(
59 sorted
60 .column("name")
61 .unwrap_or_else(|_| unreachable!())
62 .get(0),
63 Some(Scalar::String("a".into()))
64 );
65 assert_eq!(
66 sorted
67 .column("name")
68 .unwrap_or_else(|_| unreachable!())
69 .get(2),
70 Some(Scalar::String("c".into()))
71 );
72 }
73
74 #[test]
75 fn sort_descending() {
76 let df = DataFrame::new(vec![Column::from_i64s("x", vec![1, 3, 2])])
77 .unwrap_or_else(|_| unreachable!());
78 let sorted = df.sort("x", true).unwrap_or_else(|_| unreachable!());
79 assert_eq!(
80 sorted.column("x").unwrap_or_else(|_| unreachable!()).get(0),
81 Some(Scalar::Int64(3))
82 );
83 assert_eq!(
84 sorted.column("x").unwrap_or_else(|_| unreachable!()).get(2),
85 Some(Scalar::Int64(1))
86 );
87 }
88
89 #[test]
90 fn sort_with_nulls() {
91 let df = DataFrame::new(vec![Column::new_i64("x", vec![Some(2), None, Some(1)])])
92 .unwrap_or_else(|_| unreachable!());
93 let sorted = df.sort("x", false).unwrap_or_else(|_| unreachable!());
94 assert_eq!(
95 sorted.column("x").unwrap_or_else(|_| unreachable!()).get(0),
96 Some(Scalar::Int64(1))
97 );
98 assert_eq!(
99 sorted.column("x").unwrap_or_else(|_| unreachable!()).get(1),
100 Some(Scalar::Int64(2))
101 );
102 assert_eq!(
103 sorted.column("x").unwrap_or_else(|_| unreachable!()).get(2),
104 Some(Scalar::Null)
105 );
106 }
107
108 #[test]
109 fn head_and_tail() {
110 let df = DataFrame::new(vec![Column::from_i64s("x", vec![1, 2, 3, 4, 5])])
111 .unwrap_or_else(|_| unreachable!());
112 let h = df.head(3);
113 assert_eq!(h.height(), 3);
114 assert_eq!(
115 h.column("x").unwrap_or_else(|_| unreachable!()).get(0),
116 Some(Scalar::Int64(1))
117 );
118
119 let t = df.tail(2);
120 assert_eq!(t.height(), 2);
121 assert_eq!(
122 t.column("x").unwrap_or_else(|_| unreachable!()).get(0),
123 Some(Scalar::Int64(4))
124 );
125 }
126
127 #[test]
128 fn head_exceeds_length() {
129 let df = DataFrame::new(vec![Column::from_i64s("x", vec![1, 2])])
130 .unwrap_or_else(|_| unreachable!());
131 let h = df.head(100);
132 assert_eq!(h.height(), 2);
133 }
134}