Skip to main content

alopex_dataframe/dataframe/
dataframe.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use arrow::datatypes::{Field, Schema, SchemaRef};
5use arrow::record_batch::RecordBatch;
6
7use crate::ops::{FillNull, JoinKeys, JoinType, SortOptions};
8use crate::{DataFrameError, Expr, Result, Series};
9
10/// An eager table backed by one or more Arrow `RecordBatch` values.
11#[derive(Debug, Clone)]
12pub struct DataFrame {
13    schema: SchemaRef,
14    batches: Vec<RecordBatch>,
15}
16
17impl DataFrame {
18    /// Construct a `DataFrame` from a list of `Series`.
19    ///
20    /// Chunk boundaries do not need to align across series as long as total lengths match.
21    pub fn new(columns: Vec<Series>) -> Result<Self> {
22        if columns.is_empty() {
23            return Ok(Self::empty());
24        }
25
26        let mut seen_names = HashSet::with_capacity(columns.len());
27        for c in &columns {
28            if !seen_names.insert(c.name().to_string()) {
29                return Err(DataFrameError::schema_mismatch(format!(
30                    "duplicate column name '{}'",
31                    c.name()
32                )));
33            }
34        }
35
36        let expected_len = columns[0].len();
37        for c in &columns[1..] {
38            if c.len() != expected_len {
39                return Err(DataFrameError::schema_mismatch(format!(
40                    "column length mismatch: '{}' has length {}, expected {}",
41                    c.name(),
42                    c.len(),
43                    expected_len
44                )));
45            }
46        }
47
48        let fields: Vec<Field> = columns
49            .iter()
50            .map(|c| Field::new(c.name(), c.dtype(), true))
51            .collect();
52        let schema: SchemaRef = Arc::new(Schema::new(fields));
53
54        let arrays = columns
55            .iter()
56            .map(|c| {
57                if c.chunks().is_empty() {
58                    Ok(arrow::array::new_empty_array(&c.dtype()))
59                } else if c.chunks().len() == 1 {
60                    Ok(c.chunks()[0].clone())
61                } else {
62                    let arrays = c
63                        .chunks()
64                        .iter()
65                        .map(|a| a.as_ref() as &dyn arrow::array::Array)
66                        .collect::<Vec<_>>();
67                    arrow::compute::concat(&arrays)
68                        .map_err(|source| DataFrameError::Arrow { source })
69                }
70            })
71            .collect::<Result<Vec<_>>>()?;
72
73        let batch = RecordBatch::try_new(schema.clone(), arrays).map_err(|e| {
74            DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
75        })?;
76
77        Ok(Self {
78            schema,
79            batches: vec![batch],
80        })
81    }
82
83    /// Construct a `DataFrame` from Arrow record batches (all batches must share the same schema).
84    pub fn from_batches(batches: Vec<RecordBatch>) -> Result<Self> {
85        if batches.is_empty() {
86            return Ok(Self::empty());
87        }
88
89        let schema = batches[0].schema();
90        for (i, b) in batches.iter().enumerate().skip(1) {
91            if b.schema().as_ref() != schema.as_ref() {
92                return Err(DataFrameError::schema_mismatch(format!(
93                    "schema mismatch between batches: batch 0 != batch {i}"
94                )));
95            }
96        }
97
98        Ok(Self { schema, batches })
99    }
100
101    /// Alias for `DataFrame::new`.
102    pub fn from_series(series: Vec<Series>) -> Result<Self> {
103        Self::new(series)
104    }
105
106    /// Return an empty `DataFrame` (no columns, no rows).
107    pub fn empty() -> Self {
108        Self {
109            schema: Arc::new(Schema::empty()),
110            batches: Vec::new(),
111        }
112    }
113
114    /// Return the number of rows.
115    pub fn height(&self) -> usize {
116        self.batches.iter().map(|b| b.num_rows()).sum()
117    }
118
119    /// Return the number of columns.
120    pub fn width(&self) -> usize {
121        self.schema.fields().len()
122    }
123
124    /// Return the Arrow schema.
125    pub fn schema(&self) -> SchemaRef {
126        self.schema.clone()
127    }
128
129    /// Get a column by name (case-sensitive).
130    pub fn column(&self, name: &str) -> Result<Series> {
131        let idx = self
132            .schema
133            .fields()
134            .iter()
135            .position(|f| f.name() == name)
136            .ok_or_else(|| DataFrameError::column_not_found(name.to_string()))?;
137
138        let chunks = self
139            .batches
140            .iter()
141            .map(|b| b.column(idx).clone())
142            .collect::<Vec<_>>();
143        Ok(Series::from_arrow_unchecked(name, chunks))
144    }
145
146    /// Return all columns in construction order.
147    pub fn columns(&self) -> Vec<Series> {
148        self.schema
149            .fields()
150            .iter()
151            .enumerate()
152            .map(|(idx, f)| {
153                let chunks = self
154                    .batches
155                    .iter()
156                    .map(|b| b.column(idx).clone())
157                    .collect::<Vec<_>>();
158                Series::from_arrow_unchecked(f.name(), chunks)
159            })
160            .collect()
161    }
162
163    /// Return the underlying Arrow batches.
164    pub fn to_arrow(&self) -> Vec<RecordBatch> {
165        self.batches.clone()
166    }
167
168    /// Convert this eager `DataFrame` to a `LazyFrame` for query planning/execution.
169    pub fn lazy(&self) -> crate::LazyFrame {
170        crate::LazyFrame::from_dataframe(self.clone())
171    }
172
173    /// Eager `select`, implemented by delegating to `LazyFrame`.
174    pub fn select(&self, exprs: Vec<Expr>) -> Result<Self> {
175        self.clone().lazy().select(exprs).collect()
176    }
177
178    /// Eager `filter`, implemented by delegating to `LazyFrame`.
179    pub fn filter(&self, predicate: Expr) -> Result<Self> {
180        self.clone().lazy().filter(predicate).collect()
181    }
182
183    /// Eager `with_columns`, implemented by delegating to `LazyFrame`.
184    pub fn with_columns(&self, exprs: Vec<Expr>) -> Result<Self> {
185        self.clone().lazy().with_columns(exprs).collect()
186    }
187
188    /// Start a group-by aggregation (eager API).
189    pub fn group_by(&self, by: Vec<Expr>) -> GroupBy {
190        GroupBy {
191            df: self.clone(),
192            by,
193        }
194    }
195
196    /// Join with another `DataFrame` using provided join keys.
197    pub fn join<K: Into<JoinKeys>>(
198        &self,
199        other: &DataFrame,
200        keys: K,
201        how: JoinType,
202    ) -> Result<Self> {
203        self.clone()
204            .lazy()
205            .join(other.clone().lazy(), keys, how)
206            .collect()
207    }
208
209    /// Sort by one or more columns.
210    pub fn sort(&self, by: Vec<String>, descending: Vec<bool>) -> Result<Self> {
211        let options = SortOptions {
212            by,
213            descending,
214            nulls_last: true,
215            stable: true,
216        };
217        self.clone().lazy().sort(options).collect()
218    }
219
220    /// Return the first `n` rows.
221    pub fn head(&self, n: usize) -> Result<Self> {
222        self.clone().lazy().head(n).collect()
223    }
224
225    /// Return the last `n` rows.
226    pub fn tail(&self, n: usize) -> Result<Self> {
227        self.clone().lazy().tail(n).collect()
228    }
229
230    /// Remove duplicate rows.
231    pub fn unique(&self, subset: Option<Vec<String>>) -> Result<Self> {
232        self.clone().lazy().unique(subset).collect()
233    }
234
235    /// Fill null values using a scalar or strategy.
236    pub fn fill_null<T: Into<FillNull>>(&self, fill: T) -> Result<Self> {
237        self.clone().lazy().fill_null(fill).collect()
238    }
239
240    /// Drop rows containing null values.
241    pub fn drop_nulls(&self, subset: Option<Vec<String>>) -> Result<Self> {
242        self.clone().lazy().drop_nulls(subset).collect()
243    }
244
245    /// Count null values per column.
246    pub fn null_count(&self) -> Result<Self> {
247        self.clone().lazy().null_count().collect()
248    }
249}
250
251/// Eager group-by handle that delegates execution to `LazyFrame`.
252#[derive(Debug, Clone)]
253pub struct GroupBy {
254    df: DataFrame,
255    by: Vec<Expr>,
256}
257
258impl GroupBy {
259    /// Perform aggregations for this group-by.
260    pub fn agg(self, aggs: Vec<Expr>) -> Result<DataFrame> {
261        self.df.lazy().group_by(self.by).agg(aggs).collect()
262    }
263
264    /// Return the underlying `DataFrame`.
265    pub fn into_df(self) -> DataFrame {
266        self.df
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use std::sync::Arc;
273
274    use arrow::array::{ArrayRef, Int32Array, StringArray};
275    use arrow::datatypes::{DataType, Field, Schema};
276    use arrow::record_batch::RecordBatch;
277
278    use super::DataFrame;
279    use crate::{DataFrameError, Series};
280
281    fn s_i32(name: &str, chunks: Vec<Vec<i32>>) -> Series {
282        let arrays: Vec<ArrayRef> = chunks
283            .into_iter()
284            .map(|v| Arc::new(Int32Array::from(v)) as ArrayRef)
285            .collect();
286        Series::from_arrow(name, arrays).unwrap()
287    }
288
289    #[test]
290    fn dataframe_new_accepts_misaligned_chunks_by_normalizing() {
291        let a = s_i32("a", vec![vec![1, 2], vec![3]]);
292        let b = s_i32("b", vec![vec![10], vec![20, 30]]);
293
294        let df = DataFrame::new(vec![a, b]).unwrap();
295        assert_eq!(df.height(), 3);
296        assert_eq!(df.width(), 2);
297        assert_eq!(df.schema().fields()[0].name(), "a");
298        assert_eq!(df.schema().fields()[1].name(), "b");
299
300        let batches = df.to_arrow();
301        assert_eq!(batches.len(), 1);
302        assert_eq!(batches[0].num_rows(), 3);
303    }
304
305    #[test]
306    fn dataframe_new_rejects_duplicate_column_names() {
307        let a1 = s_i32("a", vec![vec![1]]);
308        let a2 = s_i32("a", vec![vec![2]]);
309        let err = DataFrame::new(vec![a1, a2]).unwrap_err();
310        assert!(matches!(err, DataFrameError::SchemaMismatch { .. }));
311    }
312
313    #[test]
314    fn dataframe_new_rejects_length_mismatch() {
315        let a = s_i32("a", vec![vec![1, 2]]);
316        let b = s_i32("b", vec![vec![10]]);
317        let err = DataFrame::new(vec![a, b]).unwrap_err();
318        assert!(matches!(err, DataFrameError::SchemaMismatch { .. }));
319    }
320
321    #[test]
322    fn dataframe_new_accepts_different_chunk_counts() {
323        let a = s_i32("a", vec![vec![1], vec![2], vec![3]]);
324        let b = s_i32("b", vec![vec![10, 20, 30]]);
325        let df = DataFrame::new(vec![a, b]).unwrap();
326        assert_eq!(df.height(), 3);
327        assert_eq!(df.to_arrow().len(), 1);
328    }
329
330    #[test]
331    fn dataframe_column_is_case_sensitive() {
332        let a = s_i32("a", vec![vec![1]]);
333        let df = DataFrame::new(vec![a]).unwrap();
334        assert!(matches!(
335            df.column("A").unwrap_err(),
336            DataFrameError::ColumnNotFound { .. }
337        ));
338    }
339
340    #[test]
341    fn dataframe_from_batches_rejects_schema_mismatch() {
342        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1]));
343        let a2: ArrayRef = Arc::new(StringArray::from(vec!["x"]));
344
345        let s1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
346        let s2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
347
348        let b1 = RecordBatch::try_new(s1, vec![a1]).unwrap();
349        let b2 = RecordBatch::try_new(s2, vec![a2]).unwrap();
350
351        let err = DataFrame::from_batches(vec![b1, b2]).unwrap_err();
352        assert!(matches!(err, DataFrameError::SchemaMismatch { .. }));
353    }
354
355    #[test]
356    fn dataframe_columns_preserves_schema_order() {
357        let a = s_i32("a", vec![vec![1], vec![2]]);
358        let b = s_i32("b", vec![vec![10], vec![20]]);
359        let df = DataFrame::new(vec![b.clone(), a.clone()]).unwrap();
360
361        let cols = df.columns();
362        assert_eq!(cols[0].name(), "b");
363        assert_eq!(cols[1].name(), "a");
364        assert_eq!(cols[0].len(), 2);
365        assert_eq!(cols[1].len(), 2);
366    }
367}