1use anyhow::{anyhow, Context, Result};
2use arrow::array::{
3 Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
4 Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
5 UInt32Array, UInt64Array, UInt8Array,
6};
7use arrow::buffer::BooleanBuffer;
8use arrow::compute;
9use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
10use arrow::record_batch::RecordBatch;
11use arrow::row::{RowConverter, SortField};
12use hashbrown::HashTable;
13use rayon::prelude::*;
14use std::collections::btree_map::Entry;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use xxhash_rust::xxh3::xxh3_64;
18
19type TableName = String;
20type FieldName = String;
21
22#[derive(Clone)]
23pub struct Query {
24 pub selection: Arc<BTreeMap<TableName, Vec<TableSelection>>>,
25 pub fields: BTreeMap<TableName, Vec<FieldName>>,
26}
27
28impl Query {
29 pub fn add_request_and_include_fields(&mut self) -> Result<()> {
30 for (table_name, selections) in self.selection.iter() {
31 for selection in selections.iter() {
32 for col_name in selection.filters.keys() {
33 let table_fields = self
34 .fields
35 .get_mut(table_name)
36 .with_context(|| format!("get fields for table {}", table_name))?;
37 table_fields.push(col_name.to_owned());
38 }
39
40 for include in selection.include.iter() {
41 let other_table_fields = self
42 .fields
43 .get_mut(&include.other_table_name)
44 .with_context(|| {
45 format!("get fields for other table {}", include.other_table_name)
46 })?;
47 other_table_fields.extend_from_slice(&include.other_table_field_names);
48 let table_fields = self
49 .fields
50 .get_mut(table_name)
51 .with_context(|| format!("get fields for table {}", table_name))?;
52 table_fields.extend_from_slice(&include.field_names);
53 }
54 }
55 }
56
57 Ok(())
58 }
59}
60
61pub struct TableSelection {
62 pub filters: BTreeMap<FieldName, Filter>,
63 pub include: Vec<Include>,
64}
65
66pub struct Include {
67 pub other_table_name: TableName,
68 pub field_names: Vec<FieldName>,
69 pub other_table_field_names: Vec<FieldName>,
70}
71
72pub enum Filter {
73 Contains(Contains),
74 Bool(bool),
75}
76
77impl Filter {
78 pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
79 Ok(Self::Contains(Contains::new(arr)?))
80 }
81
82 pub fn bool(b: bool) -> Self {
83 Self::Bool(b)
84 }
85
86 fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
87 match self {
88 Self::Contains(ct) => ct.contains(arr),
89 Self::Bool(b) => {
90 let arr = arr
91 .as_any()
92 .downcast_ref::<BooleanArray>()
93 .context("cast array to boolean array")?;
94
95 let mut filter = if *b {
96 arr.clone()
97 } else {
98 compute::not(arr).context("negate array")?
99 };
100
101 if let Some(nulls) = filter.nulls() {
102 if nulls.null_count() > 0 {
103 let nulls = BooleanArray::from(nulls.inner().clone());
104 filter = compute::and(&filter, &nulls).unwrap();
105 }
106 }
107
108 Ok(filter)
109 }
110 }
111 }
112}
113
114pub struct Contains {
115 array: Arc<dyn Array>,
116 hash_table: Option<HashTable<usize>>,
117}
118
119impl Contains {
120 fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
121 assert!(!arr.is_nullable());
122
123 let mut ht = HashTable::with_capacity(arr.len());
124
125 for (i, v) in arr.values().iter().enumerate() {
126 ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
127 xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
128 });
129 }
130
131 ht
132 }
133
134 fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
135 arr: &GenericByteArray<T>,
136 ) -> HashTable<usize> {
137 assert!(!arr.is_nullable());
138
139 let mut ht = HashTable::with_capacity(arr.len());
140
141 for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
142 ht.insert_unique(xxh3_64(v), i, |i| {
143 xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
144 });
145 }
146
147 ht
148 }
149
150 fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
151 let ht = match *array.data_type() {
152 DataType::UInt8 => {
153 let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
154 Self::ht_from_primitive(array)
155 }
156 DataType::UInt16 => {
157 let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
158 Self::ht_from_primitive(array)
159 }
160 DataType::UInt32 => {
161 let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
162 Self::ht_from_primitive(array)
163 }
164 DataType::UInt64 => {
165 let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
166 Self::ht_from_primitive(array)
167 }
168 DataType::Int8 => {
169 let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
170 Self::ht_from_primitive(array)
171 }
172 DataType::Int16 => {
173 let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
174 Self::ht_from_primitive(array)
175 }
176 DataType::Int32 => {
177 let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
178 Self::ht_from_primitive(array)
179 }
180 DataType::Int64 => {
181 let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
182 Self::ht_from_primitive(array)
183 }
184 DataType::Binary => {
185 let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
186 Self::ht_from_bytes(array)
187 }
188 DataType::Utf8 => {
189 let array = array.as_any().downcast_ref::<StringArray>().unwrap();
190 Self::ht_from_bytes(array)
191 }
192 _ => {
193 return Err(anyhow!("unsupported data type: {}", array.data_type()));
194 }
195 };
196
197 Ok(ht)
198 }
199
200 pub fn new(array: Arc<dyn Array>) -> Result<Self> {
201 if array.is_nullable() {
202 return Err(anyhow!(
203 "cannot construct contains filter with a nullable array"
204 ));
205 }
206
207 let hash_table = if array.len() >= 128 {
209 Some(Self::ht_from_array(&array).context("construct hash table")?)
210 } else {
211 None
212 };
213
214 Ok(Self { hash_table, array })
215 }
216
217 fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
218 if arr.data_type() != self.array.data_type() {
219 return Err(anyhow!(
220 "filter array is of type {} but array to be filtered is of type {}",
221 self.array.data_type(),
222 arr.data_type(),
223 ));
224 }
225 assert!(!self.array.is_nullable());
226
227 let filter = match *arr.data_type() {
228 DataType::UInt8 => {
229 let self_arr = self.array.as_any().downcast_ref::<UInt8Array>().unwrap();
230 let other_arr = arr.as_any().downcast_ref().unwrap();
231 self.contains_primitive(self_arr, other_arr)
232 }
233 DataType::UInt16 => {
234 let self_arr = self.array.as_any().downcast_ref::<UInt16Array>().unwrap();
235 let other_arr = arr.as_any().downcast_ref().unwrap();
236 self.contains_primitive(self_arr, other_arr)
237 }
238 DataType::UInt32 => {
239 let self_arr = self.array.as_any().downcast_ref::<UInt32Array>().unwrap();
240 let other_arr = arr.as_any().downcast_ref().unwrap();
241 self.contains_primitive(self_arr, other_arr)
242 }
243 DataType::UInt64 => {
244 let self_arr = self.array.as_any().downcast_ref::<UInt64Array>().unwrap();
245 let other_arr = arr.as_any().downcast_ref().unwrap();
246 self.contains_primitive(self_arr, other_arr)
247 }
248 DataType::Int8 => {
249 let self_arr = self.array.as_any().downcast_ref::<Int8Array>().unwrap();
250 let other_arr = arr.as_any().downcast_ref().unwrap();
251 self.contains_primitive(self_arr, other_arr)
252 }
253 DataType::Int16 => {
254 let self_arr = self.array.as_any().downcast_ref::<Int16Array>().unwrap();
255 let other_arr = arr.as_any().downcast_ref().unwrap();
256 self.contains_primitive(self_arr, other_arr)
257 }
258 DataType::Int32 => {
259 let self_arr = self.array.as_any().downcast_ref::<Int32Array>().unwrap();
260 let other_arr = arr.as_any().downcast_ref().unwrap();
261 self.contains_primitive(self_arr, other_arr)
262 }
263 DataType::Int64 => {
264 let self_arr = self.array.as_any().downcast_ref::<Int64Array>().unwrap();
265 let other_arr = arr.as_any().downcast_ref().unwrap();
266 self.contains_primitive(self_arr, other_arr)
267 }
268 DataType::Binary => {
269 let self_arr = self.array.as_any().downcast_ref::<BinaryArray>().unwrap();
270 let other_arr = arr.as_any().downcast_ref().unwrap();
271 self.contains_bytes(self_arr, other_arr)
272 }
273 DataType::Utf8 => {
274 let self_arr = self.array.as_any().downcast_ref::<StringArray>().unwrap();
275 let other_arr = arr.as_any().downcast_ref().unwrap();
276 self.contains_bytes(self_arr, other_arr)
277 }
278 _ => {
279 return Err(anyhow!("unsupported data type: {}", arr.data_type()));
280 }
281 };
282
283 let mut filter = filter;
284
285 if let Some(nulls) = arr.nulls() {
286 if nulls.null_count() > 0 {
287 let nulls = BooleanArray::from(nulls.inner().clone());
288 filter = compute::and(&filter, &nulls).unwrap();
289 }
290 }
291
292 Ok(filter)
293 }
294
295 fn contains_primitive<T: ArrowPrimitiveType>(
296 &self,
297 self_arr: &PrimitiveArray<T>,
298 other_arr: &PrimitiveArray<T>,
299 ) -> BooleanArray {
300 let mut filter = BooleanBuilder::with_capacity(other_arr.len());
301
302 if let Some(ht) = self.hash_table.as_ref() {
303 let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
304
305 for v in other_arr.values().iter() {
306 let c = ht
307 .find(hash_one(v), |idx| unsafe {
308 self_arr.values().get_unchecked(*idx) == v
309 })
310 .is_some();
311 filter.append_value(c);
312 }
313 } else {
314 for v in other_arr.values().iter() {
315 filter.append_value(self_arr.values().iter().any(|x| x == v));
316 }
317 }
318
319 filter.finish()
320 }
321
322 fn contains_bytes<T: ByteArrayType<Offset = i32>>(
323 &self,
324 self_arr: &GenericByteArray<T>,
325 other_arr: &GenericByteArray<T>,
326 ) -> BooleanArray {
327 let mut filter = BooleanBuilder::with_capacity(other_arr.len());
328
329 if let Some(ht) = self.hash_table.as_ref() {
330 for v in iter_byte_array_without_validity(other_arr) {
331 let c = ht
332 .find(xxh3_64(v), |idx| unsafe {
333 byte_array_get_unchecked(self_arr, *idx) == v
334 })
335 .is_some();
336 filter.append_value(c);
337 }
338 } else {
339 for v in iter_byte_array_without_validity(other_arr) {
340 filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
341 }
342 }
343
344 filter.finish()
345 }
346}
347
348unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
351 arr: &GenericByteArray<T>,
352 i: usize,
353) -> &[u8] {
354 let end = *arr.value_offsets().get_unchecked(i + 1);
355 let start = *arr.value_offsets().get_unchecked(i);
356
357 std::slice::from_raw_parts(
358 arr.value_data()
359 .as_ptr()
360 .offset(isize::try_from(start).unwrap()),
361 usize::try_from(end - start).unwrap(),
362 )
363}
364
365fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
366 arr: &GenericByteArray<T>,
367) -> impl Iterator<Item = &[u8]> {
368 (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
369}
370
371pub fn run_query(
372 data: &BTreeMap<TableName, RecordBatch>,
373 query: &Query,
374) -> Result<BTreeMap<TableName, RecordBatch>> {
375 let filters = query
376 .selection
377 .par_iter()
378 .map(|(table_name, selections)| {
379 selections
380 .par_iter()
381 .enumerate()
382 .map(|(i, selection)| {
383 run_table_selection(data, table_name, selection).with_context(|| {
384 format!("run table selection no:{} for table {}", i, table_name)
385 })
386 })
387 .collect::<Result<Vec<_>>>()
388 })
389 .collect::<Result<Vec<_>>>()?;
390
391 let data = select_fields(data, &query.fields).context("select fields")?;
392
393 data.par_iter()
394 .filter_map(|(table_name, table_data)| {
395 let mut combined_filter: Option<BooleanArray> = None;
396
397 for f in filters.iter() {
398 for f in f.iter() {
399 let filter = match f.get(table_name) {
400 Some(f) => f,
401 None => continue,
402 };
403
404 match combined_filter.as_ref() {
405 Some(e) => {
406 let f = compute::or(e, filter)
407 .with_context(|| format!("combine filters for {}", table_name));
408 let f = match f {
409 Ok(v) => v,
410 Err(err) => return Some(Err(err)),
411 };
412 combined_filter = Some(f);
413 }
414 None => {
415 combined_filter = Some(filter.clone());
416 }
417 }
418 }
419 }
420
421 let combined_filter = match combined_filter {
422 Some(f) => f,
423 None => return None,
424 };
425
426 let table_data = compute::filter_record_batch(table_data, &combined_filter)
427 .context("filter record batch");
428 let table_data = match table_data {
429 Ok(v) => v,
430 Err(err) => return Some(Err(err)),
431 };
432
433 Some(Ok((table_name.to_owned(), table_data)))
434 })
435 .collect()
436}
437
438pub fn select_fields(
439 data: &BTreeMap<TableName, RecordBatch>,
440 fields: &BTreeMap<TableName, Vec<FieldName>>,
441) -> Result<BTreeMap<TableName, RecordBatch>> {
442 let mut out = BTreeMap::new();
443
444 for (table_name, field_names) in fields.iter() {
445 let table_data = data
446 .get(table_name)
447 .with_context(|| format!("get data for table {}", table_name))?;
448
449 let indices = table_data
450 .schema_ref()
451 .fields()
452 .iter()
453 .enumerate()
454 .filter(|(_, field)| field_names.contains(field.name()))
455 .map(|(i, _)| i)
456 .collect::<Vec<usize>>();
457
458 let table_data = table_data
459 .project(&indices)
460 .with_context(|| format!("project table {}", table_name))?;
461 out.insert(table_name.to_owned(), table_data);
462 }
463
464 Ok(out)
465}
466
467fn run_table_selection(
468 data: &BTreeMap<TableName, RecordBatch>,
469 table_name: &str,
470 selection: &TableSelection,
471) -> Result<BTreeMap<TableName, BooleanArray>> {
472 let mut out = BTreeMap::new();
473
474 let table_data = data.get(table_name).context("get table data")?;
475 let mut combined_filter = None;
476 for (field_name, filter) in selection.filters.iter() {
477 let col = table_data
478 .column_by_name(field_name)
479 .with_context(|| format!("get field {}", field_name))?;
480
481 let f = filter
482 .check(&col)
483 .with_context(|| format!("check filter for column {}", field_name))?;
484
485 match combined_filter {
486 Some(cf) => {
487 combined_filter = Some(
488 compute::and(&cf, &f)
489 .with_context(|| format!("combine filter for column {}", field_name))?,
490 );
491 }
492 None => {
493 combined_filter = Some(f);
494 }
495 }
496 }
497
498 let combined_filter = match combined_filter {
499 Some(cf) => cf,
500 None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
501 };
502
503 out.insert(table_name.to_owned(), combined_filter.clone());
504
505 let mut filtered_cache = BTreeMap::new();
506
507 for (i, inc) in selection.include.iter().enumerate() {
508 if inc.other_table_field_names.len() != inc.field_names.len() {
509 return Err(anyhow!(
510 "field names are different for self table and other table while processing include no: {}. {} {}",
511 i,
512 inc.field_names.len(),
513 inc.other_table_field_names.len(),
514 ));
515 }
516
517 let other_table_data = data.get(&inc.other_table_name).with_context(|| {
518 format!(
519 "get data for table {} as other table data",
520 inc.other_table_name
521 )
522 })?;
523
524 let self_arr = columns_to_binary_array(table_data, &inc.field_names)
525 .context("get row format binary arr for self")?;
526
527 let contains = match filtered_cache.entry(inc.field_names.clone()) {
528 Entry::Vacant(entry) => {
529 let self_arr = compute::filter(&self_arr, &combined_filter)
530 .context("apply combined filter to self arr")?;
531 let contains =
532 Contains::new(Arc::new(self_arr)).context("create contains filter")?;
533 let contains = Arc::new(contains);
534 entry.insert(Arc::clone(&contains));
535 contains
536 }
537 Entry::Occupied(entry) => Arc::clone(entry.get()),
538 };
539
540 let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
541 .with_context(|| {
542 format!(
543 "get row format binary arr for other table {}",
544 inc.other_table_name
545 )
546 })?;
547
548 let f = contains
549 .contains(&other_arr)
550 .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
551
552 match out.entry(inc.other_table_name.clone()) {
553 Entry::Vacant(entry) => {
554 entry.insert(f);
555 }
556 Entry::Occupied(mut entry) => {
557 let new = compute::or(entry.get(), &f).with_context(|| {
558 format!("or include filters for table {}", inc.other_table_name)
559 })?;
560 entry.insert(new);
561 }
562 }
563 }
564
565 Ok(out)
566}
567
568fn columns_to_binary_array(
569 table_data: &RecordBatch,
570 column_names: &[String],
571) -> Result<BinaryArray> {
572 let fields = column_names
573 .iter()
574 .map(|field_name| {
575 let f = table_data
576 .schema_ref()
577 .field_with_name(field_name)
578 .with_context(|| format!("get field {} from schema", field_name))?;
579 Ok(SortField::new(f.data_type().clone()))
580 })
581 .collect::<Result<Vec<_>>>()?;
582 let conv = RowConverter::new(fields).context("create row converter")?;
583
584 let columns = column_names
585 .iter()
586 .map(|field_name| {
587 let c = table_data
588 .column_by_name(field_name)
589 .with_context(|| format!("get data for column {}", field_name))?;
590 let c = Arc::clone(c);
591 Ok(c)
592 })
593 .collect::<Result<Vec<_>>>()?;
594
595 let rows = conv
596 .convert_columns(&columns)
597 .context("convert columns to row format")?;
598 let out = rows
599 .try_into_binary()
600 .context("convert row format to binary array")?;
601
602 Ok(out)
603}
604
605#[cfg(test)]
606mod tests {
607 use arrow::{
608 array::AsArray,
609 datatypes::{Field, Schema},
610 };
611
612 use super::*;
613
614 #[test]
615 fn basic_test_cherry_query() {
616 let team_a = RecordBatch::try_new(
617 Arc::new(Schema::new(vec![
618 Arc::new(Field::new("name", DataType::Utf8, true)),
619 Arc::new(Field::new("age", DataType::UInt64, true)),
620 Arc::new(Field::new("height", DataType::UInt64, true)),
621 ])),
622 vec![
623 Arc::new(StringArray::from_iter_values(
624 vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
625 )),
626 Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
627 Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
628 ],
629 )
630 .unwrap();
631 let team_b = RecordBatch::try_new(
632 Arc::new(Schema::new(vec![
633 Arc::new(Field::new("name2", DataType::Utf8, true)),
634 Arc::new(Field::new("age2", DataType::UInt64, true)),
635 Arc::new(Field::new("height2", DataType::UInt64, true)),
636 ])),
637 vec![
638 Arc::new(StringArray::from_iter_values(vec![
639 "yusuf", "abuzer", "asd",
640 ])),
641 Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
642 Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
643 ],
644 )
645 .unwrap();
646
647 let query = Query {
648 fields: [
649 ("team_a".to_owned(), vec!["name".to_owned()]),
650 ("team_b".to_owned(), vec!["name2".to_owned()]),
651 ]
652 .into_iter()
653 .collect(),
654 selection: Arc::new(
655 [(
656 "team_a".to_owned(),
657 vec![TableSelection {
658 filters: [(
659 "name".to_owned(),
660 Filter::Contains(
661 Contains::new(Arc::new(StringArray::from_iter_values(
662 vec!["kamil", "mahmut"].into_iter(),
663 )))
664 .unwrap(),
665 ),
666 )]
667 .into_iter()
668 .collect(),
669 include: vec![
670 Include {
671 field_names: vec!["age".to_owned(), "height".to_owned()],
672 other_table_field_names: vec![
673 "age2".to_owned(),
674 "height2".to_owned(),
675 ],
676 other_table_name: "team_b".to_owned(),
677 },
678 Include {
679 field_names: vec!["height".to_owned()],
680 other_table_field_names: vec!["height".to_owned()],
681 other_table_name: "team_a".to_owned(),
682 },
683 ],
684 }],
685 )]
686 .into_iter()
687 .collect(),
688 ),
689 };
690
691 let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
692 .into_iter()
693 .collect::<BTreeMap<_, _>>();
694
695 let res = run_query(&data, &query).unwrap();
696
697 let team_a = res.get("team_a").unwrap();
698 let team_b = res.get("team_b").unwrap();
699
700 assert_eq!(res.len(), 2);
701
702 let name = team_a.column_by_name("name").unwrap();
703 let name2 = team_b.column_by_name("name2").unwrap();
704
705 assert_eq!(team_a.num_columns(), 1);
706 assert_eq!(team_b.num_columns(), 1);
707
708 assert_eq!(
709 name.as_string(),
710 &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
711 );
712 assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
713 }
714}