liquid_ml/dataframe/
local_dataframe.rs

1//! Defines functionality for a `LocalDataFrame`
2use crate::dataframe::{Row, Rower, Schema};
3use crate::error::LiquidError;
4use crossbeam_utils::thread;
5use deepsize::DeepSizeOf;
6use serde::{Deserialize, Serialize};
7use sorer::dataframe::{from_file, Column, Data};
8use sorer::schema::{infer_schema, DataType};
9use std::cmp::Ordering;
10use std::convert::TryInto;
11
12/// Represents a local data frame which contains data stored in a columnar
13/// format and a well-defined `Schema`. Is useful for data sets that fit into
14/// memory or for testing/debugging purposes.
15#[derive(Serialize, Deserialize, PartialEq, Clone, Debug, DeepSizeOf)]
16pub struct LocalDataFrame {
17    /// The `Schema` of this data frame
18    pub schema: Schema,
19    /// The data of this data frame, in columnar format
20    pub data: Vec<Column>,
21    /// Number of threads for this computer
22    pub n_threads: usize,
23    /// Current row index for implementing the `Iterator` trait
24    cur_row_idx: usize,
25}
26
27macro_rules! setter {
28    ($func_name:ident, $type:ty, $sorer_type:ident) => {
29        /// Mutates the value in this `DataFrame` at the given `col_idx, row_idx`
30        /// to be changed to the given `data`.
31        pub fn $func_name(
32            &mut self,
33            col_idx: usize,
34            row_idx: usize,
35            data: $type,
36        ) -> Result<(), LiquidError> {
37            match self.schema.schema.get(col_idx) {
38                Some(DataType::$sorer_type) => {
39                    match self.data.get_mut(col_idx) {
40                        Some(Column::$sorer_type(col)) => {
41                            match col.get_mut(row_idx) {
42                                Some(d) => {
43                                    *d = Some(data);
44                                    Ok(())
45                                }
46                                None => Err(LiquidError::RowIndexOutOfBounds),
47                            }
48                        }
49                        None => Err(LiquidError::ColIndexOutOfBounds),
50                        _ => panic!("Something is horribly wrong"),
51                    }
52                }
53                _ => Err(LiquidError::TypeMismatch),
54            }
55        }
56    };
57}
58
59/// An implementation for a `LocalDataFrame`, inspired by the data frames used
60/// in `pandas` and `R`.
61impl LocalDataFrame {
62    /// Creates a new `LocalDataFrame` from the given file by reading it in
63    /// parallel using the number of cores available on this machine. Only
64    /// reads `len` bytes of the file, starting at the given byte offset
65    /// `from`.
66    pub fn from_sor(file_name: &str, from: usize, len: usize) -> Self {
67        let schema = Schema::from(infer_schema(file_name).expect("Could not infer schema for {file_name:?}"));
68        let n_threads = num_cpus::get();
69        let data =
70            from_file(file_name, schema.schema.clone(), from, len, n_threads);
71        LocalDataFrame {
72            schema,
73            data,
74            n_threads,
75            cur_row_idx: 0,
76        }
77    }
78
79    /// Creates an empty `LocalDataFrame` from the given `Schema`. The
80    /// `LocalDataFrame` is created with no rows, but the names of the columns
81    /// in the given `schema` are cloned.
82    pub fn new(schema: &Schema) -> Self {
83        let mut data = Vec::new();
84        for data_type in &schema.schema {
85            match data_type {
86                DataType::Bool => data.push(Column::Bool(Vec::new())),
87                DataType::Int => data.push(Column::Int(Vec::new())),
88                DataType::Float => data.push(Column::Float(Vec::new())),
89                DataType::String => data.push(Column::String(Vec::new())),
90            }
91        }
92        let schema = Schema {
93            schema: schema.schema.clone(),
94            col_names: schema.col_names.clone(),
95        };
96
97        LocalDataFrame {
98            schema,
99            data,
100            n_threads: num_cpus::get(),
101            cur_row_idx: 0,
102        }
103    }
104
105    /// Obtains a reference to the schema of this `LocalDataFrame`
106    pub fn get_schema(&self) -> &Schema {
107        &self.schema
108    }
109
110    /// Adds a `Column` to this `LocalDataFrame` with an optional name. Returns
111    /// a `LiquidError::NameAlreadyExists` if the given `name` is not unique.
112    pub fn add_column(
113        &mut self,
114        col: Column,
115        name: Option<String>,
116    ) -> Result<(), LiquidError> {
117        match &col {
118            Column::Int(_) => self.schema.add_column(DataType::Int, name),
119            Column::Bool(_) => self.schema.add_column(DataType::Bool, name),
120            Column::Float(_) => self.schema.add_column(DataType::Float, name),
121            Column::String(_) => self.schema.add_column(DataType::String, name),
122        }?;
123
124        match self.n_rows().cmp(&col.len()) {
125            Ordering::Equal => self.data.push(col),
126            Ordering::Less => {
127                // our data is shorter than `col`, must add Data::Null to
128                // all of our columns until they are equal length w/`col`
129                for j in 0..self.n_cols() - 1 {
130                    let c = self.data.get_mut(j).unwrap();
131                    for _ in 0..col.len() - c.len() {
132                        match c {
133                            Column::Bool(x) => x.push(None),
134                            Column::Int(x) => x.push(None),
135                            Column::Float(x) => x.push(None),
136                            Column::String(x) => x.push(None),
137                        }
138                    }
139                }
140                self.data.push(col)
141            }
142            Ordering::Greater => {
143                // our data is longer than `col`, we must add Data::Null to
144                // `col` until it matches the len of our data
145                let diff = self.n_rows() - col.len();
146                // note that vec![] must be done inside match so types are
147                // correct. a for loop also doesn't work, i tried
148                // Also I know this is ugly but trust me i tried a lot of shit
149                // and this is the only thing that worked
150                match col {
151                    Column::Bool(mut x) => {
152                        let nones = vec![None; diff];
153                        x.extend(nones.into_iter());
154                        self.data.push(Column::Bool(x))
155                    }
156                    Column::Int(mut x) => {
157                        let nones = vec![None; diff];
158                        x.extend(nones.into_iter());
159                        self.data.push(Column::Int(x))
160                    }
161                    Column::Float(mut x) => {
162                        let nones = vec![None; diff];
163                        x.extend(nones.into_iter());
164                        self.data.push(Column::Float(x))
165                    }
166                    Column::String(mut x) => {
167                        let nones = vec![None; diff];
168                        x.extend(nones.into_iter());
169                        self.data.push(Column::String(x))
170                    }
171                }
172            }
173        }
174
175        Ok(())
176    }
177
178    /// Get the `Data` at the given `col_idx`, `row_idx` offsets.
179    pub fn get(
180        &self,
181        col_idx: usize,
182        row_idx: usize,
183    ) -> Result<Data, LiquidError> {
184        // Note that yes this is really ugly, but no it can't be abstracted
185        // without macros (must match on the types) and it is for performance
186        // so that we don't have to box/unbox values when constructing the
187        // DataFrame and mapping over it
188        match self.data.get(col_idx) {
189            Some(Column::Int(col)) => match col.get(row_idx) {
190                Some(optional_data) => match optional_data {
191                    Some(data) => Ok(Data::Int(*data)),
192                    None => Ok(Data::Null),
193                },
194                None => Err(LiquidError::RowIndexOutOfBounds),
195            },
196            Some(Column::Bool(col)) => match col.get(row_idx) {
197                Some(optional_data) => match optional_data {
198                    Some(data) => Ok(Data::Bool(*data)),
199                    None => Ok(Data::Null),
200                },
201                None => Err(LiquidError::RowIndexOutOfBounds),
202            },
203            Some(Column::Float(col)) => match col.get(row_idx) {
204                Some(optional_data) => match optional_data {
205                    Some(data) => Ok(Data::Float(*data)),
206                    None => Ok(Data::Null),
207                },
208                None => Err(LiquidError::RowIndexOutOfBounds),
209            },
210            Some(Column::String(col)) => match col.get(row_idx) {
211                Some(optional_data) => match optional_data {
212                    Some(data) => Ok(Data::String(data.clone())),
213                    None => Ok(Data::Null),
214                },
215                None => Err(LiquidError::RowIndexOutOfBounds),
216            },
217            None => Err(LiquidError::ColIndexOutOfBounds),
218        }
219    }
220
221    /// Get the index of the `Column` with the given `col_name`. Returns `Some`
222    /// if a `Column` with the given name exists, or `None` otherwise.
223    pub fn get_col_idx(&self, col_name: &str) -> Option<usize> {
224        self.schema.col_idx(col_name)
225    }
226
227    /// Given a column index, returns its name
228    pub fn col_name(
229        &self,
230        col_idx: usize,
231    ) -> Result<Option<&str>, LiquidError> {
232        self.schema.col_name(col_idx)
233    }
234
235    setter!(set_string, String, String);
236    setter!(set_bool, bool, Bool);
237    setter!(set_float, f64, Float);
238    setter!(set_int, i64, Int);
239
240    /// Set the fields of the given `Row` struct with values from this
241    /// `DataFrame` at the given `row_index`.
242    ///
243    /// If the `row` does not have the same schema as this `DataFrame`, a
244    /// `LiquidError::TypeMismatch` error will be returned.
245    pub fn fill_row(
246        &self,
247        row_index: usize,
248        row: &mut Row,
249    ) -> Result<(), LiquidError> {
250        for (c_idx, col) in self.data.iter().enumerate() {
251            match col {
252                Column::Int(c) => match c.get(row_index).unwrap() {
253                    Some(x) => row.set_int(c_idx, *x)?,
254                    None => row.set_null(c_idx)?,
255                },
256                Column::Float(c) => match c.get(row_index).unwrap() {
257                    Some(x) => row.set_float(c_idx, *x)?,
258                    None => row.set_null(c_idx)?,
259                },
260                Column::Bool(c) => match c.get(row_index).unwrap() {
261                    Some(x) => row.set_bool(c_idx, *x)?,
262                    None => row.set_null(c_idx)?,
263                },
264                Column::String(c) => match c.get(row_index).unwrap() {
265                    Some(x) => row.set_string(c_idx, x.clone())?,
266                    None => row.set_null(c_idx)?,
267                },
268            };
269        }
270        row.set_idx(row_index);
271        Ok(())
272    }
273
274    /// Add a `Row` at the end of this `DataFrame`.
275    ///
276    /// If the `row` does not have the same schema as this `DataFrame`, a
277    /// `LiquidError::TypeMismatch` error will be returned.
278    pub fn add_row(&mut self, row: &Row) -> Result<(), LiquidError> {
279        if row.schema != self.schema {
280            return Err(LiquidError::TypeMismatch);
281        }
282
283        for (data, column) in row.data.iter().zip(self.data.iter_mut()) {
284            match (data, column) {
285                (Data::Int(n), Column::Int(l)) => l.push(Some(*n)),
286                (Data::Float(n), Column::Float(l)) => l.push(Some(*n)),
287                (Data::Bool(n), Column::Bool(l)) => l.push(Some(*n)),
288                (Data::String(n), Column::String(l)) => l.push(Some(n.clone())),
289                (Data::Null, Column::Int(l)) => l.push(None),
290                (Data::Null, Column::Float(l)) => l.push(None),
291                (Data::Null, Column::Bool(l)) => l.push(None),
292                (Data::Null, Column::String(l)) => l.push(None),
293                (_, _) => unreachable!("Something is horribly wrong"),
294            };
295        }
296
297        Ok(())
298    }
299
300    /// Applies the given `rower` synchronously to every row in this
301    /// `LocalDataFrame`
302    ///
303    /// Since `map` takes an immutable reference to `self`, the `rower` can
304    /// not mutate this `DataFrame`. If mutation is desired, the `rower` must
305    /// create its own `DataFrame` internally, clone each `Row` from this
306    /// `DataFrame` as it visits them, and mutate the cloned row during each
307    /// visit.
308    pub fn map<T: Rower>(&self, rower: T) -> T {
309        map_helper(self, rower, 0, self.n_rows())
310    }
311
312    /// Applies the given `rower` to every row sequentially in this `DataFrame`
313    /// The `rower` is cloned `n_threads` times, according to the value of
314    /// `n_threads` for this `LocalDataFrame`. Each `rower` gets operates on a
315    /// chunk of this `LocalDataFrame` and are run in parallel.
316    ///
317    /// Since `pmap` takes an immutable reference to `self`, the `rower` can
318    /// not mutate this `LocalDataFrame`. If mutation is desired, the `rower`
319    /// must create its own `LocalDataFrame` internally by building one up
320    /// as it visit rows, and mutates that.
321    ///
322    /// `n_threads` defaults to the number of cores available on this machine.
323    pub fn pmap<T: Rower + Clone + Send>(&self, rower: T) -> T {
324        let rowers = vec![rower; self.n_threads];
325        let mut new_rowers = Vec::new();
326        let step = self.n_rows() / self.n_threads;
327        let mut from = 0;
328        thread::scope(|s| {
329            let mut threads = Vec::new();
330            let mut i = 0;
331            for r in rowers {
332                i += 1;
333                let to = if i == self.n_threads {
334                    self.n_rows()
335                } else {
336                    from + step
337                };
338                threads.push(s.spawn(move |_| map_helper(&self, r, from, to)));
339                from += step;
340            }
341            for thread in threads {
342                new_rowers.push(thread.join().unwrap());
343            }
344        })
345        .unwrap();
346        let acc = new_rowers.pop().unwrap();
347        new_rowers
348            .into_iter()
349            .rev()
350            .fold(acc, |prev, x| x.join(prev))
351    }
352
353    /// Creates a new `LocalDataFrame` by applying the given `rower` to every
354    /// row sequentially in this `LocalDataFrame` and cloning rows for which
355    /// the given `rower` returns true from its `accept` method. Is run
356    /// synchronously.
357    pub fn filter<T: Rower>(&self, rower: &mut T) -> Self {
358        filter_helper(self, rower, 0, self.n_rows())
359    }
360
361    /// Creates a new `LocalDataFrame` by applying the given `rower` to every
362    /// row in this data frame sequentially, and cloning rows for which the
363    /// given `rower` returns true from its `accept` method. The `rower` is
364    /// cloned `n_threads` times, according to the value of `n_threads` for
365    /// this `LocalDataFrame`. Each `rower` gets operates on a chunk of this
366    /// `LocalDataFrame` and are run in parallel.
367    ///
368    /// `n_threads` defaults to the number of cores available on this machine.
369    pub fn pfilter<T: Rower + Clone + Send>(&self, rower: &mut T) -> Self {
370        let mut rowers = Vec::new();
371        for _ in 0..self.n_threads {
372            rowers.push(rower.clone());
373        }
374        // ok.... the below syntax doesn't work, not sure why
375        //    let rowers = vec![*rower; self.n_threads];
376        let mut new_dfs = Vec::new();
377        let step = self.n_rows() / self.n_threads;
378        let mut from = 0;
379        thread::scope(|s| {
380            let mut threads = Vec::new();
381            let mut i = 0;
382            for mut r in rowers {
383                i += 1;
384                let to = if i == self.n_threads {
385                    self.n_rows()
386                } else {
387                    from + step
388                };
389                threads.push(
390                    s.spawn(move |_| filter_helper(&self, &mut r, from, to)),
391                );
392                from += step;
393            }
394            for thread in threads {
395                new_dfs.push(thread.join().unwrap());
396            }
397        })
398        .unwrap();
399        let acc = new_dfs.pop().unwrap();
400        new_dfs
401            .into_iter()
402            .rev()
403            .fold(acc, |prev, x| x.combine(prev).unwrap())
404    }
405
406    /// Consumes this `LocalDataFrame` and the other given `LocalDataFrame`,
407    /// returning a combined `LocalDataFrame` if successful.
408    ///
409    /// - The columns names and the number of threads for the resulting
410    ///   `LocalDataFrame` are from this `LocalDataFrame` and the column names
411    ///   and `n_threads` in `other` are ignored
412    /// - The data of `other` is appended to the data of this `LocalDataFrame`
413    ///
414    /// # Errors
415    /// If the schema of this `LocalDataFrame` and `other` have different
416    /// `DataType`s
417    pub fn combine(mut self, other: Self) -> Result<Self, LiquidError> {
418        if self.get_schema().schema != other.get_schema().schema {
419            return Err(LiquidError::TypeMismatch);
420        }
421
422        for (col_idx, col) in other.data.into_iter().enumerate() {
423            match self.data.get_mut(col_idx).unwrap() {
424                Column::Bool(result_col) => {
425                    let x: Vec<Option<bool>> = col.try_into().unwrap();
426                    result_col.extend(x.into_iter())
427                }
428                Column::Int(result_col) => {
429                    let x: Vec<Option<i64>> = col.try_into().unwrap();
430                    result_col.extend(x.into_iter())
431                }
432                Column::Float(result_col) => {
433                    let x: Vec<Option<f64>> = col.try_into().unwrap();
434                    result_col.extend(x.into_iter())
435                }
436                Column::String(result_col) => {
437                    let x: Vec<Option<String>> = col.try_into().unwrap();
438                    result_col.extend(x.into_iter())
439                }
440            }
441        }
442
443        Ok(self)
444    }
445
446    /// Return the number of rows in this `DataFrame`.
447    pub fn n_rows(&self) -> usize {
448        if self.data.is_empty() {
449            0
450        } else {
451            self.data[0].len()
452        }
453    }
454
455    /// Return the number of columns in this `DataFrame`.
456    pub fn n_cols(&self) -> usize {
457        self.schema.width()
458    }
459}
460
461fn filter_helper<T: Rower>(
462    df: &LocalDataFrame,
463    r: &mut T,
464    start: usize,
465    end: usize,
466) -> LocalDataFrame {
467    let mut df2 = LocalDataFrame::new(&df.schema);
468    let mut row = Row::new(&df.schema);
469
470    for i in start..end {
471        df.fill_row(i, &mut row).unwrap();
472        if r.visit(&row) {
473            df2.add_row(&row).unwrap();
474        }
475    }
476
477    df2
478}
479
480fn map_helper<T: Rower>(
481    df: &LocalDataFrame,
482    mut rower: T,
483    start: usize,
484    end: usize,
485) -> T {
486    let mut row = Row::new(&df.schema);
487    // NOTE: IS THIS THE ~10% slower way to do counted loop???? @tom
488    for i in start..end {
489        df.fill_row(i, &mut row).unwrap();
490        rower.visit(&row);
491    }
492    rower
493}
494
495impl From<Column> for LocalDataFrame {
496    /// Construct a new `DataFrame` with the given `column`.
497    fn from(column: Column) -> Self {
498        LocalDataFrame::from(vec![column])
499    }
500}
501
502impl From<Vec<Column>> for LocalDataFrame {
503    /// Construct a new `DataFrame` with the given `columns`.
504    fn from(data: Vec<Column>) -> Self {
505        let mut schema = Schema::new();
506        for column in &data {
507            match &column {
508                Column::Bool(_) => {
509                    schema.add_column(DataType::Bool, None).unwrap()
510                }
511                Column::Int(_) => {
512                    schema.add_column(DataType::Int, None).unwrap()
513                }
514                Column::Float(_) => {
515                    schema.add_column(DataType::Float, None).unwrap()
516                }
517                Column::String(_) => {
518                    schema.add_column(DataType::String, None).unwrap()
519                }
520            };
521        }
522        let n_threads = num_cpus::get();
523        LocalDataFrame {
524            schema,
525            n_threads,
526            data,
527            cur_row_idx: 0,
528        }
529    }
530}
531
532impl From<Data> for LocalDataFrame {
533    /// Construct a new `DataFrame` with the given `scalar` value.
534    fn from(scalar: Data) -> Self {
535        let c = match scalar {
536            Data::Bool(x) => Column::Bool(vec![Some(x)]),
537            Data::Int(x) => Column::Int(vec![Some(x)]),
538            Data::Float(x) => Column::Float(vec![Some(x)]),
539            Data::String(x) => Column::String(vec![Some(x)]),
540            Data::Null => panic!("Can't make a DataFrame from a null value"),
541        };
542        LocalDataFrame::from(c)
543    }
544}
545
546impl std::fmt::Display for LocalDataFrame {
547    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548        for i in 0..self.n_rows() {
549            for j in 0..self.n_cols() {
550                write!(f, "<{}>", self.get(j, i).unwrap())?;
551            }
552            writeln!(f)?;
553        }
554        Ok(())
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use crate::dataframe::{Row, Rower};
562
563    #[derive(Clone)]
564    struct PosIntSummer {
565        sum: i64,
566    }
567
568    impl Rower for PosIntSummer {
569        fn visit(&mut self, r: &Row) -> bool {
570            let i = r.get(0).unwrap();
571            match i {
572                Data::Int(val) => {
573                    if *val < 0 {
574                        return false;
575                    }
576                    self.sum += *val;
577                    true
578                }
579                _ => panic!(),
580            }
581        }
582
583        fn join(mut self, other: Self) -> Self {
584            self.sum += other.sum;
585            self
586        }
587    }
588
589    fn init() -> LocalDataFrame {
590        let s = Schema::from(vec![DataType::Int]);
591        let mut r = Row::new(&s);
592        let mut df = LocalDataFrame::new(&s);
593
594        for i in 0..1000 {
595            if i % 2 == 0 {
596                r.set_int(0, i * -1).unwrap();
597            } else {
598                r.set_int(0, i).unwrap();
599            }
600            df.add_row(&r).unwrap();
601        }
602
603        df
604    }
605
606    #[test]
607    fn test_combine_err_case() {
608        let s = Schema::from(vec![DataType::Int]);
609        let df1 = LocalDataFrame::new(&s);
610        let s = Schema::from(vec![DataType::Bool]);
611        let df2 = LocalDataFrame::new(&s);
612        assert!(df1.combine(df2).is_err());
613    }
614
615    #[test]
616    fn test_combine() {
617        let s = Schema::from(vec![]);
618        let mut df1 = LocalDataFrame::new(&s);
619        let mut df2 = LocalDataFrame::new(&s);
620        let col1 = Column::Int(vec![Some(1), Some(2), Some(3)]);
621        let col2 = Column::Bool(vec![Some(false), Some(false), Some(false)]);
622        df1.add_column(col1, Some("col1".to_string())).unwrap();
623        df1.add_column(col2, None).unwrap();
624        let col3 = Column::Int(vec![Some(4), Some(5), Some(6)]);
625        let col4 = Column::Bool(vec![Some(true), Some(true), Some(true)]);
626        df2.add_column(col3, None).unwrap();
627        df2.add_column(col4, None).unwrap();
628        let res = df1.combine(df2);
629        assert!(res.is_ok());
630        let combined = res.unwrap();
631        let mut res_schema = Schema::from(vec![DataType::Int, DataType::Bool]);
632        res_schema.col_names.insert("col1".to_string(), 0);
633        assert_eq!(combined.get_schema(), &res_schema);
634        let r = PosIntSummer { sum: 0 };
635        assert_eq!(combined.map(r).sum, 21);
636    }
637
638    #[test]
639    fn test_map() {
640        let df = init();
641        let mut rower = PosIntSummer { sum: 0 };
642        rower = df.map(rower);
643        assert_eq!(1000 * 1000 / 4, rower.sum);
644        assert_eq!(1000, df.n_rows());
645    }
646
647    #[test]
648    fn test_pmap() {
649        let df = init();
650        let mut rower = PosIntSummer { sum: 0 };
651        rower = df.pmap(rower);
652        assert_eq!(1000 * 1000 / 4, rower.sum);
653        assert_eq!(1000, df.n_rows());
654    }
655
656    #[test]
657    fn test_pmap_w_1_thread() {
658        let mut df = init();
659        df.n_threads = 1;
660        let mut rower = PosIntSummer { sum: 0 };
661        rower = df.pmap(rower);
662        assert_eq!(1000 * 1000 / 4, rower.sum);
663        assert_eq!(1000, df.n_rows());
664    }
665
666    #[test]
667    fn test_filter() {
668        let df = init();
669        let mut rower = PosIntSummer { sum: 0 };
670        let df2 = df.filter(&mut rower);
671        assert_eq!(df2.n_rows(), 501);
672        assert_eq!(df2.n_cols(), 1);
673        assert_eq!(df2.get(0, 10).unwrap(), Data::Int(19));
674    }
675
676    #[test]
677    fn test_pfilter() {
678        let df = init();
679        let mut rower = PosIntSummer { sum: 0 };
680        let df2 = df.pfilter(&mut rower);
681        assert_eq!(df2.n_rows(), 501);
682        assert_eq!(df2.n_cols(), 1);
683        assert_eq!(df2.get(0, 10).unwrap(), Data::Int(19));
684    }
685}