Skip to main content

lance_encoding/encodings/logical/primitive/
dict.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, sync::Arc};
5
6/// Bits per value for FixedWidth dictionary values (legacy default for 128-bit values)
7pub const DICT_FIXED_WIDTH_BITS_PER_VALUE: u64 = 128;
8/// Bits per index for dictionary indices (always i32)
9pub const DICT_INDICES_BITS_PER_VALUE: u64 = 32;
10
11use arrow_array::{
12    Array, DictionaryArray, PrimitiveArray, UInt64Array,
13    cast::AsArray,
14    types::{
15        ArrowDictionaryKeyType, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type,
16        UInt32Type, UInt64Type,
17    },
18};
19use arrow_buffer::ArrowNativeType;
20use arrow_schema::DataType;
21use arrow_select::take::TakeOptions;
22use lance_core::{Error, Result, error::LanceOptionExt, utils::hash::U8SliceKey};
23
24use crate::{
25    buffer::LanceBuffer,
26    data::{BlockInfo, DataBlock, FixedWidthDataBlock, VariableWidthBlock},
27    statistics::{ComputeStat, GetStat, Stat},
28};
29
30// Helper function for normalize_dict_nulls
31fn normalize_dict_nulls_impl<K: ArrowDictionaryKeyType>(
32    array: Arc<dyn Array>,
33) -> Result<Arc<dyn Array>> {
34    // TODO: Fast path when there is only one null index? (common case)
35
36    let dict_array = array.as_dictionary_opt::<K>().expect_ok()?;
37
38    if dict_array.values().null_count() == 0 {
39        return Ok(array);
40    }
41
42    let mut mapping = vec![None; dict_array.values().len()];
43    let mut skipped = 0;
44    let mut valid_indices = Vec::with_capacity(dict_array.values().len());
45    for (old_idx, is_valid) in dict_array.values().nulls().expect_ok()?.iter().enumerate() {
46        if is_valid {
47            // Should be safe since we are only decreasing K values (e.g. won't overflow u8 keys into u16)
48            mapping[old_idx] = Some(K::Native::from_usize(old_idx - skipped).expect_ok()?);
49            valid_indices.push(old_idx as u64);
50        } else {
51            skipped += 1;
52            mapping[old_idx] = None;
53        }
54    }
55
56    let mut keys_builder = PrimitiveArray::<K>::builder(dict_array.keys().len());
57    for key in dict_array.keys().iter() {
58        if let Some(key) = key {
59            if let Some(mapped) = mapping[key.to_usize().expect_ok()?] {
60                // Valid item
61                keys_builder.append_value(mapped);
62            } else {
63                // Null via values
64                keys_builder.append_null();
65            }
66        } else {
67            // Null via keys
68            keys_builder.append_null();
69        }
70    }
71    let keys = keys_builder.finish();
72
73    let valid_indices = UInt64Array::from(valid_indices);
74    let values = arrow_select::take::take(
75        dict_array.values(),
76        &valid_indices,
77        Some(TakeOptions {
78            check_bounds: false,
79        }),
80    )?;
81
82    Ok(Arc::new(DictionaryArray::new(keys, values)) as Arc<dyn Array>)
83}
84
85/// In Arrow a dictionary array can have nulls in two different places:
86/// 1. The keys can be null
87/// 2. The values can be null
88///
89/// We want to normalize this so that all nulls are in the keys.  This way we can store
90/// the nulls with the keys as rep-def values the same as any other array.
91pub fn normalize_dict_nulls(array: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
92    match array.data_type() {
93        DataType::Dictionary(key_type, _) => match key_type.as_ref() {
94            DataType::UInt8 => normalize_dict_nulls_impl::<UInt8Type>(array),
95            DataType::UInt16 => normalize_dict_nulls_impl::<UInt16Type>(array),
96            DataType::UInt32 => normalize_dict_nulls_impl::<UInt32Type>(array),
97            DataType::UInt64 => normalize_dict_nulls_impl::<UInt64Type>(array),
98            DataType::Int8 => normalize_dict_nulls_impl::<Int8Type>(array),
99            DataType::Int16 => normalize_dict_nulls_impl::<Int16Type>(array),
100            DataType::Int32 => normalize_dict_nulls_impl::<Int32Type>(array),
101            DataType::Int64 => normalize_dict_nulls_impl::<Int64Type>(array),
102            _ => Err(Error::not_supported_source(
103                format!("Unsupported dictionary key type: {}", key_type).into(),
104            )),
105        },
106        _ => Err(Error::internal(format!(
107            "Data type is not a dictionary: {}",
108            array.data_type()
109        ))),
110    }
111}
112
113fn dict_encode_variable_width<T>(
114    variable_width_data_block: &VariableWidthBlock,
115    bits_per_offset: u8,
116    max_dict_entries: u32,
117    max_encoded_size: usize,
118) -> Option<(DataBlock, DataBlock)>
119where
120    T: ArrowNativeType,
121    usize: TryFrom<T>,
122{
123    use std::collections::hash_map::Entry;
124    let mut map = HashMap::new();
125    let offsets = variable_width_data_block
126        .offsets
127        .borrow_to_typed_slice::<T>();
128    let offsets = offsets.as_ref();
129
130    let max_len = variable_width_data_block
131        .get_stat(Stat::MaxLength)
132        .expect("VariableWidth DataBlock should have valid `Stat::MaxLength` statistics");
133    let max_len = max_len.as_primitive::<UInt64Type>().value(0);
134
135    let max_dict_data_len = variable_width_data_block.data.len();
136    let max_len: usize = max_len.try_into().unwrap_or(usize::MAX);
137    let dict_data_capacity = max_len
138        .saturating_mul(32)
139        .max(1024)
140        .min(max_dict_data_len)
141        .min(max_encoded_size);
142
143    let mut dictionary_buffer: Vec<u8> = Vec::with_capacity(dict_data_capacity);
144    let mut dictionary_offsets_buffer = vec![T::default()];
145    let mut curr_idx = 0;
146    let mut indices_buffer = Vec::with_capacity(variable_width_data_block.num_values as usize);
147    let bytes_per_offset = (bits_per_offset / 8) as usize;
148
149    for window in offsets.windows(2) {
150        let start = usize::try_from(window[0]).ok()?;
151        let end = usize::try_from(window[1]).ok()?;
152        if start > end || end > variable_width_data_block.data.len() {
153            return None;
154        }
155
156        let key = &variable_width_data_block.data[start..end];
157
158        let idx = match map.entry(U8SliceKey(key)) {
159            Entry::Occupied(entry) => *entry.get(),
160            Entry::Vacant(entry) => {
161                if max_dict_entries == 0 || curr_idx as u32 >= max_dict_entries {
162                    return None;
163                }
164                if curr_idx == i32::MAX {
165                    return None;
166                }
167                dictionary_buffer.extend_from_slice(key);
168                let dict_offset = T::from_usize(dictionary_buffer.len())?;
169                dictionary_offsets_buffer.push(dict_offset);
170                let idx = curr_idx;
171                entry.insert(idx);
172                curr_idx += 1;
173                idx
174            }
175        };
176
177        indices_buffer.push(idx);
178
179        let indices_bytes = indices_buffer
180            .len()
181            .saturating_mul(DICT_INDICES_BITS_PER_VALUE as usize / 8);
182        let offsets_bytes = dictionary_offsets_buffer
183            .len()
184            .saturating_mul(bytes_per_offset);
185        let encoded_size = dictionary_buffer
186            .len()
187            .saturating_add(indices_bytes)
188            .saturating_add(offsets_bytes);
189        if encoded_size > max_encoded_size {
190            return None;
191        }
192    }
193
194    let mut dictionary_data_block = DataBlock::VariableWidth(VariableWidthBlock {
195        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
196        offsets: LanceBuffer::reinterpret_vec(dictionary_offsets_buffer),
197        bits_per_offset,
198        num_values: curr_idx as u64,
199        block_info: BlockInfo::default(),
200    });
201    dictionary_data_block.compute_stat();
202
203    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
204        data: LanceBuffer::reinterpret_vec(indices_buffer),
205        bits_per_value: DICT_INDICES_BITS_PER_VALUE,
206        num_values: variable_width_data_block.num_values,
207        block_info: BlockInfo::default(),
208    });
209    indices_data_block.compute_stat();
210
211    Some((indices_data_block, dictionary_data_block))
212}
213
214/// Dictionary encodes a data block
215///
216/// Currently only supported for some common cases (string / binary / 64-bit / 128-bit)
217///
218/// Returns a block of indices (will always be a fixed width data block) and a block of dictionary
219pub fn dictionary_encode(
220    data_block: &DataBlock,
221    max_dict_entries: u32,
222    max_encoded_size: usize,
223) -> Option<(DataBlock, DataBlock)> {
224    match data_block {
225        DataBlock::FixedWidth(fixed_width_data_block) => {
226            use std::collections::hash_map::Entry;
227
228            let bytes_per_value = match fixed_width_data_block.bits_per_value {
229                64 => 8usize,
230                128 => 16usize,
231                _ => return None,
232            };
233
234            match fixed_width_data_block.bits_per_value {
235                64 => {
236                    let mut map = HashMap::new();
237                    let u64_slice = fixed_width_data_block.data.borrow_to_typed_slice::<u64>();
238                    let u64_slice = u64_slice.as_ref();
239                    let mut dictionary_buffer =
240                        Vec::with_capacity((fixed_width_data_block.num_values as usize).min(1024));
241                    let mut indices_buffer =
242                        Vec::with_capacity(fixed_width_data_block.num_values as usize);
243                    let mut curr_idx: i32 = 0;
244
245                    for &value in u64_slice.iter() {
246                        let idx = match map.entry(value) {
247                            Entry::Occupied(entry) => *entry.get(),
248                            Entry::Vacant(entry) => {
249                                if max_dict_entries == 0 || curr_idx as u32 >= max_dict_entries {
250                                    return None;
251                                }
252                                if curr_idx == i32::MAX {
253                                    return None;
254                                }
255                                dictionary_buffer.push(value);
256                                let idx = curr_idx;
257                                entry.insert(idx);
258                                curr_idx += 1;
259                                idx
260                            }
261                        };
262                        indices_buffer.push(idx);
263                        let dict_bytes = dictionary_buffer.len().saturating_mul(bytes_per_value);
264                        let indices_bytes = indices_buffer
265                            .len()
266                            .saturating_mul(DICT_INDICES_BITS_PER_VALUE as usize / 8);
267                        let encoded_size = dict_bytes.saturating_add(indices_bytes);
268                        if encoded_size > max_encoded_size {
269                            return None;
270                        }
271                    }
272
273                    let mut dictionary_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
274                        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
275                        bits_per_value: 64,
276                        num_values: curr_idx as u64,
277                        block_info: BlockInfo::default(),
278                    });
279                    dictionary_data_block.compute_stat();
280                    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
281                        data: LanceBuffer::reinterpret_vec(indices_buffer),
282                        bits_per_value: DICT_INDICES_BITS_PER_VALUE,
283                        num_values: fixed_width_data_block.num_values,
284                        block_info: BlockInfo::default(),
285                    });
286                    indices_data_block.compute_stat();
287
288                    Some((indices_data_block, dictionary_data_block))
289                }
290                128 => {
291                    // TODO: a follow up PR to support `FixedWidth DataBlock with bits_per_value == 256`.
292                    let mut map = HashMap::new();
293                    let u128_slice = fixed_width_data_block.data.borrow_to_typed_slice::<u128>();
294                    let u128_slice = u128_slice.as_ref();
295                    let mut dictionary_buffer =
296                        Vec::with_capacity((fixed_width_data_block.num_values as usize).min(1024));
297                    let mut indices_buffer =
298                        Vec::with_capacity(fixed_width_data_block.num_values as usize);
299                    let mut curr_idx: i32 = 0;
300
301                    for &value in u128_slice.iter() {
302                        let idx = match map.entry(value) {
303                            Entry::Occupied(entry) => *entry.get(),
304                            Entry::Vacant(entry) => {
305                                if max_dict_entries == 0 || curr_idx as u32 >= max_dict_entries {
306                                    return None;
307                                }
308                                if curr_idx == i32::MAX {
309                                    return None;
310                                }
311                                dictionary_buffer.push(value);
312                                let idx = curr_idx;
313                                entry.insert(idx);
314                                curr_idx += 1;
315                                idx
316                            }
317                        };
318                        indices_buffer.push(idx);
319                        let dict_bytes = dictionary_buffer.len().saturating_mul(bytes_per_value);
320                        let indices_bytes = indices_buffer
321                            .len()
322                            .saturating_mul(DICT_INDICES_BITS_PER_VALUE as usize / 8);
323                        let encoded_size = dict_bytes.saturating_add(indices_bytes);
324                        if encoded_size > max_encoded_size {
325                            return None;
326                        }
327                    }
328
329                    let mut dictionary_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
330                        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
331                        bits_per_value: DICT_FIXED_WIDTH_BITS_PER_VALUE,
332                        num_values: curr_idx as u64,
333                        block_info: BlockInfo::default(),
334                    });
335                    dictionary_data_block.compute_stat();
336                    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
337                        data: LanceBuffer::reinterpret_vec(indices_buffer),
338                        bits_per_value: DICT_INDICES_BITS_PER_VALUE,
339                        num_values: fixed_width_data_block.num_values,
340                        block_info: BlockInfo::default(),
341                    });
342                    indices_data_block.compute_stat();
343
344                    Some((indices_data_block, dictionary_data_block))
345                }
346                _ => None,
347            }
348        }
349        DataBlock::VariableWidth(variable_width_data_block) => {
350            match variable_width_data_block.bits_per_offset {
351                32 => dict_encode_variable_width::<u32>(
352                    variable_width_data_block,
353                    32,
354                    max_dict_entries,
355                    max_encoded_size,
356                ),
357                64 => dict_encode_variable_width::<u64>(
358                    variable_width_data_block,
359                    64,
360                    max_dict_entries,
361                    max_encoded_size,
362                ),
363                _ => None,
364            }
365        }
366        _ => None,
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::{
374        buffer::LanceBuffer,
375        data::{BlockInfo, FixedWidthDataBlock},
376    };
377    use arrow_array::{Array, StringArray};
378    use std::sync::Arc;
379
380    #[test]
381    fn test_dictionary_encode_abort_fixed_width() {
382        // Create a u128 block with very high cardinality where dict encoding
383        // would result in larger data (dictionary overhead + indices > original)
384        let num_values = 120u64;
385
386        // Create actual data: each value is unique u128 so dictionary encode will not be helpful
387        let mut data = Vec::with_capacity(num_values as usize);
388        for i in 0..num_values {
389            data.push(i as u128);
390        }
391
392        let mut data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
393            bits_per_value: DICT_FIXED_WIDTH_BITS_PER_VALUE,
394            data: LanceBuffer::reinterpret_vec(data),
395            num_values,
396            block_info: BlockInfo::default(),
397        });
398
399        // Compute stats naturally
400        data_block.compute_stat();
401
402        // Dictionary encoding should abort and return None
403        let max_encoded_size = usize::try_from(data_block.data_size()).unwrap_or(usize::MAX);
404        let result = dictionary_encode(&data_block, 1000, max_encoded_size);
405        assert!(
406            result.is_none(),
407            "Dictionary encoding should abort for high cardinality u128 data"
408        );
409    }
410
411    #[test]
412    fn test_dictionary_encode_success_fixed_width() {
413        // Create a u128 block with low cardinality where dict encoding helps
414        let num_values = 120u64;
415        let cardinality = 3u64;
416
417        // Create data with few unique u128 values
418        let mut data = Vec::with_capacity(num_values as usize);
419        for i in 0..num_values {
420            data.push((i % cardinality) as u128);
421        }
422
423        let mut data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
424            bits_per_value: DICT_FIXED_WIDTH_BITS_PER_VALUE,
425            data: LanceBuffer::reinterpret_vec(data),
426            num_values,
427            block_info: BlockInfo::default(),
428        });
429
430        // Compute stats naturally
431        data_block.compute_stat();
432
433        // Dictionary encoding should succeed and return Some
434        let max_encoded_size = usize::try_from(data_block.data_size()).unwrap_or(usize::MAX);
435        let result = dictionary_encode(&data_block, 1000, max_encoded_size);
436        assert!(
437            result.is_some(),
438            "Dictionary encoding should succeed for low cardinality u128 data"
439        );
440
441        if let Some((indices, dictionary)) = result {
442            // Verify indices block
443            if let DataBlock::FixedWidth(indices_block) = indices {
444                assert_eq!(indices_block.num_values, num_values);
445                assert_eq!(indices_block.bits_per_value, DICT_INDICES_BITS_PER_VALUE);
446            } else {
447                panic!("Expected FixedWidth indices block");
448            }
449
450            // Verify dictionary block
451            if let DataBlock::FixedWidth(dict_block) = dictionary {
452                assert_eq!(dict_block.num_values, cardinality);
453                assert_eq!(dict_block.bits_per_value, DICT_FIXED_WIDTH_BITS_PER_VALUE);
454            } else {
455                panic!("Expected FixedWidth dictionary block");
456            }
457        }
458    }
459
460    #[test]
461    fn test_dictionary_encode_abort_variable_width() {
462        // Create a variable-width block with high cardinality where dict encoding
463        // won't provide sufficient benefit
464        let num_values = 120u64;
465        let mut values = Vec::with_capacity(num_values as usize);
466        for i in 0..num_values {
467            values.push(format!("unique_value_{:04}", i));
468        }
469        let array = StringArray::from(values);
470        // from_array already computes stats
471        let data_block = DataBlock::from_array(Arc::new(array) as Arc<dyn Array>);
472
473        // Dictionary encoding should abort and return None
474        let max_encoded_size = usize::try_from(data_block.data_size()).unwrap_or(usize::MAX);
475        let result = dictionary_encode(&data_block, 10, max_encoded_size);
476        assert!(
477            result.is_none(),
478            "Dictionary encoding should abort for high cardinality string data"
479        );
480    }
481
482    #[test]
483    fn test_dictionary_encode_success_low_cardinality() {
484        // Create a variable-width block with low cardinality where dict encoding helps
485        let num_values = 120u64;
486        let cardinality = 3u64;
487
488        let mut values = Vec::with_capacity(num_values as usize);
489        for i in 0..num_values {
490            values.push(format!("value_{}", i % cardinality));
491        }
492
493        let array = StringArray::from(values);
494        let data_block = DataBlock::from_array(Arc::new(array) as Arc<dyn Array>);
495
496        // Dictionary encoding should succeed and return Some
497        let max_encoded_size = usize::try_from(data_block.data_size()).unwrap_or(usize::MAX);
498        let result = dictionary_encode(&data_block, 100, max_encoded_size);
499        assert!(
500            result.is_some(),
501            "Dictionary encoding should succeed for low cardinality data"
502        );
503
504        if let Some((indices, dictionary)) = result {
505            // Verify indices block
506            if let DataBlock::FixedWidth(indices_block) = indices {
507                assert_eq!(indices_block.num_values, num_values);
508                assert_eq!(indices_block.bits_per_value, DICT_INDICES_BITS_PER_VALUE);
509            } else {
510                panic!("Expected FixedWidth indices block");
511            }
512
513            // Verify dictionary block
514            if let DataBlock::VariableWidth(dict_block) = dictionary {
515                assert_eq!(dict_block.num_values, cardinality);
516            } else {
517                panic!("Expected VariableWidth dictionary block");
518            }
519        }
520    }
521
522    #[test]
523    fn test_dictionary_encode_invalid_offset_width_returns_none() {
524        let array = StringArray::from(vec!["a", "b", "c", "a"]);
525        let data_block = DataBlock::from_array(Arc::new(array) as Arc<dyn Array>);
526        let invalid_block = match data_block {
527            DataBlock::VariableWidth(mut var) => {
528                var.bits_per_offset = 16;
529                DataBlock::VariableWidth(var)
530            }
531            other => panic!("Expected VariableWidth data block, got {:?}", other),
532        };
533        let max_encoded_size = usize::try_from(invalid_block.data_size()).unwrap_or(usize::MAX);
534        assert!(dictionary_encode(&invalid_block, 100, max_encoded_size).is_none());
535    }
536
537    #[test]
538    fn test_dictionary_encode_respects_size_limit() {
539        let num_values = 10_000u64;
540        let cardinality = 50u64;
541
542        let mut values = Vec::with_capacity(num_values as usize);
543        for i in 0..num_values {
544            values.push(format!("value_{:08}", i % cardinality));
545        }
546
547        let array = StringArray::from(values);
548        let data_block = DataBlock::from_array(Arc::new(array) as Arc<dyn Array>);
549
550        let full_size = usize::try_from(data_block.data_size()).unwrap_or(usize::MAX);
551        let too_small_limit = full_size / 10;
552        assert!(dictionary_encode(&data_block, 1000, too_small_limit).is_none());
553        assert!(dictionary_encode(&data_block, 1000, full_size).is_some());
554    }
555
556    #[test]
557    fn test_dictionary_encode_respects_entry_limit() {
558        let num_values = 10_000u64;
559        let cardinality = 200u64;
560
561        let mut values = Vec::with_capacity(num_values as usize);
562        for i in 0..num_values {
563            values.push(format!("value_{:08}", i % cardinality));
564        }
565
566        let array = StringArray::from(values);
567        let data_block = DataBlock::from_array(Arc::new(array) as Arc<dyn Array>);
568
569        let max_encoded_size = usize::try_from(data_block.data_size()).unwrap_or(usize::MAX);
570        assert!(dictionary_encode(&data_block, 10, max_encoded_size).is_none());
571        assert!(dictionary_encode(&data_block, 500, max_encoded_size).is_some());
572    }
573}