1use std::sync::Arc;
5use std::vec;
6
7use arrow_array::builder::{ArrayBuilder, StringBuilder};
8use arrow_array::cast::AsArray;
9use arrow_array::types::UInt8Type;
10use arrow_array::{
11 Array, ArrayRef, DictionaryArray, StringArray, UInt8Array, make_array, new_null_array,
12};
13use arrow_schema::DataType;
14use futures::{FutureExt, future::BoxFuture};
15use lance_arrow::DataTypeExt;
16use lance_core::{Error, Result};
17use std::collections::HashMap;
18
19use crate::buffer::LanceBuffer;
20use crate::data::{
21 BlockInfo, DataBlock, DictionaryDataBlock, FixedWidthDataBlock, NullableDataBlock,
22 VariableWidthBlock,
23};
24use crate::format::ProtobufUtils;
25use crate::previous::decoder::LogicalPageDecoder;
26use crate::previous::encodings::logical::primitive::PrimitiveFieldDecoder;
27use crate::{
28 EncodingsIo,
29 decoder::{PageScheduler, PrimitivePageDecoder},
30 previous::encoder::{ArrayEncoder, EncodedArray},
31};
32
33#[derive(Debug)]
34pub struct DictionaryPageScheduler {
35 indices_scheduler: Arc<dyn PageScheduler>,
36 items_scheduler: Arc<dyn PageScheduler>,
37 num_dictionary_items: u32,
39 should_decode_dict: bool,
42}
43
44impl DictionaryPageScheduler {
45 pub fn new(
46 indices_scheduler: Arc<dyn PageScheduler>,
47 items_scheduler: Arc<dyn PageScheduler>,
48 num_dictionary_items: u32,
49 should_decode_dict: bool,
50 ) -> Self {
51 Self {
52 indices_scheduler,
53 items_scheduler,
54 num_dictionary_items,
55 should_decode_dict,
56 }
57 }
58}
59
60impl PageScheduler for DictionaryPageScheduler {
61 fn schedule_ranges(
62 &self,
63 ranges: &[std::ops::Range<u64>],
64 scheduler: &Arc<dyn EncodingsIo>,
65 top_level_row: u64,
66 ) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
67 let indices_page_decoder =
76 self.indices_scheduler
77 .schedule_ranges(ranges, scheduler, top_level_row);
78
79 let items_range = 0..(self.num_dictionary_items as u64);
81 let items_page_decoder = self.items_scheduler.schedule_ranges(
82 std::slice::from_ref(&items_range),
83 scheduler,
84 top_level_row,
85 );
86
87 let copy_size = self.num_dictionary_items as u64;
88
89 if self.should_decode_dict {
90 tokio::spawn(async move {
91 let items_decoder: Arc<dyn PrimitivePageDecoder> =
92 Arc::from(items_page_decoder.await?);
93
94 let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data(
95 items_decoder.clone(),
96 DataType::Utf8,
97 copy_size,
98 false,
99 );
100
101 let drained_task = primitive_wrapper.drain(copy_size)?;
103 let items_decode_task = drained_task.task;
104 let decoded_dict = items_decode_task.decode()?;
105
106 let indices_decoder: Box<dyn PrimitivePageDecoder> = indices_page_decoder.await?;
107
108 Ok(Box::new(DictionaryPageDecoder {
109 decoded_dict,
110 indices_decoder,
111 }) as Box<dyn PrimitivePageDecoder>)
112 })
113 .map(|join_handle| join_handle.unwrap())
114 .boxed()
115 } else {
116 let num_dictionary_items = self.num_dictionary_items;
117 tokio::spawn(async move {
118 let items_decoder: Arc<dyn PrimitivePageDecoder> =
119 Arc::from(items_page_decoder.await?);
120
121 let decoded_dict = items_decoder
122 .decode(0, num_dictionary_items as u64)?
123 .clone();
124
125 let indices_decoder = indices_page_decoder.await?;
126
127 Ok(Box::new(DirectDictionaryPageDecoder {
128 decoded_dict,
129 indices_decoder,
130 }) as Box<dyn PrimitivePageDecoder>)
131 })
132 .map(|join_handle| join_handle.unwrap())
133 .boxed()
134 }
135 }
136}
137
138struct DirectDictionaryPageDecoder {
139 decoded_dict: DataBlock,
140 indices_decoder: Box<dyn PrimitivePageDecoder>,
141}
142
143impl PrimitivePageDecoder for DirectDictionaryPageDecoder {
144 fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
145 let indices = self
146 .indices_decoder
147 .decode(rows_to_skip, num_rows)?
148 .as_fixed_width()
149 .unwrap();
150 let dict = self.decoded_dict.clone();
151 Ok(DataBlock::Dictionary(DictionaryDataBlock {
152 indices,
153 dictionary: Box::new(dict),
154 }))
155 }
156}
157
158struct DictionaryPageDecoder {
159 decoded_dict: Arc<dyn Array>,
160 indices_decoder: Box<dyn PrimitivePageDecoder>,
161}
162
163impl PrimitivePageDecoder for DictionaryPageDecoder {
164 fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<DataBlock> {
165 let indices_data = self.indices_decoder.decode(rows_to_skip, num_rows)?;
167
168 let indices_array = make_array(indices_data.into_arrow(DataType::UInt8, false)?);
169 let indices_array = indices_array.as_primitive::<UInt8Type>();
170
171 let dictionary = self.decoded_dict.clone();
172
173 let adjusted_indices: UInt8Array = indices_array
174 .iter()
175 .map(|x| match x {
176 Some(0) => None,
177 Some(x) => Some(x - 1),
178 None => None,
179 })
180 .collect();
181
182 let dict_array =
184 DictionaryArray::<UInt8Type>::try_new(adjusted_indices, dictionary).unwrap();
185 let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap();
186 let string_array = string_array.as_any().downcast_ref::<StringArray>().unwrap();
187
188 let null_buffer = string_array.nulls().map(|n| n.buffer().clone());
189 let offsets_buffer = string_array.offsets().inner().inner().clone();
190 let bytes_buffer = string_array.values().clone();
191
192 let string_data = DataBlock::VariableWidth(VariableWidthBlock {
193 bits_per_offset: 32,
194 data: LanceBuffer::from(bytes_buffer),
195 offsets: LanceBuffer::from(offsets_buffer),
196 num_values: num_rows,
197 block_info: BlockInfo::new(),
198 });
199 if let Some(nulls) = null_buffer {
200 Ok(DataBlock::Nullable(NullableDataBlock {
201 data: Box::new(string_data),
202 nulls: LanceBuffer::from(nulls),
203 block_info: BlockInfo::new(),
204 }))
205 } else {
206 Ok(string_data)
207 }
208 }
209}
210
211#[derive(Debug)]
214pub struct AlreadyDictionaryEncoder {
215 indices_encoder: Box<dyn ArrayEncoder>,
216 items_encoder: Box<dyn ArrayEncoder>,
217}
218
219impl AlreadyDictionaryEncoder {
220 pub fn new(
221 indices_encoder: Box<dyn ArrayEncoder>,
222 items_encoder: Box<dyn ArrayEncoder>,
223 ) -> Self {
224 Self {
225 indices_encoder,
226 items_encoder,
227 }
228 }
229}
230
231impl ArrayEncoder for AlreadyDictionaryEncoder {
232 fn encode(
233 &self,
234 data: DataBlock,
235 data_type: &DataType,
236 buffer_index: &mut u32,
237 ) -> Result<EncodedArray> {
238 let DataType::Dictionary(key_type, value_type) = data_type else {
239 panic!("Expected dictionary type");
240 };
241
242 let dict_data = match data {
243 DataBlock::Dictionary(dict_data) => dict_data,
244 DataBlock::AllNull(all_null) => {
245 let indices = UInt8Array::from(vec![0; all_null.num_values as usize]);
247 let indices = arrow_cast::cast(&indices, key_type.as_ref()).unwrap();
248 let indices = indices.into_data();
249 let values = new_null_array(value_type, 1);
250 DictionaryDataBlock {
251 indices: FixedWidthDataBlock {
252 bits_per_value: key_type.byte_width() as u64 * 8,
253 data: LanceBuffer::from(indices.buffers()[0].clone()),
254 num_values: all_null.num_values,
255 block_info: BlockInfo::new(),
256 },
257 dictionary: Box::new(DataBlock::from_array(values)),
258 }
259 }
260 _ => panic!("Expected dictionary data"),
261 };
262 let num_dictionary_items = dict_data.dictionary.num_values() as u32;
263
264 let encoded_indices = self.indices_encoder.encode(
265 DataBlock::FixedWidth(dict_data.indices),
266 key_type,
267 buffer_index,
268 )?;
269 let encoded_items =
270 self.items_encoder
271 .encode(*dict_data.dictionary, value_type, buffer_index)?;
272
273 let encoded = DataBlock::Dictionary(DictionaryDataBlock {
274 dictionary: Box::new(encoded_items.data),
275 indices: encoded_indices.data.as_fixed_width().unwrap(),
276 });
277
278 let encoding = ProtobufUtils::dict_encoding(
279 encoded_indices.encoding,
280 encoded_items.encoding,
281 num_dictionary_items,
282 );
283
284 Ok(EncodedArray {
285 data: encoded,
286 encoding,
287 })
288 }
289}
290
291#[derive(Debug)]
292pub struct DictionaryEncoder {
293 indices_encoder: Box<dyn ArrayEncoder>,
294 items_encoder: Box<dyn ArrayEncoder>,
295}
296
297impl DictionaryEncoder {
298 pub fn new(
299 indices_encoder: Box<dyn ArrayEncoder>,
300 items_encoder: Box<dyn ArrayEncoder>,
301 ) -> Self {
302 Self {
303 indices_encoder,
304 items_encoder,
305 }
306 }
307}
308
309fn encode_dict_indices_and_items(string_array: &StringArray) -> (ArrayRef, ArrayRef) {
310 let mut arr_hashmap: HashMap<&str, u8> = HashMap::new();
311 let mut curr_dict_index = 1;
314 let total_capacity = string_array.len();
315
316 let mut dict_indices = Vec::with_capacity(total_capacity);
317 let mut dict_builder = StringBuilder::new();
318
319 for i in 0..string_array.len() {
320 if !string_array.is_valid(i) {
321 dict_indices.push(0);
323 continue;
324 }
325
326 let st = string_array.value(i);
327
328 let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index);
329 dict_indices.push(hashmap_entry);
330
331 if hashmap_entry == curr_dict_index {
334 dict_builder.append_value(st);
335 curr_dict_index += 1;
336 }
337 }
338
339 let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef;
340
341 if dict_builder.is_empty() {
346 dict_builder.append_option(Option::<&str>::None);
347 }
348
349 let dict_elements = dict_builder.finish();
350 let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap();
351
352 (array_dict_indices, array_dict_elements)
353}
354
355impl ArrayEncoder for DictionaryEncoder {
356 fn encode(
357 &self,
358 data: DataBlock,
359 data_type: &DataType,
360 buffer_index: &mut u32,
361 ) -> Result<EncodedArray> {
362 if !matches!(data_type, DataType::Utf8) {
363 return Err(Error::invalid_input_source(
364 format!(
365 "DictionaryEncoder only supports string arrays but got {}",
366 data_type
367 )
368 .into(),
369 ));
370 }
371 let str_data = make_array(data.into_arrow(DataType::Utf8, false)?);
373
374 let (index_array, items_array) = encode_dict_indices_and_items(str_data.as_string());
375 let dict_size = items_array.len() as u32;
376 let index_data = DataBlock::from(index_array);
377 let items_data = DataBlock::from(items_array);
378
379 let encoded_indices =
380 self.indices_encoder
381 .encode(index_data, &DataType::UInt8, buffer_index)?;
382
383 let encoded_items = self
384 .items_encoder
385 .encode(items_data, &DataType::Utf8, buffer_index)?;
386
387 let encoded_data = DataBlock::Dictionary(DictionaryDataBlock {
388 indices: encoded_indices.data.as_fixed_width().unwrap(),
389 dictionary: Box::new(encoded_items.data),
390 });
391
392 let encoding = ProtobufUtils::dict_encoding(
393 encoded_indices.encoding,
394 encoded_items.encoding,
395 dict_size,
396 );
397
398 Ok(EncodedArray {
399 data: encoded_data,
400 encoding,
401 })
402 }
403}
404
405#[cfg(test)]
406pub mod tests {
407
408 use arrow_array::{
409 ArrayRef, DictionaryArray, StringArray, UInt8Array,
410 builder::{LargeStringBuilder, StringBuilder},
411 };
412 use arrow_schema::{DataType, Field};
413 use std::{collections::HashMap, sync::Arc, vec};
414
415 use crate::testing::{TestCases, check_basic_random, check_round_trip_encoding_of_data};
416
417 use super::encode_dict_indices_and_items;
418
419 #[test]
423 fn test_encode_dict_nulls() {
424 let string_array = Arc::new(StringArray::from(vec![
426 None,
427 Some("foo"),
428 Some("bar"),
429 Some("bar"),
430 None,
431 Some("foo"),
432 None,
433 None,
434 ]));
435 let (dict_indices, dict_items) = encode_dict_indices_and_items(&string_array);
436
437 let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef;
438 let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef;
439 assert_eq!(&dict_indices, &expected_indices);
440 assert_eq!(&dict_items, &expected_items);
441 }
442
443 #[test_log::test(tokio::test)]
444 async fn test_utf8() {
445 let field = Field::new("", DataType::Utf8, false);
446 check_basic_random(field).await;
447 }
448
449 #[test_log::test(tokio::test)]
450 async fn test_binary() {
451 let field = Field::new("", DataType::Binary, false);
452 check_basic_random(field).await;
453 }
454
455 #[test_log::test(tokio::test)]
456 async fn test_large_binary() {
457 let field = Field::new("", DataType::LargeBinary, true);
458 check_basic_random(field).await;
459 }
460
461 #[test_log::test(tokio::test)]
462 async fn test_large_utf8() {
463 let field = Field::new("", DataType::LargeUtf8, true);
464 check_basic_random(field).await;
465 }
466
467 #[test_log::test(tokio::test)]
468 async fn test_simple_utf8() {
469 let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
470
471 let test_cases = TestCases::default()
472 .with_range(0..2)
473 .with_range(0..3)
474 .with_range(1..3)
475 .with_indices(vec![1, 3]);
476 check_round_trip_encoding_of_data(
477 vec![Arc::new(string_array)],
478 &test_cases,
479 HashMap::new(),
480 )
481 .await;
482 }
483
484 #[test_log::test(tokio::test)]
485 async fn test_sliced_utf8() {
486 let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
487 let string_array = string_array.slice(1, 3);
488
489 let test_cases = TestCases::default()
490 .with_range(0..1)
491 .with_range(0..2)
492 .with_range(1..2);
493 check_round_trip_encoding_of_data(
494 vec![Arc::new(string_array)],
495 &test_cases,
496 HashMap::new(),
497 )
498 .await;
499 }
500
501 #[test_log::test(tokio::test)]
502 async fn test_empty_strings() {
503 let values = [Some("abc"), Some(""), None];
506 for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] {
508 let mut string_builder = StringBuilder::new();
509 for idx in order {
510 string_builder.append_option(values[idx]);
511 }
512 let string_array = Arc::new(string_builder.finish());
513 let test_cases = TestCases::default()
514 .with_indices(vec![1])
515 .with_indices(vec![0])
516 .with_indices(vec![2]);
517 check_round_trip_encoding_of_data(
518 vec![string_array.clone()],
519 &test_cases,
520 HashMap::new(),
521 )
522 .await;
523 let test_cases = test_cases.with_batch_size(1);
524 check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new())
525 .await;
526 }
527
528 let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")]));
533
534 let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]);
535 check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases, HashMap::new())
536 .await;
537 let test_cases = test_cases.with_batch_size(1);
538 check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new()).await;
539 }
540
541 #[test_log::test(tokio::test)]
542 #[ignore] async fn test_jumbo_string() {
544 let mut string_builder = LargeStringBuilder::new();
548 let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0'));
550 for _ in 0..5000 {
551 string_builder.append_option(Some(&giant_string));
552 }
553 let giant_array = Arc::new(string_builder.finish()) as ArrayRef;
554 let arrs = vec![giant_array];
555
556 let test_cases = TestCases::default().without_validation();
558 check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await;
559 }
560
561 #[test_log::test(tokio::test)]
564 async fn test_random_dictionary_input() {
565 let dict_field = Field::new(
566 "",
567 DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
568 false,
569 );
570 check_basic_random(dict_field).await;
571 }
572
573 #[test_log::test(tokio::test)]
574 async fn test_simple_already_dictionary() {
575 let values = StringArray::from_iter_values(["a", "bb", "ccc"]);
576 let indices = UInt8Array::from(vec![0, 1, 2, 0, 1, 2, 0, 1, 2]);
577 let dict_array = DictionaryArray::new(indices, Arc::new(values));
578
579 let test_cases = TestCases::default()
580 .with_range(0..2)
581 .with_range(1..3)
582 .with_range(2..4)
583 .with_indices(vec![1])
584 .with_indices(vec![2]);
585 check_round_trip_encoding_of_data(vec![Arc::new(dict_array)], &test_cases, HashMap::new())
586 .await;
587 }
588}