cherry_query/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use arrow::array::{
3    Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
4    Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
5    UInt32Array, UInt64Array, UInt8Array,
6};
7use arrow::buffer::BooleanBuffer;
8use arrow::compute;
9use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
10use arrow::record_batch::RecordBatch;
11use arrow::row::{RowConverter, SortField};
12use hashbrown::HashTable;
13use rayon::prelude::*;
14use std::collections::btree_map::Entry;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use xxhash_rust::xxh3::xxh3_64;
18
19type TableName = String;
20type FieldName = String;
21
22#[derive(Clone)]
23pub struct Query {
24    pub selection: Arc<BTreeMap<TableName, Vec<TableSelection>>>,
25    pub fields: BTreeMap<TableName, Vec<FieldName>>,
26}
27
28impl Query {
29    pub fn add_request_and_include_fields(&mut self) -> Result<()> {
30        for (table_name, selections) in self.selection.iter() {
31            for selection in selections.iter() {
32                for col_name in selection.filters.keys() {
33                    let table_fields = self
34                        .fields
35                        .get_mut(table_name)
36                        .with_context(|| format!("get fields for table {}", table_name))?;
37                    table_fields.push(col_name.to_owned());
38                }
39
40                for include in selection.include.iter() {
41                    let other_table_fields = self
42                        .fields
43                        .get_mut(&include.other_table_name)
44                        .with_context(|| {
45                            format!("get fields for other table {}", include.other_table_name)
46                        })?;
47                    other_table_fields.extend_from_slice(&include.other_table_field_names);
48                    let table_fields = self
49                        .fields
50                        .get_mut(table_name)
51                        .with_context(|| format!("get fields for table {}", table_name))?;
52                    table_fields.extend_from_slice(&include.field_names);
53                }
54            }
55        }
56
57        Ok(())
58    }
59}
60
61pub struct TableSelection {
62    pub filters: BTreeMap<FieldName, Filter>,
63    pub include: Vec<Include>,
64}
65
66pub struct Include {
67    pub other_table_name: TableName,
68    pub field_names: Vec<FieldName>,
69    pub other_table_field_names: Vec<FieldName>,
70}
71
72pub enum Filter {
73    Contains(Contains),
74    Bool(bool),
75}
76
77impl Filter {
78    pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
79        Ok(Self::Contains(Contains::new(arr)?))
80    }
81
82    pub fn bool(b: bool) -> Self {
83        Self::Bool(b)
84    }
85
86    fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
87        match self {
88            Self::Contains(ct) => ct.contains(arr),
89            Self::Bool(b) => {
90                let arr = arr
91                    .as_any()
92                    .downcast_ref::<BooleanArray>()
93                    .context("cast array to boolean array")?;
94
95                let mut filter = if *b {
96                    arr.clone()
97                } else {
98                    compute::not(arr).context("negate array")?
99                };
100
101                if let Some(nulls) = filter.nulls() {
102                    if nulls.null_count() > 0 {
103                        let nulls = BooleanArray::from(nulls.inner().clone());
104                        filter = compute::and(&filter, &nulls).unwrap();
105                    }
106                }
107
108                Ok(filter)
109            }
110        }
111    }
112}
113
114pub struct Contains {
115    array: Arc<dyn Array>,
116    hash_table: Option<HashTable<usize>>,
117}
118
119impl Contains {
120    fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
121        assert!(!arr.is_nullable());
122
123        let mut ht = HashTable::with_capacity(arr.len());
124
125        for (i, v) in arr.values().iter().enumerate() {
126            ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
127                xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
128            });
129        }
130
131        ht
132    }
133
134    fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
135        arr: &GenericByteArray<T>,
136    ) -> HashTable<usize> {
137        assert!(!arr.is_nullable());
138
139        let mut ht = HashTable::with_capacity(arr.len());
140
141        for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
142            ht.insert_unique(xxh3_64(v), i, |i| {
143                xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
144            });
145        }
146
147        ht
148    }
149
150    fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
151        let ht = match *array.data_type() {
152            DataType::UInt8 => {
153                let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
154                Self::ht_from_primitive(array)
155            }
156            DataType::UInt16 => {
157                let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
158                Self::ht_from_primitive(array)
159            }
160            DataType::UInt32 => {
161                let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
162                Self::ht_from_primitive(array)
163            }
164            DataType::UInt64 => {
165                let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
166                Self::ht_from_primitive(array)
167            }
168            DataType::Int8 => {
169                let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
170                Self::ht_from_primitive(array)
171            }
172            DataType::Int16 => {
173                let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
174                Self::ht_from_primitive(array)
175            }
176            DataType::Int32 => {
177                let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
178                Self::ht_from_primitive(array)
179            }
180            DataType::Int64 => {
181                let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
182                Self::ht_from_primitive(array)
183            }
184            DataType::Binary => {
185                let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
186                Self::ht_from_bytes(array)
187            }
188            DataType::Utf8 => {
189                let array = array.as_any().downcast_ref::<StringArray>().unwrap();
190                Self::ht_from_bytes(array)
191            }
192            _ => {
193                return Err(anyhow!("unsupported data type: {}", array.data_type()));
194            }
195        };
196
197        Ok(ht)
198    }
199
200    pub fn new(array: Arc<dyn Array>) -> Result<Self> {
201        if array.is_nullable() {
202            return Err(anyhow!(
203                "cannot construct contains filter with a nullable array"
204            ));
205        }
206
207        // only use a hash table if there are more than 128 elements
208        let hash_table = if array.len() >= 128 {
209            Some(Self::ht_from_array(&array).context("construct hash table")?)
210        } else {
211            None
212        };
213
214        Ok(Self { hash_table, array })
215    }
216
217    fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
218        if arr.data_type() != self.array.data_type() {
219            return Err(anyhow!(
220                "filter array is of type {} but array to be filtered is of type {}",
221                self.array.data_type(),
222                arr.data_type(),
223            ));
224        }
225        assert!(!self.array.is_nullable());
226
227        let filter = match *arr.data_type() {
228            DataType::UInt8 => {
229                let self_arr = self.array.as_any().downcast_ref::<UInt8Array>().unwrap();
230                let other_arr = arr.as_any().downcast_ref().unwrap();
231                self.contains_primitive(self_arr, other_arr)
232            }
233            DataType::UInt16 => {
234                let self_arr = self.array.as_any().downcast_ref::<UInt16Array>().unwrap();
235                let other_arr = arr.as_any().downcast_ref().unwrap();
236                self.contains_primitive(self_arr, other_arr)
237            }
238            DataType::UInt32 => {
239                let self_arr = self.array.as_any().downcast_ref::<UInt32Array>().unwrap();
240                let other_arr = arr.as_any().downcast_ref().unwrap();
241                self.contains_primitive(self_arr, other_arr)
242            }
243            DataType::UInt64 => {
244                let self_arr = self.array.as_any().downcast_ref::<UInt64Array>().unwrap();
245                let other_arr = arr.as_any().downcast_ref().unwrap();
246                self.contains_primitive(self_arr, other_arr)
247            }
248            DataType::Int8 => {
249                let self_arr = self.array.as_any().downcast_ref::<Int8Array>().unwrap();
250                let other_arr = arr.as_any().downcast_ref().unwrap();
251                self.contains_primitive(self_arr, other_arr)
252            }
253            DataType::Int16 => {
254                let self_arr = self.array.as_any().downcast_ref::<Int16Array>().unwrap();
255                let other_arr = arr.as_any().downcast_ref().unwrap();
256                self.contains_primitive(self_arr, other_arr)
257            }
258            DataType::Int32 => {
259                let self_arr = self.array.as_any().downcast_ref::<Int32Array>().unwrap();
260                let other_arr = arr.as_any().downcast_ref().unwrap();
261                self.contains_primitive(self_arr, other_arr)
262            }
263            DataType::Int64 => {
264                let self_arr = self.array.as_any().downcast_ref::<Int64Array>().unwrap();
265                let other_arr = arr.as_any().downcast_ref().unwrap();
266                self.contains_primitive(self_arr, other_arr)
267            }
268            DataType::Binary => {
269                let self_arr = self.array.as_any().downcast_ref::<BinaryArray>().unwrap();
270                let other_arr = arr.as_any().downcast_ref().unwrap();
271                self.contains_bytes(self_arr, other_arr)
272            }
273            DataType::Utf8 => {
274                let self_arr = self.array.as_any().downcast_ref::<StringArray>().unwrap();
275                let other_arr = arr.as_any().downcast_ref().unwrap();
276                self.contains_bytes(self_arr, other_arr)
277            }
278            _ => {
279                return Err(anyhow!("unsupported data type: {}", arr.data_type()));
280            }
281        };
282
283        let mut filter = filter;
284
285        if let Some(nulls) = arr.nulls() {
286            if nulls.null_count() > 0 {
287                let nulls = BooleanArray::from(nulls.inner().clone());
288                filter = compute::and(&filter, &nulls).unwrap();
289            }
290        }
291
292        Ok(filter)
293    }
294
295    fn contains_primitive<T: ArrowPrimitiveType>(
296        &self,
297        self_arr: &PrimitiveArray<T>,
298        other_arr: &PrimitiveArray<T>,
299    ) -> BooleanArray {
300        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
301
302        if let Some(ht) = self.hash_table.as_ref() {
303            let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
304
305            for v in other_arr.values().iter() {
306                let c = ht
307                    .find(hash_one(v), |idx| unsafe {
308                        self_arr.values().get_unchecked(*idx) == v
309                    })
310                    .is_some();
311                filter.append_value(c);
312            }
313        } else {
314            for v in other_arr.values().iter() {
315                filter.append_value(self_arr.values().iter().any(|x| x == v));
316            }
317        }
318
319        filter.finish()
320    }
321
322    fn contains_bytes<T: ByteArrayType<Offset = i32>>(
323        &self,
324        self_arr: &GenericByteArray<T>,
325        other_arr: &GenericByteArray<T>,
326    ) -> BooleanArray {
327        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
328
329        if let Some(ht) = self.hash_table.as_ref() {
330            for v in iter_byte_array_without_validity(other_arr) {
331                let c = ht
332                    .find(xxh3_64(v), |idx| unsafe {
333                        byte_array_get_unchecked(self_arr, *idx) == v
334                    })
335                    .is_some();
336                filter.append_value(c);
337            }
338        } else {
339            for v in iter_byte_array_without_validity(other_arr) {
340                filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
341            }
342        }
343
344        filter.finish()
345    }
346}
347
348// Taken from arrow-rs
349// https://docs.rs/arrow-array/54.2.1/src/arrow_array/array/byte_array.rs.html#278
350unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
351    arr: &GenericByteArray<T>,
352    i: usize,
353) -> &[u8] {
354    let end = *arr.value_offsets().get_unchecked(i + 1);
355    let start = *arr.value_offsets().get_unchecked(i);
356
357    std::slice::from_raw_parts(
358        arr.value_data()
359            .as_ptr()
360            .offset(isize::try_from(start).unwrap()),
361        usize::try_from(end - start).unwrap(),
362    )
363}
364
365fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
366    arr: &GenericByteArray<T>,
367) -> impl Iterator<Item = &[u8]> {
368    (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
369}
370
371pub fn run_query(
372    data: &BTreeMap<TableName, RecordBatch>,
373    query: &Query,
374) -> Result<BTreeMap<TableName, RecordBatch>> {
375    let filters = query
376        .selection
377        .par_iter()
378        .map(|(table_name, selections)| {
379            selections
380                .par_iter()
381                .enumerate()
382                .map(|(i, selection)| {
383                    run_table_selection(data, table_name, selection).with_context(|| {
384                        format!("run table selection no:{} for table {}", i, table_name)
385                    })
386                })
387                .collect::<Result<Vec<_>>>()
388        })
389        .collect::<Result<Vec<_>>>()?;
390
391    let data = select_fields(data, &query.fields).context("select fields")?;
392
393    data.par_iter()
394        .filter_map(|(table_name, table_data)| {
395            let mut combined_filter: Option<BooleanArray> = None;
396
397            for f in filters.iter() {
398                for f in f.iter() {
399                    let filter = match f.get(table_name) {
400                        Some(f) => f,
401                        None => continue,
402                    };
403
404                    match combined_filter.as_ref() {
405                        Some(e) => {
406                            let f = compute::or(e, filter)
407                                .with_context(|| format!("combine filters for {}", table_name));
408                            let f = match f {
409                                Ok(v) => v,
410                                Err(err) => return Some(Err(err)),
411                            };
412                            combined_filter = Some(f);
413                        }
414                        None => {
415                            combined_filter = Some(filter.clone());
416                        }
417                    }
418                }
419            }
420
421            let combined_filter = match combined_filter {
422                Some(f) => f,
423                None => return None,
424            };
425
426            let table_data = compute::filter_record_batch(table_data, &combined_filter)
427                .context("filter record batch");
428            let table_data = match table_data {
429                Ok(v) => v,
430                Err(err) => return Some(Err(err)),
431            };
432
433            Some(Ok((table_name.to_owned(), table_data)))
434        })
435        .collect()
436}
437
438pub fn select_fields(
439    data: &BTreeMap<TableName, RecordBatch>,
440    fields: &BTreeMap<TableName, Vec<FieldName>>,
441) -> Result<BTreeMap<TableName, RecordBatch>> {
442    let mut out = BTreeMap::new();
443
444    for (table_name, field_names) in fields.iter() {
445        let table_data = data
446            .get(table_name)
447            .with_context(|| format!("get data for table {}", table_name))?;
448
449        let indices = field_names
450            .iter()
451            .map(|n| {
452                table_data
453                    .schema_ref()
454                    .index_of(n)
455                    .with_context(|| format!("find index of field {} in table {}", n, table_name))
456            })
457            .collect::<Result<Vec<usize>>>()?;
458
459        let table_data = table_data
460            .project(&indices)
461            .with_context(|| format!("project table {}", table_name))?;
462        out.insert(table_name.to_owned(), table_data);
463    }
464
465    Ok(out)
466}
467
468fn run_table_selection(
469    data: &BTreeMap<TableName, RecordBatch>,
470    table_name: &str,
471    selection: &TableSelection,
472) -> Result<BTreeMap<TableName, BooleanArray>> {
473    let mut out = BTreeMap::new();
474
475    let table_data = data.get(table_name).context("get table data")?;
476    let mut combined_filter = None;
477    for (field_name, filter) in selection.filters.iter() {
478        let col = table_data
479            .column_by_name(field_name)
480            .with_context(|| format!("get field {}", field_name))?;
481
482        let f = filter
483            .check(&col)
484            .with_context(|| format!("check filter for column {}", field_name))?;
485
486        match combined_filter {
487            Some(cf) => {
488                combined_filter = Some(
489                    compute::and(&cf, &f)
490                        .with_context(|| format!("combine filter for column {}", field_name))?,
491                );
492            }
493            None => {
494                combined_filter = Some(f);
495            }
496        }
497    }
498
499    let combined_filter = match combined_filter {
500        Some(cf) => cf,
501        None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
502    };
503
504    out.insert(table_name.to_owned(), combined_filter.clone());
505
506    let mut filtered_cache = BTreeMap::new();
507
508    for (i, inc) in selection.include.iter().enumerate() {
509        if inc.other_table_field_names.len() != inc.field_names.len() {
510            return Err(anyhow!(
511                "field names are different for self table and other table while processing include no: {}. {} {}",
512                i,
513                inc.field_names.len(),
514                inc.other_table_field_names.len(),
515            ));
516        }
517
518        let other_table_data = data.get(&inc.other_table_name).with_context(|| {
519            format!(
520                "get data for table {} as other table data",
521                inc.other_table_name
522            )
523        })?;
524
525        let self_arr = columns_to_binary_array(table_data, &inc.field_names)
526            .context("get row format binary arr for self")?;
527
528        let contains = match filtered_cache.entry(inc.field_names.clone()) {
529            Entry::Vacant(entry) => {
530                let self_arr = compute::filter(&self_arr, &combined_filter)
531                    .context("apply combined filter to self arr")?;
532                let contains =
533                    Contains::new(Arc::new(self_arr)).context("create contains filter")?;
534                let contains = Arc::new(contains);
535                entry.insert(Arc::clone(&contains));
536                contains
537            }
538            Entry::Occupied(entry) => Arc::clone(entry.get()),
539        };
540
541        let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
542            .with_context(|| {
543                format!(
544                    "get row format binary arr for other table {}",
545                    inc.other_table_name
546                )
547            })?;
548
549        let f = contains
550            .contains(&other_arr)
551            .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
552
553        match out.entry(inc.other_table_name.clone()) {
554            Entry::Vacant(entry) => {
555                entry.insert(f);
556            }
557            Entry::Occupied(mut entry) => {
558                let new = compute::or(entry.get(), &f).with_context(|| {
559                    format!("or include filters for table {}", inc.other_table_name)
560                })?;
561                entry.insert(new);
562            }
563        }
564    }
565
566    Ok(out)
567}
568
569fn columns_to_binary_array(
570    table_data: &RecordBatch,
571    column_names: &[String],
572) -> Result<BinaryArray> {
573    let fields = column_names
574        .iter()
575        .map(|field_name| {
576            let f = table_data
577                .schema_ref()
578                .field_with_name(field_name)
579                .with_context(|| format!("get field {} from schema", field_name))?;
580            Ok(SortField::new(f.data_type().clone()))
581        })
582        .collect::<Result<Vec<_>>>()?;
583    let conv = RowConverter::new(fields).context("create row converter")?;
584
585    let columns = column_names
586        .iter()
587        .map(|field_name| {
588            let c = table_data
589                .column_by_name(field_name)
590                .with_context(|| format!("get data for column {}", field_name))?;
591            let c = Arc::clone(c);
592            Ok(c)
593        })
594        .collect::<Result<Vec<_>>>()?;
595
596    let rows = conv
597        .convert_columns(&columns)
598        .context("convert columns to row format")?;
599    let out = rows
600        .try_into_binary()
601        .context("convert row format to binary array")?;
602
603    Ok(out)
604}
605
606#[cfg(test)]
607mod tests {
608    use arrow::{
609        array::AsArray,
610        datatypes::{Field, Schema},
611    };
612
613    use super::*;
614
615    #[test]
616    fn basic_test_cherry_query() {
617        let team_a = RecordBatch::try_new(
618            Arc::new(Schema::new(vec![
619                Arc::new(Field::new("name", DataType::Utf8, true)),
620                Arc::new(Field::new("age", DataType::UInt64, true)),
621                Arc::new(Field::new("height", DataType::UInt64, true)),
622            ])),
623            vec![
624                Arc::new(StringArray::from_iter_values(
625                    vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
626                )),
627                Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
628                Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
629            ],
630        )
631        .unwrap();
632        let team_b = RecordBatch::try_new(
633            Arc::new(Schema::new(vec![
634                Arc::new(Field::new("name2", DataType::Utf8, true)),
635                Arc::new(Field::new("age2", DataType::UInt64, true)),
636                Arc::new(Field::new("height2", DataType::UInt64, true)),
637            ])),
638            vec![
639                Arc::new(StringArray::from_iter_values(vec![
640                    "yusuf", "abuzer", "asd",
641                ])),
642                Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
643                Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
644            ],
645        )
646        .unwrap();
647
648        let query = Query {
649            fields: [
650                ("team_a".to_owned(), vec!["name".to_owned()]),
651                ("team_b".to_owned(), vec!["name2".to_owned()]),
652            ]
653            .into_iter()
654            .collect(),
655            selection: Arc::new(
656                [(
657                    "team_a".to_owned(),
658                    vec![TableSelection {
659                        filters: [(
660                            "name".to_owned(),
661                            Filter::Contains(
662                                Contains::new(Arc::new(StringArray::from_iter_values(
663                                    vec!["kamil", "mahmut"].into_iter(),
664                                )))
665                                .unwrap(),
666                            ),
667                        )]
668                        .into_iter()
669                        .collect(),
670                        include: vec![
671                            Include {
672                                field_names: vec!["age".to_owned(), "height".to_owned()],
673                                other_table_field_names: vec![
674                                    "age2".to_owned(),
675                                    "height2".to_owned(),
676                                ],
677                                other_table_name: "team_b".to_owned(),
678                            },
679                            Include {
680                                field_names: vec!["height".to_owned()],
681                                other_table_field_names: vec!["height".to_owned()],
682                                other_table_name: "team_a".to_owned(),
683                            },
684                        ],
685                    }],
686                )]
687                .into_iter()
688                .collect(),
689            ),
690        };
691
692        let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
693            .into_iter()
694            .collect::<BTreeMap<_, _>>();
695
696        let res = run_query(&data, &query).unwrap();
697
698        let team_a = res.get("team_a").unwrap();
699        let team_b = res.get("team_b").unwrap();
700
701        assert_eq!(res.len(), 2);
702
703        let name = team_a.column_by_name("name").unwrap();
704        let name2 = team_b.column_by_name("name2").unwrap();
705
706        assert_eq!(team_a.num_columns(), 1);
707        assert_eq!(team_b.num_columns(), 1);
708
709        assert_eq!(
710            name.as_string(),
711            &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
712        );
713        assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
714    }
715}