1use anyhow::{anyhow, Context, Result};
2use arrow::array::{
3 Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
4 Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
5 UInt32Array, UInt64Array, UInt8Array,
6};
7use arrow::buffer::BooleanBuffer;
8use arrow::compute;
9use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
10use arrow::record_batch::RecordBatch;
11use arrow::row::{RowConverter, SortField};
12use hashbrown::HashTable;
13use rayon::prelude::*;
14use std::collections::btree_map::Entry;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use xxhash_rust::xxh3::xxh3_64;
18
19type TableName = String;
20type FieldName = String;
21
22#[derive(Clone)]
23pub struct Query {
24 pub selection: Arc<BTreeMap<TableName, Vec<TableSelection>>>,
25 pub fields: BTreeMap<TableName, Vec<FieldName>>,
26}
27
28impl Query {
29 pub fn add_request_and_include_fields(&mut self) -> Result<()> {
30 for (table_name, selections) in self.selection.iter() {
31 for selection in selections.iter() {
32 for col_name in selection.filters.keys() {
33 let table_fields = self
34 .fields
35 .get_mut(table_name)
36 .with_context(|| format!("get fields for table {}", table_name))?;
37 table_fields.push(col_name.to_owned());
38 }
39
40 for include in selection.include.iter() {
41 let other_table_fields = self
42 .fields
43 .get_mut(&include.other_table_name)
44 .with_context(|| {
45 format!("get fields for other table {}", include.other_table_name)
46 })?;
47 other_table_fields.extend_from_slice(&include.other_table_field_names);
48 let table_fields = self
49 .fields
50 .get_mut(table_name)
51 .with_context(|| format!("get fields for table {}", table_name))?;
52 table_fields.extend_from_slice(&include.field_names);
53 }
54 }
55 }
56
57 Ok(())
58 }
59}
60
61pub struct TableSelection {
62 pub filters: BTreeMap<FieldName, Filter>,
63 pub include: Vec<Include>,
64}
65
66pub struct Include {
67 pub other_table_name: TableName,
68 pub field_names: Vec<FieldName>,
69 pub other_table_field_names: Vec<FieldName>,
70}
71
72pub enum Filter {
73 Contains(Contains),
74 Bool(bool),
75}
76
77impl Filter {
78 pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
79 Ok(Self::Contains(Contains::new(arr)?))
80 }
81
82 pub fn bool(b: bool) -> Self {
83 Self::Bool(b)
84 }
85
86 fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
87 match self {
88 Self::Contains(ct) => ct.contains(arr),
89 Self::Bool(b) => {
90 let arr = arr
91 .as_any()
92 .downcast_ref::<BooleanArray>()
93 .context("cast array to boolean array")?;
94
95 let mut filter = if *b {
96 arr.clone()
97 } else {
98 compute::not(arr).context("negate array")?
99 };
100
101 if let Some(nulls) = filter.nulls() {
102 if nulls.null_count() > 0 {
103 let nulls = BooleanArray::from(nulls.inner().clone());
104 filter = compute::and(&filter, &nulls).unwrap();
105 }
106 }
107
108 Ok(filter)
109 }
110 }
111 }
112}
113
114pub struct Contains {
115 array: Arc<dyn Array>,
116 hash_table: Option<HashTable<usize>>,
117}
118
119impl Contains {
120 fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
121 assert!(!arr.is_nullable());
122
123 let mut ht = HashTable::with_capacity(arr.len());
124
125 for (i, v) in arr.values().iter().enumerate() {
126 ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
127 xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
128 });
129 }
130
131 ht
132 }
133
134 fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
135 arr: &GenericByteArray<T>,
136 ) -> HashTable<usize> {
137 assert!(!arr.is_nullable());
138
139 let mut ht = HashTable::with_capacity(arr.len());
140
141 for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
142 ht.insert_unique(xxh3_64(v), i, |i| {
143 xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
144 });
145 }
146
147 ht
148 }
149
150 fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
151 let ht = match *array.data_type() {
152 DataType::UInt8 => {
153 let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
154 Self::ht_from_primitive(array)
155 }
156 DataType::UInt16 => {
157 let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
158 Self::ht_from_primitive(array)
159 }
160 DataType::UInt32 => {
161 let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
162 Self::ht_from_primitive(array)
163 }
164 DataType::UInt64 => {
165 let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
166 Self::ht_from_primitive(array)
167 }
168 DataType::Int8 => {
169 let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
170 Self::ht_from_primitive(array)
171 }
172 DataType::Int16 => {
173 let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
174 Self::ht_from_primitive(array)
175 }
176 DataType::Int32 => {
177 let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
178 Self::ht_from_primitive(array)
179 }
180 DataType::Int64 => {
181 let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
182 Self::ht_from_primitive(array)
183 }
184 DataType::Binary => {
185 let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
186 Self::ht_from_bytes(array)
187 }
188 DataType::Utf8 => {
189 let array = array.as_any().downcast_ref::<StringArray>().unwrap();
190 Self::ht_from_bytes(array)
191 }
192 _ => {
193 return Err(anyhow!("unsupported data type: {}", array.data_type()));
194 }
195 };
196
197 Ok(ht)
198 }
199
200 pub fn new(array: Arc<dyn Array>) -> Result<Self> {
201 if array.is_nullable() {
202 return Err(anyhow!(
203 "cannot construct contains filter with a nullable array"
204 ));
205 }
206
207 let hash_table = if array.len() >= 128 {
209 Some(Self::ht_from_array(&array).context("construct hash table")?)
210 } else {
211 None
212 };
213
214 Ok(Self { hash_table, array })
215 }
216
217 fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
218 if arr.data_type() != self.array.data_type() {
219 return Err(anyhow!(
220 "filter array is of type {} but array to be filtered is of type {}",
221 self.array.data_type(),
222 arr.data_type(),
223 ));
224 }
225 assert!(!self.array.is_nullable());
226
227 let filter = match *arr.data_type() {
228 DataType::UInt8 => {
229 let self_arr = self.array.as_any().downcast_ref::<UInt8Array>().unwrap();
230 let other_arr = arr.as_any().downcast_ref().unwrap();
231 self.contains_primitive(self_arr, other_arr)
232 }
233 DataType::UInt16 => {
234 let self_arr = self.array.as_any().downcast_ref::<UInt16Array>().unwrap();
235 let other_arr = arr.as_any().downcast_ref().unwrap();
236 self.contains_primitive(self_arr, other_arr)
237 }
238 DataType::UInt32 => {
239 let self_arr = self.array.as_any().downcast_ref::<UInt32Array>().unwrap();
240 let other_arr = arr.as_any().downcast_ref().unwrap();
241 self.contains_primitive(self_arr, other_arr)
242 }
243 DataType::UInt64 => {
244 let self_arr = self.array.as_any().downcast_ref::<UInt64Array>().unwrap();
245 let other_arr = arr.as_any().downcast_ref().unwrap();
246 self.contains_primitive(self_arr, other_arr)
247 }
248 DataType::Int8 => {
249 let self_arr = self.array.as_any().downcast_ref::<Int8Array>().unwrap();
250 let other_arr = arr.as_any().downcast_ref().unwrap();
251 self.contains_primitive(self_arr, other_arr)
252 }
253 DataType::Int16 => {
254 let self_arr = self.array.as_any().downcast_ref::<Int16Array>().unwrap();
255 let other_arr = arr.as_any().downcast_ref().unwrap();
256 self.contains_primitive(self_arr, other_arr)
257 }
258 DataType::Int32 => {
259 let self_arr = self.array.as_any().downcast_ref::<Int32Array>().unwrap();
260 let other_arr = arr.as_any().downcast_ref().unwrap();
261 self.contains_primitive(self_arr, other_arr)
262 }
263 DataType::Int64 => {
264 let self_arr = self.array.as_any().downcast_ref::<Int64Array>().unwrap();
265 let other_arr = arr.as_any().downcast_ref().unwrap();
266 self.contains_primitive(self_arr, other_arr)
267 }
268 DataType::Binary => {
269 let self_arr = self.array.as_any().downcast_ref::<BinaryArray>().unwrap();
270 let other_arr = arr.as_any().downcast_ref().unwrap();
271 self.contains_bytes(self_arr, other_arr)
272 }
273 DataType::Utf8 => {
274 let self_arr = self.array.as_any().downcast_ref::<StringArray>().unwrap();
275 let other_arr = arr.as_any().downcast_ref().unwrap();
276 self.contains_bytes(self_arr, other_arr)
277 }
278 _ => {
279 return Err(anyhow!("unsupported data type: {}", arr.data_type()));
280 }
281 };
282
283 let mut filter = filter;
284
285 if let Some(nulls) = arr.nulls() {
286 if nulls.null_count() > 0 {
287 let nulls = BooleanArray::from(nulls.inner().clone());
288 filter = compute::and(&filter, &nulls).unwrap();
289 }
290 }
291
292 Ok(filter)
293 }
294
295 fn contains_primitive<T: ArrowPrimitiveType>(
296 &self,
297 self_arr: &PrimitiveArray<T>,
298 other_arr: &PrimitiveArray<T>,
299 ) -> BooleanArray {
300 let mut filter = BooleanBuilder::with_capacity(other_arr.len());
301
302 if let Some(ht) = self.hash_table.as_ref() {
303 let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
304
305 for v in other_arr.values().iter() {
306 let c = ht
307 .find(hash_one(v), |idx| unsafe {
308 self_arr.values().get_unchecked(*idx) == v
309 })
310 .is_some();
311 filter.append_value(c);
312 }
313 } else {
314 for v in other_arr.values().iter() {
315 filter.append_value(self_arr.values().iter().any(|x| x == v));
316 }
317 }
318
319 filter.finish()
320 }
321
322 fn contains_bytes<T: ByteArrayType<Offset = i32>>(
323 &self,
324 self_arr: &GenericByteArray<T>,
325 other_arr: &GenericByteArray<T>,
326 ) -> BooleanArray {
327 let mut filter = BooleanBuilder::with_capacity(other_arr.len());
328
329 if let Some(ht) = self.hash_table.as_ref() {
330 for v in iter_byte_array_without_validity(other_arr) {
331 let c = ht
332 .find(xxh3_64(v), |idx| unsafe {
333 byte_array_get_unchecked(self_arr, *idx) == v
334 })
335 .is_some();
336 filter.append_value(c);
337 }
338 } else {
339 for v in iter_byte_array_without_validity(other_arr) {
340 filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
341 }
342 }
343
344 filter.finish()
345 }
346}
347
348unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
351 arr: &GenericByteArray<T>,
352 i: usize,
353) -> &[u8] {
354 let end = *arr.value_offsets().get_unchecked(i + 1);
355 let start = *arr.value_offsets().get_unchecked(i);
356
357 std::slice::from_raw_parts(
358 arr.value_data()
359 .as_ptr()
360 .offset(isize::try_from(start).unwrap()),
361 usize::try_from(end - start).unwrap(),
362 )
363}
364
365fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
366 arr: &GenericByteArray<T>,
367) -> impl Iterator<Item = &[u8]> {
368 (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
369}
370
371pub fn run_query(
372 data: &BTreeMap<TableName, RecordBatch>,
373 query: &Query,
374) -> Result<BTreeMap<TableName, RecordBatch>> {
375 let filters = query
376 .selection
377 .par_iter()
378 .map(|(table_name, selections)| {
379 selections
380 .par_iter()
381 .enumerate()
382 .map(|(i, selection)| {
383 run_table_selection(data, table_name, selection).with_context(|| {
384 format!("run table selection no:{} for table {}", i, table_name)
385 })
386 })
387 .collect::<Result<Vec<_>>>()
388 })
389 .collect::<Result<Vec<_>>>()?;
390
391 let data = select_fields(data, &query.fields).context("select fields")?;
392
393 data.par_iter()
394 .filter_map(|(table_name, table_data)| {
395 let mut combined_filter: Option<BooleanArray> = None;
396
397 for f in filters.iter() {
398 for f in f.iter() {
399 let filter = match f.get(table_name) {
400 Some(f) => f,
401 None => continue,
402 };
403
404 match combined_filter.as_ref() {
405 Some(e) => {
406 let f = compute::or(e, filter)
407 .with_context(|| format!("combine filters for {}", table_name));
408 let f = match f {
409 Ok(v) => v,
410 Err(err) => return Some(Err(err)),
411 };
412 combined_filter = Some(f);
413 }
414 None => {
415 combined_filter = Some(filter.clone());
416 }
417 }
418 }
419 }
420
421 let combined_filter = match combined_filter {
422 Some(f) => f,
423 None => return None,
424 };
425
426 let table_data = compute::filter_record_batch(table_data, &combined_filter)
427 .context("filter record batch");
428 let table_data = match table_data {
429 Ok(v) => v,
430 Err(err) => return Some(Err(err)),
431 };
432
433 Some(Ok((table_name.to_owned(), table_data)))
434 })
435 .collect()
436}
437
438pub fn select_fields(
439 data: &BTreeMap<TableName, RecordBatch>,
440 fields: &BTreeMap<TableName, Vec<FieldName>>,
441) -> Result<BTreeMap<TableName, RecordBatch>> {
442 let mut out = BTreeMap::new();
443
444 for (table_name, field_names) in fields.iter() {
445 let table_data = data
446 .get(table_name)
447 .with_context(|| format!("get data for table {}", table_name))?;
448
449 let indices = field_names
450 .iter()
451 .map(|n| {
452 table_data
453 .schema_ref()
454 .index_of(n)
455 .with_context(|| format!("find index of field {} in table {}", n, table_name))
456 })
457 .collect::<Result<Vec<usize>>>()?;
458
459 let table_data = table_data
460 .project(&indices)
461 .with_context(|| format!("project table {}", table_name))?;
462 out.insert(table_name.to_owned(), table_data);
463 }
464
465 Ok(out)
466}
467
468fn run_table_selection(
469 data: &BTreeMap<TableName, RecordBatch>,
470 table_name: &str,
471 selection: &TableSelection,
472) -> Result<BTreeMap<TableName, BooleanArray>> {
473 let mut out = BTreeMap::new();
474
475 let table_data = data.get(table_name).context("get table data")?;
476 let mut combined_filter = None;
477 for (field_name, filter) in selection.filters.iter() {
478 let col = table_data
479 .column_by_name(field_name)
480 .with_context(|| format!("get field {}", field_name))?;
481
482 let f = filter
483 .check(&col)
484 .with_context(|| format!("check filter for column {}", field_name))?;
485
486 match combined_filter {
487 Some(cf) => {
488 combined_filter = Some(
489 compute::and(&cf, &f)
490 .with_context(|| format!("combine filter for column {}", field_name))?,
491 );
492 }
493 None => {
494 combined_filter = Some(f);
495 }
496 }
497 }
498
499 let combined_filter = match combined_filter {
500 Some(cf) => cf,
501 None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
502 };
503
504 out.insert(table_name.to_owned(), combined_filter.clone());
505
506 let mut filtered_cache = BTreeMap::new();
507
508 for (i, inc) in selection.include.iter().enumerate() {
509 if inc.other_table_field_names.len() != inc.field_names.len() {
510 return Err(anyhow!(
511 "field names are different for self table and other table while processing include no: {}. {} {}",
512 i,
513 inc.field_names.len(),
514 inc.other_table_field_names.len(),
515 ));
516 }
517
518 let other_table_data = data.get(&inc.other_table_name).with_context(|| {
519 format!(
520 "get data for table {} as other table data",
521 inc.other_table_name
522 )
523 })?;
524
525 let self_arr = columns_to_binary_array(table_data, &inc.field_names)
526 .context("get row format binary arr for self")?;
527
528 let contains = match filtered_cache.entry(inc.field_names.clone()) {
529 Entry::Vacant(entry) => {
530 let self_arr = compute::filter(&self_arr, &combined_filter)
531 .context("apply combined filter to self arr")?;
532 let contains =
533 Contains::new(Arc::new(self_arr)).context("create contains filter")?;
534 let contains = Arc::new(contains);
535 entry.insert(Arc::clone(&contains));
536 contains
537 }
538 Entry::Occupied(entry) => Arc::clone(entry.get()),
539 };
540
541 let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
542 .with_context(|| {
543 format!(
544 "get row format binary arr for other table {}",
545 inc.other_table_name
546 )
547 })?;
548
549 let f = contains
550 .contains(&other_arr)
551 .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
552
553 match out.entry(inc.other_table_name.clone()) {
554 Entry::Vacant(entry) => {
555 entry.insert(f);
556 }
557 Entry::Occupied(mut entry) => {
558 let new = compute::or(entry.get(), &f).with_context(|| {
559 format!("or include filters for table {}", inc.other_table_name)
560 })?;
561 entry.insert(new);
562 }
563 }
564 }
565
566 Ok(out)
567}
568
569fn columns_to_binary_array(
570 table_data: &RecordBatch,
571 column_names: &[String],
572) -> Result<BinaryArray> {
573 let fields = column_names
574 .iter()
575 .map(|field_name| {
576 let f = table_data
577 .schema_ref()
578 .field_with_name(field_name)
579 .with_context(|| format!("get field {} from schema", field_name))?;
580 Ok(SortField::new(f.data_type().clone()))
581 })
582 .collect::<Result<Vec<_>>>()?;
583 let conv = RowConverter::new(fields).context("create row converter")?;
584
585 let columns = column_names
586 .iter()
587 .map(|field_name| {
588 let c = table_data
589 .column_by_name(field_name)
590 .with_context(|| format!("get data for column {}", field_name))?;
591 let c = Arc::clone(c);
592 Ok(c)
593 })
594 .collect::<Result<Vec<_>>>()?;
595
596 let rows = conv
597 .convert_columns(&columns)
598 .context("convert columns to row format")?;
599 let out = rows
600 .try_into_binary()
601 .context("convert row format to binary array")?;
602
603 Ok(out)
604}
605
606#[cfg(test)]
607mod tests {
608 use arrow::{
609 array::AsArray,
610 datatypes::{Field, Schema},
611 };
612
613 use super::*;
614
615 #[test]
616 fn basic_test_cherry_query() {
617 let team_a = RecordBatch::try_new(
618 Arc::new(Schema::new(vec![
619 Arc::new(Field::new("name", DataType::Utf8, true)),
620 Arc::new(Field::new("age", DataType::UInt64, true)),
621 Arc::new(Field::new("height", DataType::UInt64, true)),
622 ])),
623 vec![
624 Arc::new(StringArray::from_iter_values(
625 vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
626 )),
627 Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
628 Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
629 ],
630 )
631 .unwrap();
632 let team_b = RecordBatch::try_new(
633 Arc::new(Schema::new(vec![
634 Arc::new(Field::new("name2", DataType::Utf8, true)),
635 Arc::new(Field::new("age2", DataType::UInt64, true)),
636 Arc::new(Field::new("height2", DataType::UInt64, true)),
637 ])),
638 vec![
639 Arc::new(StringArray::from_iter_values(vec![
640 "yusuf", "abuzer", "asd",
641 ])),
642 Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
643 Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
644 ],
645 )
646 .unwrap();
647
648 let query = Query {
649 fields: [
650 ("team_a".to_owned(), vec!["name".to_owned()]),
651 ("team_b".to_owned(), vec!["name2".to_owned()]),
652 ]
653 .into_iter()
654 .collect(),
655 selection: Arc::new(
656 [(
657 "team_a".to_owned(),
658 vec![TableSelection {
659 filters: [(
660 "name".to_owned(),
661 Filter::Contains(
662 Contains::new(Arc::new(StringArray::from_iter_values(
663 vec!["kamil", "mahmut"].into_iter(),
664 )))
665 .unwrap(),
666 ),
667 )]
668 .into_iter()
669 .collect(),
670 include: vec![
671 Include {
672 field_names: vec!["age".to_owned(), "height".to_owned()],
673 other_table_field_names: vec![
674 "age2".to_owned(),
675 "height2".to_owned(),
676 ],
677 other_table_name: "team_b".to_owned(),
678 },
679 Include {
680 field_names: vec!["height".to_owned()],
681 other_table_field_names: vec!["height".to_owned()],
682 other_table_name: "team_a".to_owned(),
683 },
684 ],
685 }],
686 )]
687 .into_iter()
688 .collect(),
689 ),
690 };
691
692 let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
693 .into_iter()
694 .collect::<BTreeMap<_, _>>();
695
696 let res = run_query(&data, &query).unwrap();
697
698 let team_a = res.get("team_a").unwrap();
699 let team_b = res.get("team_b").unwrap();
700
701 assert_eq!(res.len(), 2);
702
703 let name = team_a.column_by_name("name").unwrap();
704 let name2 = team_b.column_by_name("name2").unwrap();
705
706 assert_eq!(team_a.num_columns(), 1);
707 assert_eq!(team_b.num_columns(), 1);
708
709 assert_eq!(
710 name.as_string(),
711 &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
712 );
713 assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
714 }
715}