ella_tensor/
frame.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
mod data_frame;
mod print;

pub use data_frame::DataFrame;
pub use print::print_frames;

use crate::{column::array_to_column, tensor_schema, NamedColumn};
use arrow::{datatypes::Schema, record_batch::RecordBatch};
use std::sync::Arc;

pub trait Frame {
    fn ncols(&self) -> usize;
    fn nrows(&self) -> usize;
    fn column(&self, i: usize) -> &NamedColumn;
    fn columns(&self) -> FrameColIter<'_, Self> {
        FrameColIter {
            frame: self,
            index: 0,
        }
    }
}

impl<'a, F: Frame> Frame for &'a F {
    fn ncols(&self) -> usize {
        (*self).ncols()
    }

    fn nrows(&self) -> usize {
        (*self).nrows()
    }

    fn column(&self, i: usize) -> &NamedColumn {
        (*self).column(i)
    }
}

pub(crate) fn batch_to_columns(rb: &RecordBatch) -> crate::Result<Arc<[NamedColumn]>> {
    let schema = rb.schema();
    let mut columns = Vec::with_capacity(rb.num_columns());

    for (array, field) in rb.columns().iter().zip(schema.fields()) {
        let col = array_to_column(field, array.clone())?;
        columns.push(NamedColumn::new(field.name().clone(), col));
    }
    Ok(columns.into())
}

pub(crate) fn frame_to_batch<F: Frame>(frame: &F) -> RecordBatch {
    let columns = frame.columns().map(|c| c.to_arrow()).collect::<Vec<_>>();
    RecordBatch::try_new(Arc::new(frame_to_schema(frame)), columns).unwrap()
}

pub(crate) fn frame_to_schema<F: Frame>(frame: &F) -> Schema {
    Schema::new(
        frame
            .columns()
            .map(|c| {
                tensor_schema(
                    c.name().to_string(),
                    c.tensor_type(),
                    c.row_shape(),
                    c.nullable(),
                )
            })
            .collect::<Vec<_>>(),
    )
}

pub struct FrameColIter<'a, F: ?Sized> {
    frame: &'a F,
    index: usize,
}

impl<'a, F: Frame> Iterator for FrameColIter<'a, F> {
    type Item = &'a NamedColumn;

    fn next(&mut self) -> Option<Self::Item> {
        if self.index < self.frame.ncols() {
            let col = self.frame.column(self.index);
            self.index += 1;
            Some(col)
        } else {
            None
        }
    }
}

#[macro_export]
macro_rules! frame {
    () => {
        $crate::DataFrame::new()
    };
    ($($($name:tt).+ = $col:expr),+ $(,)?) => {
        [$($crate::NamedColumn::new(stringify!($($name).+).to_string(), std::sync::Arc::new($col) as $crate::ColumnRef)),+]
            .into_iter()
            .collect::<$crate::DataFrame>()
    };
}