rustframes/dataframe/
arrow.rs

1use super::{DataFrame, Series};
2use arrow::array::{
3    Array as ArrowArray, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray,
4};
5use arrow::datatypes::{DataType, Field, Schema};
6use arrow::record_batch::RecordBatch;
7use parquet::arrow::ArrowWriter;
8use std::sync::Arc;
9
10impl DataFrame {
11    /// Convert DataFrame to Apache Arrow RecordBatch
12    pub fn to_arrow(&self) -> Result<RecordBatch, Box<dyn std::error::Error>> {
13        let mut fields = Vec::new();
14        let mut arrays: Vec<ArrayRef> = Vec::new();
15
16        for (i, column_name) in self.columns.iter().enumerate() {
17            match &self.data[i] {
18                Series::Int64(values) => {
19                    fields.push(Field::new(column_name, DataType::Int64, false));
20                    let array = Int64Array::from(values.clone());
21                    arrays.push(Arc::new(array));
22                }
23                Series::Float64(values) => {
24                    fields.push(Field::new(column_name, DataType::Float64, false));
25                    let array = Float64Array::from(values.clone());
26                    arrays.push(Arc::new(array));
27                }
28                Series::Bool(values) => {
29                    fields.push(Field::new(column_name, DataType::Boolean, false));
30                    let array = BooleanArray::from(values.clone());
31                    arrays.push(Arc::new(array));
32                }
33                Series::Utf8(values) => {
34                    fields.push(Field::new(column_name, DataType::Utf8, false));
35                    let array = StringArray::from(values.clone());
36                    arrays.push(Arc::new(array));
37                }
38            }
39        }
40
41        let schema = Arc::new(Schema::new(fields));
42        let record_batch = RecordBatch::try_new(schema, arrays)?;
43        Ok(record_batch)
44    }
45
46    /// Create DataFrame from Apache Arrow RecordBatch
47    pub fn from_arrow(batch: &RecordBatch) -> Result<Self, Box<dyn std::error::Error>> {
48        let schema = batch.schema();
49        let mut columns = Vec::new();
50        let mut data = Vec::new();
51
52        for (i, field) in schema.fields().iter().enumerate() {
53            let column_name = field.name().clone();
54            let array = batch.column(i);
55
56            let series = match field.data_type() {
57                DataType::Int64 => {
58                    let int_array = array
59                        .as_any()
60                        .downcast_ref::<Int64Array>()
61                        .ok_or("Failed to downcast to Int64Array")?;
62                    let values: Vec<i64> =
63                        (0..int_array.len()).map(|i| int_array.value(i)).collect();
64                    Series::Int64(values)
65                }
66                DataType::Float64 => {
67                    let float_array = array
68                        .as_any()
69                        .downcast_ref::<Float64Array>()
70                        .ok_or("Failed to downcast to Float64Array")?;
71                    let values: Vec<f64> = (0..float_array.len())
72                        .map(|i| float_array.value(i))
73                        .collect();
74                    Series::Float64(values)
75                }
76                DataType::Boolean => {
77                    let bool_array = array
78                        .as_any()
79                        .downcast_ref::<BooleanArray>()
80                        .ok_or("Failed to downcast to BooleanArray")?;
81                    let values: Vec<bool> =
82                        (0..bool_array.len()).map(|i| bool_array.value(i)).collect();
83                    Series::Bool(values)
84                }
85                DataType::Utf8 => {
86                    let string_array = array
87                        .as_any()
88                        .downcast_ref::<StringArray>()
89                        .ok_or("Failed to downcast to StringArray")?;
90                    let values: Vec<String> = (0..string_array.len())
91                        .map(|i| string_array.value(i).to_string())
92                        .collect();
93                    Series::Utf8(values)
94                }
95                _ => return Err(format!("Unsupported data type: {:?}", field.data_type()).into()),
96            };
97
98            columns.push(column_name);
99            data.push(series);
100        }
101
102        Ok(DataFrame { columns, data })
103    }
104
105    /// Read Parquet file using Arrow
106    pub fn from_parquet(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
107        use std::fs::File;
108
109        let file = File::open(path)?;
110        let mut arrow_reader =
111            parquet::arrow::arrow_reader::ArrowReaderBuilder::try_new(file)?.build()?;
112
113        if let Some(batch_result) = arrow_reader.next() {
114            let batch = batch_result?;
115            Self::from_arrow(&batch)
116        } else {
117            Err("No data in Parquet file".into())
118        }
119    }
120
121    /// Write DataFrame to Parquet file
122    pub fn to_parquet(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
123        use std::fs::File;
124
125        let batch = self.to_arrow()?;
126        let file = File::create(path)?;
127        let mut writer = ArrowWriter::try_new(file, batch.schema(), None)?;
128
129        writer.write(&batch)?;
130        writer.close()?;
131
132        Ok(())
133    }
134
135    /// Create DataFrame from Arrow IPC (Feather) file
136    pub fn from_ipc(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
137        use arrow::ipc::reader::FileReader;
138        use std::fs::File;
139
140        let file = File::open(path)?;
141        let mut reader = FileReader::try_new(file, None)?;
142
143        if let Some(batch_result) = reader.next() {
144            let batch = batch_result?;
145            Self::from_arrow(&batch)
146        } else {
147            Err("No data in IPC file".into())
148        }
149    }
150
151    /// Write DataFrame to Arrow IPC (Feather) file
152    pub fn to_ipc(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
153        use arrow::ipc::writer::FileWriter;
154        use std::fs::File;
155
156        let batch = self.to_arrow()?;
157        let file = File::create(path)?;
158        let mut writer = FileWriter::try_new(file, &batch.schema())?;
159
160        // propagate any error from write
161        writer.write(&batch)?;
162        writer.finish()?;
163
164        Ok(())
165    }
166
167    /// Convert to Arrow and perform operations using Arrow Compute
168    pub fn arrow_filter(
169        &self,
170        column: &str,
171        predicate: ArrowPredicate,
172    ) -> Result<DataFrame, Box<dyn std::error::Error>> {
173        use arrow::array::{BooleanArray, Float64Array, Int64Array};
174        use arrow::compute;
175        use arrow::datatypes::DataType;
176
177        let batch = self.to_arrow()?;
178        let col_index = batch
179            .schema()
180            .column_with_name(column)
181            .ok_or("Column not found")?
182            .0;
183        let array = batch.column(col_index);
184
185        let filter_array: BooleanArray = match predicate {
186            ArrowPredicate::Gt(value) => match array.data_type() {
187                DataType::Float64 => {
188                    let float_array = array.as_any().downcast_ref::<Float64Array>().unwrap();
189                    let mut mask: Vec<bool> = Vec::with_capacity(float_array.len());
190                    for i in 0..float_array.len() {
191                        mask.push(float_array.value(i) > value);
192                    }
193                    BooleanArray::from(mask)
194                }
195                DataType::Int64 => {
196                    let int_array = array.as_any().downcast_ref::<Int64Array>().unwrap();
197                    let mut mask: Vec<bool> = Vec::with_capacity(int_array.len());
198                    for i in 0..int_array.len() {
199                        mask.push((int_array.value(i) as f64) > value);
200                    }
201                    BooleanArray::from(mask)
202                }
203                _ => return Err("Unsupported type for comparison".into()),
204            },
205            ArrowPredicate::Lt(value) => match array.data_type() {
206                DataType::Float64 => {
207                    let float_array = array.as_any().downcast_ref::<Float64Array>().unwrap();
208                    let mut mask: Vec<bool> = Vec::with_capacity(float_array.len());
209                    for i in 0..float_array.len() {
210                        mask.push(float_array.value(i) < value);
211                    }
212                    BooleanArray::from(mask)
213                }
214                DataType::Int64 => {
215                    let int_array = array.as_any().downcast_ref::<Int64Array>().unwrap();
216                    let mut mask: Vec<bool> = Vec::with_capacity(int_array.len());
217                    for i in 0..int_array.len() {
218                        mask.push((int_array.value(i) as f64) < value);
219                    }
220                    BooleanArray::from(mask)
221                }
222                _ => return Err("Unsupported type for comparison".into()),
223            },
224            ArrowPredicate::Eq(value) => match array.data_type() {
225                DataType::Float64 => {
226                    let float_array = array.as_any().downcast_ref::<Float64Array>().unwrap();
227                    let mut mask: Vec<bool> = Vec::with_capacity(float_array.len());
228                    for i in 0..float_array.len() {
229                        mask.push(float_array.value(i) == value);
230                    }
231                    BooleanArray::from(mask)
232                }
233                DataType::Int64 => {
234                    let int_array = array.as_any().downcast_ref::<Int64Array>().unwrap();
235                    let mut mask: Vec<bool> = Vec::with_capacity(int_array.len());
236                    for i in 0..int_array.len() {
237                        mask.push((int_array.value(i) as f64) == value);
238                    }
239                    BooleanArray::from(mask)
240                }
241                _ => return Err("Unsupported type for comparison".into()),
242            },
243        };
244
245        let filtered_arrays: Result<Vec<ArrayRef>, _> = batch
246            .columns()
247            .iter()
248            .map(|array| compute::filter(array, &filter_array))
249            .collect();
250
251        let filtered_batch = RecordBatch::try_new(batch.schema(), filtered_arrays?)?;
252        Self::from_arrow(&filtered_batch)
253    }
254
255    /// Aggregation using Arrow compute
256    pub fn arrow_agg(
257        &self,
258        column: &str,
259        agg_func: ArrowAggFunc,
260    ) -> Result<f64, Box<dyn std::error::Error>> {
261        use arrow::compute;
262
263        let batch = self.to_arrow()?;
264        let col_index = batch
265            .schema()
266            .column_with_name(column)
267            .ok_or("Column not found")?
268            .0;
269        let array = batch.column(col_index);
270
271        let result = match agg_func {
272            ArrowAggFunc::Sum => match array.data_type() {
273                DataType::Float64 => {
274                    let float_array = array.as_any().downcast_ref::<Float64Array>().unwrap();
275                    compute::sum(float_array).unwrap_or(0.0)
276                }
277                DataType::Int64 => {
278                    let int_array = array.as_any().downcast_ref::<Int64Array>().unwrap();
279                    compute::sum(int_array).unwrap_or(0) as f64
280                }
281                _ => return Err("Sum not supported for this type".into()),
282            },
283            ArrowAggFunc::Min => match array.data_type() {
284                DataType::Float64 => {
285                    let float_array = array.as_any().downcast_ref::<Float64Array>().unwrap();
286                    compute::min(float_array).unwrap_or(f64::NAN)
287                }
288                DataType::Int64 => {
289                    let int_array = array.as_any().downcast_ref::<Int64Array>().unwrap();
290                    compute::min(int_array).unwrap_or(0) as f64
291                }
292                _ => return Err("Min not supported for this type".into()),
293            },
294            ArrowAggFunc::Max => match array.data_type() {
295                DataType::Float64 => {
296                    let float_array = array.as_any().downcast_ref::<Float64Array>().unwrap();
297                    compute::max(float_array).unwrap_or(f64::NAN)
298                }
299                DataType::Int64 => {
300                    let int_array = array.as_any().downcast_ref::<Int64Array>().unwrap();
301                    compute::max(int_array).unwrap_or(0) as f64
302                }
303                _ => return Err("Max not supported for this type".into()),
304            },
305        };
306
307        Ok(result)
308    }
309
310    /// Zero-copy slice using Arrow
311    pub fn arrow_slice(
312        &self,
313        offset: usize,
314        length: usize,
315    ) -> Result<DataFrame, Box<dyn std::error::Error>> {
316        let batch = self.to_arrow()?;
317        let sliced_arrays: Vec<ArrayRef> = batch
318            .columns()
319            .iter()
320            .map(|array| array.slice(offset, length))
321            .collect();
322
323        let sliced_batch = RecordBatch::try_new(batch.schema(), sliced_arrays)?;
324        Self::from_arrow(&sliced_batch)
325    }
326}
327
328#[derive(Debug, Clone)]
329pub enum ArrowPredicate {
330    Gt(f64),
331    Lt(f64),
332    Eq(f64),
333}
334
335#[derive(Debug, Clone)]
336pub enum ArrowAggFunc {
337    Sum,
338    Min,
339    Max,
340}
341
342// Integration with NumPy (requires Python bindings)
343#[cfg(feature = "python")]
344pub mod numpy_interop {
345    use super::*;
346    use numpy::{PyArray, PyReadonlyArray1};
347    use pyo3::prelude::*;
348    use pyo3::types::PyArray1;
349
350    impl DataFrame {
351        /// Convert Series to NumPy array
352        pub fn series_to_numpy<'py>(
353            &self,
354            py: Python<'py>,
355            column: &str,
356        ) -> PyResult<&'py PyArray1<f64>> {
357            let series = self
358                .get_column(column)
359                .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("Column not found"))?;
360
361            match series {
362                Series::Float64(values) => Ok(PyArray::from_slice(py, values)),
363                Series::Int64(values) => {
364                    let float_values: Vec<f64> = values.iter().map(|&x| x as f64).collect();
365                    Ok(PyArray::from_vec(py, float_values))
366                }
367                _ => Err(pyo3::exceptions::PyTypeError::new_err(
368                    "Only numeric columns can be converted to NumPy arrays",
369                )),
370            }
371        }
372
373        /// Create DataFrame from NumPy array
374        pub fn from_numpy(array: PyReadonlyArray1<f64>, column_name: &str) -> Self {
375            let values: Vec<f64> = array.as_slice().unwrap().to_vec();
376            DataFrame::new(vec![(column_name.to_string(), Series::Float64(values))])
377        }
378    }
379}
380
381// Memory mapping for large files
382pub mod memory_mapped {
383    use super::*;
384    use memmap2::MmapOptions;
385    use std::fs::File;
386
387    impl DataFrame {
388        /// Memory-mapped CSV reading for large files
389        pub fn from_csv_map(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
390            let file = File::open(path)?;
391            let mmap = unsafe { MmapOptions::new().map(&file)? };
392
393            let mut rdr = csv::ReaderBuilder::new()
394                .has_headers(true)
395                .from_reader(&mmap[..]);
396
397            let headers = rdr.headers()?.clone();
398            let mut raw_data: Vec<Vec<String>> = vec![Vec::new(); headers.len()];
399
400            for result in rdr.records() {
401                let record = result?;
402                for (i, field) in record.iter().enumerate() {
403                    if i < raw_data.len() {
404                        raw_data[i].push(field.to_string());
405                    }
406                }
407            }
408
409            let mut series_data = Vec::new();
410            for col_data in raw_data {
411                let col_type = Self::infer_column_type(&col_data);
412                let series = match col_type {
413                    crate::dataframe::core::SeriesType::Float64 => {
414                        let parsed: Vec<f64> = col_data
415                            .iter()
416                            .map(|s| s.trim().parse().unwrap_or(0.0))
417                            .collect();
418                        Series::Float64(parsed)
419                    }
420                    crate::dataframe::core::SeriesType::Int64 => {
421                        let parsed: Vec<i64> = col_data
422                            .iter()
423                            .map(|s| s.trim().parse().unwrap_or(0))
424                            .collect();
425                        Series::Int64(parsed)
426                    }
427                    crate::dataframe::core::SeriesType::Bool => {
428                        let parsed: Vec<bool> = col_data
429                            .iter()
430                            .map(|s| Self::parse_bool(s.trim()).unwrap_or(false))
431                            .collect();
432                        Series::Bool(parsed)
433                    }
434                    crate::dataframe::core::SeriesType::Utf8 => Series::Utf8(col_data),
435                };
436                series_data.push(series);
437            }
438
439            let column_names: Vec<String> = headers.iter().map(|h| h.to_string()).collect();
440            Ok(DataFrame::new(
441                column_names.into_iter().zip(series_data).collect(),
442            ))
443        }
444    }
445}