cherry_query/
lib.rs

1use anyhow::{anyhow, Context, Result};
2use arrow::array::{
3    Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
4    Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
5    UInt32Array, UInt64Array, UInt8Array,
6};
7use arrow::buffer::BooleanBuffer;
8use arrow::compute;
9use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
10use arrow::record_batch::RecordBatch;
11use arrow::row::{RowConverter, SortField};
12use hashbrown::HashTable;
13use rayon::prelude::*;
14use std::collections::btree_map::Entry;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use xxhash_rust::xxh3::xxh3_64;
18
19type TableName = String;
20type FieldName = String;
21
22pub struct Query {
23    pub selection: BTreeMap<TableName, Vec<TableSelection>>,
24    pub fields: BTreeMap<TableName, Vec<FieldName>>,
25}
26
27pub struct TableSelection {
28    pub filters: BTreeMap<FieldName, Filter>,
29    pub include: Vec<Include>,
30}
31
32pub struct Include {
33    pub other_table_name: TableName,
34    pub field_names: Vec<FieldName>,
35    pub other_table_field_names: Vec<FieldName>,
36}
37
38pub enum Filter {
39    Contains(Contains),
40    Bool(bool),
41}
42
43impl Filter {
44    pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
45        Ok(Self::Contains(Contains::new(arr)?))
46    }
47
48    pub fn bool(b: bool) -> Self {
49        Self::Bool(b)
50    }
51
52    fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
53        match self {
54            Self::Contains(ct) => ct.contains(arr),
55            Self::Bool(b) => {
56                let arr = arr
57                    .as_any()
58                    .downcast_ref::<BooleanArray>()
59                    .context("cast array to boolean array")?;
60
61                let mut filter = if *b {
62                    arr.clone()
63                } else {
64                    compute::not(arr).context("negate array")?
65                };
66
67                if let Some(nulls) = filter.nulls() {
68                    if nulls.null_count() > 0 {
69                        let nulls = BooleanArray::from(nulls.inner().clone());
70                        filter = compute::and(&filter, &nulls).unwrap();
71                    }
72                }
73
74                Ok(filter)
75            }
76        }
77    }
78}
79
80pub struct Contains {
81    array: Arc<dyn Array>,
82    hash_table: Option<HashTable<usize>>,
83}
84
85impl Contains {
86    fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
87        assert!(!arr.is_nullable());
88
89        let mut ht = HashTable::with_capacity(arr.len());
90
91        for (i, v) in arr.values().iter().enumerate() {
92            ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
93                xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
94            });
95        }
96
97        ht
98    }
99
100    fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
101        arr: &GenericByteArray<T>,
102    ) -> HashTable<usize> {
103        assert!(!arr.is_nullable());
104
105        let mut ht = HashTable::with_capacity(arr.len());
106
107        for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
108            ht.insert_unique(xxh3_64(v), i, |i| {
109                xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
110            });
111        }
112
113        ht
114    }
115
116    fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
117        let ht = match *array.data_type() {
118            DataType::UInt8 => {
119                let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
120                Self::ht_from_primitive(array)
121            }
122            DataType::UInt16 => {
123                let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
124                Self::ht_from_primitive(array)
125            }
126            DataType::UInt32 => {
127                let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
128                Self::ht_from_primitive(array)
129            }
130            DataType::UInt64 => {
131                let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
132                Self::ht_from_primitive(array)
133            }
134            DataType::Int8 => {
135                let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
136                Self::ht_from_primitive(array)
137            }
138            DataType::Int16 => {
139                let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
140                Self::ht_from_primitive(array)
141            }
142            DataType::Int32 => {
143                let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
144                Self::ht_from_primitive(array)
145            }
146            DataType::Int64 => {
147                let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
148                Self::ht_from_primitive(array)
149            }
150            DataType::Binary => {
151                let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
152                Self::ht_from_bytes(array)
153            }
154            DataType::Utf8 => {
155                let array = array.as_any().downcast_ref::<StringArray>().unwrap();
156                Self::ht_from_bytes(array)
157            }
158            _ => {
159                return Err(anyhow!("unsupported data type: {}", array.data_type()));
160            }
161        };
162
163        Ok(ht)
164    }
165
166    pub fn new(array: Arc<dyn Array>) -> Result<Self> {
167        if array.is_nullable() {
168            return Err(anyhow!(
169                "cannot construct contains filter with a nullable array"
170            ));
171        }
172
173        // only use a hash table if there are more than 128 elements
174        let hash_table = if array.len() >= 128 {
175            Some(Self::ht_from_array(&array).context("construct hash table")?)
176        } else {
177            None
178        };
179
180        Ok(Self { hash_table, array })
181    }
182
183    fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
184        if arr.data_type() != self.array.data_type() {
185            return Err(anyhow!(
186                "filter array is of type {} but array to be filtered is of type {}",
187                self.array.data_type(),
188                arr.data_type(),
189            ));
190        }
191        assert!(!self.array.is_nullable());
192
193        let filter = match *arr.data_type() {
194            DataType::UInt8 => {
195                let self_arr = self.array.as_any().downcast_ref::<UInt8Array>().unwrap();
196                let other_arr = arr.as_any().downcast_ref().unwrap();
197                self.contains_primitive(self_arr, other_arr)
198            }
199            DataType::UInt16 => {
200                let self_arr = self.array.as_any().downcast_ref::<UInt16Array>().unwrap();
201                let other_arr = arr.as_any().downcast_ref().unwrap();
202                self.contains_primitive(self_arr, other_arr)
203            }
204            DataType::UInt32 => {
205                let self_arr = self.array.as_any().downcast_ref::<UInt32Array>().unwrap();
206                let other_arr = arr.as_any().downcast_ref().unwrap();
207                self.contains_primitive(self_arr, other_arr)
208            }
209            DataType::UInt64 => {
210                let self_arr = self.array.as_any().downcast_ref::<UInt64Array>().unwrap();
211                let other_arr = arr.as_any().downcast_ref().unwrap();
212                self.contains_primitive(self_arr, other_arr)
213            }
214            DataType::Int8 => {
215                let self_arr = self.array.as_any().downcast_ref::<Int8Array>().unwrap();
216                let other_arr = arr.as_any().downcast_ref().unwrap();
217                self.contains_primitive(self_arr, other_arr)
218            }
219            DataType::Int16 => {
220                let self_arr = self.array.as_any().downcast_ref::<Int16Array>().unwrap();
221                let other_arr = arr.as_any().downcast_ref().unwrap();
222                self.contains_primitive(self_arr, other_arr)
223            }
224            DataType::Int32 => {
225                let self_arr = self.array.as_any().downcast_ref::<Int32Array>().unwrap();
226                let other_arr = arr.as_any().downcast_ref().unwrap();
227                self.contains_primitive(self_arr, other_arr)
228            }
229            DataType::Int64 => {
230                let self_arr = self.array.as_any().downcast_ref::<Int64Array>().unwrap();
231                let other_arr = arr.as_any().downcast_ref().unwrap();
232                self.contains_primitive(self_arr, other_arr)
233            }
234            DataType::Binary => {
235                let self_arr = self.array.as_any().downcast_ref::<BinaryArray>().unwrap();
236                let other_arr = arr.as_any().downcast_ref().unwrap();
237                self.contains_bytes(self_arr, other_arr)
238            }
239            DataType::Utf8 => {
240                let self_arr = self.array.as_any().downcast_ref::<StringArray>().unwrap();
241                let other_arr = arr.as_any().downcast_ref().unwrap();
242                self.contains_bytes(self_arr, other_arr)
243            }
244            _ => {
245                return Err(anyhow!("unsupported data type: {}", arr.data_type()));
246            }
247        };
248
249        let mut filter = filter;
250
251        if let Some(nulls) = arr.nulls() {
252            if nulls.null_count() > 0 {
253                let nulls = BooleanArray::from(nulls.inner().clone());
254                filter = compute::and(&filter, &nulls).unwrap();
255            }
256        }
257
258        Ok(filter)
259    }
260
261    fn contains_primitive<T: ArrowPrimitiveType>(
262        &self,
263        self_arr: &PrimitiveArray<T>,
264        other_arr: &PrimitiveArray<T>,
265    ) -> BooleanArray {
266        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
267
268        if let Some(ht) = self.hash_table.as_ref() {
269            let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
270
271            for v in other_arr.values().iter() {
272                let c = ht
273                    .find(hash_one(v), |idx| unsafe {
274                        self_arr.values().get_unchecked(*idx) == v
275                    })
276                    .is_some();
277                filter.append_value(c);
278            }
279        } else {
280            for v in other_arr.values().iter() {
281                filter.append_value(self_arr.values().iter().any(|x| x == v));
282            }
283        }
284
285        filter.finish()
286    }
287
288    fn contains_bytes<T: ByteArrayType<Offset = i32>>(
289        &self,
290        self_arr: &GenericByteArray<T>,
291        other_arr: &GenericByteArray<T>,
292    ) -> BooleanArray {
293        let mut filter = BooleanBuilder::with_capacity(other_arr.len());
294
295        if let Some(ht) = self.hash_table.as_ref() {
296            for v in iter_byte_array_without_validity(other_arr) {
297                let c = ht
298                    .find(xxh3_64(v), |idx| unsafe {
299                        byte_array_get_unchecked(self_arr, *idx) == v
300                    })
301                    .is_some();
302                filter.append_value(c);
303            }
304        } else {
305            for v in iter_byte_array_without_validity(other_arr) {
306                filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
307            }
308        }
309
310        filter.finish()
311    }
312}
313
314// Taken from arrow-rs
315// https://docs.rs/arrow-array/54.2.1/src/arrow_array/array/byte_array.rs.html#278
316unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
317    arr: &GenericByteArray<T>,
318    i: usize,
319) -> &[u8] {
320    let end = *arr.value_offsets().get_unchecked(i + 1);
321    let start = *arr.value_offsets().get_unchecked(i);
322
323    std::slice::from_raw_parts(
324        arr.value_data()
325            .as_ptr()
326            .offset(isize::try_from(start).unwrap()),
327        usize::try_from(end - start).unwrap(),
328    )
329}
330
331fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
332    arr: &GenericByteArray<T>,
333) -> impl Iterator<Item = &[u8]> {
334    (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
335}
336
337pub fn run_query(
338    data: &BTreeMap<TableName, RecordBatch>,
339    query: &Query,
340) -> Result<BTreeMap<TableName, RecordBatch>> {
341    let filters = query
342        .selection
343        .par_iter()
344        .map(|(table_name, selections)| {
345            selections
346                .par_iter()
347                .enumerate()
348                .map(|(i, selection)| {
349                    run_table_selection(data, table_name, selection).with_context(|| {
350                        format!("run table selection no:{} for table {}", i, table_name)
351                    })
352                })
353                .collect::<Result<Vec<_>>>()
354        })
355        .collect::<Result<Vec<_>>>()?;
356
357    let data = select_fields(data, &query.fields).context("select fields")?;
358
359    data.par_iter()
360        .filter_map(|(table_name, table_data)| {
361            let mut combined_filter: Option<BooleanArray> = None;
362
363            for f in filters.iter() {
364                for f in f.iter() {
365                    let filter = match f.get(table_name) {
366                        Some(f) => f,
367                        None => continue,
368                    };
369
370                    match combined_filter.as_ref() {
371                        Some(e) => {
372                            let f = compute::or(e, filter)
373                                .with_context(|| format!("combine filters for {}", table_name));
374                            let f = match f {
375                                Ok(v) => v,
376                                Err(err) => return Some(Err(err)),
377                            };
378                            combined_filter = Some(f);
379                        }
380                        None => {
381                            combined_filter = Some(filter.clone());
382                        }
383                    }
384                }
385            }
386
387            let combined_filter = match combined_filter {
388                Some(f) => f,
389                None => return None,
390            };
391
392            let table_data = compute::filter_record_batch(table_data, &combined_filter)
393                .context("filter record batch");
394            let table_data = match table_data {
395                Ok(v) => v,
396                Err(err) => return Some(Err(err)),
397            };
398
399            Some(Ok((table_name.to_owned(), table_data)))
400        })
401        .collect()
402}
403
404fn select_fields(
405    data: &BTreeMap<TableName, RecordBatch>,
406    fields: &BTreeMap<TableName, Vec<FieldName>>,
407) -> Result<BTreeMap<TableName, RecordBatch>> {
408    let mut out = BTreeMap::new();
409
410    for (table_name, field_names) in fields.iter() {
411        let table_data = data
412            .get(table_name)
413            .with_context(|| format!("get data for table {}", table_name))?;
414
415        let indices = field_names
416            .iter()
417            .map(|n| {
418                table_data
419                    .schema_ref()
420                    .index_of(n)
421                    .with_context(|| format!("find index of field {} in table {}", n, table_name))
422            })
423            .collect::<Result<Vec<usize>>>()?;
424
425        let table_data = table_data
426            .project(&indices)
427            .with_context(|| format!("project table {}", table_name))?;
428        out.insert(table_name.to_owned(), table_data);
429    }
430
431    Ok(out)
432}
433
434fn run_table_selection(
435    data: &BTreeMap<TableName, RecordBatch>,
436    table_name: &str,
437    selection: &TableSelection,
438) -> Result<BTreeMap<TableName, BooleanArray>> {
439    let mut out = BTreeMap::new();
440
441    let table_data = data.get(table_name).context("get table data")?;
442    let mut combined_filter = None;
443    for (field_name, filter) in selection.filters.iter() {
444        let col = table_data
445            .column_by_name(field_name)
446            .with_context(|| format!("get field {}", field_name))?;
447
448        let f = filter
449            .check(&col)
450            .with_context(|| format!("check filter for column {}", field_name))?;
451
452        match combined_filter {
453            Some(cf) => {
454                combined_filter = Some(
455                    compute::and(&cf, &f)
456                        .with_context(|| format!("combine filter for column {}", field_name))?,
457                );
458            }
459            None => {
460                combined_filter = Some(f);
461            }
462        }
463    }
464
465    let combined_filter = match combined_filter {
466        Some(cf) => cf,
467        None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
468    };
469
470    out.insert(table_name.to_owned(), combined_filter.clone());
471
472    let mut filtered_cache = BTreeMap::new();
473
474    for (i, inc) in selection.include.iter().enumerate() {
475        if inc.other_table_field_names.len() != inc.field_names.len() {
476            return Err(anyhow!(
477                "field names are different for self table and other table while processing include no: {}. {} {}",
478                i,
479                inc.field_names.len(),
480                inc.other_table_field_names.len(),
481            ));
482        }
483
484        let other_table_data = data.get(&inc.other_table_name).with_context(|| {
485            format!(
486                "get data for table {} as other table data",
487                inc.other_table_name
488            )
489        })?;
490
491        let self_arr = columns_to_binary_array(table_data, &inc.field_names)
492            .context("get row format binary arr for self")?;
493
494        let self_arr = match filtered_cache.entry(inc.field_names.clone()) {
495            Entry::Vacant(entry) => {
496                let self_arr = compute::filter(&self_arr, &combined_filter)
497                    .context("apply combined filter to self arr")?;
498                entry.insert(self_arr.clone());
499                self_arr
500            }
501            Entry::Occupied(entry) => Arc::clone(entry.get()),
502        };
503
504        let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
505            .with_context(|| {
506                format!(
507                    "get row format binary arr for other table {}",
508                    inc.other_table_name
509                )
510            })?;
511
512        let contains = Contains::new(Arc::new(self_arr)).context("create contains filter")?;
513
514        let f = contains
515            .contains(&other_arr)
516            .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
517
518        match out.entry(inc.other_table_name.clone()) {
519            Entry::Vacant(entry) => {
520                entry.insert(f);
521            }
522            Entry::Occupied(mut entry) => {
523                let new = compute::or(entry.get(), &f).with_context(|| {
524                    format!("or include filters for table {}", inc.other_table_name)
525                })?;
526                entry.insert(new);
527            }
528        }
529    }
530
531    Ok(out)
532}
533
534fn columns_to_binary_array(
535    table_data: &RecordBatch,
536    column_names: &[String],
537) -> Result<BinaryArray> {
538    let fields = column_names
539        .iter()
540        .map(|field_name| {
541            let f = table_data
542                .schema_ref()
543                .field_with_name(field_name)
544                .with_context(|| format!("get field {} from schema", field_name))?;
545            Ok(SortField::new(f.data_type().clone()))
546        })
547        .collect::<Result<Vec<_>>>()?;
548    let conv = RowConverter::new(fields).context("create row converter")?;
549
550    let columns = column_names
551        .iter()
552        .map(|field_name| {
553            let c = table_data
554                .column_by_name(field_name)
555                .with_context(|| format!("get data for column {}", field_name))?;
556            let c = Arc::clone(c);
557            Ok(c)
558        })
559        .collect::<Result<Vec<_>>>()?;
560
561    let rows = conv
562        .convert_columns(&columns)
563        .context("convert columns to row format")?;
564    let out = rows
565        .try_into_binary()
566        .context("convert row format to binary array")?;
567
568    Ok(out)
569}
570
571#[cfg(test)]
572mod tests {
573    use arrow::{
574        array::AsArray,
575        datatypes::{Field, Schema},
576    };
577
578    use super::*;
579
580    #[test]
581    fn basic_test_cherry_query() {
582        let team_a = RecordBatch::try_new(
583            Arc::new(Schema::new(vec![
584                Arc::new(Field::new("name", DataType::Utf8, true)),
585                Arc::new(Field::new("age", DataType::UInt64, true)),
586                Arc::new(Field::new("height", DataType::UInt64, true)),
587            ])),
588            vec![
589                Arc::new(StringArray::from_iter_values(
590                    vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
591                )),
592                Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
593                Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
594            ],
595        )
596        .unwrap();
597        let team_b = RecordBatch::try_new(
598            Arc::new(Schema::new(vec![
599                Arc::new(Field::new("name2", DataType::Utf8, true)),
600                Arc::new(Field::new("age2", DataType::UInt64, true)),
601                Arc::new(Field::new("height2", DataType::UInt64, true)),
602            ])),
603            vec![
604                Arc::new(StringArray::from_iter_values(vec![
605                    "yusuf", "abuzer", "asd",
606                ])),
607                Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
608                Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
609            ],
610        )
611        .unwrap();
612
613        let query = Query {
614            fields: [
615                ("team_a".to_owned(), vec!["name".to_owned()]),
616                ("team_b".to_owned(), vec!["name2".to_owned()]),
617            ]
618            .into_iter()
619            .collect(),
620            selection: [(
621                "team_a".to_owned(),
622                vec![TableSelection {
623                    filters: [(
624                        "name".to_owned(),
625                        Filter::Contains(
626                            Contains::new(Arc::new(StringArray::from_iter_values(
627                                vec!["kamil", "mahmut"].into_iter(),
628                            )))
629                            .unwrap(),
630                        ),
631                    )]
632                    .into_iter()
633                    .collect(),
634                    include: vec![
635                        Include {
636                            field_names: vec!["age".to_owned(), "height".to_owned()],
637                            other_table_field_names: vec!["age2".to_owned(), "height2".to_owned()],
638                            other_table_name: "team_b".to_owned(),
639                        },
640                        Include {
641                            field_names: vec!["height".to_owned()],
642                            other_table_field_names: vec!["height".to_owned()],
643                            other_table_name: "team_a".to_owned(),
644                        },
645                    ],
646                }],
647            )]
648            .into_iter()
649            .collect(),
650        };
651
652        let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
653            .into_iter()
654            .collect::<BTreeMap<_, _>>();
655
656        let res = run_query(&data, &query).unwrap();
657
658        let team_a = res.get("team_a").unwrap();
659        let team_b = res.get("team_b").unwrap();
660
661        assert_eq!(res.len(), 2);
662
663        let name = team_a.column_by_name("name").unwrap();
664        let name2 = team_b.column_by_name("name2").unwrap();
665
666        assert_eq!(team_a.num_columns(), 1);
667        assert_eq!(team_b.num_columns(), 1);
668
669        assert_eq!(
670            name.as_string(),
671            &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
672        );
673        assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
674    }
675}