datafusion_physical_plan/aggregates/topk/
hash_table.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index
19
20use crate::aggregates::group_values::HashValue;
21use crate::aggregates::topk::heap::Comparable;
22use ahash::RandomState;
23use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
24use arrow::array::{
25    builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef,
26    ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray,
27};
28use arrow::datatypes::{i256, DataType};
29use datafusion_common::DataFusionError;
30use datafusion_common::Result;
31use half::f16;
32use hashbrown::raw::RawTable;
33use std::fmt::Debug;
34use std::sync::Arc;
35
36/// A "type alias" for Keys which are stored in our map
37pub trait KeyType: Clone + Comparable + Debug {}
38
39impl<T> KeyType for T where T: Clone + Comparable + Debug {}
40
41/// An entry in our hash table that:
42/// 1. memoizes the hash
43/// 2. contains the key (ID)
44/// 3. contains the value (heap_idx - an index into the corresponding heap)
45pub struct HashTableItem<ID: KeyType> {
46    hash: u64,
47    pub id: ID,
48    pub heap_idx: usize,
49}
50
51/// A custom wrapper around `hashbrown::RawTable` that:
52/// 1. limits the number of entries to the top K
53/// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing
54/// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash
55/// 4. Catches resize events to allow the corresponding heap to update it's indexes
56struct TopKHashTable<ID: KeyType> {
57    map: RawTable<HashTableItem<ID>>,
58    limit: usize,
59}
60
61/// An interface to hide the generic type signature of TopKHashTable behind arrow arrays
62pub trait ArrowHashTable {
63    fn set_batch(&mut self, ids: ArrayRef);
64    fn len(&self) -> usize;
65    // JUSTIFICATION
66    //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
67    //  Soundness: the caller must provide valid indexes
68    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]);
69    // JUSTIFICATION
70    //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
71    //  Soundness: the caller must provide a valid index
72    unsafe fn heap_idx_at(&self, map_idx: usize) -> usize;
73    unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef;
74
75    // JUSTIFICATION
76    //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
77    //  Soundness: the caller must provide valid indexes
78    unsafe fn find_or_insert(
79        &mut self,
80        row_idx: usize,
81        replace_idx: usize,
82        map: &mut Vec<(usize, usize)>,
83    ) -> (usize, bool);
84}
85
86// An implementation of ArrowHashTable for String keys
87pub struct StringHashTable {
88    owned: ArrayRef,
89    map: TopKHashTable<Option<String>>,
90    rnd: RandomState,
91    data_type: DataType,
92}
93
94// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key
95struct PrimitiveHashTable<VAL: ArrowPrimitiveType>
96where
97    Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
98{
99    owned: ArrayRef,
100    map: TopKHashTable<Option<VAL::Native>>,
101    rnd: RandomState,
102    kt: DataType,
103}
104
105impl StringHashTable {
106    pub fn new(limit: usize, data_type: DataType) -> Self {
107        let vals: Vec<&str> = Vec::new();
108        let owned: ArrayRef = match data_type {
109            DataType::Utf8 => Arc::new(StringArray::from(vals)),
110            DataType::Utf8View => Arc::new(StringViewArray::from(vals)),
111            DataType::LargeUtf8 => Arc::new(LargeStringArray::from(vals)),
112            _ => panic!("Unsupported data type"),
113        };
114
115        Self {
116            owned,
117            map: TopKHashTable::new(limit, limit * 10),
118            rnd: RandomState::default(),
119            data_type,
120        }
121    }
122}
123
124impl ArrowHashTable for StringHashTable {
125    fn set_batch(&mut self, ids: ArrayRef) {
126        self.owned = ids;
127    }
128
129    fn len(&self) -> usize {
130        self.map.len()
131    }
132
133    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
134        self.map.update_heap_idx(mapper);
135    }
136
137    unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
138        self.map.heap_idx_at(map_idx)
139    }
140
141    unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
142        let ids = self.map.take_all(indexes);
143        match self.data_type {
144            DataType::Utf8 => Arc::new(StringArray::from(ids)),
145            DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)),
146            DataType::Utf8View => Arc::new(StringViewArray::from(ids)),
147            _ => unreachable!(),
148        }
149    }
150
151    unsafe fn find_or_insert(
152        &mut self,
153        row_idx: usize,
154        replace_idx: usize,
155        mapper: &mut Vec<(usize, usize)>,
156    ) -> (usize, bool) {
157        let id = match self.data_type {
158            DataType::Utf8 => {
159                let ids = self
160                    .owned
161                    .as_any()
162                    .downcast_ref::<StringArray>()
163                    .expect("Expected StringArray for DataType::Utf8");
164                if ids.is_null(row_idx) {
165                    None
166                } else {
167                    Some(ids.value(row_idx))
168                }
169            }
170            DataType::LargeUtf8 => {
171                let ids = self
172                    .owned
173                    .as_any()
174                    .downcast_ref::<LargeStringArray>()
175                    .expect("Expected LargeStringArray for DataType::LargeUtf8");
176                if ids.is_null(row_idx) {
177                    None
178                } else {
179                    Some(ids.value(row_idx))
180                }
181            }
182            DataType::Utf8View => {
183                let ids = self
184                    .owned
185                    .as_any()
186                    .downcast_ref::<StringViewArray>()
187                    .expect("Expected StringViewArray for DataType::Utf8View");
188                if ids.is_null(row_idx) {
189                    None
190                } else {
191                    Some(ids.value(row_idx))
192                }
193            }
194            _ => panic!("Unsupported data type"),
195        };
196
197        let hash = self.rnd.hash_one(id);
198        if let Some(map_idx) = self
199            .map
200            .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str()))
201        {
202            return (map_idx, false);
203        }
204
205        // we're full and this is a better value, so remove the worst
206        let heap_idx = self.map.remove_if_full(replace_idx);
207
208        // add the new group
209        let id = id.map(|id| id.to_string());
210        let map_idx = self.map.insert(hash, id, heap_idx, mapper);
211        (map_idx, true)
212    }
213}
214
215impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL>
216where
217    Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
218    Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
219{
220    pub fn new(limit: usize, kt: DataType) -> Self {
221        let owned = Arc::new(
222            PrimitiveArray::<VAL>::builder(0)
223                .with_data_type(kt.clone())
224                .finish(),
225        );
226        Self {
227            owned,
228            map: TopKHashTable::new(limit, limit * 10),
229            rnd: RandomState::default(),
230            kt,
231        }
232    }
233}
234
235impl<VAL: ArrowPrimitiveType> ArrowHashTable for PrimitiveHashTable<VAL>
236where
237    Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
238    Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
239{
240    fn set_batch(&mut self, ids: ArrayRef) {
241        self.owned = ids;
242    }
243
244    fn len(&self) -> usize {
245        self.map.len()
246    }
247
248    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
249        self.map.update_heap_idx(mapper);
250    }
251
252    unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
253        self.map.heap_idx_at(map_idx)
254    }
255
256    unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
257        let ids = self.map.take_all(indexes);
258        let mut builder: PrimitiveBuilder<VAL> =
259            PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone());
260        for id in ids.into_iter() {
261            match id {
262                None => builder.append_null(),
263                Some(id) => builder.append_value(id),
264            }
265        }
266        let ids = builder.finish();
267        Arc::new(ids)
268    }
269
270    unsafe fn find_or_insert(
271        &mut self,
272        row_idx: usize,
273        replace_idx: usize,
274        mapper: &mut Vec<(usize, usize)>,
275    ) -> (usize, bool) {
276        let ids = self.owned.as_primitive::<VAL>();
277        let id: Option<VAL::Native> = if ids.is_null(row_idx) {
278            None
279        } else {
280            Some(ids.value(row_idx))
281        };
282
283        let hash: u64 = id.hash(&self.rnd);
284        if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) {
285            return (map_idx, false);
286        }
287
288        // we're full and this is a better value, so remove the worst
289        let heap_idx = self.map.remove_if_full(replace_idx);
290
291        // add the new group
292        let map_idx = self.map.insert(hash, id, heap_idx, mapper);
293        (map_idx, true)
294    }
295}
296
297impl<ID: KeyType> TopKHashTable<ID> {
298    pub fn new(limit: usize, capacity: usize) -> Self {
299        Self {
300            map: RawTable::with_capacity(capacity),
301            limit,
302        }
303    }
304
305    pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option<usize> {
306        let bucket = self.map.find(hash, |mi| eq(&mi.id))?;
307        // JUSTIFICATION
308        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
309        //  Soundness: getting the index of a bucket we just found
310        let idx = unsafe { self.map.bucket_index(&bucket) };
311        Some(idx)
312    }
313
314    pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
315        let bucket = unsafe { self.map.bucket(map_idx) };
316        bucket.as_ref().heap_idx
317    }
318
319    pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize {
320        if self.map.len() >= self.limit {
321            self.map.erase(self.map.bucket(replace_idx));
322            0 // if full, always replace top node
323        } else {
324            self.map.len() // if we're not full, always append to end
325        }
326    }
327
328    unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
329        for (m, h) in mapper {
330            self.map.bucket(*m).as_mut().heap_idx = *h
331        }
332    }
333
334    pub fn insert(
335        &mut self,
336        hash: u64,
337        id: ID,
338        heap_idx: usize,
339        mapper: &mut Vec<(usize, usize)>,
340    ) -> usize {
341        let mi = HashTableItem::new(hash, id, heap_idx);
342        let bucket = self.map.try_insert_no_grow(hash, mi);
343        let bucket = match bucket {
344            Ok(bucket) => bucket,
345            Err(new_item) => {
346                let bucket = self.map.insert(hash, new_item, |mi| mi.hash);
347                // JUSTIFICATION
348                //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
349                //  Soundness: we're getting indexes of buckets, not dereferencing them
350                unsafe {
351                    for bucket in self.map.iter() {
352                        let heap_idx = bucket.as_ref().heap_idx;
353                        let map_idx = self.map.bucket_index(&bucket);
354                        mapper.push((heap_idx, map_idx));
355                    }
356                }
357                bucket
358            }
359        };
360        // JUSTIFICATION
361        //  Benefit:  ~15% speedup + required to index into RawTable from binary heap
362        //  Soundness: we're getting indexes of buckets, not dereferencing them
363        unsafe { self.map.bucket_index(&bucket) }
364    }
365
366    pub fn len(&self) -> usize {
367        self.map.len()
368    }
369
370    pub unsafe fn take_all(&mut self, idxs: Vec<usize>) -> Vec<ID> {
371        let ids = idxs
372            .into_iter()
373            .map(|idx| self.map.bucket(idx).as_ref().id.clone())
374            .collect();
375        self.map.clear();
376        ids
377    }
378}
379
380impl<ID: KeyType> HashTableItem<ID> {
381    pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self {
382        Self { hash, id, heap_idx }
383    }
384}
385
386impl HashValue for Option<String> {
387    fn hash(&self, state: &RandomState) -> u64 {
388        state.hash_one(self)
389    }
390}
391
392macro_rules! hash_float {
393    ($($t:ty),+) => {
394        $(impl HashValue for Option<$t> {
395            fn hash(&self, state: &RandomState) -> u64 {
396                self.map(|me| me.hash(state)).unwrap_or(0)
397            }
398        })+
399    };
400}
401
402macro_rules! has_integer {
403    ($($t:ty),+) => {
404        $(impl HashValue for Option<$t> {
405            fn hash(&self, state: &RandomState) -> u64 {
406                self.map(|me| me.hash(state)).unwrap_or(0)
407            }
408        })+
409    };
410}
411
412has_integer!(i8, i16, i32, i64, i128, i256);
413has_integer!(u8, u16, u32, u64);
414has_integer!(IntervalDayTime, IntervalMonthDayNano);
415hash_float!(f16, f32, f64);
416
417pub fn new_hash_table(
418    limit: usize,
419    kt: DataType,
420) -> Result<Box<dyn ArrowHashTable + Send>> {
421    macro_rules! downcast_helper {
422        ($kt:ty, $d:ident) => {
423            return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit, kt)))
424        };
425    }
426
427    downcast_primitive! {
428        kt => (downcast_helper, kt),
429        DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))),
430        DataType::LargeUtf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::LargeUtf8))),
431        DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))),
432        _ => {}
433    }
434
435    Err(DataFusionError::Execution(format!(
436        "Can't create HashTable for type: {kt:?}"
437    )))
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use arrow::array::TimestampMillisecondArray;
444    use arrow_schema::TimeUnit;
445    use std::collections::BTreeMap;
446
447    #[test]
448    fn should_emit_correct_type() -> Result<()> {
449        let ids =
450            TimestampMillisecondArray::from(vec![1000]).with_timezone("UTC".to_string());
451        let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()));
452        let mut ht = new_hash_table(1, dt.clone())?;
453        ht.set_batch(Arc::new(ids));
454        let mut mapper = vec![];
455        let ids = unsafe {
456            ht.find_or_insert(0, 0, &mut mapper);
457            ht.take_all(vec![0])
458        };
459        assert_eq!(ids.data_type(), &dt);
460
461        Ok(())
462    }
463
464    #[test]
465    fn should_resize_properly() -> Result<()> {
466        let mut heap_to_map = BTreeMap::<usize, usize>::new();
467        let mut map = TopKHashTable::<Option<String>>::new(5, 3);
468        for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() {
469            let mut mapper = vec![];
470            let hash = heap_idx as u64;
471            let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper);
472            let _ = heap_to_map.insert(heap_idx, map_idx);
473            if heap_idx == 3 {
474                assert_eq!(
475                    mapper,
476                    vec![(0, 0), (1, 1), (2, 2), (3, 3)],
477                    "Pass {heap_idx} resized incorrectly!"
478                );
479                for (heap_idx, map_idx) in mapper {
480                    let _ = heap_to_map.insert(heap_idx, map_idx);
481                }
482            } else {
483                assert_eq!(mapper, vec![], "Pass {heap_idx} should not have resized!");
484            }
485        }
486
487        let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip();
488        let ids = unsafe { map.take_all(map_idxs) };
489        assert_eq!(
490            format!("{ids:?}"),
491            r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"#
492        );
493        assert_eq!(map.len(), 0, "Map should have been cleared!");
494
495        Ok(())
496    }
497}