Skip to main content

lance_encoding/previous/encodings/physical/
dictionary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5use std::vec;
6
7use arrow_array::builder::{ArrayBuilder, StringBuilder};
8use arrow_array::cast::AsArray;
9use arrow_array::types::UInt8Type;
10use arrow_array::{
11    Array, ArrayRef, DictionaryArray, StringArray, UInt8Array, make_array, new_null_array,
12};
13use arrow_schema::DataType;
14use futures::{FutureExt, future::BoxFuture};
15use lance_arrow::DataTypeExt;
16use lance_core::{Error, Result};
17use std::collections::HashMap;
18
19use crate::buffer::LanceBuffer;
20use crate::data::{
21    BlockInfo, DataBlock, DictionaryDataBlock, FixedWidthDataBlock, NullableDataBlock,
22    VariableWidthBlock,
23};
24use crate::format::ProtobufUtils;
25use crate::previous::decoder::LogicalPageDecoder;
26use crate::previous::encodings::logical::primitive::PrimitiveFieldDecoder;
27use crate::{
28    EncodingsIo,
29    decoder::{PageScheduler, PrimitivePageDecoder},
30    previous::encoder::{ArrayEncoder, EncodedArray},
31};
32
33#[derive(Debug)]
34pub struct DictionaryPageScheduler {
35    indices_scheduler: Arc<dyn PageScheduler>,
36    items_scheduler: Arc<dyn PageScheduler>,
37    // The number of items in the dictionary
38    num_dictionary_items: u32,
39    // If true, decode the dictionary items.  If false, leave them dictionary encoded (e.g. the
40    // output type is probably a dictionary type)
41    should_decode_dict: bool,
42}
43
44impl DictionaryPageScheduler {
45    pub fn new(
46        indices_scheduler: Arc<dyn PageScheduler>,
47        items_scheduler: Arc<dyn PageScheduler>,
48        num_dictionary_items: u32,
49        should_decode_dict: bool,
50    ) -> Self {
51        Self {
52            indices_scheduler,
53            items_scheduler,
54            num_dictionary_items,
55            should_decode_dict,
56        }
57    }
58}
59
60impl PageScheduler for DictionaryPageScheduler {
61    fn schedule_ranges(
62        &self,
63        ranges: &[std::ops::Range<u64>],
64        scheduler: &Arc<dyn EncodingsIo>,
65        top_level_row: u64,
66    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
67        // We want to decode indices and items
68        // e.g. indices [0, 1, 2, 0, 1, 0]
69        // items (dictionary) ["abcd", "hello", "apple"]
70        // This will map to ["abcd", "hello", "apple", "abcd", "hello", "abcd"]
71        // We decode all the items during scheduling itself
72        // These are used to rebuild the string later
73
74        // Schedule indices for decoding
75        let indices_page_decoder =
76            self.indices_scheduler
77                .schedule_ranges(ranges, scheduler, top_level_row);
78
79        // Schedule items for decoding
80        let items_range = 0..(self.num_dictionary_items as u64);
81        let items_page_decoder = self.items_scheduler.schedule_ranges(
82            std::slice::from_ref(&items_range),
83            scheduler,
84            top_level_row,
85        );
86
87        let copy_size = self.num_dictionary_items as u64;
88
89        if self.should_decode_dict {
90            tokio::spawn(async move {
91                let items_decoder: Arc<dyn PrimitivePageDecoder> =
92                    Arc::from(items_page_decoder.await?);
93
94                let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data(
95                    items_decoder.clone(),
96                    DataType::Utf8,
97                    copy_size,
98                    false,
99                );
100
101                // Decode all items
102                let drained_task = primitive_wrapper.drain(copy_size)?;
103                let items_decode_task = drained_task.task;
104                let decoded_dict = items_decode_task.decode()?;
105
106                let indices_decoder: Box<dyn PrimitivePageDecoder> = indices_page_decoder.await?;
107
108                Ok(Box::new(DictionaryPageDecoder {
109                    decoded_dict,
110                    indices_decoder,
111                }) as Box<dyn PrimitivePageDecoder>)
112            })
113            .map(|join_handle| join_handle.unwrap())
114            .boxed()
115        } else {
116            let num_dictionary_items = self.num_dictionary_items;
117            tokio::spawn(async move {
118                let items_decoder: Arc<dyn PrimitivePageDecoder> =
119                    Arc::from(items_page_decoder.await?);
120
121                let decoded_dict = items_decoder
122                    .decode(0, num_dictionary_items as u64)?
123                    .clone();
124
125                let indices_decoder = indices_page_decoder.await?;
126
127                Ok(Box::new(DirectDictionaryPageDecoder {
128                    decoded_dict,
129                    indices_decoder,
130                }) as Box<dyn PrimitivePageDecoder>)
131            })
132            .map(|join_handle| join_handle.unwrap())
133            .boxed()
134        }
135    }
136}
137
138struct DirectDictionaryPageDecoder {
139    decoded_dict: DataBlock,
140    indices_decoder: Box<dyn PrimitivePageDecoder>,
141}
142
143impl PrimitivePageDecoder for DirectDictionaryPageDecoder {
144    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
145        let indices = self
146            .indices_decoder
147            .decode(rows_to_skip, num_rows)?
148            .as_fixed_width()
149            .unwrap();
150        let dict = self.decoded_dict.clone();
151        Ok(DataBlock::Dictionary(DictionaryDataBlock {
152            indices,
153            dictionary: Box::new(dict),
154        }))
155    }
156}
157
158struct DictionaryPageDecoder {
159    decoded_dict: Arc<dyn Array>,
160    indices_decoder: Box<dyn PrimitivePageDecoder>,
161}
162
163impl PrimitivePageDecoder for DictionaryPageDecoder {
164    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
165        // Decode the indices
166        let indices_data = self.indices_decoder.decode(rows_to_skip, num_rows)?;
167
168        let indices_array = make_array(indices_data.into_arrow(DataType::UInt8, false)?);
169        let indices_array = indices_array.as_primitive::<UInt8Type>();
170
171        let dictionary = self.decoded_dict.clone();
172
173        let adjusted_indices: UInt8Array = indices_array
174            .iter()
175            .map(|x| match x {
176                Some(0) => None,
177                Some(x) => Some(x - 1),
178                None => None,
179            })
180            .collect();
181
182        // Build dictionary array using indices and items
183        let dict_array =
184            DictionaryArray::<UInt8Type>::try_new(adjusted_indices, dictionary).unwrap();
185        let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap();
186        let string_array = string_array.as_any().downcast_ref::<StringArray>().unwrap();
187
188        let null_buffer = string_array.nulls().map(|n| n.buffer().clone());
189        let offsets_buffer = string_array.offsets().inner().inner().clone();
190        let bytes_buffer = string_array.values().clone();
191
192        let string_data = DataBlock::VariableWidth(VariableWidthBlock {
193            bits_per_offset: 32,
194            data: LanceBuffer::from(bytes_buffer),
195            offsets: LanceBuffer::from(offsets_buffer),
196            num_values: num_rows,
197            block_info: BlockInfo::new(),
198        });
199        if let Some(nulls) = null_buffer {
200            Ok(DataBlock::Nullable(NullableDataBlock {
201                data: Box::new(string_data),
202                nulls: LanceBuffer::from(nulls),
203                block_info: BlockInfo::new(),
204            }))
205        } else {
206            Ok(string_data)
207        }
208    }
209}
210
211/// An encoder for data that is already dictionary encoded.  Stores the
212/// data as a dictionary encoding.
213#[derive(Debug)]
214pub struct AlreadyDictionaryEncoder {
215    indices_encoder: Box<dyn ArrayEncoder>,
216    items_encoder: Box<dyn ArrayEncoder>,
217}
218
219impl AlreadyDictionaryEncoder {
220    pub fn new(
221        indices_encoder: Box<dyn ArrayEncoder>,
222        items_encoder: Box<dyn ArrayEncoder>,
223    ) -> Self {
224        Self {
225            indices_encoder,
226            items_encoder,
227        }
228    }
229}
230
231impl ArrayEncoder for AlreadyDictionaryEncoder {
232    fn encode(
233        &self,
234        data: DataBlock,
235        data_type: &DataType,
236        buffer_index: &mut u32,
237    ) -> Result<EncodedArray> {
238        let DataType::Dictionary(key_type, value_type) = data_type else {
239            panic!("Expected dictionary type");
240        };
241
242        let dict_data = match data {
243            DataBlock::Dictionary(dict_data) => dict_data,
244            DataBlock::AllNull(all_null) => {
245                // In 2.1 this won't happen, kind of annoying to materialize a bunch of nulls
246                let indices = UInt8Array::from(vec![0; all_null.num_values as usize]);
247                let indices = arrow_cast::cast(&indices, key_type.as_ref()).unwrap();
248                let indices = indices.into_data();
249                let values = new_null_array(value_type, 1);
250                DictionaryDataBlock {
251                    indices: FixedWidthDataBlock {
252                        bits_per_value: key_type.byte_width() as u64 * 8,
253                        data: LanceBuffer::from(indices.buffers()[0].clone()),
254                        num_values: all_null.num_values,
255                        block_info: BlockInfo::new(),
256                    },
257                    dictionary: Box::new(DataBlock::from_array(values)),
258                }
259            }
260            _ => panic!("Expected dictionary data"),
261        };
262        let num_dictionary_items = dict_data.dictionary.num_values() as u32;
263
264        let encoded_indices = self.indices_encoder.encode(
265            DataBlock::FixedWidth(dict_data.indices),
266            key_type,
267            buffer_index,
268        )?;
269        let encoded_items =
270            self.items_encoder
271                .encode(*dict_data.dictionary, value_type, buffer_index)?;
272
273        let encoded = DataBlock::Dictionary(DictionaryDataBlock {
274            dictionary: Box::new(encoded_items.data),
275            indices: encoded_indices.data.as_fixed_width().unwrap(),
276        });
277
278        let encoding = ProtobufUtils::dict_encoding(
279            encoded_indices.encoding,
280            encoded_items.encoding,
281            num_dictionary_items,
282        );
283
284        Ok(EncodedArray {
285            data: encoded,
286            encoding,
287        })
288    }
289}
290
291#[derive(Debug)]
292pub struct DictionaryEncoder {
293    indices_encoder: Box<dyn ArrayEncoder>,
294    items_encoder: Box<dyn ArrayEncoder>,
295}
296
297impl DictionaryEncoder {
298    pub fn new(
299        indices_encoder: Box<dyn ArrayEncoder>,
300        items_encoder: Box<dyn ArrayEncoder>,
301    ) -> Self {
302        Self {
303            indices_encoder,
304            items_encoder,
305        }
306    }
307}
308
309fn encode_dict_indices_and_items(string_array: &StringArray) -> (ArrayRef, ArrayRef) {
310    let mut arr_hashmap: HashMap<&str, u8> = HashMap::new();
311    // We start with a dict index of 1 because the value 0 is reserved for nulls
312    // The dict indices are adjusted by subtracting 1 later during decode
313    let mut curr_dict_index = 1;
314    let total_capacity = string_array.len();
315
316    let mut dict_indices = Vec::with_capacity(total_capacity);
317    let mut dict_builder = StringBuilder::new();
318
319    for i in 0..string_array.len() {
320        if !string_array.is_valid(i) {
321            // null value
322            dict_indices.push(0);
323            continue;
324        }
325
326        let st = string_array.value(i);
327
328        let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index);
329        dict_indices.push(hashmap_entry);
330
331        // if item didn't exist in the hashmap, add it to the dictionary
332        // and increment the dictionary index
333        if hashmap_entry == curr_dict_index {
334            dict_builder.append_value(st);
335            curr_dict_index += 1;
336        }
337    }
338
339    let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef;
340
341    // If there is an empty dictionary:
342    // Either there is an array of nulls or an empty array altogether
343    // In this case create the dictionary with a single null element
344    // Because decoding [] is not currently supported by the binary decoder
345    if dict_builder.is_empty() {
346        dict_builder.append_option(Option::<&str>::None);
347    }
348
349    let dict_elements = dict_builder.finish();
350    let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap();
351
352    (array_dict_indices, array_dict_elements)
353}
354
355impl ArrayEncoder for DictionaryEncoder {
356    fn encode(
357        &self,
358        data: DataBlock,
359        data_type: &DataType,
360        buffer_index: &mut u32,
361    ) -> Result<EncodedArray> {
362        if !matches!(data_type, DataType::Utf8) {
363            return Err(Error::invalid_input_source(
364                format!(
365                    "DictionaryEncoder only supports string arrays but got {}",
366                    data_type
367                )
368                .into(),
369            ));
370        }
371        // We only support string arrays for now
372        let str_data = make_array(data.into_arrow(DataType::Utf8, false)?);
373
374        let (index_array, items_array) = encode_dict_indices_and_items(str_data.as_string());
375        let dict_size = items_array.len() as u32;
376        let index_data = DataBlock::from(index_array);
377        let items_data = DataBlock::from(items_array);
378
379        let encoded_indices =
380            self.indices_encoder
381                .encode(index_data, &DataType::UInt8, buffer_index)?;
382
383        let encoded_items = self
384            .items_encoder
385            .encode(items_data, &DataType::Utf8, buffer_index)?;
386
387        let encoded_data = DataBlock::Dictionary(DictionaryDataBlock {
388            indices: encoded_indices.data.as_fixed_width().unwrap(),
389            dictionary: Box::new(encoded_items.data),
390        });
391
392        let encoding = ProtobufUtils::dict_encoding(
393            encoded_indices.encoding,
394            encoded_items.encoding,
395            dict_size,
396        );
397
398        Ok(EncodedArray {
399            data: encoded_data,
400            encoding,
401        })
402    }
403}
404
405#[cfg(test)]
406pub mod tests {
407
408    use arrow_array::{
409        ArrayRef, DictionaryArray, StringArray, UInt8Array,
410        builder::{LargeStringBuilder, StringBuilder},
411    };
412    use arrow_schema::{DataType, Field};
413    use std::{collections::HashMap, sync::Arc, vec};
414
415    use crate::testing::{TestCases, check_basic_random, check_round_trip_encoding_of_data};
416
417    use super::encode_dict_indices_and_items;
418
419    // These tests cover the case where we opportunistically convert some (or all) pages of
420    // a string column into dictionaries (and decode on read)
421
422    #[test]
423    fn test_encode_dict_nulls() {
424        // Null entries in string arrays should be adjusted
425        let string_array = Arc::new(StringArray::from(vec![
426            None,
427            Some("foo"),
428            Some("bar"),
429            Some("bar"),
430            None,
431            Some("foo"),
432            None,
433            None,
434        ]));
435        let (dict_indices, dict_items) = encode_dict_indices_and_items(&string_array);
436
437        let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef;
438        let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef;
439        assert_eq!(&dict_indices, &expected_indices);
440        assert_eq!(&dict_items, &expected_items);
441    }
442
443    #[test_log::test(tokio::test)]
444    async fn test_utf8() {
445        let field = Field::new("", DataType::Utf8, false);
446        check_basic_random(field).await;
447    }
448
449    #[test_log::test(tokio::test)]
450    async fn test_binary() {
451        let field = Field::new("", DataType::Binary, false);
452        check_basic_random(field).await;
453    }
454
455    #[test_log::test(tokio::test)]
456    async fn test_large_binary() {
457        let field = Field::new("", DataType::LargeBinary, true);
458        check_basic_random(field).await;
459    }
460
461    #[test_log::test(tokio::test)]
462    async fn test_large_utf8() {
463        let field = Field::new("", DataType::LargeUtf8, true);
464        check_basic_random(field).await;
465    }
466
467    #[test_log::test(tokio::test)]
468    async fn test_simple_utf8() {
469        let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
470
471        let test_cases = TestCases::default()
472            .with_range(0..2)
473            .with_range(0..3)
474            .with_range(1..3)
475            .with_indices(vec![1, 3]);
476        check_round_trip_encoding_of_data(
477            vec![Arc::new(string_array)],
478            &test_cases,
479            HashMap::new(),
480        )
481        .await;
482    }
483
484    #[test_log::test(tokio::test)]
485    async fn test_sliced_utf8() {
486        let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
487        let string_array = string_array.slice(1, 3);
488
489        let test_cases = TestCases::default()
490            .with_range(0..1)
491            .with_range(0..2)
492            .with_range(1..2);
493        check_round_trip_encoding_of_data(
494            vec![Arc::new(string_array)],
495            &test_cases,
496            HashMap::new(),
497        )
498        .await;
499    }
500
501    #[test_log::test(tokio::test)]
502    async fn test_empty_strings() {
503        // Scenario 1: Some strings are empty
504
505        let values = [Some("abc"), Some(""), None];
506        // Test empty list at beginning, middle, and end
507        for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] {
508            let mut string_builder = StringBuilder::new();
509            for idx in order {
510                string_builder.append_option(values[idx]);
511            }
512            let string_array = Arc::new(string_builder.finish());
513            let test_cases = TestCases::default()
514                .with_indices(vec![1])
515                .with_indices(vec![0])
516                .with_indices(vec![2]);
517            check_round_trip_encoding_of_data(
518                vec![string_array.clone()],
519                &test_cases,
520                HashMap::new(),
521            )
522            .await;
523            let test_cases = test_cases.with_batch_size(1);
524            check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new())
525                .await;
526        }
527
528        // Scenario 2: All strings are empty
529
530        // When encoding an array of empty strings there are no bytes to encode
531        // which is strange and we want to ensure we handle it
532        let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")]));
533
534        let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]);
535        check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases, HashMap::new())
536            .await;
537        let test_cases = test_cases.with_batch_size(1);
538        check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new()).await;
539    }
540
541    #[test_log::test(tokio::test)]
542    #[ignore] // This test is quite slow in debug mode
543    async fn test_jumbo_string() {
544        // This is an overflow test.  We have a list of lists where each list
545        // has 1Mi items.  We encode 5000 of these lists and so we have over 4Gi in the
546        // offsets range
547        let mut string_builder = LargeStringBuilder::new();
548        // a 1 MiB string
549        let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0'));
550        for _ in 0..5000 {
551            string_builder.append_option(Some(&giant_string));
552        }
553        let giant_array = Arc::new(string_builder.finish()) as ArrayRef;
554        let arrs = vec![giant_array];
555
556        // // We can't validate because our validation relies on concatenating all input arrays
557        let test_cases = TestCases::default().without_validation();
558        check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await;
559    }
560
561    // These tests cover the case where the input is already dictionary encoded
562
563    #[test_log::test(tokio::test)]
564    async fn test_random_dictionary_input() {
565        let dict_field = Field::new(
566            "",
567            DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
568            false,
569        );
570        check_basic_random(dict_field).await;
571    }
572
573    #[test_log::test(tokio::test)]
574    async fn test_simple_already_dictionary() {
575        let values = StringArray::from_iter_values(["a", "bb", "ccc"]);
576        let indices = UInt8Array::from(vec![0, 1, 2, 0, 1, 2, 0, 1, 2]);
577        let dict_array = DictionaryArray::new(indices, Arc::new(values));
578
579        let test_cases = TestCases::default()
580            .with_range(0..2)
581            .with_range(1..3)
582            .with_range(2..4)
583            .with_indices(vec![1])
584            .with_indices(vec![2]);
585        check_round_trip_encoding_of_data(vec![Arc::new(dict_array)], &test_cases, HashMap::new())
586            .await;
587    }
588}