Skip to main content

lance_encoding/encodings/logical/primitive/
dict.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{collections::HashMap, sync::Arc};
5
6/// Bits per value for FixedWidth dictionary values (legacy default for 128-bit values)
7pub const DICT_FIXED_WIDTH_BITS_PER_VALUE: u64 = 128;
8/// Bits per index for dictionary indices (always i32)
9pub const DICT_INDICES_BITS_PER_VALUE: u64 = 32;
10
11use arrow_array::{
12    cast::AsArray,
13    types::{
14        ArrowDictionaryKeyType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
15        UInt64Type, UInt8Type,
16    },
17    Array, DictionaryArray, PrimitiveArray, UInt64Array,
18};
19use arrow_buffer::ArrowNativeType;
20use arrow_schema::DataType;
21use arrow_select::take::TakeOptions;
22use lance_core::{error::LanceOptionExt, utils::hash::U8SliceKey, Error, Result};
23use snafu::location;
24
25use crate::{
26    buffer::LanceBuffer,
27    data::{BlockInfo, DataBlock, FixedWidthDataBlock, VariableWidthBlock},
28    statistics::{ComputeStat, GetStat, Stat},
29};
30
31// Helper function for normalize_dict_nulls
32fn normalize_dict_nulls_impl<K: ArrowDictionaryKeyType>(
33    array: Arc<dyn Array>,
34) -> Result<Arc<dyn Array>> {
35    // TODO: Fast path when there is only one null index? (common case)
36
37    let dict_array = array.as_dictionary_opt::<K>().expect_ok()?;
38
39    if dict_array.values().null_count() == 0 {
40        return Ok(array);
41    }
42
43    let mut mapping = vec![None; dict_array.values().len()];
44    let mut skipped = 0;
45    let mut valid_indices = Vec::with_capacity(dict_array.values().len());
46    for (old_idx, is_valid) in dict_array.values().nulls().expect_ok()?.iter().enumerate() {
47        if is_valid {
48            // Should be safe since we are only decreasing K values (e.g. won't overflow u8 keys into u16)
49            mapping[old_idx] = Some(K::Native::from_usize(old_idx - skipped).expect_ok()?);
50            valid_indices.push(old_idx as u64);
51        } else {
52            skipped += 1;
53            mapping[old_idx] = None;
54        }
55    }
56
57    let mut keys_builder = PrimitiveArray::<K>::builder(dict_array.keys().len());
58    for key in dict_array.keys().iter() {
59        if let Some(key) = key {
60            if let Some(mapped) = mapping[key.to_usize().expect_ok()?] {
61                // Valid item
62                keys_builder.append_value(mapped);
63            } else {
64                // Null via values
65                keys_builder.append_null();
66            }
67        } else {
68            // Null via keys
69            keys_builder.append_null();
70        }
71    }
72    let keys = keys_builder.finish();
73
74    let valid_indices = UInt64Array::from(valid_indices);
75    let values = arrow_select::take::take(
76        dict_array.values(),
77        &valid_indices,
78        Some(TakeOptions {
79            check_bounds: false,
80        }),
81    )?;
82
83    Ok(Arc::new(DictionaryArray::new(keys, values)) as Arc<dyn Array>)
84}
85
86/// In Arrow a dictionary array can have nulls in two different places:
87/// 1. The keys can be null
88/// 2. The values can be null
89///
90/// We want to normalize this so that all nulls are in the keys.  This way we can store
91/// the nulls with the keys as rep-def values the same as any other array.
92pub fn normalize_dict_nulls(array: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
93    match array.data_type() {
94        DataType::Dictionary(key_type, _) => match key_type.as_ref() {
95            DataType::UInt8 => normalize_dict_nulls_impl::<UInt8Type>(array),
96            DataType::UInt16 => normalize_dict_nulls_impl::<UInt16Type>(array),
97            DataType::UInt32 => normalize_dict_nulls_impl::<UInt32Type>(array),
98            DataType::UInt64 => normalize_dict_nulls_impl::<UInt64Type>(array),
99            DataType::Int8 => normalize_dict_nulls_impl::<Int8Type>(array),
100            DataType::Int16 => normalize_dict_nulls_impl::<Int16Type>(array),
101            DataType::Int32 => normalize_dict_nulls_impl::<Int32Type>(array),
102            DataType::Int64 => normalize_dict_nulls_impl::<Int64Type>(array),
103            _ => Err(Error::NotSupported {
104                source: format!("Unsupported dictionary key type: {}", key_type).into(),
105                location: location!(),
106            }),
107        },
108        _ => Err(Error::Internal {
109            message: format!("Data type is not a dictionary: {}", array.data_type()),
110            location: location!(),
111        }),
112    }
113}
114
115/// Dictionary encodes a data block
116///
117/// Currently only supported for some common cases (string / binary / 64-bit / 128-bit)
118///
119/// Returns a block of indices (will always be a fixed width data block) and a block of dictionary
120pub fn dictionary_encode(mut data_block: DataBlock) -> (DataBlock, DataBlock) {
121    let cardinality = data_block
122        .get_stat(Stat::Cardinality)
123        .unwrap()
124        .as_primitive::<UInt64Type>()
125        .value(0);
126    match data_block {
127        DataBlock::FixedWidth(ref mut fixed_width_data_block) => {
128            match fixed_width_data_block.bits_per_value {
129                64 => {
130                    let mut map = HashMap::new();
131                    let u64_slice = fixed_width_data_block.data.borrow_to_typed_slice::<u64>();
132                    let u64_slice = u64_slice.as_ref();
133                    let mut dictionary_buffer = Vec::with_capacity(cardinality as usize);
134                    let mut indices_buffer =
135                        Vec::with_capacity(fixed_width_data_block.num_values as usize);
136                    let mut curr_idx: i32 = 0;
137                    u64_slice.iter().for_each(|&value| {
138                        let idx = *map.entry(value).or_insert_with(|| {
139                            dictionary_buffer.push(value);
140                            curr_idx += 1;
141                            curr_idx - 1
142                        });
143                        indices_buffer.push(idx);
144                    });
145                    let mut dictionary_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
146                        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
147                        bits_per_value: 64,
148                        num_values: curr_idx as u64,
149                        block_info: BlockInfo::default(),
150                    });
151                    dictionary_data_block.compute_stat();
152                    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
153                        data: LanceBuffer::reinterpret_vec(indices_buffer),
154                        bits_per_value: DICT_INDICES_BITS_PER_VALUE,
155                        num_values: fixed_width_data_block.num_values,
156                        block_info: BlockInfo::default(),
157                    });
158                    indices_data_block.compute_stat();
159                    (indices_data_block, dictionary_data_block)
160                }
161                128 => {
162                    // TODO: a follow up PR to support `FixedWidth DataBlock with bits_per_value == 256`.
163                    let mut map = HashMap::new();
164                    let u128_slice = fixed_width_data_block.data.borrow_to_typed_slice::<u128>();
165                    let u128_slice = u128_slice.as_ref();
166                    let mut dictionary_buffer = Vec::with_capacity(cardinality as usize);
167                    let mut indices_buffer =
168                        Vec::with_capacity(fixed_width_data_block.num_values as usize);
169                    let mut curr_idx: i32 = 0;
170                    u128_slice.iter().for_each(|&value| {
171                        let idx = *map.entry(value).or_insert_with(|| {
172                            dictionary_buffer.push(value);
173                            curr_idx += 1;
174                            curr_idx - 1
175                        });
176                        indices_buffer.push(idx);
177                    });
178                    let mut dictionary_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
179                        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
180                        bits_per_value: DICT_FIXED_WIDTH_BITS_PER_VALUE,
181                        num_values: curr_idx as u64,
182                        block_info: BlockInfo::default(),
183                    });
184                    dictionary_data_block.compute_stat();
185                    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
186                        data: LanceBuffer::reinterpret_vec(indices_buffer),
187                        bits_per_value: DICT_INDICES_BITS_PER_VALUE,
188                        num_values: fixed_width_data_block.num_values,
189                        block_info: BlockInfo::default(),
190                    });
191                    indices_data_block.compute_stat();
192                    (indices_data_block, dictionary_data_block)
193                }
194                other => unreachable!(
195                    "dictionary encode called with FixedWidth DataBlock bits_per_value={}",
196                    other
197                ),
198            }
199        }
200        DataBlock::VariableWidth(ref mut variable_width_data_block) => {
201            match variable_width_data_block.bits_per_offset {
202                32 => {
203                    let mut map = HashMap::new();
204                    let offsets = variable_width_data_block
205                        .offsets
206                        .borrow_to_typed_slice::<u32>();
207                    let offsets = offsets.as_ref();
208
209                    let max_len = variable_width_data_block.get_stat(Stat::MaxLength).expect(
210                        "VariableWidth DataBlock should have valid `Stat::DataSize` statistics",
211                    );
212                    let max_len = max_len.as_primitive::<UInt64Type>().value(0);
213
214                    let mut dictionary_buffer: Vec<u8> =
215                        Vec::with_capacity((max_len * cardinality) as usize);
216                    let mut dictionary_offsets_buffer = vec![0];
217                    let mut curr_idx = 0;
218                    let mut indices_buffer =
219                        Vec::with_capacity(variable_width_data_block.num_values as usize);
220
221                    offsets
222                        .iter()
223                        .zip(offsets.iter().skip(1))
224                        .for_each(|(&start, &end)| {
225                            let key = &variable_width_data_block.data[start as usize..end as usize];
226                            let idx: i32 = *map.entry(U8SliceKey(key)).or_insert_with(|| {
227                                dictionary_buffer.extend_from_slice(key);
228                                dictionary_offsets_buffer.push(dictionary_buffer.len() as u32);
229                                curr_idx += 1;
230                                curr_idx - 1
231                            });
232                            indices_buffer.push(idx);
233                        });
234
235                    let dictionary_data_block = DataBlock::VariableWidth(VariableWidthBlock {
236                        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
237                        offsets: LanceBuffer::reinterpret_vec(dictionary_offsets_buffer),
238                        bits_per_offset: 32,
239                        num_values: curr_idx as u64,
240                        block_info: BlockInfo::default(),
241                    });
242
243                    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
244                        data: LanceBuffer::reinterpret_vec(indices_buffer),
245                        bits_per_value: 32,
246                        num_values: variable_width_data_block.num_values,
247                        block_info: BlockInfo::default(),
248                    });
249                    // Todo: if we decide to do eager statistics computing, wrap statistics computing
250                    // in DataBlock constructor.
251                    indices_data_block.compute_stat();
252
253                    (indices_data_block, dictionary_data_block)
254                }
255                64 => {
256                    let mut map = HashMap::new();
257                    let offsets = variable_width_data_block
258                        .offsets
259                        .borrow_to_typed_slice::<u64>();
260                    let offsets = offsets.as_ref();
261
262                    let max_len = variable_width_data_block.get_stat(Stat::MaxLength).expect(
263                        "VariableWidth DataBlock should have valid `Stat::DataSize` statistics",
264                    );
265                    let max_len = max_len.as_primitive::<UInt64Type>().value(0);
266
267                    let mut dictionary_buffer: Vec<u8> =
268                        Vec::with_capacity((max_len * cardinality) as usize);
269                    let mut dictionary_offsets_buffer = vec![0];
270                    let mut curr_idx = 0;
271                    let mut indices_buffer =
272                        Vec::with_capacity(variable_width_data_block.num_values as usize);
273
274                    offsets
275                        .iter()
276                        .zip(offsets.iter().skip(1))
277                        .for_each(|(&start, &end)| {
278                            let key = &variable_width_data_block.data[start as usize..end as usize];
279                            let idx: i64 = *map.entry(U8SliceKey(key)).or_insert_with(|| {
280                                dictionary_buffer.extend_from_slice(key);
281                                dictionary_offsets_buffer.push(dictionary_buffer.len() as u64);
282                                curr_idx += 1;
283                                curr_idx - 1
284                            });
285                            indices_buffer.push(idx);
286                        });
287
288                    let dictionary_data_block = DataBlock::VariableWidth(VariableWidthBlock {
289                        data: LanceBuffer::reinterpret_vec(dictionary_buffer),
290                        offsets: LanceBuffer::reinterpret_vec(dictionary_offsets_buffer),
291                        bits_per_offset: 64,
292                        num_values: curr_idx as u64,
293                        block_info: BlockInfo::default(),
294                    });
295
296                    let mut indices_data_block = DataBlock::FixedWidth(FixedWidthDataBlock {
297                        data: LanceBuffer::reinterpret_vec(indices_buffer),
298                        bits_per_value: 64,
299                        num_values: variable_width_data_block.num_values,
300                        block_info: BlockInfo::default(),
301                    });
302                    // Todo: if we decide to do eager statistics computing, wrap statistics computing
303                    // in DataBlock constructor.
304                    indices_data_block.compute_stat();
305
306                    (indices_data_block, dictionary_data_block)
307                }
308                _ => {
309                    unreachable!()
310                }
311            }
312        }
313        _ => {
314            unreachable!("dictionary encode called with data block {:?}", data_block)
315        }
316    }
317}