1use crate::builder::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder};
19use crate::types::{ArrowDictionaryKeyType, ByteArrayType, GenericBinaryType, GenericStringType};
20use crate::{Array, ArrayRef, DictionaryArray, GenericByteArray};
21use arrow_buffer::ArrowNativeType;
22use arrow_schema::{ArrowError, DataType};
23use hashbrown::HashTable;
24use std::any::Any;
25use std::sync::Arc;
26
27#[derive(Debug)]
33pub struct GenericByteDictionaryBuilder<K, T>
34where
35 K: ArrowDictionaryKeyType,
36 T: ByteArrayType,
37{
38 state: ahash::RandomState,
39 dedup: HashTable<usize>,
40
41 keys_builder: PrimitiveBuilder<K>,
42 values_builder: GenericByteBuilder<T>,
43}
44
45impl<K, T> Default for GenericByteDictionaryBuilder<K, T>
46where
47 K: ArrowDictionaryKeyType,
48 T: ByteArrayType,
49{
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl<K, T> GenericByteDictionaryBuilder<K, T>
56where
57 K: ArrowDictionaryKeyType,
58 T: ByteArrayType,
59{
60 pub fn new() -> Self {
62 let keys_builder = PrimitiveBuilder::new();
63 let values_builder = GenericByteBuilder::<T>::new();
64 Self {
65 state: Default::default(),
66 dedup: HashTable::with_capacity(keys_builder.capacity()),
67 keys_builder,
68 values_builder,
69 }
70 }
71
72 pub fn with_capacity(
78 keys_capacity: usize,
79 value_capacity: usize,
80 data_capacity: usize,
81 ) -> Self {
82 Self {
83 state: Default::default(),
84 dedup: Default::default(),
85 keys_builder: PrimitiveBuilder::with_capacity(keys_capacity),
86 values_builder: GenericByteBuilder::<T>::with_capacity(value_capacity, data_capacity),
87 }
88 }
89
90 pub fn new_with_dictionary(
114 keys_capacity: usize,
115 dictionary_values: &GenericByteArray<T>,
116 ) -> Result<Self, ArrowError> {
117 let state = ahash::RandomState::default();
118 let dict_len = dictionary_values.len();
119
120 let mut dedup = HashTable::with_capacity(dict_len);
121
122 let values_len = dictionary_values.value_data().len();
123 let mut values_builder = GenericByteBuilder::<T>::with_capacity(dict_len, values_len);
124
125 K::Native::from_usize(dictionary_values.len())
126 .ok_or(ArrowError::DictionaryKeyOverflowError)?;
127
128 for (idx, maybe_value) in dictionary_values.iter().enumerate() {
129 match maybe_value {
130 Some(value) => {
131 let value_bytes: &[u8] = value.as_ref();
132 let hash = state.hash_one(value_bytes);
133
134 dedup
135 .entry(
136 hash,
137 |idx: &usize| value_bytes == get_bytes(&values_builder, *idx),
138 |idx: &usize| state.hash_one(get_bytes(&values_builder, *idx)),
139 )
140 .or_insert(idx);
141
142 values_builder.append_value(value);
143 }
144 None => values_builder.append_null(),
145 }
146 }
147
148 Ok(Self {
149 state,
150 dedup,
151 keys_builder: PrimitiveBuilder::with_capacity(keys_capacity),
152 values_builder,
153 })
154 }
155}
156
157impl<K, T> ArrayBuilder for GenericByteDictionaryBuilder<K, T>
158where
159 K: ArrowDictionaryKeyType,
160 T: ByteArrayType,
161{
162 fn as_any(&self) -> &dyn Any {
164 self
165 }
166
167 fn as_any_mut(&mut self) -> &mut dyn Any {
169 self
170 }
171
172 fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
174 self
175 }
176
177 fn len(&self) -> usize {
179 self.keys_builder.len()
180 }
181
182 fn finish(&mut self) -> ArrayRef {
184 Arc::new(self.finish())
185 }
186
187 fn finish_cloned(&self) -> ArrayRef {
189 Arc::new(self.finish_cloned())
190 }
191}
192
193impl<K, T> GenericByteDictionaryBuilder<K, T>
194where
195 K: ArrowDictionaryKeyType,
196 T: ByteArrayType,
197{
198 fn get_or_insert_key(&mut self, value: impl AsRef<T::Native>) -> Result<K::Native, ArrowError> {
199 let value_native: &T::Native = value.as_ref();
200 let value_bytes: &[u8] = value_native.as_ref();
201
202 let state = &self.state;
203 let storage = &mut self.values_builder;
204 let hash = state.hash_one(value_bytes);
205
206 let idx = *self
207 .dedup
208 .entry(
209 hash,
210 |idx| value_bytes == get_bytes(storage, *idx),
211 |idx| state.hash_one(get_bytes(storage, *idx)),
212 )
213 .or_insert_with(|| {
214 let idx = storage.len();
215 storage.append_value(value);
216 idx
217 })
218 .get();
219
220 let key = K::Native::from_usize(idx).ok_or(ArrowError::DictionaryKeyOverflowError)?;
221
222 Ok(key)
223 }
224
225 pub fn append(&mut self, value: impl AsRef<T::Native>) -> Result<K::Native, ArrowError> {
231 let key = self.get_or_insert_key(value)?;
232 self.keys_builder.append_value(key);
233 Ok(key)
234 }
235
236 pub fn append_n(
241 &mut self,
242 value: impl AsRef<T::Native>,
243 count: usize,
244 ) -> Result<K::Native, ArrowError> {
245 let key = self.get_or_insert_key(value)?;
246 self.keys_builder.append_value_n(key, count);
247 Ok(key)
248 }
249
250 pub fn append_value(&mut self, value: impl AsRef<T::Native>) {
256 self.append(value).expect("dictionary key overflow");
257 }
258
259 pub fn append_values(&mut self, value: impl AsRef<T::Native>, count: usize) {
266 self.append_n(value, count)
267 .expect("dictionary key overflow");
268 }
269
270 #[inline]
272 pub fn append_null(&mut self) {
273 self.keys_builder.append_null()
274 }
275
276 #[inline]
278 pub fn append_nulls(&mut self, n: usize) {
279 self.keys_builder.append_nulls(n)
280 }
281
282 #[inline]
288 pub fn append_option(&mut self, value: Option<impl AsRef<T::Native>>) {
289 match value {
290 None => self.append_null(),
291 Some(v) => self.append_value(v),
292 };
293 }
294
295 pub fn append_options(&mut self, value: Option<impl AsRef<T::Native>>, count: usize) {
302 match value {
303 None => self.keys_builder.append_nulls(count),
304 Some(v) => self.append_values(v, count),
305 };
306 }
307
308 pub fn finish(&mut self) -> DictionaryArray<K> {
310 self.dedup.clear();
311 let values = self.values_builder.finish();
312 let keys = self.keys_builder.finish();
313
314 let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE));
315
316 let builder = keys
317 .into_data()
318 .into_builder()
319 .data_type(data_type)
320 .child_data(vec![values.into_data()]);
321
322 DictionaryArray::from(unsafe { builder.build_unchecked() })
323 }
324
325 pub fn finish_cloned(&self) -> DictionaryArray<K> {
327 let values = self.values_builder.finish_cloned();
328 let keys = self.keys_builder.finish_cloned();
329
330 let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE));
331
332 let builder = keys
333 .into_data()
334 .into_builder()
335 .data_type(data_type)
336 .child_data(vec![values.into_data()]);
337
338 DictionaryArray::from(unsafe { builder.build_unchecked() })
339 }
340
341 pub fn validity_slice(&self) -> Option<&[u8]> {
343 self.keys_builder.validity_slice()
344 }
345}
346
347impl<K: ArrowDictionaryKeyType, T: ByteArrayType, V: AsRef<T::Native>> Extend<Option<V>>
348 for GenericByteDictionaryBuilder<K, T>
349{
350 #[inline]
351 fn extend<I: IntoIterator<Item = Option<V>>>(&mut self, iter: I) {
352 for v in iter {
353 self.append_option(v)
354 }
355 }
356}
357
358fn get_bytes<T: ByteArrayType>(values: &GenericByteBuilder<T>, idx: usize) -> &[u8] {
359 let offsets = values.offsets_slice();
360 let values = values.values_slice();
361
362 let end_offset = offsets[idx + 1].as_usize();
363 let start_offset = offsets[idx].as_usize();
364
365 &values[start_offset..end_offset]
366}
367
368pub type StringDictionaryBuilder<K> = GenericByteDictionaryBuilder<K, GenericStringType<i32>>;
401
402pub type LargeStringDictionaryBuilder<K> = GenericByteDictionaryBuilder<K, GenericStringType<i64>>;
404
405pub type BinaryDictionaryBuilder<K> = GenericByteDictionaryBuilder<K, GenericBinaryType<i32>>;
439
440pub type LargeBinaryDictionaryBuilder<K> = GenericByteDictionaryBuilder<K, GenericBinaryType<i64>>;
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 use crate::array::Int8Array;
448 use crate::types::{Int16Type, Int32Type, Int8Type, Utf8Type};
449 use crate::{BinaryArray, StringArray};
450
451 fn test_bytes_dictionary_builder<T>(values: Vec<&T::Native>)
452 where
453 T: ByteArrayType,
454 <T as ByteArrayType>::Native: PartialEq,
455 <T as ByteArrayType>::Native: AsRef<<T as ByteArrayType>::Native>,
456 {
457 let mut builder = GenericByteDictionaryBuilder::<Int8Type, T>::new();
458 builder.append(values[0]).unwrap();
459 builder.append_null();
460 builder.append(values[1]).unwrap();
461 builder.append(values[1]).unwrap();
462 builder.append(values[0]).unwrap();
463 let array = builder.finish();
464
465 assert_eq!(
466 array.keys(),
467 &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)])
468 );
469
470 let av = array.values();
472 let ava: &GenericByteArray<T> = av.as_any().downcast_ref::<GenericByteArray<T>>().unwrap();
473
474 assert_eq!(*ava.value(0), *values[0]);
475 assert_eq!(*ava.value(1), *values[1]);
476 }
477
478 #[test]
479 fn test_string_dictionary_builder() {
480 test_bytes_dictionary_builder::<GenericStringType<i32>>(vec!["abc", "def"]);
481 }
482
483 #[test]
484 fn test_binary_dictionary_builder() {
485 test_bytes_dictionary_builder::<GenericBinaryType<i32>>(vec![b"abc", b"def"]);
486 }
487
488 fn test_bytes_dictionary_builder_finish_cloned<T>(values: Vec<&T::Native>)
489 where
490 T: ByteArrayType,
491 <T as ByteArrayType>::Native: PartialEq,
492 <T as ByteArrayType>::Native: AsRef<<T as ByteArrayType>::Native>,
493 {
494 let mut builder = GenericByteDictionaryBuilder::<Int8Type, T>::new();
495
496 builder.append(values[0]).unwrap();
497 builder.append_null();
498 builder.append(values[1]).unwrap();
499 builder.append(values[1]).unwrap();
500 builder.append(values[0]).unwrap();
501 let mut array = builder.finish_cloned();
502
503 assert_eq!(
504 array.keys(),
505 &Int8Array::from(vec![Some(0), None, Some(1), Some(1), Some(0)])
506 );
507
508 let av = array.values();
510 let ava: &GenericByteArray<T> = av.as_any().downcast_ref::<GenericByteArray<T>>().unwrap();
511
512 assert_eq!(ava.value(0), values[0]);
513 assert_eq!(ava.value(1), values[1]);
514
515 builder.append(values[0]).unwrap();
516 builder.append(values[2]).unwrap();
517 builder.append(values[1]).unwrap();
518
519 array = builder.finish();
520
521 assert_eq!(
522 array.keys(),
523 &Int8Array::from(vec![
524 Some(0),
525 None,
526 Some(1),
527 Some(1),
528 Some(0),
529 Some(0),
530 Some(2),
531 Some(1)
532 ])
533 );
534
535 let av2 = array.values();
537 let ava2: &GenericByteArray<T> =
538 av2.as_any().downcast_ref::<GenericByteArray<T>>().unwrap();
539
540 assert_eq!(ava2.value(0), values[0]);
541 assert_eq!(ava2.value(1), values[1]);
542 assert_eq!(ava2.value(2), values[2]);
543 }
544
545 #[test]
546 fn test_string_dictionary_builder_finish_cloned() {
547 test_bytes_dictionary_builder_finish_cloned::<GenericStringType<i32>>(vec![
548 "abc", "def", "ghi",
549 ]);
550 }
551
552 #[test]
553 fn test_binary_dictionary_builder_finish_cloned() {
554 test_bytes_dictionary_builder_finish_cloned::<GenericBinaryType<i32>>(vec![
555 b"abc", b"def", b"ghi",
556 ]);
557 }
558
559 fn test_bytes_dictionary_builder_with_existing_dictionary<T>(
560 dictionary: GenericByteArray<T>,
561 values: Vec<&T::Native>,
562 ) where
563 T: ByteArrayType,
564 <T as ByteArrayType>::Native: PartialEq,
565 <T as ByteArrayType>::Native: AsRef<<T as ByteArrayType>::Native>,
566 {
567 let mut builder =
568 GenericByteDictionaryBuilder::<Int8Type, T>::new_with_dictionary(6, &dictionary)
569 .unwrap();
570 builder.append(values[0]).unwrap();
571 builder.append_null();
572 builder.append(values[1]).unwrap();
573 builder.append(values[1]).unwrap();
574 builder.append(values[0]).unwrap();
575 builder.append(values[2]).unwrap();
576 let array = builder.finish();
577
578 assert_eq!(
579 array.keys(),
580 &Int8Array::from(vec![Some(2), None, Some(1), Some(1), Some(2), Some(3)])
581 );
582
583 let av = array.values();
585 let ava: &GenericByteArray<T> = av.as_any().downcast_ref::<GenericByteArray<T>>().unwrap();
586
587 assert!(!ava.is_valid(0));
588 assert_eq!(ava.value(1), values[1]);
589 assert_eq!(ava.value(2), values[0]);
590 assert_eq!(ava.value(3), values[2]);
591 }
592
593 #[test]
594 fn test_string_dictionary_builder_with_existing_dictionary() {
595 test_bytes_dictionary_builder_with_existing_dictionary::<GenericStringType<i32>>(
596 StringArray::from(vec![None, Some("def"), Some("abc")]),
597 vec!["abc", "def", "ghi"],
598 );
599 }
600
601 #[test]
602 fn test_binary_dictionary_builder_with_existing_dictionary() {
603 let values: Vec<Option<&[u8]>> = vec![None, Some(b"def"), Some(b"abc")];
604 test_bytes_dictionary_builder_with_existing_dictionary::<GenericBinaryType<i32>>(
605 BinaryArray::from(values),
606 vec![b"abc", b"def", b"ghi"],
607 );
608 }
609
610 fn test_bytes_dictionary_builder_with_reserved_null_value<T>(
611 dictionary: GenericByteArray<T>,
612 values: Vec<&T::Native>,
613 ) where
614 T: ByteArrayType,
615 <T as ByteArrayType>::Native: PartialEq,
616 <T as ByteArrayType>::Native: AsRef<<T as ByteArrayType>::Native>,
617 {
618 let mut builder =
619 GenericByteDictionaryBuilder::<Int16Type, T>::new_with_dictionary(4, &dictionary)
620 .unwrap();
621 builder.append(values[0]).unwrap();
622 builder.append_null();
623 builder.append(values[1]).unwrap();
624 builder.append(values[0]).unwrap();
625 let array = builder.finish();
626
627 assert!(array.is_null(1));
628 assert!(!array.is_valid(1));
629
630 let keys = array.keys();
631
632 assert_eq!(keys.value(0), 1);
633 assert!(keys.is_null(1));
634 assert_eq!(keys.value(1), 0);
636 assert_eq!(keys.value(2), 2);
637 assert_eq!(keys.value(3), 1);
638 }
639
640 #[test]
641 fn test_string_dictionary_builder_with_reserved_null_value() {
642 let v: Vec<Option<&str>> = vec![None];
643 test_bytes_dictionary_builder_with_reserved_null_value::<GenericStringType<i32>>(
644 StringArray::from(v),
645 vec!["abc", "def"],
646 );
647 }
648
649 #[test]
650 fn test_binary_dictionary_builder_with_reserved_null_value() {
651 let values: Vec<Option<&[u8]>> = vec![None];
652 test_bytes_dictionary_builder_with_reserved_null_value::<GenericBinaryType<i32>>(
653 BinaryArray::from(values),
654 vec![b"abc", b"def"],
655 );
656 }
657
658 #[test]
659 fn test_extend() {
660 let mut builder = GenericByteDictionaryBuilder::<Int32Type, Utf8Type>::new();
661 builder.extend(["a", "b", "c", "a", "b", "c"].into_iter().map(Some));
662 builder.extend(["c", "d", "a"].into_iter().map(Some));
663 let dict = builder.finish();
664 assert_eq!(dict.keys().values(), &[0, 1, 2, 0, 1, 2, 2, 3, 0]);
665 assert_eq!(dict.values().len(), 4);
666 }
667}