use std::sync::Arc;
use arrow_array::builder::{ArrayBuilder, StringBuilder};
use arrow_array::types::UInt8Type;
use arrow_array::{make_array, Array, ArrayRef, DictionaryArray, StringArray, UInt8Array};
use futures::{future::BoxFuture, FutureExt};
use crate::buffer::LanceBuffer;
use crate::data::{DataBlock, DictionaryDataBlock, NullableDataBlock, VariableWidthBlock};
use crate::format::pb::nullable::AllNull;
use crate::{
decoder::{PageScheduler, PrimitivePageDecoder},
encoder::{ArrayEncoder, EncodedArray},
format::pb,
EncodingsIo,
};
use crate::decoder::LogicalPageDecoder;
use crate::encodings::logical::primitive::PrimitiveFieldDecoder;
use arrow_schema::DataType;
use lance_core::Result;
use std::collections::HashMap;
use arrow_array::cast::AsArray;
#[derive(Debug)]
pub struct DictionaryPageScheduler {
indices_scheduler: Arc<dyn PageScheduler>,
items_scheduler: Arc<dyn PageScheduler>,
num_dictionary_items: u32,
should_decode_dict: bool,
}
impl DictionaryPageScheduler {
pub fn new(
indices_scheduler: Arc<dyn PageScheduler>,
items_scheduler: Arc<dyn PageScheduler>,
num_dictionary_items: u32,
should_decode_dict: bool,
) -> Self {
Self {
indices_scheduler,
items_scheduler,
num_dictionary_items,
should_decode_dict,
}
}
}
impl PageScheduler for DictionaryPageScheduler {
fn schedule_ranges(
&self,
ranges: &[std::ops::Range<u64>],
scheduler: &Arc<dyn EncodingsIo>,
top_level_row: u64,
) -> BoxFuture<'static, Result<Box<dyn PrimitivePageDecoder>>> {
let indices_page_decoder =
self.indices_scheduler
.schedule_ranges(ranges, scheduler, top_level_row);
let items_range = 0..(self.num_dictionary_items as u64);
let items_page_decoder = self.items_scheduler.schedule_ranges(
std::slice::from_ref(&items_range),
scheduler,
top_level_row,
);
let copy_size = self.num_dictionary_items as u64;
if self.should_decode_dict {
tokio::spawn(async move {
let items_decoder: Arc<dyn PrimitivePageDecoder> =
Arc::from(items_page_decoder.await?);
let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data(
items_decoder.clone(),
DataType::Utf8,
copy_size,
false,
);
let drained_task = primitive_wrapper.drain(copy_size)?;
let items_decode_task = drained_task.task;
let decoded_dict = items_decode_task.decode()?;
let indices_decoder: Box<dyn PrimitivePageDecoder> = indices_page_decoder.await?;
Ok(Box::new(DictionaryPageDecoder {
decoded_dict,
indices_decoder,
}) as Box<dyn PrimitivePageDecoder>)
})
.map(|join_handle| join_handle.unwrap())
.boxed()
} else {
let num_dictionary_items = self.num_dictionary_items;
tokio::spawn(async move {
let items_decoder: Arc<dyn PrimitivePageDecoder> =
Arc::from(items_page_decoder.await?);
let decoded_dict = items_decoder
.decode(0, num_dictionary_items as u64)?
.borrow_and_clone();
let indices_decoder = indices_page_decoder.await?;
Ok(Box::new(DirectDictionaryPageDecoder {
decoded_dict,
indices_decoder,
}) as Box<dyn PrimitivePageDecoder>)
})
.map(|join_handle| join_handle.unwrap())
.boxed()
}
}
}
struct DirectDictionaryPageDecoder {
decoded_dict: Box<dyn DataBlock>,
indices_decoder: Box<dyn PrimitivePageDecoder>,
}
impl PrimitivePageDecoder for DirectDictionaryPageDecoder {
fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<Box<dyn DataBlock>> {
let indices = self.indices_decoder.decode(rows_to_skip, num_rows)?;
let dict = self.decoded_dict.try_clone()?;
Ok(Box::new(DictionaryDataBlock {
indices,
dictionary: dict,
}))
}
}
struct DictionaryPageDecoder {
decoded_dict: Arc<dyn Array>,
indices_decoder: Box<dyn PrimitivePageDecoder>,
}
impl PrimitivePageDecoder for DictionaryPageDecoder {
fn decode(&self, rows_to_skip: u64, num_rows: u64) -> Result<Box<dyn DataBlock>> {
let indices_data = self.indices_decoder.decode(rows_to_skip, num_rows)?;
let indices_array = make_array(indices_data.into_arrow(DataType::UInt8, false)?);
let indices_array = indices_array.as_primitive::<UInt8Type>();
let dictionary = self.decoded_dict.clone();
let adjusted_indices: UInt8Array = indices_array
.iter()
.map(|x| match x {
Some(0) => None,
Some(x) => Some(x - 1),
None => None,
})
.collect();
let dict_array =
DictionaryArray::<UInt8Type>::try_new(adjusted_indices, dictionary).unwrap();
let string_array = arrow_cast::cast(&dict_array, &DataType::Utf8).unwrap();
let string_array = string_array.as_any().downcast_ref::<StringArray>().unwrap();
let null_buffer = string_array.nulls().map(|n| n.buffer().clone());
let offsets_buffer = string_array.offsets().inner().inner().clone();
let bytes_buffer = string_array.values().clone();
let string_data = Box::new(VariableWidthBlock {
bits_per_offset: 32,
data: LanceBuffer::from(bytes_buffer),
offsets: LanceBuffer::from(offsets_buffer),
num_values: num_rows,
});
if let Some(nulls) = null_buffer {
Ok(Box::new(NullableDataBlock {
data: string_data,
nulls: LanceBuffer::from(nulls),
}))
} else {
Ok(string_data)
}
}
}
#[derive(Debug)]
pub struct AlreadyDictionaryEncoder {
indices_encoder: Box<dyn ArrayEncoder>,
items_encoder: Box<dyn ArrayEncoder>,
}
impl AlreadyDictionaryEncoder {
pub fn new(
indices_encoder: Box<dyn ArrayEncoder>,
items_encoder: Box<dyn ArrayEncoder>,
) -> Self {
Self {
indices_encoder,
items_encoder,
}
}
}
impl ArrayEncoder for AlreadyDictionaryEncoder {
fn encode(&self, arrays: &[ArrayRef], buffer_index: &mut u32) -> Result<EncodedArray> {
let array_refs = arrays.iter().map(|arr| arr.as_ref()).collect::<Vec<_>>();
let array = arrow_select::concat::concat(&array_refs)?;
let array_dict = array.as_any_dictionary();
let indices = make_array(array_dict.keys().to_data());
let items = array_dict.values().clone();
if items.is_empty() {
return Ok(EncodedArray {
buffers: Vec::default(),
encoding: pb::ArrayEncoding {
array_encoding: Some(pb::array_encoding::ArrayEncoding::Nullable(Box::new(
pb::Nullable {
nullability: Some(pb::nullable::Nullability::AllNulls(AllNull {})),
},
))),
},
});
}
let dictionary_size = items.len() as u32;
let encoded_indices = self.indices_encoder.encode(&[indices], buffer_index)?;
let encoded_items = self.items_encoder.encode(&[items], buffer_index)?;
let mut all_buffers = encoded_indices.buffers;
all_buffers.extend(encoded_items.buffers);
Ok(EncodedArray {
buffers: all_buffers,
encoding: pb::ArrayEncoding {
array_encoding: Some(pb::array_encoding::ArrayEncoding::Dictionary(Box::new(
pb::Dictionary {
indices: Some(Box::new(encoded_indices.encoding)),
items: Some(Box::new(encoded_items.encoding)),
num_dictionary_items: dictionary_size,
},
))),
},
})
}
}
#[derive(Debug)]
pub struct DictionaryEncoder {
indices_encoder: Box<dyn ArrayEncoder>,
items_encoder: Box<dyn ArrayEncoder>,
}
impl DictionaryEncoder {
pub fn new(
indices_encoder: Box<dyn ArrayEncoder>,
items_encoder: Box<dyn ArrayEncoder>,
) -> Self {
Self {
indices_encoder,
items_encoder,
}
}
}
fn encode_dict_indices_and_items(arrays: &[ArrayRef]) -> (ArrayRef, ArrayRef) {
let mut arr_hashmap: HashMap<&str, u8> = HashMap::new();
let mut curr_dict_index = 1;
let total_capacity = arrays.iter().map(|arr| arr.len()).sum();
let mut dict_indices = Vec::with_capacity(total_capacity);
let mut dict_builder = StringBuilder::new();
for arr in arrays.iter() {
let string_array = arrow_array::cast::as_string_array(arr);
for i in 0..string_array.len() {
if !string_array.is_valid(i) {
dict_indices.push(0);
continue;
}
let st = string_array.value(i);
let hashmap_entry = *arr_hashmap.entry(st).or_insert(curr_dict_index);
dict_indices.push(hashmap_entry);
if hashmap_entry == curr_dict_index {
dict_builder.append_value(st);
curr_dict_index += 1;
}
}
}
let array_dict_indices = Arc::new(UInt8Array::from(dict_indices)) as ArrayRef;
if dict_builder.is_empty() {
dict_builder.append_option(Option::<&str>::None);
}
let dict_elements = dict_builder.finish();
let array_dict_elements = arrow_cast::cast(&dict_elements, &DataType::Utf8).unwrap();
(array_dict_indices, array_dict_elements)
}
impl ArrayEncoder for DictionaryEncoder {
fn encode(&self, arrays: &[ArrayRef], buffer_index: &mut u32) -> Result<EncodedArray> {
let (index_array, items_array) = encode_dict_indices_and_items(arrays);
let encoded_indices = self
.indices_encoder
.encode(&[index_array.clone()], buffer_index)?;
let encoded_items = self
.items_encoder
.encode(&[items_array.clone()], buffer_index)?;
let mut encoded_buffers = encoded_indices.buffers;
encoded_buffers.extend(encoded_items.buffers);
let dict_size = items_array.len() as u32;
Ok(EncodedArray {
buffers: encoded_buffers,
encoding: pb::ArrayEncoding {
array_encoding: Some(pb::array_encoding::ArrayEncoding::Dictionary(Box::new(
pb::Dictionary {
indices: Some(Box::new(encoded_indices.encoding)),
items: Some(Box::new(encoded_items.encoding)),
num_dictionary_items: dict_size,
},
))),
},
})
}
}
#[cfg(test)]
pub mod tests {
use arrow_array::{
builder::{LargeStringBuilder, StringBuilder},
ArrayRef, StringArray, UInt8Array,
};
use arrow_schema::{DataType, Field};
use std::{collections::HashMap, sync::Arc, vec};
use crate::testing::{
check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases,
};
use super::encode_dict_indices_and_items;
#[test]
fn test_encode_dict_nulls() {
let string_array1 = Arc::new(StringArray::from(vec![None, Some("foo"), Some("bar")]));
let string_array2 = Arc::new(StringArray::from(vec![Some("bar"), None, Some("foo")]));
let string_array3 = Arc::new(StringArray::from(vec![None as Option<&str>, None]));
let (dict_indices, dict_items) =
encode_dict_indices_and_items(&[string_array1, string_array2, string_array3]);
let expected_indices = Arc::new(UInt8Array::from(vec![0, 1, 2, 2, 0, 1, 0, 0])) as ArrayRef;
let expected_items = Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef;
assert_eq!(&dict_indices, &expected_indices);
assert_eq!(&dict_items, &expected_items);
}
#[test_log::test(tokio::test)]
async fn test_utf8() {
let field = Field::new("", DataType::Utf8, false);
check_round_trip_encoding_random(field, HashMap::new()).await;
}
#[test_log::test(tokio::test)]
async fn test_binary() {
let field = Field::new("", DataType::Binary, false);
check_round_trip_encoding_random(field, HashMap::new()).await;
}
#[test_log::test(tokio::test)]
async fn test_large_binary() {
let field = Field::new("", DataType::LargeBinary, true);
check_round_trip_encoding_random(field, HashMap::new()).await;
}
#[test_log::test(tokio::test)]
async fn test_large_utf8() {
let field = Field::new("", DataType::LargeUtf8, true);
check_round_trip_encoding_random(field, HashMap::new()).await;
}
#[test_log::test(tokio::test)]
async fn test_simple_utf8() {
let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
let test_cases = TestCases::default()
.with_range(0..2)
.with_range(0..3)
.with_range(1..3)
.with_indices(vec![1, 3]);
check_round_trip_encoding_of_data(
vec![Arc::new(string_array)],
&test_cases,
HashMap::new(),
)
.await;
}
#[test_log::test(tokio::test)]
async fn test_sliced_utf8() {
let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]);
let string_array = string_array.slice(1, 3);
let test_cases = TestCases::default()
.with_range(0..1)
.with_range(0..2)
.with_range(1..2);
check_round_trip_encoding_of_data(
vec![Arc::new(string_array)],
&test_cases,
HashMap::new(),
)
.await;
}
#[test_log::test(tokio::test)]
async fn test_empty_strings() {
let values = [Some("abc"), Some(""), None];
for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] {
let mut string_builder = StringBuilder::new();
for idx in order {
string_builder.append_option(values[idx]);
}
let string_array = Arc::new(string_builder.finish());
let test_cases = TestCases::default()
.with_indices(vec![1])
.with_indices(vec![0])
.with_indices(vec![2]);
check_round_trip_encoding_of_data(
vec![string_array.clone()],
&test_cases,
HashMap::new(),
)
.await;
let test_cases = test_cases.with_batch_size(1);
check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new())
.await;
}
let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")]));
let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]);
check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases, HashMap::new())
.await;
let test_cases = test_cases.with_batch_size(1);
check_round_trip_encoding_of_data(vec![string_array], &test_cases, HashMap::new()).await;
}
#[test_log::test(tokio::test)]
#[ignore] async fn test_jumbo_string() {
let mut string_builder = LargeStringBuilder::new();
let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0'));
for _ in 0..5000 {
string_builder.append_option(Some(&giant_string));
}
let giant_array = Arc::new(string_builder.finish()) as ArrayRef;
let arrs = vec![giant_array];
let test_cases = TestCases::default().without_validation();
check_round_trip_encoding_of_data(arrs, &test_cases, HashMap::new()).await;
}
#[test_log::test(tokio::test)]
async fn test_random_dictionary_input() {
let dict_field = Field::new(
"",
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
false,
);
check_round_trip_encoding_random(dict_field, HashMap::new()).await;
}
}