lance_encoding/previous/encodings/physical/
dictionary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::sync::Arc;
5use std::vec;
6
7use arrow_array::builder::{ArrayBuilder, StringBuilder};
8use arrow_array::cast::AsArray;
9use arrow_array::types::UInt8Type;
10use arrow_array::{
11    make_array, new_null_array, Array, ArrayRef, DictionaryArray, StringArray, UInt8Array,
12};
13use arrow_schema::DataType;
14use futures::{future::BoxFuture, FutureExt};
15use lance_arrow::DataTypeExt;
16use lance_core::{Error, Result};
17use snafu::location;
18use std::collections::HashMap;
19
20use crate::buffer::LanceBuffer;
21use crate::data::{
22    BlockInfo, DataBlock, DictionaryDataBlock, FixedWidthDataBlock, NullableDataBlock,
23    VariableWidthBlock,
24};
25use crate::format::ProtobufUtils;
26use crate::previous::decoder::LogicalPageDecoder;
27use crate::previous::encodings::logical::primitive::PrimitiveFieldDecoder;
28use crate::{
29    decoder::{PageScheduler, PrimitivePageDecoder},
30    previous::encoder::{ArrayEncoder, EncodedArray},
31    EncodingsIo,
32};
33
34#[derive(Debug)]
35pub struct DictionaryPageScheduler {
36    indices_scheduler: Arc<dyn PageScheduler>,
37    items_scheduler: Arc<dyn PageScheduler>,
38    // The number of items in the dictionary
39    num_dictionary_items: u32,
40    // If true, decode the dictionary items.  If false, leave them dictionary encoded (e.g. the
41    // output type is probably a dictionary type)
42    should_decode_dict: bool,
43}
44
45impl DictionaryPageScheduler {
46    pub fn new(
47        indices_scheduler: Arc<dyn PageScheduler>,
48        items_scheduler: Arc<dyn PageScheduler>,
49        num_dictionary_items: u32,
50        should_decode_dict: bool,
51    ) -> Self {
52        Self {
53            indices_scheduler,
54            items_scheduler,
55            num_dictionary_items,
56            should_decode_dict,
57        }
58    }
59}
60
61impl PageScheduler for DictionaryPageScheduler {
62    fn schedule_ranges(
63        &self,
64        ranges: &[std::ops::Range<u64>],
65        scheduler: &Arc<dyn EncodingsIo>,
66        top_level_row: u64,
67    ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
68        // We want to decode indices and items
69        // e.g. indices [0, 1, 2, 0, 1, 0]
70        // items (dictionary) ["abcd", "hello", "apple"]
71        // This will map to ["abcd", "hello", "apple", "abcd", "hello", "abcd"]
72        // We decode all the items during scheduling itself
73        // These are used to rebuild the string later
74
75        // Schedule indices for decoding
76        let indices_page_decoder =
77            self.indices_scheduler
78                .schedule_ranges(ranges, scheduler, top_level_row);
79
80        // Schedule items for decoding
81        let items_range = 0..(self.num_dictionary_items as u64);
82        let items_page_decoder = self.items_scheduler.schedule_ranges(
83            std::slice::from_ref(&items_range),
84            scheduler,
85            top_level_row,
86        );
87
88        let copy_size = self.num_dictionary_items as u64;
89
90        if self.should_decode_dict {
91            tokio::spawn(async move {
92                let items_decoder: Arc<dyn PrimitivePageDecoder> =
93                    Arc::from(items_page_decoder.await?);
94
95                let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data(
96                    items_decoder.clone(),
97                    DataType::Utf8,
98                    copy_size,
99                    false,
100                );
101
102                // Decode all items
103                let drained_task = primitive_wrapper.drain(copy_size)?;
104                let items_decode_task = drained_task.task;
105                let decoded_dict = items_decode_task.decode()?;
106
107                let indices_decoder: Box<dyn PrimitivePageDecoder> = indices_page_decoder.await?;
108
109                Ok(Box::new(DictionaryPageDecoder {
110                    decoded_dict,
111                    indices_decoder,
112                }) as Box<dyn PrimitivePageDecoder>)
113            })
114            .map(|join_handle| join_handle.unwrap())
115            .boxed()
116        } else {
117            let num_dictionary_items = self.num_dictionary_items;
118            tokio::spawn(async move {
119                let items_decoder: Arc<dyn PrimitivePageDecoder> =
120                    Arc::from(items_page_decoder.await?);
121
122                let decoded_dict = items_decoder
123                    .decode(0, num_dictionary_items as u64)?
124                    .clone();
125
126                let indices_decoder = indices_page_decoder.await?;
127
128                Ok(Box::new(DirectDictionaryPageDecoder {
129                    decoded_dict,
130                    indices_decoder,
131                }) as Box<dyn PrimitivePageDecoder>)
132            })
133            .map(|join_handle| join_handle.unwrap())
134            .boxed()
135        }
136    }
137}
138
139struct DirectDictionaryPageDecoder {
140    decoded_dict: DataBlock,
141    indices_decoder: Box<dyn PrimitivePageDecoder>,
142}
143
144impl PrimitivePageDecoder for DirectDictionaryPageDecoder {
145    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
146        let indices = self
147            .indices_decoder
148            .decode(rows_to_skip, num_rows)?
149            .as_fixed_width()
150            .unwrap();
151        let dict = self.decoded_dict.clone();
152        Ok(DataBlock::Dictionary(DictionaryDataBlock {
153            indices,
154            dictionary: Box::new(dict),
155        }))
156    }
157}
158
159struct DictionaryPageDecoder {
160    decoded_dict: Arc<dyn Array>,
161    indices_decoder: Box<dyn PrimitivePageDecoder>,
162}
163
164impl PrimitivePageDecoder for DictionaryPageDecoder {
165    fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
166        // Decode the indices
167        let indices_data = self.indices_decoder.decode(rows_to_skip, num_rows)?;
168
169        let indices_array = make_array(indices_data.into_arrow(DataType::UInt8, false)?);
170        let indices_array = indices_array.as_primitive::<UInt8Type>();
171
172        let dictionary = self.decoded_dict.clone();
173
174        let adjusted_indices: UInt8Array = indices_array
175            .iter()
176            .map(|x| match x {
177                Some(0) => None,
178                Some(x) => Some(x - 1),
179                None => None,
180            })
181            .collect();
182
183        // Build dictionary array using indices and items
184        let dict_array =
185            DictionaryArray::<UInt8Type>::try_new(adjusted_indices, dictionary).unwrap();
186        let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap();
187        let string_array = string_array.as_any().downcast_ref::<StringArray>().unwrap();
188
189        let null_buffer = string_array.nulls().map(|n| n.buffer().clone());
190        let offsets_buffer = string_array.offsets().inner().inner().clone();
191        let bytes_buffer = string_array.values().clone();
192
193        let string_data = DataBlock::VariableWidth(VariableWidthBlock {
194            bits_per_offset: 32,
195            data: LanceBuffer::from(bytes_buffer),
196            offsets: LanceBuffer::from(offsets_buffer),
197            num_values: num_rows,
198            block_info: BlockInfo::new(),
199        });
200        if let Some(nulls) = null_buffer {
201            Ok(DataBlock::Nullable(NullableDataBlock {
202                data: Box::new(string_data),
203                nulls: LanceBuffer::from(nulls),
204                block_info: BlockInfo::new(),
205            }))
206        } else {
207            Ok(string_data)
208        }
209    }
210}
211
212/// An encoder for data that is already dictionary encoded.  Stores the
213/// data as a dictionary encoding.
214#[derive(Debug)]
215pub struct AlreadyDictionaryEncoder {
216    indices_encoder: Box<dyn ArrayEncoder>,
217    items_encoder: Box<dyn ArrayEncoder>,
218}
219
220impl AlreadyDictionaryEncoder {
221    pub fn new(
222        indices_encoder: Box<dyn ArrayEncoder>,
223        items_encoder: Box<dyn ArrayEncoder>,
224    ) -> Self {
225        Self {
226            indices_encoder,
227            items_encoder,
228        }
229    }
230}
231
232impl ArrayEncoder for AlreadyDictionaryEncoder {
233    fn encode(
234        &self,
235        data: DataBlock,
236        data_type: &DataType,
237        buffer_index: &mut u32,
238    ) -> Result<EncodedArray> {
239        let DataType::Dictionary(key_type, value_type) = data_type else {
240            panic!("Expected dictionary type");
241        };
242
243        let dict_data = match data {
244            DataBlock::Dictionary(dict_data) => dict_data,
245            DataBlock::AllNull(all_null) => {
246                // In 2.1 this won't happen, kind of annoying to materialize a bunch of nulls
247                let indices = UInt8Array::from(vec![0; all_null.num_values as usize]);
248                let indices = arrow_cast::cast(&indices, key_type.as_ref()).unwrap();
249                let indices = indices.into_data();
250                let values = new_null_array(value_type, 1);
251                DictionaryDataBlock {
252                    indices: FixedWidthDataBlock {
253                        bits_per_value: key_type.byte_width() as u64 * 8,
254                        data: LanceBuffer::from(indices.buffers()[0].clone()),
255                        num_values: all_null.num_values,
256                        block_info: BlockInfo::new(),
257                    },
258                    dictionary: Box::new(DataBlock::from_array(values)),
259                }
260            }
261            _ => panic!("Expected dictionary data"),
262        };
263        let num_dictionary_items = dict_data.dictionary.num_values() as u32;
264
265        let encoded_indices = self.indices_encoder.encode(
266            DataBlock::FixedWidth(dict_data.indices),
267            key_type,
268            buffer_index,
269        )?;
270        let encoded_items =
271            self.items_encoder
272                .encode(*dict_data.dictionary, value_type, buffer_index)?;
273
274        let encoded = DataBlock::Dictionary(DictionaryDataBlock {
275            dictionary: Box::new(encoded_items.data),
276            indices: encoded_indices.data.as_fixed_width().unwrap(),
277        });
278
279        let encoding = ProtobufUtils::dict_encoding(
280            encoded_indices.encoding,
281            encoded_items.encoding,
282            num_dictionary_items,
283        );
284
285        Ok(EncodedArray {
286            data: encoded,
287            encoding,
288        })
289    }
290}
291
292#[derive(Debug)]
293pub struct DictionaryEncoder {
294    indices_encoder: Box<dyn ArrayEncoder>,
295    items_encoder: Box<dyn ArrayEncoder>,
296}
297
298impl DictionaryEncoder {
299    pub fn new(
300        indices_encoder: Box<dyn ArrayEncoder>,
301        items_encoder: Box<dyn ArrayEncoder>,
302    ) -> Self {
303        Self {
304            indices_encoder,
305            items_encoder,
306        }
307    }
308}
309
310fn encode_dict_indices_and_items(string_array: &StringArray) -> (ArrayRef, ArrayRef) {
311    let mut arr_hashmap: HashMap<&str, u8> = HashMap::new();
312    // We start with a dict index of 1 because the value 0 is reserved for nulls
313    // The dict indices are adjusted by subtracting 1 later during decode
314    let mut curr_dict_index = 1;
315    let total_capacity = string_array.len();
316
317    let mut dict_indices = Vec::with_capacity(total_capacity);
318    let mut dict_builder = StringBuilder::new();
319
320    for i in 0..string_array.len() {
321        if !string_array.is_valid(i) {
322            // null value
323            dict_indices.push(0);
324            continue;
325        }
326
327        let st = string_array.value(i);
328
329        let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index);
330        dict_indices.push(hashmap_entry);
331
332        // if item didn't exist in the hashmap, add it to the dictionary
333        // and increment the dictionary index
334        if hashmap_entry == curr_dict_index {
335            dict_builder.append_value(st);
336            curr_dict_index += 1;
337        }
338    }
339
340    let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef;
341
342    // If there is an empty dictionary:
343    // Either there is an array of nulls or an empty array altogether
344    // In this case create the dictionary with a single null element
345    // Because decoding [] is not currently supported by the binary decoder
346    if dict_builder.is_empty() {
347        dict_builder.append_option(Option::<&str>::None);
348    }
349
350    let dict_elements = dict_builder.finish();
351    let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap();
352
353    (array_dict_indices, array_dict_elements)
354}
355
356impl ArrayEncoder for DictionaryEncoder {
357    fn encode(
358        &self,
359        data: DataBlock,
360        data_type: &DataType,
361        buffer_index: &mut u32,
362    ) -> Result<EncodedArray> {
363        if !matches!(data_type, DataType::Utf8) {
364            return Err(Error::InvalidInput {
365                source: format!(
366                    "DictionaryEncoder only supports string arrays but got {}",
367                    data_type
368                )
369                .into(),
370                location: location!(),
371            });
372        }
373        // We only support string arrays for now
374        let str_data = make_array(data.into_arrow(DataType::Utf8, false)?);
375
376        let (index_array, items_array) = encode_dict_indices_and_items(str_data.as_string());
377        let dict_size = items_array.len() as u32;
378        let index_data = DataBlock::from(index_array);
379        let items_data = DataBlock::from(items_array);
380
381        let encoded_indices =
382            self.indices_encoder
383                .encode(index_data, &DataType::UInt8, buffer_index)?;
384
385        let encoded_items = self
386            .items_encoder
387            .encode(items_data, &DataType::Utf8, buffer_index)?;
388
389        let encoded_data = DataBlock::Dictionary(DictionaryDataBlock {
390            indices: encoded_indices.data.as_fixed_width().unwrap(),
391            dictionary: Box::new(encoded_items.data),
392        });
393
394        let encoding = ProtobufUtils::dict_encoding(
395            encoded_indices.encoding,
396            encoded_items.encoding,
397            dict_size,
398        );
399
400        Ok(EncodedArray {
401            data: encoded_data,
402            encoding,
403        })
404    }
405}
406
407#[cfg(test)]
408pub mod tests {
409
410    use arrow_array::{
411        builder::{LargeStringBuilder, StringBuilder},
412        ArrayRef, DictionaryArray, StringArray, UInt8Array,
413    };
414    use arrow_schema::{DataType, Field};
415    use std::{collections::HashMap, sync::Arc, vec};
416
417    use crate::testing::{check_basic_random, check_round_trip_encoding_of_data, TestCases};
418
419    use super::encode_dict_indices_and_items;
420
421    // These tests cover the case where we opportunistically convert some (or all) pages of
422    // a string column into dictionaries (and decode on read)
423
424    #[test]
425    fn test_encode_dict_nulls() {
426        // Null entries in string arrays should be adjusted
427        let string_array = Arc::new(StringArray::from(vec![
428            None,
429            Some("foo"),
430            Some("bar"),
431            Some("bar"),
432            None,
433            Some("foo"),
434            None,
435            None,
436        ]));
437        let (dict_indices, dict_items) = encode_dict_indices_and_items(&string_array);
438
439        let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef;
440        let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef;
441        assert_eq!(&dict_indices, &expected_indices);
442        assert_eq!(&dict_items, &expected_items);
443    }
444
445    #[test_log::test(tokio::test)]
446    async fn test_utf8() {
447        let field = Field::new("", DataType::Utf8, false);
448        check_basic_random(field).await;
449    }
450
451    #[test_log::test(tokio::test)]
452    async fn test_binary() {
453        let field = Field::new("", DataType::Binary, false);
454        check_basic_random(field).await;
455    }
456
457    #[test_log::test(tokio::test)]
458    async fn test_large_binary() {
459        let field = Field::new("", DataType::LargeBinary, true);
460        check_basic_random(field).await;
461    }
462
463    #[test_log::test(tokio::test)]
464    async fn test_large_utf8() {
465        let field = Field::new("", DataType::LargeUtf8, true);
466        check_basic_random(field).await;
467    }
468
469    #[test_log::test(tokio::test)]
470    async fn test_simple_utf8() {
471        let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
472
473        let test_cases = TestCases::default()
474            .with_range(0..2)
475            .with_range(0..3)
476            .with_range(1..3)
477            .with_indices(vec![1, 3]);
478        check_round_trip_encoding_of_data(
479            vec![Arc::new(string_array)],
480            &test_cases,
481            HashMap::new(),
482        )
483        .await;
484    }
485
486    #[test_log::test(tokio::test)]
487    async fn test_sliced_utf8() {
488        let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
489        let string_array = string_array.slice(1, 3);
490
491        let test_cases = TestCases::default()
492            .with_range(0..1)
493            .with_range(0..2)
494            .with_range(1..2);
495        check_round_trip_encoding_of_data(
496            vec![Arc::new(string_array)],
497            &test_cases,
498            HashMap::new(),
499        )
500        .await;
501    }
502
503    #[test_log::test(tokio::test)]
504    async fn test_empty_strings() {
505        // Scenario 1: Some strings are empty
506
507        let values = [Some("abc"), Some(""), None];
508        // Test empty list at beginning, middle, and end
509        for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] {
510            let mut string_builder = StringBuilder::new();
511            for idx in order {
512                string_builder.append_option(values[idx]);
513            }
514            let string_array = Arc::new(string_builder.finish());
515            let test_cases = TestCases::default()
516                .with_indices(vec![1])
517                .with_indices(vec![0])
518                .with_indices(vec![2]);
519            check_round_trip_encoding_of_data(
520                vec![string_array.clone()],
521                &test_cases,
522                HashMap::new(),
523            )
524            .await;
525            let test_cases = test_cases.with_batch_size(1);
526            check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new())
527                .await;
528        }
529
530        // Scenario 2: All strings are empty
531
532        // When encoding an array of empty strings there are no bytes to encode
533        // which is strange and we want to ensure we handle it
534        let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")]));
535
536        let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]);
537        check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases, HashMap::new())
538            .await;
539        let test_cases = test_cases.with_batch_size(1);
540        check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new()).await;
541    }
542
543    #[test_log::test(tokio::test)]
544    #[ignore] // This test is quite slow in debug mode
545    async fn test_jumbo_string() {
546        // This is an overflow test.  We have a list of lists where each list
547        // has 1Mi items.  We encode 5000 of these lists and so we have over 4Gi in the
548        // offsets range
549        let mut string_builder = LargeStringBuilder::new();
550        // a 1 MiB string
551        let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0'));
552        for _ in 0..5000 {
553            string_builder.append_option(Some(&giant_string));
554        }
555        let giant_array = Arc::new(string_builder.finish()) as ArrayRef;
556        let arrs = vec![giant_array];
557
558        // // We can't validate because our validation relies on concatenating all input arrays
559        let test_cases = TestCases::default().without_validation();
560        check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await;
561    }
562
563    // These tests cover the case where the input is already dictionary encoded
564
565    #[test_log::test(tokio::test)]
566    async fn test_random_dictionary_input() {
567        let dict_field = Field::new(
568            "",
569            DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
570            false,
571        );
572        check_basic_random(dict_field).await;
573    }
574
575    #[test_log::test(tokio::test)]
576    async fn test_simple_already_dictionary() {
577        let values = StringArray::from_iter_values(["a", "bb", "ccc"]);
578        let indices = UInt8Array::from(vec![0, 1, 2, 0, 1, 2, 0, 1, 2]);
579        let dict_array = DictionaryArray::new(indices, Arc::new(values));
580
581        let test_cases = TestCases::default()
582            .with_range(0..2)
583            .with_range(1..3)
584            .with_range(2..4)
585            .with_indices(vec![1])
586            .with_indices(vec![2]);
587        check_round_trip_encoding_of_data(vec![Arc::new(dict_array)], &test_cases, HashMap::new())
588            .await;
589    }
590}