1use crate::aggregates::group_values::HashValue;
21use crate::aggregates::topk::heap::Comparable;
22use ahash::RandomState;
23use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
24use arrow::array::{
25 builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef,
26 ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray,
27};
28use arrow::datatypes::{i256, DataType};
29use datafusion_common::DataFusionError;
30use datafusion_common::Result;
31use half::f16;
32use hashbrown::raw::RawTable;
33use std::fmt::Debug;
34use std::sync::Arc;
35
36pub trait KeyType: Clone + Comparable + Debug {}
38
39impl<T> KeyType for T where T: Clone + Comparable + Debug {}
40
41pub struct HashTableItem<ID: KeyType> {
46 hash: u64,
47 pub id: ID,
48 pub heap_idx: usize,
49}
50
51struct TopKHashTable<ID: KeyType> {
57 map: RawTable<HashTableItem<ID>>,
58 limit: usize,
59}
60
61pub trait ArrowHashTable {
63 fn set_batch(&mut self, ids: ArrayRef);
64 fn len(&self) -> usize;
65 unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]);
69 unsafe fn heap_idx_at(&self, map_idx: usize) -> usize;
73 unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef;
74
75 unsafe fn find_or_insert(
79 &mut self,
80 row_idx: usize,
81 replace_idx: usize,
82 map: &mut Vec<(usize, usize)>,
83 ) -> (usize, bool);
84}
85
86pub struct StringHashTable {
88 owned: ArrayRef,
89 map: TopKHashTable<Option<String>>,
90 rnd: RandomState,
91 data_type: DataType,
92}
93
94struct PrimitiveHashTable<VAL: ArrowPrimitiveType>
96where
97 Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
98{
99 owned: ArrayRef,
100 map: TopKHashTable<Option<VAL::Native>>,
101 rnd: RandomState,
102 kt: DataType,
103}
104
105impl StringHashTable {
106 pub fn new(limit: usize, data_type: DataType) -> Self {
107 let vals: Vec<&str> = Vec::new();
108 let owned: ArrayRef = match data_type {
109 DataType::Utf8 => Arc::new(StringArray::from(vals)),
110 DataType::Utf8View => Arc::new(StringViewArray::from(vals)),
111 DataType::LargeUtf8 => Arc::new(LargeStringArray::from(vals)),
112 _ => panic!("Unsupported data type"),
113 };
114
115 Self {
116 owned,
117 map: TopKHashTable::new(limit, limit * 10),
118 rnd: RandomState::default(),
119 data_type,
120 }
121 }
122}
123
124impl ArrowHashTable for StringHashTable {
125 fn set_batch(&mut self, ids: ArrayRef) {
126 self.owned = ids;
127 }
128
129 fn len(&self) -> usize {
130 self.map.len()
131 }
132
133 unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
134 self.map.update_heap_idx(mapper);
135 }
136
137 unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
138 self.map.heap_idx_at(map_idx)
139 }
140
141 unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
142 let ids = self.map.take_all(indexes);
143 match self.data_type {
144 DataType::Utf8 => Arc::new(StringArray::from(ids)),
145 DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)),
146 DataType::Utf8View => Arc::new(StringViewArray::from(ids)),
147 _ => unreachable!(),
148 }
149 }
150
151 unsafe fn find_or_insert(
152 &mut self,
153 row_idx: usize,
154 replace_idx: usize,
155 mapper: &mut Vec<(usize, usize)>,
156 ) -> (usize, bool) {
157 let id = match self.data_type {
158 DataType::Utf8 => {
159 let ids = self
160 .owned
161 .as_any()
162 .downcast_ref::<StringArray>()
163 .expect("Expected StringArray for DataType::Utf8");
164 if ids.is_null(row_idx) {
165 None
166 } else {
167 Some(ids.value(row_idx))
168 }
169 }
170 DataType::LargeUtf8 => {
171 let ids = self
172 .owned
173 .as_any()
174 .downcast_ref::<LargeStringArray>()
175 .expect("Expected LargeStringArray for DataType::LargeUtf8");
176 if ids.is_null(row_idx) {
177 None
178 } else {
179 Some(ids.value(row_idx))
180 }
181 }
182 DataType::Utf8View => {
183 let ids = self
184 .owned
185 .as_any()
186 .downcast_ref::<StringViewArray>()
187 .expect("Expected StringViewArray for DataType::Utf8View");
188 if ids.is_null(row_idx) {
189 None
190 } else {
191 Some(ids.value(row_idx))
192 }
193 }
194 _ => panic!("Unsupported data type"),
195 };
196
197 let hash = self.rnd.hash_one(id);
198 if let Some(map_idx) = self
199 .map
200 .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str()))
201 {
202 return (map_idx, false);
203 }
204
205 let heap_idx = self.map.remove_if_full(replace_idx);
207
208 let id = id.map(|id| id.to_string());
210 let map_idx = self.map.insert(hash, id, heap_idx, mapper);
211 (map_idx, true)
212 }
213}
214
215impl<VAL: ArrowPrimitiveType> PrimitiveHashTable<VAL>
216where
217 Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
218 Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
219{
220 pub fn new(limit: usize, kt: DataType) -> Self {
221 let owned = Arc::new(
222 PrimitiveArray::<VAL>::builder(0)
223 .with_data_type(kt.clone())
224 .finish(),
225 );
226 Self {
227 owned,
228 map: TopKHashTable::new(limit, limit * 10),
229 rnd: RandomState::default(),
230 kt,
231 }
232 }
233}
234
235impl<VAL: ArrowPrimitiveType> ArrowHashTable for PrimitiveHashTable<VAL>
236where
237 Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
238 Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
239{
240 fn set_batch(&mut self, ids: ArrayRef) {
241 self.owned = ids;
242 }
243
244 fn len(&self) -> usize {
245 self.map.len()
246 }
247
248 unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
249 self.map.update_heap_idx(mapper);
250 }
251
252 unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
253 self.map.heap_idx_at(map_idx)
254 }
255
256 unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
257 let ids = self.map.take_all(indexes);
258 let mut builder: PrimitiveBuilder<VAL> =
259 PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone());
260 for id in ids.into_iter() {
261 match id {
262 None => builder.append_null(),
263 Some(id) => builder.append_value(id),
264 }
265 }
266 let ids = builder.finish();
267 Arc::new(ids)
268 }
269
270 unsafe fn find_or_insert(
271 &mut self,
272 row_idx: usize,
273 replace_idx: usize,
274 mapper: &mut Vec<(usize, usize)>,
275 ) -> (usize, bool) {
276 let ids = self.owned.as_primitive::<VAL>();
277 let id: Option<VAL::Native> = if ids.is_null(row_idx) {
278 None
279 } else {
280 Some(ids.value(row_idx))
281 };
282
283 let hash: u64 = id.hash(&self.rnd);
284 if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) {
285 return (map_idx, false);
286 }
287
288 let heap_idx = self.map.remove_if_full(replace_idx);
290
291 let map_idx = self.map.insert(hash, id, heap_idx, mapper);
293 (map_idx, true)
294 }
295}
296
297impl<ID: KeyType> TopKHashTable<ID> {
298 pub fn new(limit: usize, capacity: usize) -> Self {
299 Self {
300 map: RawTable::with_capacity(capacity),
301 limit,
302 }
303 }
304
305 pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option<usize> {
306 let bucket = self.map.find(hash, |mi| eq(&mi.id))?;
307 let idx = unsafe { self.map.bucket_index(&bucket) };
311 Some(idx)
312 }
313
314 pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize {
315 let bucket = unsafe { self.map.bucket(map_idx) };
316 bucket.as_ref().heap_idx
317 }
318
319 pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize {
320 if self.map.len() >= self.limit {
321 self.map.erase(self.map.bucket(replace_idx));
322 0 } else {
324 self.map.len() }
326 }
327
328 unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) {
329 for (m, h) in mapper {
330 self.map.bucket(*m).as_mut().heap_idx = *h
331 }
332 }
333
334 pub fn insert(
335 &mut self,
336 hash: u64,
337 id: ID,
338 heap_idx: usize,
339 mapper: &mut Vec<(usize, usize)>,
340 ) -> usize {
341 let mi = HashTableItem::new(hash, id, heap_idx);
342 let bucket = self.map.try_insert_no_grow(hash, mi);
343 let bucket = match bucket {
344 Ok(bucket) => bucket,
345 Err(new_item) => {
346 let bucket = self.map.insert(hash, new_item, |mi| mi.hash);
347 unsafe {
351 for bucket in self.map.iter() {
352 let heap_idx = bucket.as_ref().heap_idx;
353 let map_idx = self.map.bucket_index(&bucket);
354 mapper.push((heap_idx, map_idx));
355 }
356 }
357 bucket
358 }
359 };
360 unsafe { self.map.bucket_index(&bucket) }
364 }
365
366 pub fn len(&self) -> usize {
367 self.map.len()
368 }
369
370 pub unsafe fn take_all(&mut self, idxs: Vec<usize>) -> Vec<ID> {
371 let ids = idxs
372 .into_iter()
373 .map(|idx| self.map.bucket(idx).as_ref().id.clone())
374 .collect();
375 self.map.clear();
376 ids
377 }
378}
379
380impl<ID: KeyType> HashTableItem<ID> {
381 pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self {
382 Self { hash, id, heap_idx }
383 }
384}
385
386impl HashValue for Option<String> {
387 fn hash(&self, state: &RandomState) -> u64 {
388 state.hash_one(self)
389 }
390}
391
392macro_rules! hash_float {
393 ($($t:ty),+) => {
394 $(impl HashValue for Option<$t> {
395 fn hash(&self, state: &RandomState) -> u64 {
396 self.map(|me| me.hash(state)).unwrap_or(0)
397 }
398 })+
399 };
400}
401
402macro_rules! has_integer {
403 ($($t:ty),+) => {
404 $(impl HashValue for Option<$t> {
405 fn hash(&self, state: &RandomState) -> u64 {
406 self.map(|me| me.hash(state)).unwrap_or(0)
407 }
408 })+
409 };
410}
411
412has_integer!(i8, i16, i32, i64, i128, i256);
413has_integer!(u8, u16, u32, u64);
414has_integer!(IntervalDayTime, IntervalMonthDayNano);
415hash_float!(f16, f32, f64);
416
417pub fn new_hash_table(
418 limit: usize,
419 kt: DataType,
420) -> Result<Box<dyn ArrowHashTable + Send>> {
421 macro_rules! downcast_helper {
422 ($kt:ty, $d:ident) => {
423 return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit, kt)))
424 };
425 }
426
427 downcast_primitive! {
428 kt => (downcast_helper, kt),
429 DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))),
430 DataType::LargeUtf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::LargeUtf8))),
431 DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))),
432 _ => {}
433 }
434
435 Err(DataFusionError::Execution(format!(
436 "Can't create HashTable for type: {kt:?}"
437 )))
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use arrow::array::TimestampMillisecondArray;
444 use arrow_schema::TimeUnit;
445 use std::collections::BTreeMap;
446
447 #[test]
448 fn should_emit_correct_type() -> Result<()> {
449 let ids =
450 TimestampMillisecondArray::from(vec![1000]).with_timezone("UTC".to_string());
451 let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()));
452 let mut ht = new_hash_table(1, dt.clone())?;
453 ht.set_batch(Arc::new(ids));
454 let mut mapper = vec![];
455 let ids = unsafe {
456 ht.find_or_insert(0, 0, &mut mapper);
457 ht.take_all(vec![0])
458 };
459 assert_eq!(ids.data_type(), &dt);
460
461 Ok(())
462 }
463
464 #[test]
465 fn should_resize_properly() -> Result<()> {
466 let mut heap_to_map = BTreeMap::<usize, usize>::new();
467 let mut map = TopKHashTable::<Option<String>>::new(5, 3);
468 for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() {
469 let mut mapper = vec![];
470 let hash = heap_idx as u64;
471 let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper);
472 let _ = heap_to_map.insert(heap_idx, map_idx);
473 if heap_idx == 3 {
474 assert_eq!(
475 mapper,
476 vec![(0, 0), (1, 1), (2, 2), (3, 3)],
477 "Pass {heap_idx} resized incorrectly!"
478 );
479 for (heap_idx, map_idx) in mapper {
480 let _ = heap_to_map.insert(heap_idx, map_idx);
481 }
482 } else {
483 assert_eq!(mapper, vec![], "Pass {heap_idx} should not have resized!");
484 }
485 }
486
487 let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip();
488 let ids = unsafe { map.take_all(map_idxs) };
489 assert_eq!(
490 format!("{ids:?}"),
491 r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"#
492 );
493 assert_eq!(map.len(), 0, "Map should have been cleared!");
494
495 Ok(())
496 }
497}