1use anyhow::{anyhow, Context, Result};
2use arrow::array::{
3 Array, ArrowPrimitiveType, BinaryArray, BooleanArray, BooleanBuilder, GenericByteArray,
4 Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, StringArray, UInt16Array,
5 UInt32Array, UInt64Array, UInt8Array,
6};
7use arrow::buffer::BooleanBuffer;
8use arrow::compute;
9use arrow::datatypes::{ByteArrayType, DataType, ToByteSlice};
10use arrow::record_batch::RecordBatch;
11use arrow::row::{RowConverter, SortField};
12use hashbrown::HashTable;
13use rayon::prelude::*;
14use std::collections::btree_map::Entry;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use xxhash_rust::xxh3::xxh3_64;
18
19type TableName = String;
20type FieldName = String;
21
22pub struct Query {
23 pub selection: BTreeMap<TableName, Vec<TableSelection>>,
24 pub fields: BTreeMap<TableName, Vec<FieldName>>,
25}
26
27pub struct TableSelection {
28 pub filters: BTreeMap<FieldName, Filter>,
29 pub include: Vec<Include>,
30}
31
32pub struct Include {
33 pub other_table_name: TableName,
34 pub field_names: Vec<FieldName>,
35 pub other_table_field_names: Vec<FieldName>,
36}
37
38pub enum Filter {
39 Contains(Contains),
40 Bool(bool),
41}
42
43impl Filter {
44 pub fn contains(arr: Arc<dyn Array>) -> Result<Self> {
45 Ok(Self::Contains(Contains::new(arr)?))
46 }
47
48 pub fn bool(b: bool) -> Self {
49 Self::Bool(b)
50 }
51
52 fn check(&self, arr: &dyn Array) -> Result<BooleanArray> {
53 match self {
54 Self::Contains(ct) => ct.contains(arr),
55 Self::Bool(b) => {
56 let arr = arr
57 .as_any()
58 .downcast_ref::<BooleanArray>()
59 .context("cast array to boolean array")?;
60
61 let mut filter = if *b {
62 arr.clone()
63 } else {
64 compute::not(arr).context("negate array")?
65 };
66
67 if let Some(nulls) = filter.nulls() {
68 if nulls.null_count() > 0 {
69 let nulls = BooleanArray::from(nulls.inner().clone());
70 filter = compute::and(&filter, &nulls).unwrap();
71 }
72 }
73
74 Ok(filter)
75 }
76 }
77 }
78}
79
80pub struct Contains {
81 array: Arc<dyn Array>,
82 hash_table: Option<HashTable<usize>>,
83}
84
85impl Contains {
86 fn ht_from_primitive<T: ArrowPrimitiveType>(arr: &PrimitiveArray<T>) -> HashTable<usize> {
87 assert!(!arr.is_nullable());
88
89 let mut ht = HashTable::with_capacity(arr.len());
90
91 for (i, v) in arr.values().iter().enumerate() {
92 ht.insert_unique(xxh3_64(v.to_byte_slice()), i, |i| {
93 xxh3_64(unsafe { arr.value_unchecked(*i).to_byte_slice() })
94 });
95 }
96
97 ht
98 }
99
100 fn ht_from_bytes<T: ByteArrayType<Offset = i32>>(
101 arr: &GenericByteArray<T>,
102 ) -> HashTable<usize> {
103 assert!(!arr.is_nullable());
104
105 let mut ht = HashTable::with_capacity(arr.len());
106
107 for (i, v) in iter_byte_array_without_validity(arr).enumerate() {
108 ht.insert_unique(xxh3_64(v), i, |i| {
109 xxh3_64(unsafe { byte_array_get_unchecked(arr, *i) })
110 });
111 }
112
113 ht
114 }
115
116 fn ht_from_array(array: &dyn Array) -> Result<HashTable<usize>> {
117 let ht = match *array.data_type() {
118 DataType::UInt8 => {
119 let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
120 Self::ht_from_primitive(array)
121 }
122 DataType::UInt16 => {
123 let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
124 Self::ht_from_primitive(array)
125 }
126 DataType::UInt32 => {
127 let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
128 Self::ht_from_primitive(array)
129 }
130 DataType::UInt64 => {
131 let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
132 Self::ht_from_primitive(array)
133 }
134 DataType::Int8 => {
135 let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
136 Self::ht_from_primitive(array)
137 }
138 DataType::Int16 => {
139 let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
140 Self::ht_from_primitive(array)
141 }
142 DataType::Int32 => {
143 let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
144 Self::ht_from_primitive(array)
145 }
146 DataType::Int64 => {
147 let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
148 Self::ht_from_primitive(array)
149 }
150 DataType::Binary => {
151 let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
152 Self::ht_from_bytes(array)
153 }
154 DataType::Utf8 => {
155 let array = array.as_any().downcast_ref::<StringArray>().unwrap();
156 Self::ht_from_bytes(array)
157 }
158 _ => {
159 return Err(anyhow!("unsupported data type: {}", array.data_type()));
160 }
161 };
162
163 Ok(ht)
164 }
165
166 pub fn new(array: Arc<dyn Array>) -> Result<Self> {
167 if array.is_nullable() {
168 return Err(anyhow!(
169 "cannot construct contains filter with a nullable array"
170 ));
171 }
172
173 let hash_table = if array.len() >= 128 {
175 Some(Self::ht_from_array(&array).context("construct hash table")?)
176 } else {
177 None
178 };
179
180 Ok(Self { hash_table, array })
181 }
182
183 fn contains(&self, arr: &dyn Array) -> Result<BooleanArray> {
184 if arr.data_type() != self.array.data_type() {
185 return Err(anyhow!(
186 "filter array is of type {} but array to be filtered is of type {}",
187 self.array.data_type(),
188 arr.data_type(),
189 ));
190 }
191 assert!(!self.array.is_nullable());
192
193 let filter = match *arr.data_type() {
194 DataType::UInt8 => {
195 let self_arr = self.array.as_any().downcast_ref::<UInt8Array>().unwrap();
196 let other_arr = arr.as_any().downcast_ref().unwrap();
197 self.contains_primitive(self_arr, other_arr)
198 }
199 DataType::UInt16 => {
200 let self_arr = self.array.as_any().downcast_ref::<UInt16Array>().unwrap();
201 let other_arr = arr.as_any().downcast_ref().unwrap();
202 self.contains_primitive(self_arr, other_arr)
203 }
204 DataType::UInt32 => {
205 let self_arr = self.array.as_any().downcast_ref::<UInt32Array>().unwrap();
206 let other_arr = arr.as_any().downcast_ref().unwrap();
207 self.contains_primitive(self_arr, other_arr)
208 }
209 DataType::UInt64 => {
210 let self_arr = self.array.as_any().downcast_ref::<UInt64Array>().unwrap();
211 let other_arr = arr.as_any().downcast_ref().unwrap();
212 self.contains_primitive(self_arr, other_arr)
213 }
214 DataType::Int8 => {
215 let self_arr = self.array.as_any().downcast_ref::<Int8Array>().unwrap();
216 let other_arr = arr.as_any().downcast_ref().unwrap();
217 self.contains_primitive(self_arr, other_arr)
218 }
219 DataType::Int16 => {
220 let self_arr = self.array.as_any().downcast_ref::<Int16Array>().unwrap();
221 let other_arr = arr.as_any().downcast_ref().unwrap();
222 self.contains_primitive(self_arr, other_arr)
223 }
224 DataType::Int32 => {
225 let self_arr = self.array.as_any().downcast_ref::<Int32Array>().unwrap();
226 let other_arr = arr.as_any().downcast_ref().unwrap();
227 self.contains_primitive(self_arr, other_arr)
228 }
229 DataType::Int64 => {
230 let self_arr = self.array.as_any().downcast_ref::<Int64Array>().unwrap();
231 let other_arr = arr.as_any().downcast_ref().unwrap();
232 self.contains_primitive(self_arr, other_arr)
233 }
234 DataType::Binary => {
235 let self_arr = self.array.as_any().downcast_ref::<BinaryArray>().unwrap();
236 let other_arr = arr.as_any().downcast_ref().unwrap();
237 self.contains_bytes(self_arr, other_arr)
238 }
239 DataType::Utf8 => {
240 let self_arr = self.array.as_any().downcast_ref::<StringArray>().unwrap();
241 let other_arr = arr.as_any().downcast_ref().unwrap();
242 self.contains_bytes(self_arr, other_arr)
243 }
244 _ => {
245 return Err(anyhow!("unsupported data type: {}", arr.data_type()));
246 }
247 };
248
249 let mut filter = filter;
250
251 if let Some(nulls) = arr.nulls() {
252 if nulls.null_count() > 0 {
253 let nulls = BooleanArray::from(nulls.inner().clone());
254 filter = compute::and(&filter, &nulls).unwrap();
255 }
256 }
257
258 Ok(filter)
259 }
260
261 fn contains_primitive<T: ArrowPrimitiveType>(
262 &self,
263 self_arr: &PrimitiveArray<T>,
264 other_arr: &PrimitiveArray<T>,
265 ) -> BooleanArray {
266 let mut filter = BooleanBuilder::with_capacity(other_arr.len());
267
268 if let Some(ht) = self.hash_table.as_ref() {
269 let hash_one = |v: &T::Native| -> u64 { xxh3_64(v.to_byte_slice()) };
270
271 for v in other_arr.values().iter() {
272 let c = ht
273 .find(hash_one(v), |idx| unsafe {
274 self_arr.values().get_unchecked(*idx) == v
275 })
276 .is_some();
277 filter.append_value(c);
278 }
279 } else {
280 for v in other_arr.values().iter() {
281 filter.append_value(self_arr.values().iter().any(|x| x == v));
282 }
283 }
284
285 filter.finish()
286 }
287
288 fn contains_bytes<T: ByteArrayType<Offset = i32>>(
289 &self,
290 self_arr: &GenericByteArray<T>,
291 other_arr: &GenericByteArray<T>,
292 ) -> BooleanArray {
293 let mut filter = BooleanBuilder::with_capacity(other_arr.len());
294
295 if let Some(ht) = self.hash_table.as_ref() {
296 for v in iter_byte_array_without_validity(other_arr) {
297 let c = ht
298 .find(xxh3_64(v), |idx| unsafe {
299 byte_array_get_unchecked(self_arr, *idx) == v
300 })
301 .is_some();
302 filter.append_value(c);
303 }
304 } else {
305 for v in iter_byte_array_without_validity(other_arr) {
306 filter.append_value(iter_byte_array_without_validity(self_arr).any(|x| x == v));
307 }
308 }
309
310 filter.finish()
311 }
312}
313
314unsafe fn byte_array_get_unchecked<T: ByteArrayType<Offset = i32>>(
317 arr: &GenericByteArray<T>,
318 i: usize,
319) -> &[u8] {
320 let end = *arr.value_offsets().get_unchecked(i + 1);
321 let start = *arr.value_offsets().get_unchecked(i);
322
323 std::slice::from_raw_parts(
324 arr.value_data()
325 .as_ptr()
326 .offset(isize::try_from(start).unwrap()),
327 usize::try_from(end - start).unwrap(),
328 )
329}
330
331fn iter_byte_array_without_validity<T: ByteArrayType<Offset = i32>>(
332 arr: &GenericByteArray<T>,
333) -> impl Iterator<Item = &[u8]> {
334 (0..arr.len()).map(|i| unsafe { byte_array_get_unchecked(arr, i) })
335}
336
337pub fn run_query(
338 data: &BTreeMap<TableName, RecordBatch>,
339 query: &Query,
340) -> Result<BTreeMap<TableName, RecordBatch>> {
341 let filters = query
342 .selection
343 .par_iter()
344 .map(|(table_name, selections)| {
345 selections
346 .par_iter()
347 .enumerate()
348 .map(|(i, selection)| {
349 run_table_selection(data, table_name, selection).with_context(|| {
350 format!("run table selection no:{} for table {}", i, table_name)
351 })
352 })
353 .collect::<Result<Vec<_>>>()
354 })
355 .collect::<Result<Vec<_>>>()?;
356
357 let data = select_fields(data, &query.fields).context("select fields")?;
358
359 data.par_iter()
360 .filter_map(|(table_name, table_data)| {
361 let mut combined_filter: Option<BooleanArray> = None;
362
363 for f in filters.iter() {
364 for f in f.iter() {
365 let filter = match f.get(table_name) {
366 Some(f) => f,
367 None => continue,
368 };
369
370 match combined_filter.as_ref() {
371 Some(e) => {
372 let f = compute::or(e, filter)
373 .with_context(|| format!("combine filters for {}", table_name));
374 let f = match f {
375 Ok(v) => v,
376 Err(err) => return Some(Err(err)),
377 };
378 combined_filter = Some(f);
379 }
380 None => {
381 combined_filter = Some(filter.clone());
382 }
383 }
384 }
385 }
386
387 let combined_filter = match combined_filter {
388 Some(f) => f,
389 None => return None,
390 };
391
392 let table_data = compute::filter_record_batch(table_data, &combined_filter)
393 .context("filter record batch");
394 let table_data = match table_data {
395 Ok(v) => v,
396 Err(err) => return Some(Err(err)),
397 };
398
399 Some(Ok((table_name.to_owned(), table_data)))
400 })
401 .collect()
402}
403
404fn select_fields(
405 data: &BTreeMap<TableName, RecordBatch>,
406 fields: &BTreeMap<TableName, Vec<FieldName>>,
407) -> Result<BTreeMap<TableName, RecordBatch>> {
408 let mut out = BTreeMap::new();
409
410 for (table_name, field_names) in fields.iter() {
411 let table_data = data
412 .get(table_name)
413 .with_context(|| format!("get data for table {}", table_name))?;
414
415 let indices = field_names
416 .iter()
417 .map(|n| {
418 table_data
419 .schema_ref()
420 .index_of(n)
421 .with_context(|| format!("find index of field {} in table {}", n, table_name))
422 })
423 .collect::<Result<Vec<usize>>>()?;
424
425 let table_data = table_data
426 .project(&indices)
427 .with_context(|| format!("project table {}", table_name))?;
428 out.insert(table_name.to_owned(), table_data);
429 }
430
431 Ok(out)
432}
433
434fn run_table_selection(
435 data: &BTreeMap<TableName, RecordBatch>,
436 table_name: &str,
437 selection: &TableSelection,
438) -> Result<BTreeMap<TableName, BooleanArray>> {
439 let mut out = BTreeMap::new();
440
441 let table_data = data.get(table_name).context("get table data")?;
442 let mut combined_filter = None;
443 for (field_name, filter) in selection.filters.iter() {
444 let col = table_data
445 .column_by_name(field_name)
446 .with_context(|| format!("get field {}", field_name))?;
447
448 let f = filter
449 .check(&col)
450 .with_context(|| format!("check filter for column {}", field_name))?;
451
452 match combined_filter {
453 Some(cf) => {
454 combined_filter = Some(
455 compute::and(&cf, &f)
456 .with_context(|| format!("combine filter for column {}", field_name))?,
457 );
458 }
459 None => {
460 combined_filter = Some(f);
461 }
462 }
463 }
464
465 let combined_filter = match combined_filter {
466 Some(cf) => cf,
467 None => BooleanArray::new(BooleanBuffer::new_set(table_data.num_rows()), None),
468 };
469
470 out.insert(table_name.to_owned(), combined_filter.clone());
471
472 let mut filtered_cache = BTreeMap::new();
473
474 for (i, inc) in selection.include.iter().enumerate() {
475 if inc.other_table_field_names.len() != inc.field_names.len() {
476 return Err(anyhow!(
477 "field names are different for self table and other table while processing include no: {}. {} {}",
478 i,
479 inc.field_names.len(),
480 inc.other_table_field_names.len(),
481 ));
482 }
483
484 let other_table_data = data.get(&inc.other_table_name).with_context(|| {
485 format!(
486 "get data for table {} as other table data",
487 inc.other_table_name
488 )
489 })?;
490
491 let self_arr = columns_to_binary_array(table_data, &inc.field_names)
492 .context("get row format binary arr for self")?;
493
494 let self_arr = match filtered_cache.entry(inc.field_names.clone()) {
495 Entry::Vacant(entry) => {
496 let self_arr = compute::filter(&self_arr, &combined_filter)
497 .context("apply combined filter to self arr")?;
498 entry.insert(self_arr.clone());
499 self_arr
500 }
501 Entry::Occupied(entry) => Arc::clone(entry.get()),
502 };
503
504 let other_arr = columns_to_binary_array(other_table_data, &inc.other_table_field_names)
505 .with_context(|| {
506 format!(
507 "get row format binary arr for other table {}",
508 inc.other_table_name
509 )
510 })?;
511
512 let contains = Contains::new(Arc::new(self_arr)).context("create contains filter")?;
513
514 let f = contains
515 .contains(&other_arr)
516 .with_context(|| format!("run contains for other table {}", inc.other_table_name))?;
517
518 match out.entry(inc.other_table_name.clone()) {
519 Entry::Vacant(entry) => {
520 entry.insert(f);
521 }
522 Entry::Occupied(mut entry) => {
523 let new = compute::or(entry.get(), &f).with_context(|| {
524 format!("or include filters for table {}", inc.other_table_name)
525 })?;
526 entry.insert(new);
527 }
528 }
529 }
530
531 Ok(out)
532}
533
534fn columns_to_binary_array(
535 table_data: &RecordBatch,
536 column_names: &[String],
537) -> Result<BinaryArray> {
538 let fields = column_names
539 .iter()
540 .map(|field_name| {
541 let f = table_data
542 .schema_ref()
543 .field_with_name(field_name)
544 .with_context(|| format!("get field {} from schema", field_name))?;
545 Ok(SortField::new(f.data_type().clone()))
546 })
547 .collect::<Result<Vec<_>>>()?;
548 let conv = RowConverter::new(fields).context("create row converter")?;
549
550 let columns = column_names
551 .iter()
552 .map(|field_name| {
553 let c = table_data
554 .column_by_name(field_name)
555 .with_context(|| format!("get data for column {}", field_name))?;
556 let c = Arc::clone(c);
557 Ok(c)
558 })
559 .collect::<Result<Vec<_>>>()?;
560
561 let rows = conv
562 .convert_columns(&columns)
563 .context("convert columns to row format")?;
564 let out = rows
565 .try_into_binary()
566 .context("convert row format to binary array")?;
567
568 Ok(out)
569}
570
571#[cfg(test)]
572mod tests {
573 use arrow::{
574 array::AsArray,
575 datatypes::{Field, Schema},
576 };
577
578 use super::*;
579
580 #[test]
581 fn basic_test_cherry_query() {
582 let team_a = RecordBatch::try_new(
583 Arc::new(Schema::new(vec![
584 Arc::new(Field::new("name", DataType::Utf8, true)),
585 Arc::new(Field::new("age", DataType::UInt64, true)),
586 Arc::new(Field::new("height", DataType::UInt64, true)),
587 ])),
588 vec![
589 Arc::new(StringArray::from_iter_values(
590 vec!["kamil", "mahmut", "qwe", "kazim"].into_iter(),
591 )),
592 Arc::new(UInt64Array::from_iter(vec![11, 12, 13, 31].into_iter())),
593 Arc::new(UInt64Array::from_iter(vec![50, 60, 70, 60].into_iter())),
594 ],
595 )
596 .unwrap();
597 let team_b = RecordBatch::try_new(
598 Arc::new(Schema::new(vec![
599 Arc::new(Field::new("name2", DataType::Utf8, true)),
600 Arc::new(Field::new("age2", DataType::UInt64, true)),
601 Arc::new(Field::new("height2", DataType::UInt64, true)),
602 ])),
603 vec![
604 Arc::new(StringArray::from_iter_values(vec![
605 "yusuf", "abuzer", "asd",
606 ])),
607 Arc::new(UInt64Array::from_iter(vec![11, 12, 13].into_iter())),
608 Arc::new(UInt64Array::from_iter(vec![50, 61, 70].into_iter())),
609 ],
610 )
611 .unwrap();
612
613 let query = Query {
614 fields: [
615 ("team_a".to_owned(), vec!["name".to_owned()]),
616 ("team_b".to_owned(), vec!["name2".to_owned()]),
617 ]
618 .into_iter()
619 .collect(),
620 selection: [(
621 "team_a".to_owned(),
622 vec![TableSelection {
623 filters: [(
624 "name".to_owned(),
625 Filter::Contains(
626 Contains::new(Arc::new(StringArray::from_iter_values(
627 vec!["kamil", "mahmut"].into_iter(),
628 )))
629 .unwrap(),
630 ),
631 )]
632 .into_iter()
633 .collect(),
634 include: vec![
635 Include {
636 field_names: vec!["age".to_owned(), "height".to_owned()],
637 other_table_field_names: vec!["age2".to_owned(), "height2".to_owned()],
638 other_table_name: "team_b".to_owned(),
639 },
640 Include {
641 field_names: vec!["height".to_owned()],
642 other_table_field_names: vec!["height".to_owned()],
643 other_table_name: "team_a".to_owned(),
644 },
645 ],
646 }],
647 )]
648 .into_iter()
649 .collect(),
650 };
651
652 let data = [("team_a".to_owned(), team_a), ("team_b".to_owned(), team_b)]
653 .into_iter()
654 .collect::<BTreeMap<_, _>>();
655
656 let res = run_query(&data, &query).unwrap();
657
658 let team_a = res.get("team_a").unwrap();
659 let team_b = res.get("team_b").unwrap();
660
661 assert_eq!(res.len(), 2);
662
663 let name = team_a.column_by_name("name").unwrap();
664 let name2 = team_b.column_by_name("name2").unwrap();
665
666 assert_eq!(team_a.num_columns(), 1);
667 assert_eq!(team_b.num_columns(), 1);
668
669 assert_eq!(
670 name.as_string(),
671 &StringArray::from_iter_values(["kamil", "mahmut", "kazim"])
672 );
673 assert_eq!(name2.as_string(), &StringArray::from_iter_values(["yusuf"]));
674 }
675}