alopex_dataframe/dataframe/
dataframe.rs1use std::collections::HashSet;
2use std::sync::Arc;
3
4use arrow::datatypes::{Field, Schema, SchemaRef};
5use arrow::record_batch::RecordBatch;
6
7use crate::ops::{FillNull, JoinKeys, JoinType, SortOptions};
8use crate::{DataFrameError, Expr, Result, Series};
9
10#[derive(Debug, Clone)]
12pub struct DataFrame {
13 schema: SchemaRef,
14 batches: Vec<RecordBatch>,
15}
16
17impl DataFrame {
18 pub fn new(columns: Vec<Series>) -> Result<Self> {
22 if columns.is_empty() {
23 return Ok(Self::empty());
24 }
25
26 let mut seen_names = HashSet::with_capacity(columns.len());
27 for c in &columns {
28 if !seen_names.insert(c.name().to_string()) {
29 return Err(DataFrameError::schema_mismatch(format!(
30 "duplicate column name '{}'",
31 c.name()
32 )));
33 }
34 }
35
36 let expected_len = columns[0].len();
37 for c in &columns[1..] {
38 if c.len() != expected_len {
39 return Err(DataFrameError::schema_mismatch(format!(
40 "column length mismatch: '{}' has length {}, expected {}",
41 c.name(),
42 c.len(),
43 expected_len
44 )));
45 }
46 }
47
48 let fields: Vec<Field> = columns
49 .iter()
50 .map(|c| Field::new(c.name(), c.dtype(), true))
51 .collect();
52 let schema: SchemaRef = Arc::new(Schema::new(fields));
53
54 let arrays = columns
55 .iter()
56 .map(|c| {
57 if c.chunks().is_empty() {
58 Ok(arrow::array::new_empty_array(&c.dtype()))
59 } else if c.chunks().len() == 1 {
60 Ok(c.chunks()[0].clone())
61 } else {
62 let arrays = c
63 .chunks()
64 .iter()
65 .map(|a| a.as_ref() as &dyn arrow::array::Array)
66 .collect::<Vec<_>>();
67 arrow::compute::concat(&arrays)
68 .map_err(|source| DataFrameError::Arrow { source })
69 }
70 })
71 .collect::<Result<Vec<_>>>()?;
72
73 let batch = RecordBatch::try_new(schema.clone(), arrays).map_err(|e| {
74 DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
75 })?;
76
77 Ok(Self {
78 schema,
79 batches: vec![batch],
80 })
81 }
82
83 pub fn from_batches(batches: Vec<RecordBatch>) -> Result<Self> {
85 if batches.is_empty() {
86 return Ok(Self::empty());
87 }
88
89 let schema = batches[0].schema();
90 for (i, b) in batches.iter().enumerate().skip(1) {
91 if b.schema().as_ref() != schema.as_ref() {
92 return Err(DataFrameError::schema_mismatch(format!(
93 "schema mismatch between batches: batch 0 != batch {i}"
94 )));
95 }
96 }
97
98 Ok(Self { schema, batches })
99 }
100
101 pub fn from_series(series: Vec<Series>) -> Result<Self> {
103 Self::new(series)
104 }
105
106 pub fn empty() -> Self {
108 Self {
109 schema: Arc::new(Schema::empty()),
110 batches: Vec::new(),
111 }
112 }
113
114 pub fn height(&self) -> usize {
116 self.batches.iter().map(|b| b.num_rows()).sum()
117 }
118
119 pub fn width(&self) -> usize {
121 self.schema.fields().len()
122 }
123
124 pub fn schema(&self) -> SchemaRef {
126 self.schema.clone()
127 }
128
129 pub fn column(&self, name: &str) -> Result<Series> {
131 let idx = self
132 .schema
133 .fields()
134 .iter()
135 .position(|f| f.name() == name)
136 .ok_or_else(|| DataFrameError::column_not_found(name.to_string()))?;
137
138 let chunks = self
139 .batches
140 .iter()
141 .map(|b| b.column(idx).clone())
142 .collect::<Vec<_>>();
143 Ok(Series::from_arrow_unchecked(name, chunks))
144 }
145
146 pub fn columns(&self) -> Vec<Series> {
148 self.schema
149 .fields()
150 .iter()
151 .enumerate()
152 .map(|(idx, f)| {
153 let chunks = self
154 .batches
155 .iter()
156 .map(|b| b.column(idx).clone())
157 .collect::<Vec<_>>();
158 Series::from_arrow_unchecked(f.name(), chunks)
159 })
160 .collect()
161 }
162
163 pub fn to_arrow(&self) -> Vec<RecordBatch> {
165 self.batches.clone()
166 }
167
168 pub fn lazy(&self) -> crate::LazyFrame {
170 crate::LazyFrame::from_dataframe(self.clone())
171 }
172
173 pub fn select(&self, exprs: Vec<Expr>) -> Result<Self> {
175 self.clone().lazy().select(exprs).collect()
176 }
177
178 pub fn filter(&self, predicate: Expr) -> Result<Self> {
180 self.clone().lazy().filter(predicate).collect()
181 }
182
183 pub fn with_columns(&self, exprs: Vec<Expr>) -> Result<Self> {
185 self.clone().lazy().with_columns(exprs).collect()
186 }
187
188 pub fn group_by(&self, by: Vec<Expr>) -> GroupBy {
190 GroupBy {
191 df: self.clone(),
192 by,
193 }
194 }
195
196 pub fn join<K: Into<JoinKeys>>(
198 &self,
199 other: &DataFrame,
200 keys: K,
201 how: JoinType,
202 ) -> Result<Self> {
203 self.clone()
204 .lazy()
205 .join(other.clone().lazy(), keys, how)
206 .collect()
207 }
208
209 pub fn sort(&self, by: Vec<String>, descending: Vec<bool>) -> Result<Self> {
211 let options = SortOptions {
212 by,
213 descending,
214 nulls_last: true,
215 stable: true,
216 };
217 self.clone().lazy().sort(options).collect()
218 }
219
220 pub fn head(&self, n: usize) -> Result<Self> {
222 self.clone().lazy().head(n).collect()
223 }
224
225 pub fn tail(&self, n: usize) -> Result<Self> {
227 self.clone().lazy().tail(n).collect()
228 }
229
230 pub fn unique(&self, subset: Option<Vec<String>>) -> Result<Self> {
232 self.clone().lazy().unique(subset).collect()
233 }
234
235 pub fn fill_null<T: Into<FillNull>>(&self, fill: T) -> Result<Self> {
237 self.clone().lazy().fill_null(fill).collect()
238 }
239
240 pub fn drop_nulls(&self, subset: Option<Vec<String>>) -> Result<Self> {
242 self.clone().lazy().drop_nulls(subset).collect()
243 }
244
245 pub fn null_count(&self) -> Result<Self> {
247 self.clone().lazy().null_count().collect()
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct GroupBy {
254 df: DataFrame,
255 by: Vec<Expr>,
256}
257
258impl GroupBy {
259 pub fn agg(self, aggs: Vec<Expr>) -> Result<DataFrame> {
261 self.df.lazy().group_by(self.by).agg(aggs).collect()
262 }
263
264 pub fn into_df(self) -> DataFrame {
266 self.df
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use std::sync::Arc;
273
274 use arrow::array::{ArrayRef, Int32Array, StringArray};
275 use arrow::datatypes::{DataType, Field, Schema};
276 use arrow::record_batch::RecordBatch;
277
278 use super::DataFrame;
279 use crate::{DataFrameError, Series};
280
281 fn s_i32(name: &str, chunks: Vec<Vec<i32>>) -> Series {
282 let arrays: Vec<ArrayRef> = chunks
283 .into_iter()
284 .map(|v| Arc::new(Int32Array::from(v)) as ArrayRef)
285 .collect();
286 Series::from_arrow(name, arrays).unwrap()
287 }
288
289 #[test]
290 fn dataframe_new_accepts_misaligned_chunks_by_normalizing() {
291 let a = s_i32("a", vec![vec![1, 2], vec![3]]);
292 let b = s_i32("b", vec![vec![10], vec![20, 30]]);
293
294 let df = DataFrame::new(vec![a, b]).unwrap();
295 assert_eq!(df.height(), 3);
296 assert_eq!(df.width(), 2);
297 assert_eq!(df.schema().fields()[0].name(), "a");
298 assert_eq!(df.schema().fields()[1].name(), "b");
299
300 let batches = df.to_arrow();
301 assert_eq!(batches.len(), 1);
302 assert_eq!(batches[0].num_rows(), 3);
303 }
304
305 #[test]
306 fn dataframe_new_rejects_duplicate_column_names() {
307 let a1 = s_i32("a", vec![vec![1]]);
308 let a2 = s_i32("a", vec![vec![2]]);
309 let err = DataFrame::new(vec![a1, a2]).unwrap_err();
310 assert!(matches!(err, DataFrameError::SchemaMismatch { .. }));
311 }
312
313 #[test]
314 fn dataframe_new_rejects_length_mismatch() {
315 let a = s_i32("a", vec![vec![1, 2]]);
316 let b = s_i32("b", vec![vec![10]]);
317 let err = DataFrame::new(vec![a, b]).unwrap_err();
318 assert!(matches!(err, DataFrameError::SchemaMismatch { .. }));
319 }
320
321 #[test]
322 fn dataframe_new_accepts_different_chunk_counts() {
323 let a = s_i32("a", vec![vec![1], vec![2], vec![3]]);
324 let b = s_i32("b", vec![vec![10, 20, 30]]);
325 let df = DataFrame::new(vec![a, b]).unwrap();
326 assert_eq!(df.height(), 3);
327 assert_eq!(df.to_arrow().len(), 1);
328 }
329
330 #[test]
331 fn dataframe_column_is_case_sensitive() {
332 let a = s_i32("a", vec![vec![1]]);
333 let df = DataFrame::new(vec![a]).unwrap();
334 assert!(matches!(
335 df.column("A").unwrap_err(),
336 DataFrameError::ColumnNotFound { .. }
337 ));
338 }
339
340 #[test]
341 fn dataframe_from_batches_rejects_schema_mismatch() {
342 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1]));
343 let a2: ArrayRef = Arc::new(StringArray::from(vec!["x"]));
344
345 let s1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
346 let s2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
347
348 let b1 = RecordBatch::try_new(s1, vec![a1]).unwrap();
349 let b2 = RecordBatch::try_new(s2, vec![a2]).unwrap();
350
351 let err = DataFrame::from_batches(vec![b1, b2]).unwrap_err();
352 assert!(matches!(err, DataFrameError::SchemaMismatch { .. }));
353 }
354
355 #[test]
356 fn dataframe_columns_preserves_schema_order() {
357 let a = s_i32("a", vec![vec![1], vec![2]]);
358 let b = s_i32("b", vec![vec![10], vec![20]]);
359 let df = DataFrame::new(vec![b.clone(), a.clone()]).unwrap();
360
361 let cols = df.columns();
362 assert_eq!(cols[0].name(), "b");
363 assert_eq!(cols[1].name(), "a");
364 assert_eq!(cols[0].len(), 2);
365 assert_eq!(cols[1].len(), 2);
366 }
367}