1use crate::dataframe::{Row, Rower, Schema};
3use crate::error::LiquidError;
4use crossbeam_utils::thread;
5use deepsize::DeepSizeOf;
6use serde::{Deserialize, Serialize};
7use sorer::dataframe::{from_file, Column, Data};
8use sorer::schema::{infer_schema, DataType};
9use std::cmp::Ordering;
10use std::convert::TryInto;
11
12#[derive(Serialize, Deserialize, PartialEq, Clone, Debug, DeepSizeOf)]
16pub struct LocalDataFrame {
17 pub schema: Schema,
19 pub data: Vec<Column>,
21 pub n_threads: usize,
23 cur_row_idx: usize,
25}
26
27macro_rules! setter {
28 ($func_name:ident, $type:ty, $sorer_type:ident) => {
29 pub fn $func_name(
32 &mut self,
33 col_idx: usize,
34 row_idx: usize,
35 data: $type,
36 ) -> Result<(), LiquidError> {
37 match self.schema.schema.get(col_idx) {
38 Some(DataType::$sorer_type) => {
39 match self.data.get_mut(col_idx) {
40 Some(Column::$sorer_type(col)) => {
41 match col.get_mut(row_idx) {
42 Some(d) => {
43 *d = Some(data);
44 Ok(())
45 }
46 None => Err(LiquidError::RowIndexOutOfBounds),
47 }
48 }
49 None => Err(LiquidError::ColIndexOutOfBounds),
50 _ => panic!("Something is horribly wrong"),
51 }
52 }
53 _ => Err(LiquidError::TypeMismatch),
54 }
55 }
56 };
57}
58
59impl LocalDataFrame {
62 pub fn from_sor(file_name: &str, from: usize, len: usize) -> Self {
67 let schema = Schema::from(infer_schema(file_name).expect("Could not infer schema for {file_name:?}"));
68 let n_threads = num_cpus::get();
69 let data =
70 from_file(file_name, schema.schema.clone(), from, len, n_threads);
71 LocalDataFrame {
72 schema,
73 data,
74 n_threads,
75 cur_row_idx: 0,
76 }
77 }
78
79 pub fn new(schema: &Schema) -> Self {
83 let mut data = Vec::new();
84 for data_type in &schema.schema {
85 match data_type {
86 DataType::Bool => data.push(Column::Bool(Vec::new())),
87 DataType::Int => data.push(Column::Int(Vec::new())),
88 DataType::Float => data.push(Column::Float(Vec::new())),
89 DataType::String => data.push(Column::String(Vec::new())),
90 }
91 }
92 let schema = Schema {
93 schema: schema.schema.clone(),
94 col_names: schema.col_names.clone(),
95 };
96
97 LocalDataFrame {
98 schema,
99 data,
100 n_threads: num_cpus::get(),
101 cur_row_idx: 0,
102 }
103 }
104
105 pub fn get_schema(&self) -> &Schema {
107 &self.schema
108 }
109
110 pub fn add_column(
113 &mut self,
114 col: Column,
115 name: Option<String>,
116 ) -> Result<(), LiquidError> {
117 match &col {
118 Column::Int(_) => self.schema.add_column(DataType::Int, name),
119 Column::Bool(_) => self.schema.add_column(DataType::Bool, name),
120 Column::Float(_) => self.schema.add_column(DataType::Float, name),
121 Column::String(_) => self.schema.add_column(DataType::String, name),
122 }?;
123
124 match self.n_rows().cmp(&col.len()) {
125 Ordering::Equal => self.data.push(col),
126 Ordering::Less => {
127 for j in 0..self.n_cols() - 1 {
130 let c = self.data.get_mut(j).unwrap();
131 for _ in 0..col.len() - c.len() {
132 match c {
133 Column::Bool(x) => x.push(None),
134 Column::Int(x) => x.push(None),
135 Column::Float(x) => x.push(None),
136 Column::String(x) => x.push(None),
137 }
138 }
139 }
140 self.data.push(col)
141 }
142 Ordering::Greater => {
143 let diff = self.n_rows() - col.len();
146 match col {
151 Column::Bool(mut x) => {
152 let nones = vec![None; diff];
153 x.extend(nones.into_iter());
154 self.data.push(Column::Bool(x))
155 }
156 Column::Int(mut x) => {
157 let nones = vec![None; diff];
158 x.extend(nones.into_iter());
159 self.data.push(Column::Int(x))
160 }
161 Column::Float(mut x) => {
162 let nones = vec![None; diff];
163 x.extend(nones.into_iter());
164 self.data.push(Column::Float(x))
165 }
166 Column::String(mut x) => {
167 let nones = vec![None; diff];
168 x.extend(nones.into_iter());
169 self.data.push(Column::String(x))
170 }
171 }
172 }
173 }
174
175 Ok(())
176 }
177
178 pub fn get(
180 &self,
181 col_idx: usize,
182 row_idx: usize,
183 ) -> Result<Data, LiquidError> {
184 match self.data.get(col_idx) {
189 Some(Column::Int(col)) => match col.get(row_idx) {
190 Some(optional_data) => match optional_data {
191 Some(data) => Ok(Data::Int(*data)),
192 None => Ok(Data::Null),
193 },
194 None => Err(LiquidError::RowIndexOutOfBounds),
195 },
196 Some(Column::Bool(col)) => match col.get(row_idx) {
197 Some(optional_data) => match optional_data {
198 Some(data) => Ok(Data::Bool(*data)),
199 None => Ok(Data::Null),
200 },
201 None => Err(LiquidError::RowIndexOutOfBounds),
202 },
203 Some(Column::Float(col)) => match col.get(row_idx) {
204 Some(optional_data) => match optional_data {
205 Some(data) => Ok(Data::Float(*data)),
206 None => Ok(Data::Null),
207 },
208 None => Err(LiquidError::RowIndexOutOfBounds),
209 },
210 Some(Column::String(col)) => match col.get(row_idx) {
211 Some(optional_data) => match optional_data {
212 Some(data) => Ok(Data::String(data.clone())),
213 None => Ok(Data::Null),
214 },
215 None => Err(LiquidError::RowIndexOutOfBounds),
216 },
217 None => Err(LiquidError::ColIndexOutOfBounds),
218 }
219 }
220
221 pub fn get_col_idx(&self, col_name: &str) -> Option<usize> {
224 self.schema.col_idx(col_name)
225 }
226
227 pub fn col_name(
229 &self,
230 col_idx: usize,
231 ) -> Result<Option<&str>, LiquidError> {
232 self.schema.col_name(col_idx)
233 }
234
235 setter!(set_string, String, String);
236 setter!(set_bool, bool, Bool);
237 setter!(set_float, f64, Float);
238 setter!(set_int, i64, Int);
239
240 pub fn fill_row(
246 &self,
247 row_index: usize,
248 row: &mut Row,
249 ) -> Result<(), LiquidError> {
250 for (c_idx, col) in self.data.iter().enumerate() {
251 match col {
252 Column::Int(c) => match c.get(row_index).unwrap() {
253 Some(x) => row.set_int(c_idx, *x)?,
254 None => row.set_null(c_idx)?,
255 },
256 Column::Float(c) => match c.get(row_index).unwrap() {
257 Some(x) => row.set_float(c_idx, *x)?,
258 None => row.set_null(c_idx)?,
259 },
260 Column::Bool(c) => match c.get(row_index).unwrap() {
261 Some(x) => row.set_bool(c_idx, *x)?,
262 None => row.set_null(c_idx)?,
263 },
264 Column::String(c) => match c.get(row_index).unwrap() {
265 Some(x) => row.set_string(c_idx, x.clone())?,
266 None => row.set_null(c_idx)?,
267 },
268 };
269 }
270 row.set_idx(row_index);
271 Ok(())
272 }
273
274 pub fn add_row(&mut self, row: &Row) -> Result<(), LiquidError> {
279 if row.schema != self.schema {
280 return Err(LiquidError::TypeMismatch);
281 }
282
283 for (data, column) in row.data.iter().zip(self.data.iter_mut()) {
284 match (data, column) {
285 (Data::Int(n), Column::Int(l)) => l.push(Some(*n)),
286 (Data::Float(n), Column::Float(l)) => l.push(Some(*n)),
287 (Data::Bool(n), Column::Bool(l)) => l.push(Some(*n)),
288 (Data::String(n), Column::String(l)) => l.push(Some(n.clone())),
289 (Data::Null, Column::Int(l)) => l.push(None),
290 (Data::Null, Column::Float(l)) => l.push(None),
291 (Data::Null, Column::Bool(l)) => l.push(None),
292 (Data::Null, Column::String(l)) => l.push(None),
293 (_, _) => unreachable!("Something is horribly wrong"),
294 };
295 }
296
297 Ok(())
298 }
299
300 pub fn map<T: Rower>(&self, rower: T) -> T {
309 map_helper(self, rower, 0, self.n_rows())
310 }
311
312 pub fn pmap<T: Rower + Clone + Send>(&self, rower: T) -> T {
324 let rowers = vec![rower; self.n_threads];
325 let mut new_rowers = Vec::new();
326 let step = self.n_rows() / self.n_threads;
327 let mut from = 0;
328 thread::scope(|s| {
329 let mut threads = Vec::new();
330 let mut i = 0;
331 for r in rowers {
332 i += 1;
333 let to = if i == self.n_threads {
334 self.n_rows()
335 } else {
336 from + step
337 };
338 threads.push(s.spawn(move |_| map_helper(&self, r, from, to)));
339 from += step;
340 }
341 for thread in threads {
342 new_rowers.push(thread.join().unwrap());
343 }
344 })
345 .unwrap();
346 let acc = new_rowers.pop().unwrap();
347 new_rowers
348 .into_iter()
349 .rev()
350 .fold(acc, |prev, x| x.join(prev))
351 }
352
353 pub fn filter<T: Rower>(&self, rower: &mut T) -> Self {
358 filter_helper(self, rower, 0, self.n_rows())
359 }
360
361 pub fn pfilter<T: Rower + Clone + Send>(&self, rower: &mut T) -> Self {
370 let mut rowers = Vec::new();
371 for _ in 0..self.n_threads {
372 rowers.push(rower.clone());
373 }
374 let mut new_dfs = Vec::new();
377 let step = self.n_rows() / self.n_threads;
378 let mut from = 0;
379 thread::scope(|s| {
380 let mut threads = Vec::new();
381 let mut i = 0;
382 for mut r in rowers {
383 i += 1;
384 let to = if i == self.n_threads {
385 self.n_rows()
386 } else {
387 from + step
388 };
389 threads.push(
390 s.spawn(move |_| filter_helper(&self, &mut r, from, to)),
391 );
392 from += step;
393 }
394 for thread in threads {
395 new_dfs.push(thread.join().unwrap());
396 }
397 })
398 .unwrap();
399 let acc = new_dfs.pop().unwrap();
400 new_dfs
401 .into_iter()
402 .rev()
403 .fold(acc, |prev, x| x.combine(prev).unwrap())
404 }
405
406 pub fn combine(mut self, other: Self) -> Result<Self, LiquidError> {
418 if self.get_schema().schema != other.get_schema().schema {
419 return Err(LiquidError::TypeMismatch);
420 }
421
422 for (col_idx, col) in other.data.into_iter().enumerate() {
423 match self.data.get_mut(col_idx).unwrap() {
424 Column::Bool(result_col) => {
425 let x: Vec<Option<bool>> = col.try_into().unwrap();
426 result_col.extend(x.into_iter())
427 }
428 Column::Int(result_col) => {
429 let x: Vec<Option<i64>> = col.try_into().unwrap();
430 result_col.extend(x.into_iter())
431 }
432 Column::Float(result_col) => {
433 let x: Vec<Option<f64>> = col.try_into().unwrap();
434 result_col.extend(x.into_iter())
435 }
436 Column::String(result_col) => {
437 let x: Vec<Option<String>> = col.try_into().unwrap();
438 result_col.extend(x.into_iter())
439 }
440 }
441 }
442
443 Ok(self)
444 }
445
446 pub fn n_rows(&self) -> usize {
448 if self.data.is_empty() {
449 0
450 } else {
451 self.data[0].len()
452 }
453 }
454
455 pub fn n_cols(&self) -> usize {
457 self.schema.width()
458 }
459}
460
461fn filter_helper<T: Rower>(
462 df: &LocalDataFrame,
463 r: &mut T,
464 start: usize,
465 end: usize,
466) -> LocalDataFrame {
467 let mut df2 = LocalDataFrame::new(&df.schema);
468 let mut row = Row::new(&df.schema);
469
470 for i in start..end {
471 df.fill_row(i, &mut row).unwrap();
472 if r.visit(&row) {
473 df2.add_row(&row).unwrap();
474 }
475 }
476
477 df2
478}
479
480fn map_helper<T: Rower>(
481 df: &LocalDataFrame,
482 mut rower: T,
483 start: usize,
484 end: usize,
485) -> T {
486 let mut row = Row::new(&df.schema);
487 for i in start..end {
489 df.fill_row(i, &mut row).unwrap();
490 rower.visit(&row);
491 }
492 rower
493}
494
495impl From<Column> for LocalDataFrame {
496 fn from(column: Column) -> Self {
498 LocalDataFrame::from(vec![column])
499 }
500}
501
502impl From<Vec<Column>> for LocalDataFrame {
503 fn from(data: Vec<Column>) -> Self {
505 let mut schema = Schema::new();
506 for column in &data {
507 match &column {
508 Column::Bool(_) => {
509 schema.add_column(DataType::Bool, None).unwrap()
510 }
511 Column::Int(_) => {
512 schema.add_column(DataType::Int, None).unwrap()
513 }
514 Column::Float(_) => {
515 schema.add_column(DataType::Float, None).unwrap()
516 }
517 Column::String(_) => {
518 schema.add_column(DataType::String, None).unwrap()
519 }
520 };
521 }
522 let n_threads = num_cpus::get();
523 LocalDataFrame {
524 schema,
525 n_threads,
526 data,
527 cur_row_idx: 0,
528 }
529 }
530}
531
532impl From<Data> for LocalDataFrame {
533 fn from(scalar: Data) -> Self {
535 let c = match scalar {
536 Data::Bool(x) => Column::Bool(vec![Some(x)]),
537 Data::Int(x) => Column::Int(vec![Some(x)]),
538 Data::Float(x) => Column::Float(vec![Some(x)]),
539 Data::String(x) => Column::String(vec![Some(x)]),
540 Data::Null => panic!("Can't make a DataFrame from a null value"),
541 };
542 LocalDataFrame::from(c)
543 }
544}
545
546impl std::fmt::Display for LocalDataFrame {
547 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 for i in 0..self.n_rows() {
549 for j in 0..self.n_cols() {
550 write!(f, "<{}>", self.get(j, i).unwrap())?;
551 }
552 writeln!(f)?;
553 }
554 Ok(())
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use crate::dataframe::{Row, Rower};
562
563 #[derive(Clone)]
564 struct PosIntSummer {
565 sum: i64,
566 }
567
568 impl Rower for PosIntSummer {
569 fn visit(&mut self, r: &Row) -> bool {
570 let i = r.get(0).unwrap();
571 match i {
572 Data::Int(val) => {
573 if *val < 0 {
574 return false;
575 }
576 self.sum += *val;
577 true
578 }
579 _ => panic!(),
580 }
581 }
582
583 fn join(mut self, other: Self) -> Self {
584 self.sum += other.sum;
585 self
586 }
587 }
588
589 fn init() -> LocalDataFrame {
590 let s = Schema::from(vec![DataType::Int]);
591 let mut r = Row::new(&s);
592 let mut df = LocalDataFrame::new(&s);
593
594 for i in 0..1000 {
595 if i % 2 == 0 {
596 r.set_int(0, i * -1).unwrap();
597 } else {
598 r.set_int(0, i).unwrap();
599 }
600 df.add_row(&r).unwrap();
601 }
602
603 df
604 }
605
606 #[test]
607 fn test_combine_err_case() {
608 let s = Schema::from(vec![DataType::Int]);
609 let df1 = LocalDataFrame::new(&s);
610 let s = Schema::from(vec![DataType::Bool]);
611 let df2 = LocalDataFrame::new(&s);
612 assert!(df1.combine(df2).is_err());
613 }
614
615 #[test]
616 fn test_combine() {
617 let s = Schema::from(vec![]);
618 let mut df1 = LocalDataFrame::new(&s);
619 let mut df2 = LocalDataFrame::new(&s);
620 let col1 = Column::Int(vec![Some(1), Some(2), Some(3)]);
621 let col2 = Column::Bool(vec![Some(false), Some(false), Some(false)]);
622 df1.add_column(col1, Some("col1".to_string())).unwrap();
623 df1.add_column(col2, None).unwrap();
624 let col3 = Column::Int(vec![Some(4), Some(5), Some(6)]);
625 let col4 = Column::Bool(vec![Some(true), Some(true), Some(true)]);
626 df2.add_column(col3, None).unwrap();
627 df2.add_column(col4, None).unwrap();
628 let res = df1.combine(df2);
629 assert!(res.is_ok());
630 let combined = res.unwrap();
631 let mut res_schema = Schema::from(vec![DataType::Int, DataType::Bool]);
632 res_schema.col_names.insert("col1".to_string(), 0);
633 assert_eq!(combined.get_schema(), &res_schema);
634 let r = PosIntSummer { sum: 0 };
635 assert_eq!(combined.map(r).sum, 21);
636 }
637
638 #[test]
639 fn test_map() {
640 let df = init();
641 let mut rower = PosIntSummer { sum: 0 };
642 rower = df.map(rower);
643 assert_eq!(1000 * 1000 / 4, rower.sum);
644 assert_eq!(1000, df.n_rows());
645 }
646
647 #[test]
648 fn test_pmap() {
649 let df = init();
650 let mut rower = PosIntSummer { sum: 0 };
651 rower = df.pmap(rower);
652 assert_eq!(1000 * 1000 / 4, rower.sum);
653 assert_eq!(1000, df.n_rows());
654 }
655
656 #[test]
657 fn test_pmap_w_1_thread() {
658 let mut df = init();
659 df.n_threads = 1;
660 let mut rower = PosIntSummer { sum: 0 };
661 rower = df.pmap(rower);
662 assert_eq!(1000 * 1000 / 4, rower.sum);
663 assert_eq!(1000, df.n_rows());
664 }
665
666 #[test]
667 fn test_filter() {
668 let df = init();
669 let mut rower = PosIntSummer { sum: 0 };
670 let df2 = df.filter(&mut rower);
671 assert_eq!(df2.n_rows(), 501);
672 assert_eq!(df2.n_cols(), 1);
673 assert_eq!(df2.get(0, 10).unwrap(), Data::Int(19));
674 }
675
676 #[test]
677 fn test_pfilter() {
678 let df = init();
679 let mut rower = PosIntSummer { sum: 0 };
680 let df2 = df.pfilter(&mut rower);
681 assert_eq!(df2.n_rows(), 501);
682 assert_eq!(df2.n_cols(), 1);
683 assert_eq!(df2.get(0, 10).unwrap(), Data::Int(19));
684 }
685}