cherry_query/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use arrow::array::{
3    Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
4    Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
5    UInt32Array, UInt64Array, UInt8Array,
6};
7use arrow::buffer::BooleanBuffer;
8use arrow::compute;
9use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
10use arrow::record_batch::RecordBatch;
11use arrow::row::{RowConverter, SortField};
12use hashbrown::HashTable;
13use rayon::prelude::*;
14use std::collections::btree_map::Entry;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use xxhash_rust::xxh3::xxh3_64;
18
19type TableName = String;
20type FieldName = String;
21
22#[derive(Clone)]
23pub struct Query {
24    pub selection: Arc<BTreeMap<TableName, Vec<TableSelection>>>,
25    pub fields: BTreeMap<TableName, Vec<FieldName>>,
26}
27
28impl Query {
29    pub fn add_request_and_include_fields(&mut self) -> Result<()> {
30        for (table_name, selections) in self.selection.iter() {
31            for selection in selections.iter() {
32                for col_name in selection.filters.keys() {
33                    let table_fields = self
34                        .fields
35                        .get_mut(table_name)
36                        .with_context(|| format!("get fields for table {}", table_name))?;
37                    table_fields.push(col_name.to_owned());
38                }
39
40                for include in selection.include.iter() {
41                    let other_table_fields = self
42                        .fields
43                        .get_mut(&include.other_table_name)
44                        .with_context(|| {
45                            format!("get fields for other table {}", include.other_table_name)
46                        })?;
47                    other_table_fields.extend_from_slice(&include.other_table_field_names);
48                    let table_fields = self
49                        .fields
50                        .get_mut(table_name)
51                        .with_context(|| format!("get fields for table {}", table_name))?;
52                    table_fields.extend_from_slice(&include.field_names);
53                }
54            }
55        }
56
57        Ok(())
58    }
59}
60
61pub struct TableSelection {
62    pub filters: BTreeMap<FieldName, Filter>,
63    pub include: Vec<Include>,
64}
65
66pub struct Include {
67    pub other_table_name: TableName,
68    pub field_names: Vec<FieldName>,
69    pub other_table_field_names: Vec<FieldName>,
70}
71
72pub enum Filter {
73    Contains(Contains),
74    Bool(bool),
75}
76
77impl Filter {
78    pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
79        Ok(Self::Contains(Contains::new(arr)?))
80    }
81
82    pub fn bool(b: bool) -> Self {
83        Self::Bool(b)
84    }
85
86    fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
87        match self {
88            Self::Contains(ct) => ct.contains(arr),
89            Self::Bool(b) => {
90                let arr = arr
91                    .as_any()
92                    .downcast_ref::<BooleanArray>()
93                    .context("cast array to boolean array")?;
94
95                let mut filter = if *b {
96                    arr.clone()
97                } else {
98                    compute::not(arr).context("negate array")?
99                };
100
101                if let Some(nulls) = filter.nulls() {
102                    if nulls.null_count() > 0 {
103                        let nulls = BooleanArray::from(nulls.inner().clone());
104                        filter = compute::and(&filter, &nulls).unwrap();
105                    }
106                }
107
108                Ok(filter)
109            }
110        }
111    }
112}
113
114pub struct Contains {
115    array: Arc<dyn Array>,
116    hash_table: Option<HashTable<usize>>,
117}
118
119impl Contains {
120    fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
121        assert!(!arr.is_nullable());
122
123        let mut ht = HashTable::with_capacity(arr.len());
124
125        for (i, v) in arr.values().iter().enumerate() {
126            ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
127                xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
128            });
129        }
130
131        ht
132    }
133
134    fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
135        arr: &GenericByteArray<T>,
136    ) -> HashTable<usize> {
137        assert!(!arr.is_nullable());
138
139        let mut ht = HashTable::with_capacity(arr.len());
140
141        for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
142            ht.insert_unique(xxh3_64(v), i, |i| {
143                xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
144            });
145        }
146
147        ht
148    }
149
150    fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
151        let ht = match *array.data_type() {
152            DataType::UInt8 => {
153                let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
154                Self::ht_from_primitive(array)
155            }
156            DataType::UInt16 => {
157                let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
158                Self::ht_from_primitive(array)
159            }
160            DataType::UInt32 => {
161                let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
162                Self::ht_from_primitive(array)
163            }
164            DataType::UInt64 => {
165                let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
166                Self::ht_from_primitive(array)
167            }
168            DataType::Int8 => {
169                let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
170                Self::ht_from_primitive(array)
171            }
172            DataType::Int16 => {
173                let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
174                Self::ht_from_primitive(array)
175            }
176            DataType::Int32 => {
177                let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
178                Self::ht_from_primitive(array)
179            }
180            DataType::Int64 => {
181                let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
182                Self::ht_from_primitive(array)
183            }
184            DataType::Binary => {
185                let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
186                Self::ht_from_bytes(array)
187            }
188            DataType::Utf8 => {
189                let array = array.as_any().downcast_ref::<StringArray>().unwrap();
190                Self::ht_from_bytes(array)
191            }
192            _ => {
193                return Err(anyhow!("unsupported data type: {}", array.data_type()));
194            }
195        };
196
197        Ok(ht)
198    }
199
200    pub fn new(array: Arc<dyn Array>) -> Result<Self> {
201        if array.is_nullable() {
202            return Err(anyhow!(
203                "cannot construct contains filter with a nullable array"
204            ));
205        }
206
207        // only use a hash table if there are more than 128 elements
208        let hash_table = if array.len() >= 128 {
209            Some(Self::ht_from_array(&array).context("construct hash table")?)
210        } else {
211            None
212        };
213
214        Ok(Self { hash_table, array })
215    }
216
217    fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
218        if arr.data_type() != self.array.data_type() {
219            return Err(anyhow!(
220                "filter array is of type {} but array to be filtered is of type {}",
221                self.array.data_type(),
222                arr.data_type(),
223            ));
224        }
225        assert!(!self.array.is_nullable());
226
227        let filter = match *arr.data_type() {
228            DataType::UInt8 => {
229                let self_arr = self.array.as_any().downcast_ref::<UInt8Array>().unwrap();
230                let other_arr = arr.as_any().downcast_ref().unwrap();
231                self.contains_primitive(self_arr, other_arr)
232            }
233            DataType::UInt16 => {
234                let self_arr = self.array.as_any().downcast_ref::<UInt16Array>().unwrap();
235                let other_arr = arr.as_any().downcast_ref().unwrap();
236                self.contains_primitive(self_arr, other_arr)
237            }
238            DataType::UInt32 => {
239                let self_arr = self.array.as_any().downcast_ref::<UInt32Array>().unwrap();
240                let other_arr = arr.as_any().downcast_ref().unwrap();
241                self.contains_primitive(self_arr, other_arr)
242            }
243            DataType::UInt64 => {
244                let self_arr = self.array.as_any().downcast_ref::<UInt64Array>().unwrap();
245                let other_arr = arr.as_any().downcast_ref().unwrap();
246                self.contains_primitive(self_arr, other_arr)
247            }
248            DataType::Int8 => {
249                let self_arr = self.array.as_any().downcast_ref::<Int8Array>().unwrap();
250                let other_arr = arr.as_any().downcast_ref().unwrap();
251                self.contains_primitive(self_arr, other_arr)
252            }
253            DataType::Int16 => {
254                let self_arr = self.array.as_any().downcast_ref::<Int16Array>().unwrap();
255                let other_arr = arr.as_any().downcast_ref().unwrap();
256                self.contains_primitive(self_arr, other_arr)
257            }
258            DataType::Int32 => {
259                let self_arr = self.array.as_any().downcast_ref::<Int32Array>().unwrap();
260                let other_arr = arr.as_any().downcast_ref().unwrap();
261                self.contains_primitive(self_arr, other_arr)
262            }
263            DataType::Int64 => {
264                let self_arr = self.array.as_any().downcast_ref::<Int64Array>().unwrap();
265                let other_arr = arr.as_any().downcast_ref().unwrap();
266                self.contains_primitive(self_arr, other_arr)
267            }
268            DataType::Binary => {
269                let self_arr = self.array.as_any().downcast_ref::<BinaryArray>().unwrap();
270                let other_arr = arr.as_any().downcast_ref().unwrap();
271                self.contains_bytes(self_arr, other_arr)
272            }
273            DataType::Utf8 => {
274                let self_arr = self.array.as_any().downcast_ref::<StringArray>().unwrap();
275                let other_arr = arr.as_any().downcast_ref().unwrap();
276                self.contains_bytes(self_arr, other_arr)
277            }
278            _ => {
279                return Err(anyhow!("unsupported data type: {}", arr.data_type()));
280            }
281        };
282
283        let mut filter = filter;
284
285        if let Some(nulls) = arr.nulls() {
286            if nulls.null_count() > 0 {
287                let nulls = BooleanArray::from(nulls.inner().clone());
288                filter = compute::and(&filter, &nulls).unwrap();
289            }
290        }
291
292        Ok(filter)
293    }
294
295    fn contains_primitive<T: ArrowPrimitiveType>(
296        &self,
297        self_arr: &PrimitiveArray<T>,
298        other_arr: &PrimitiveArray<T>,
299    ) -> BooleanArray {
300        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
301
302        if let Some(ht) = self.hash_table.as_ref() {
303            let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
304
305            for v in other_arr.values().iter() {
306                let c = ht
307                    .find(hash_one(v), |idx| unsafe {
308                        self_arr.values().get_unchecked(*idx) == v
309                    })
310                    .is_some();
311                filter.append_value(c);
312            }
313        } else {
314            for v in other_arr.values().iter() {
315                filter.append_value(self_arr.values().iter().any(|x| x == v));
316            }
317        }
318
319        filter.finish()
320    }
321
322    fn contains_bytes<T: ByteArrayType<Offset = i32>>(
323        &self,
324        self_arr: &GenericByteArray<T>,
325        other_arr: &GenericByteArray<T>,
326    ) -> BooleanArray {
327        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
328
329        if let Some(ht) = self.hash_table.as_ref() {
330            for v in iter_byte_array_without_validity(other_arr) {
331                let c = ht
332                    .find(xxh3_64(v), |idx| unsafe {
333                        byte_array_get_unchecked(self_arr, *idx) == v
334                    })
335                    .is_some();
336                filter.append_value(c);
337            }
338        } else {
339            for v in iter_byte_array_without_validity(other_arr) {
340                filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
341            }
342        }
343
344        filter.finish()
345    }
346}
347
348// Taken from arrow-rs
349// https://docs.rs/arrow-array/54.2.1/src/arrow_array/array/byte_array.rs.html#278
350unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
351    arr: &GenericByteArray<T>,
352    i: usize,
353) -> &[u8] {
354    let end = *arr.value_offsets().get_unchecked(i + 1);
355    let start = *arr.value_offsets().get_unchecked(i);
356
357    std::slice::from_raw_parts(
358        arr.value_data()
359            .as_ptr()
360            .offset(isize::try_from(start).unwrap()),
361        usize::try_from(end - start).unwrap(),
362    )
363}
364
365fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
366    arr: &GenericByteArray<T>,
367) -> impl Iterator<Item = &[u8]> {
368    (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
369}
370
371pub fn run_query(
372    data: &BTreeMap<TableName, RecordBatch>,
373    query: &Query,
374) -> Result<BTreeMap<TableName, RecordBatch>> {
375    let filters = query
376        .selection
377        .par_iter()
378        .map(|(table_name, selections)| {
379            selections
380                .par_iter()
381                .enumerate()
382                .map(|(i, selection)| {
383                    run_table_selection(data, table_name, selection).with_context(|| {
384                        format!("run table selection no:{} for table {}", i, table_name)
385                    })
386                })
387                .collect::<Result<Vec<_>>>()
388        })
389        .collect::<Result<Vec<_>>>()?;
390
391    let data = select_fields(data, &query.fields).context("select fields")?;
392
393    data.par_iter()
394        .filter_map(|(table_name, table_data)| {
395            let mut combined_filter: Option<BooleanArray> = None;
396
397            for f in filters.iter() {
398                for f in f.iter() {
399                    let filter = match f.get(table_name) {
400                        Some(f) => f,
401                        None => continue,
402                    };
403
404                    match combined_filter.as_ref() {
405                        Some(e) => {
406                            let f = compute::or(e, filter)
407                                .with_context(|| format!("combine filters for {}", table_name));
408                            let f = match f {
409                                Ok(v) => v,
410                                Err(err) => return Some(Err(err)),
411                            };
412                            combined_filter = Some(f);
413                        }
414                        None => {
415                            combined_filter = Some(filter.clone());
416                        }
417                    }
418                }
419            }
420
421            let combined_filter = match combined_filter {
422                Some(f) => f,
423                None => return None,
424            };
425
426            let table_data = compute::filter_record_batch(table_data, &combined_filter)
427                .context("filter record batch");
428            let table_data = match table_data {
429                Ok(v) => v,
430                Err(err) => return Some(Err(err)),
431            };
432
433            Some(Ok((table_name.to_owned(), table_data)))
434        })
435        .collect()
436}
437
438pub fn select_fields(
439    data: &BTreeMap<TableName, RecordBatch>,
440    fields: &BTreeMap<TableName, Vec<FieldName>>,
441) -> Result<BTreeMap<TableName, RecordBatch>> {
442    let mut out = BTreeMap::new();
443
444    for (table_name, field_names) in fields.iter() {
445        let table_data = data
446            .get(table_name)
447            .with_context(|| format!("get data for table {}", table_name))?;
448
449        let indices = table_data
450            .schema_ref()
451            .fields()
452            .iter()
453            .enumerate()
454            .filter(|(_, field)| field_names.contains(field.name()))
455            .map(|(i, _)| i)
456            .collect::<Vec<usize>>();
457
458        let table_data = table_data
459            .project(&indices)
460            .with_context(|| format!("project table {}", table_name))?;
461        out.insert(table_name.to_owned(), table_data);
462    }
463
464    Ok(out)
465}
466
467fn run_table_selection(
468    data: &BTreeMap<TableName, RecordBatch>,
469    table_name: &str,
470    selection: &TableSelection,
471) -> Result<BTreeMap<TableName, BooleanArray>> {
472    let mut out = BTreeMap::new();
473
474    let table_data = data.get(table_name).context("get table data")?;
475    let mut combined_filter = None;
476    for (field_name, filter) in selection.filters.iter() {
477        let col = table_data
478            .column_by_name(field_name)
479            .with_context(|| format!("get field {}", field_name))?;
480
481        let f = filter
482            .check(&col)
483            .with_context(|| format!("check filter for column {}", field_name))?;
484
485        match combined_filter {
486            Some(cf) => {
487                combined_filter = Some(
488                    compute::and(&cf, &f)
489                        .with_context(|| format!("combine filter for column {}", field_name))?,
490                );
491            }
492            None => {
493                combined_filter = Some(f);
494            }
495        }
496    }
497
498    let combined_filter = match combined_filter {
499        Some(cf) => cf,
500        None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
501    };
502
503    out.insert(table_name.to_owned(), combined_filter.clone());
504
505    let mut filtered_cache = BTreeMap::new();
506
507    for (i, inc) in selection.include.iter().enumerate() {
508        if inc.other_table_field_names.len() != inc.field_names.len() {
509            return Err(anyhow!(
510                "field names are different for self table and other table while processing include no: {}. {} {}",
511                i,
512                inc.field_names.len(),
513                inc.other_table_field_names.len(),
514            ));
515        }
516
517        let other_table_data = data.get(&inc.other_table_name).with_context(|| {
518            format!(
519                "get data for table {} as other table data",
520                inc.other_table_name
521            )
522        })?;
523
524        let self_arr = columns_to_binary_array(table_data, &inc.field_names)
525            .context("get row format binary arr for self")?;
526
527        let contains = match filtered_cache.entry(inc.field_names.clone()) {
528            Entry::Vacant(entry) => {
529                let self_arr = compute::filter(&self_arr, &combined_filter)
530                    .context("apply combined filter to self arr")?;
531                let contains =
532                    Contains::new(Arc::new(self_arr)).context("create contains filter")?;
533                let contains = Arc::new(contains);
534                entry.insert(Arc::clone(&contains));
535                contains
536            }
537            Entry::Occupied(entry) => Arc::clone(entry.get()),
538        };
539
540        let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
541            .with_context(|| {
542                format!(
543                    "get row format binary arr for other table {}",
544                    inc.other_table_name
545                )
546            })?;
547
548        let f = contains
549            .contains(&other_arr)
550            .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
551
552        match out.entry(inc.other_table_name.clone()) {
553            Entry::Vacant(entry) => {
554                entry.insert(f);
555            }
556            Entry::Occupied(mut entry) => {
557                let new = compute::or(entry.get(), &f).with_context(|| {
558                    format!("or include filters for table {}", inc.other_table_name)
559                })?;
560                entry.insert(new);
561            }
562        }
563    }
564
565    Ok(out)
566}
567
568fn columns_to_binary_array(
569    table_data: &RecordBatch,
570    column_names: &[String],
571) -> Result<BinaryArray> {
572    let fields = column_names
573        .iter()
574        .map(|field_name| {
575            let f = table_data
576                .schema_ref()
577                .field_with_name(field_name)
578                .with_context(|| format!("get field {} from schema", field_name))?;
579            Ok(SortField::new(f.data_type().clone()))
580        })
581        .collect::<Result<Vec<_>>>()?;
582    let conv = RowConverter::new(fields).context("create row converter")?;
583
584    let columns = column_names
585        .iter()
586        .map(|field_name| {
587            let c = table_data
588                .column_by_name(field_name)
589                .with_context(|| format!("get data for column {}", field_name))?;
590            let c = Arc::clone(c);
591            Ok(c)
592        })
593        .collect::<Result<Vec<_>>>()?;
594
595    let rows = conv
596        .convert_columns(&columns)
597        .context("convert columns to row format")?;
598    let out = rows
599        .try_into_binary()
600        .context("convert row format to binary array")?;
601
602    Ok(out)
603}
604
605#[cfg(test)]
606mod tests {
607    use arrow::{
608        array::AsArray,
609        datatypes::{Field, Schema},
610    };
611
612    use super::*;
613
614    #[test]
615    fn basic_test_cherry_query() {
616        let team_a = RecordBatch::try_new(
617            Arc::new(Schema::new(vec![
618                Arc::new(Field::new("name", DataType::Utf8, true)),
619                Arc::new(Field::new("age", DataType::UInt64, true)),
620                Arc::new(Field::new("height", DataType::UInt64, true)),
621            ])),
622            vec![
623                Arc::new(StringArray::from_iter_values(
624                    vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
625                )),
626                Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
627                Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
628            ],
629        )
630        .unwrap();
631        let team_b = RecordBatch::try_new(
632            Arc::new(Schema::new(vec![
633                Arc::new(Field::new("name2", DataType::Utf8, true)),
634                Arc::new(Field::new("age2", DataType::UInt64, true)),
635                Arc::new(Field::new("height2", DataType::UInt64, true)),
636            ])),
637            vec![
638                Arc::new(StringArray::from_iter_values(vec![
639                    "yusuf", "abuzer", "asd",
640                ])),
641                Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
642                Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
643            ],
644        )
645        .unwrap();
646
647        let query = Query {
648            fields: [
649                ("team_a".to_owned(), vec!["name".to_owned()]),
650                ("team_b".to_owned(), vec!["name2".to_owned()]),
651            ]
652            .into_iter()
653            .collect(),
654            selection: Arc::new(
655                [(
656                    "team_a".to_owned(),
657                    vec![TableSelection {
658                        filters: [(
659                            "name".to_owned(),
660                            Filter::Contains(
661                                Contains::new(Arc::new(StringArray::from_iter_values(
662                                    vec!["kamil", "mahmut"].into_iter(),
663                                )))
664                                .unwrap(),
665                            ),
666                        )]
667                        .into_iter()
668                        .collect(),
669                        include: vec![
670                            Include {
671                                field_names: vec!["age".to_owned(), "height".to_owned()],
672                                other_table_field_names: vec![
673                                    "age2".to_owned(),
674                                    "height2".to_owned(),
675                                ],
676                                other_table_name: "team_b".to_owned(),
677                            },
678                            Include {
679                                field_names: vec!["height".to_owned()],
680                                other_table_field_names: vec!["height".to_owned()],
681                                other_table_name: "team_a".to_owned(),
682                            },
683                        ],
684                    }],
685                )]
686                .into_iter()
687                .collect(),
688            ),
689        };
690
691        let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
692            .into_iter()
693            .collect::<BTreeMap<_, _>>();
694
695        let res = run_query(&data, &query).unwrap();
696
697        let team_a = res.get("team_a").unwrap();
698        let team_b = res.get("team_b").unwrap();
699
700        assert_eq!(res.len(), 2);
701
702        let name = team_a.column_by_name("name").unwrap();
703        let name2 = team_b.column_by_name("name2").unwrap();
704
705        assert_eq!(team_a.num_columns(), 1);
706        assert_eq!(team_b.num_columns(), 1);
707
708        assert_eq!(
709            name.as_string(),
710            &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
711        );
712        assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
713    }
714}