use std::{cmp::Ordering, collections::HashMap, ops::Range, sync::Arc};
use crate::{
decoder::DecoderConfig,
encodings::physical::block::CompressionScheme,
format::pb21::{
BufferCompression, CompressiveEncoding, PageLayout, compressive_encoding::Compression,
},
};
use arrow_array::{Array, StructArray, UInt64Array, make_array};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_ord::ord::make_comparator;
use arrow_schema::{DataType, Field, Field as ArrowField, FieldRef, Schema, SortOptions};
use arrow_select::concat::concat;
use bytes::{Bytes, BytesMut};
use futures::{FutureExt, StreamExt, future::BoxFuture};
use log::{debug, info, trace};
use tokio::sync::mpsc::{self, UnboundedSender};
use lance_core::{Result, utils::bit::pad_bytes};
use lance_datagen::{ArrayGenerator, RowCount, Seed, array, gen_batch};
use crate::{
EncodingsIo,
buffer::LanceBuffer,
decoder::{
ColumnInfo, DecodeBatchScheduler, DecoderMessage, DecoderPlugins, FilterExpression,
PageInfo, create_decode_stream,
},
encoder::{
ColumnIndexSequence, EncodedColumn, EncodedPage, EncodingOptions, FieldEncoder,
MIN_PAGE_BUFFER_ALIGNMENT, OutOfLineBuffers, default_encoding_strategy,
},
repdef::RepDefBuilder,
version::LanceFileVersion,
};
const MAX_PAGE_BYTES: u64 = 32 * 1024 * 1024;
const TEST_ALIGNMENT: usize = MIN_PAGE_BUFFER_ALIGNMENT as usize;
#[derive(Debug)]
pub(crate) struct SimulatedScheduler {
data: Bytes,
}
impl SimulatedScheduler {
pub fn new(data: Bytes) -> Self {
Self { data }
}
}
impl EncodingsIo for SimulatedScheduler {
fn submit_request(
&self,
ranges: Vec<Range<u64>>,
priority: u64,
) -> BoxFuture<'static, Result<Vec<Bytes>>> {
let data = ranges
.into_iter()
.map(|range| self.data.slice(range.start as usize..range.end as usize))
.collect();
log::trace!("Scheduled request with priority {}", priority);
std::future::ready(data)
.map(move |data| {
log::trace!("Decoded request with priority {}", priority);
Ok(data)
})
.boxed()
}
}
fn column_indices_from_schema_helper(
fields: &[FieldRef],
column_indices: &mut Vec<u32>,
column_counter: &mut u32,
is_structural_encoding: bool,
) {
for field in fields {
if is_structural_encoding && field.metadata().contains_key("lance-encoding:packed") {
column_indices.push(*column_counter);
*column_counter += 1;
continue;
}
match field.data_type() {
DataType::Struct(fields) => {
if !is_structural_encoding {
column_indices.push(*column_counter);
*column_counter += 1;
}
column_indices_from_schema_helper(
fields.as_ref(),
column_indices,
column_counter,
is_structural_encoding,
);
}
DataType::List(inner) => {
if !is_structural_encoding {
column_indices.push(*column_counter);
*column_counter += 1;
}
column_indices_from_schema_helper(
std::slice::from_ref(inner),
column_indices,
column_counter,
is_structural_encoding,
);
}
DataType::LargeList(inner) => {
if !is_structural_encoding {
column_indices.push(*column_counter);
*column_counter += 1;
}
column_indices_from_schema_helper(
std::slice::from_ref(inner),
column_indices,
column_counter,
is_structural_encoding,
);
}
DataType::Map(entries, _) => {
column_indices_from_schema_helper(
std::slice::from_ref(entries),
column_indices,
column_counter,
is_structural_encoding,
);
}
DataType::FixedSizeList(inner, _) => {
column_indices_from_schema_helper(
std::slice::from_ref(inner),
column_indices,
column_counter,
is_structural_encoding,
);
}
_ => {
column_indices.push(*column_counter);
*column_counter += 1;
column_indices_from_schema_helper(
&[],
column_indices,
column_counter,
is_structural_encoding,
);
}
}
}
}
fn column_indices_from_schema(schema: &Schema, is_structural_encoding: bool) -> Vec<u32> {
let mut column_indices = Vec::new();
let mut column_counter = 0;
column_indices_from_schema_helper(
schema.fields(),
&mut column_indices,
&mut column_counter,
is_structural_encoding,
);
column_indices
}
#[allow(clippy::too_many_arguments)]
async fn test_decode(
num_rows: u64,
batch_size: u32,
schema: &Schema,
column_infos: &[Arc<ColumnInfo>],
expected: Option<Arc<dyn Array>>,
io: Arc<dyn EncodingsIo>,
is_structural_encoding: bool,
schedule_fn: impl FnOnce(
DecodeBatchScheduler,
UnboundedSender<Result<DecoderMessage>>,
) -> BoxFuture<'static, ()>,
) {
let lance_schema = lance_core::datatypes::Schema::try_from(schema).unwrap();
let cache = Arc::new(lance_core::cache::LanceCache::with_capacity(
128 * 1024 * 1024,
));
let column_indices = column_indices_from_schema(schema, is_structural_encoding);
let decode_scheduler = DecodeBatchScheduler::try_new(
&lance_schema,
&column_indices,
column_infos,
&Vec::new(),
num_rows,
Arc::<DecoderPlugins>::default(),
io,
cache,
&FilterExpression::no_filter(),
&DecoderConfig::default(),
)
.await
.unwrap();
let (tx, rx) = mpsc::unbounded_channel();
let scheduler_fut = schedule_fn(decode_scheduler, tx);
scheduler_fut.await;
let mut decode_stream = create_decode_stream(
&lance_schema,
num_rows,
batch_size,
is_structural_encoding,
true,
is_structural_encoding,
rx,
)
.unwrap();
let mut offset = 0;
while let Some(batch) = decode_stream.next().await {
let batch = batch.task.await.unwrap();
if let Some(expected) = expected.as_ref() {
let actual = batch.column(0);
let expected_size = (batch_size as usize).min(expected.len() - offset);
let expected = expected.slice(offset, expected_size);
assert_eq!(expected.data_type(), actual.data_type());
if expected.len() != actual.len() {
panic!(
"Mismatch in length (at offset={}) expected {} but got {}",
offset,
expected.len(),
actual.len()
);
}
if &expected != actual {
if let Ok(comparator) = make_comparator(&expected, &actual, SortOptions::default())
{
for i in 0..expected.len() {
if !matches!(comparator(i, i), Ordering::Equal) {
panic!(
"Mismatch at index {} (offset={}) expected {:?} but got {:?} first mismatch is expected {:?} but got {:?}",
i,
offset,
expected,
actual,
expected.slice(i, 1),
actual.slice(i, 1)
);
}
}
} else {
assert_eq!(&expected, actual);
}
}
}
offset += batch.num_rows();
}
if let Some(expected) = expected.as_ref() {
assert_eq!(offset, expected.len());
}
}
pub trait ArrayGeneratorProvider {
fn provide(&self) -> Box<dyn ArrayGenerator>;
fn copy(&self) -> Box<dyn ArrayGeneratorProvider>;
}
struct RandomArrayGeneratorProvider {
field: Field,
}
impl ArrayGeneratorProvider for RandomArrayGeneratorProvider {
fn provide(&self) -> Box<dyn ArrayGenerator> {
array::rand_type(self.field.data_type())
}
fn copy(&self) -> Box<dyn ArrayGeneratorProvider> {
Box::new(Self {
field: self.field.clone(),
})
}
}
pub async fn check_basic_random(field: Field) {
check_specific_random(field, TestCases::basic()).await;
}
pub async fn check_specific_random(field: Field, test_cases: TestCases) {
let array_generator_provider = RandomArrayGeneratorProvider {
field: field.clone(),
};
check_round_trip_encoding_generated(field, Box::new(array_generator_provider), test_cases)
.await;
}
pub struct FnArrayGeneratorProvider<F: Fn() -> Box<dyn ArrayGenerator> + Clone + 'static> {
provider_fn: F,
}
impl<F: Fn() -> Box<dyn ArrayGenerator> + Clone + 'static> FnArrayGeneratorProvider<F> {
pub fn new(provider_fn: F) -> Self {
Self { provider_fn }
}
}
impl<F: Fn() -> Box<dyn ArrayGenerator> + Clone + 'static> ArrayGeneratorProvider
for FnArrayGeneratorProvider<F>
{
fn provide(&self) -> Box<dyn ArrayGenerator> {
(self.provider_fn)()
}
fn copy(&self) -> Box<dyn ArrayGeneratorProvider> {
Box::new(Self {
provider_fn: self.provider_fn.clone(),
})
}
}
pub async fn check_basic_generated(
field: Field,
array_generator_provider: Box<dyn ArrayGeneratorProvider>,
) {
check_round_trip_encoding_generated(field, array_generator_provider, TestCases::basic()).await;
}
pub async fn check_round_trip_encoding_generated(
field: Field,
array_generator_provider: Box<dyn ArrayGeneratorProvider>,
test_cases: TestCases,
) {
let lance_field = lance_core::datatypes::Field::try_from(&field).unwrap();
for page_size in test_cases.page_sizes.iter().copied() {
debug!("Testing random data with a page size of {}", page_size);
let encoder_factory = |version: LanceFileVersion| {
let encoding_strategy = default_encoding_strategy(version);
let mut column_index_seq = ColumnIndexSequence::default();
let encoding_options = EncodingOptions {
max_page_bytes: MAX_PAGE_BYTES,
cache_bytes_per_column: page_size,
keep_original_array: true,
buffer_alignment: MIN_PAGE_BUFFER_ALIGNMENT,
version,
};
encoding_strategy
.create_field_encoder(
encoding_strategy.as_ref(),
&lance_field,
&mut column_index_seq,
&encoding_options,
)
.unwrap()
};
check_round_trip_random(
encoder_factory,
field.clone(),
array_generator_provider.copy(),
&test_cases,
)
.await
}
}
fn supports_nulls(data_type: &DataType, version: LanceFileVersion) -> bool {
if let DataType::Struct(fields) = data_type {
if version == LanceFileVersion::V2_0 {
false
} else if fields.is_empty() {
false
} else {
true
}
} else {
true
}
}
type EncodingVerificationFn = dyn Fn(&[EncodedColumn], &LanceFileVersion);
#[derive(Clone)]
pub struct TestCases {
ranges: Vec<Range<u64>>,
indices: Vec<Vec<u64>>,
batch_size: u32,
skip_validation: bool,
max_page_size: Option<u64>,
page_sizes: Vec<u64>,
min_file_version: Option<LanceFileVersion>,
max_file_version: Option<LanceFileVersion>,
verify_encoding: Option<Arc<EncodingVerificationFn>>,
expected_encoding: Option<Vec<String>>,
}
impl Default for TestCases {
fn default() -> Self {
Self {
batch_size: 100,
ranges: Vec::new(),
indices: Vec::new(),
skip_validation: false,
max_page_size: None,
page_sizes: vec![4096, 1024 * 1024],
min_file_version: None,
max_file_version: None,
verify_encoding: None,
expected_encoding: None,
}
}
}
impl TestCases {
pub fn basic() -> Self {
Self::default()
.with_range(0..500)
.with_range(100..1100)
.with_range(8000..8500)
.with_indices(vec![100])
.with_indices(vec![0])
.with_indices(vec![9999])
.with_indices(vec![100, 1100, 5000])
.with_indices(vec![1000, 2000, 3000])
.with_indices(vec![2000, 2001, 2002, 2003, 2004])
.with_indices((100..500).map(|i| i * 3).collect::<Vec<_>>())
}
pub fn with_range(mut self, range: Range<u64>) -> Self {
self.ranges.push(range);
self
}
pub fn with_indices(mut self, indices: Vec<u64>) -> Self {
self.indices.push(indices);
self
}
pub fn with_batch_size(mut self, batch_size: u32) -> Self {
self.batch_size = batch_size;
self
}
pub fn without_validation(mut self) -> Self {
self.skip_validation = true;
self
}
pub fn with_min_file_version(mut self, version: LanceFileVersion) -> Self {
self.min_file_version = Some(version);
self
}
pub fn with_max_file_version(mut self, version: LanceFileVersion) -> Self {
self.max_file_version = Some(version);
self
}
pub fn with_page_sizes(mut self, page_sizes: Vec<u64>) -> Self {
self.page_sizes = page_sizes;
self
}
pub fn with_max_page_size(mut self, max_page_size: u64) -> Self {
self.max_page_size = Some(max_page_size);
self
}
fn get_max_page_size(&self) -> u64 {
self.max_page_size.unwrap_or(MAX_PAGE_BYTES)
}
fn get_versions(&self) -> Vec<LanceFileVersion> {
LanceFileVersion::iter_non_legacy()
.filter(|v| {
if let Some(min_file_version) = &self.min_file_version
&& v < min_file_version
{
return false;
}
if let Some(max_file_version) = &self.max_file_version
&& v > max_file_version
{
return false;
}
true
})
.collect()
}
pub fn with_verify_encoding(mut self, verify_encoding: Arc<EncodingVerificationFn>) -> Self {
self.verify_encoding = Some(verify_encoding);
self
}
fn verify_encoding(&self, encoding: &[EncodedColumn], version: &LanceFileVersion) {
if let Some(verify_encoding) = self.verify_encoding.as_ref() {
verify_encoding(encoding, version);
}
}
pub fn with_expected_encoding(mut self, encoding: impl Into<String>) -> Self {
self.expected_encoding = Some(vec![encoding.into()]);
self
}
pub fn with_expected_encoding_chain<I: IntoIterator<Item = T>, T: Into<String>>(
mut self,
encodings: I,
) -> Self {
self.expected_encoding = Some(encodings.into_iter().map(Into::into).collect());
self
}
}
fn tag(e: &Compression) -> &'static str {
use Compression::*;
match e {
Flat(_) => "flat",
Variable(_) => "variable",
Constant(_) => "constant",
OutOfLineBitpacking(_) => "out_of_line_bitpacking",
InlineBitpacking(_) => "inline_bitpacking",
Fsst(_) => "fsst",
Dictionary(_) => "dictionary",
Rle(_) => "rle",
ByteStreamSplit(_) => "byte_stream_split",
General(_) => "general",
FixedSizeList(_) => "fixed_size_list",
PackedStruct(_) => "packed_struct",
VariablePackedStruct(_) => "variable_packed_struct",
}
}
fn buffer(c: &Compression) -> Option<Vec<&BufferCompression>> {
use Compression::*;
match c {
Flat(f) => f.data.as_ref().map(|b| vec![b]),
Variable(v) => v.values.as_ref().map(|b| vec![b]),
InlineBitpacking(i) => i.values.as_ref().map(|b| vec![b]),
General(g) => g.compression.as_ref().map(|c| vec![c]),
_ => None,
}
}
fn child(c: &Compression) -> Option<Vec<&CompressiveEncoding>> {
use Compression::*;
match c {
Variable(v) => v.offsets.as_ref().map(|b| vec![b.as_ref()]),
OutOfLineBitpacking(o) => o.values.as_ref().map(|b| vec![b.as_ref()]),
Fsst(f) => f.values.as_ref().map(|b| vec![b.as_ref()]),
ByteStreamSplit(b) => b.values.as_ref().map(|b| vec![b.as_ref()]),
General(g) => g.values.as_ref().map(|b| vec![b.as_ref()]),
Dictionary(d) => {
let mut children = Vec::new();
if let Some(values) = d.items.as_ref() {
children.push(values.as_ref());
}
if let Some(indices) = d.indices.as_ref() {
children.push(indices.as_ref());
}
Some(children)
}
Rle(r) => {
let mut children = Vec::new();
if let Some(values) = r.values.as_ref() {
children.push(values.as_ref());
}
if let Some(run_lengths) = r.run_lengths.as_ref() {
children.push(run_lengths.as_ref());
}
Some(children)
}
FixedSizeList(f) => f.values.as_ref().map(|b| vec![b.as_ref()]),
VariablePackedStruct(v) => {
let nested: Vec<&CompressiveEncoding> = v
.fields
.iter()
.filter_map(|field| field.value.as_ref())
.collect();
if nested.is_empty() {
None
} else {
Some(nested)
}
}
_ => None,
}
}
pub fn extract_array_encoding_chain(enc: &CompressiveEncoding) -> Vec<String> {
let mut chain = Vec::with_capacity(8);
let mut stack = vec![enc];
while let Some(cur) = stack.pop() {
if let Some(inner) = &cur.compression {
chain.push(tag(inner).to_string());
if let Some(buffer) = buffer(inner) {
chain.extend(buffer.into_iter().map(|b| {
let scheme = CompressionScheme::try_from(b.scheme()).unwrap();
scheme.to_string()
}));
}
if let Some(children) = child(inner) {
stack.extend(children);
}
}
}
chain
}
fn collect_page_encoding(layout: &PageLayout, actual_chain: &mut Vec<String>) -> Result<()> {
use crate::format::pb21::page_layout::Layout;
if let Some(ref layout_type) = layout.layout {
match layout_type {
Layout::MiniBlockLayout(mini_block) => {
if mini_block.dictionary.is_some() {
actual_chain.push("dictionary".to_string());
}
if let Some(ref value_comp) = mini_block.value_compression {
let chain = extract_array_encoding_chain(value_comp);
actual_chain.extend(chain);
}
}
Layout::FullZipLayout(full_zip) => {
if let Some(ref value_comp) = full_zip.value_compression {
let chain = extract_array_encoding_chain(value_comp);
actual_chain.extend(chain);
}
}
Layout::ConstantLayout(_) => {
}
Layout::BlobLayout(blob) => {
if let Some(inner_layout) = &blob.inner_layout {
collect_page_encoding(inner_layout.as_ref(), actual_chain)?
}
}
}
}
Ok(())
}
fn verify_page_encoding(
page: &EncodedPage,
expected_chain: &[String],
col_idx: usize,
) -> Result<()> {
use crate::decoder::PageEncoding;
use lance_core::Error;
let mut actual_chain = Vec::new();
match &page.description {
PageEncoding::Structural(layout) => {
collect_page_encoding(layout, &mut actual_chain)?;
if actual_chain.is_empty()
&& page.data.is_empty()
&& let Some(crate::format::pb21::page_layout::Layout::ConstantLayout(cl)) =
layout.layout.as_ref()
&& cl.inline_value.is_none()
{
return Ok(());
}
}
PageEncoding::Legacy(_) => {
}
}
for expected in expected_chain {
if !actual_chain.iter().any(|actual| actual.contains(expected)) {
return Err(Error::invalid_input_source(
format!(
"Column {} expected encoding chain {:?} but got {:?}",
col_idx, expected_chain, actual_chain
)
.into(),
));
}
}
Ok(())
}
pub async fn check_round_trip_encoding_of_data(
data: Vec<Arc<dyn Array>>,
test_cases: &TestCases,
metadata: HashMap<String, String>,
) {
check_round_trip_encoding_of_data_with_expected(data, None, test_cases, metadata).await
}
pub async fn check_round_trip_encoding_of_data_with_expected(
data: Vec<Arc<dyn Array>>,
expected_override: Option<Arc<dyn Array>>,
test_cases: &TestCases,
metadata: HashMap<String, String>,
) {
let example_data = data.first().expect("Data must have at least one array");
let mut field = Field::new("", example_data.data_type().clone(), true);
field = field.with_metadata(metadata);
let lance_field = lance_core::datatypes::Field::try_from(&field).unwrap();
for file_version in test_cases.get_versions() {
for page_size in test_cases.page_sizes.iter() {
let encoding_strategy = default_encoding_strategy(file_version);
let mut column_index_seq = ColumnIndexSequence::default();
let encoding_options = EncodingOptions {
cache_bytes_per_column: *page_size,
max_page_bytes: test_cases.get_max_page_size(),
keep_original_array: true,
buffer_alignment: MIN_PAGE_BUFFER_ALIGNMENT,
version: file_version,
};
let encoder = encoding_strategy
.create_field_encoder(
encoding_strategy.as_ref(),
&lance_field,
&mut column_index_seq,
&encoding_options,
)
.unwrap();
info!(
"Testing round trip encoding of data with file version {} and page size {}",
file_version, page_size
);
check_round_trip_encoding_inner(
encoder,
&field,
data.clone(),
expected_override.clone(),
test_cases,
file_version,
)
.await
}
}
}
struct SimulatedWriter {
page_infos: Vec<Vec<PageInfo>>,
encoded_data: BytesMut,
}
impl SimulatedWriter {
fn new(num_columns: u32) -> Self {
let mut page_infos = Vec::with_capacity(num_columns as usize);
page_infos.resize_with(num_columns as usize, Default::default);
Self {
page_infos,
encoded_data: BytesMut::new(),
}
}
fn write_buffer(&mut self, buffer: LanceBuffer) -> (u64, u64) {
let offset = self.encoded_data.len() as u64;
self.encoded_data.extend_from_slice(&buffer);
let size = self.encoded_data.len() as u64 - offset;
let pad_bytes = pad_bytes::<TEST_ALIGNMENT>(self.encoded_data.len());
self.encoded_data.extend(std::iter::repeat_n(0, pad_bytes));
(offset, size)
}
fn write_lance_buffer(&mut self, buffer: LanceBuffer) {
self.encoded_data.extend_from_slice(&buffer);
let pad_bytes = pad_bytes::<TEST_ALIGNMENT>(self.encoded_data.len());
self.encoded_data.extend(std::iter::repeat_n(0, pad_bytes));
}
fn write_page(&mut self, encoded_page: EncodedPage) {
trace!("Encoded page {:?}", encoded_page);
let page_buffers = encoded_page.data;
let page_encoding = encoded_page.description;
let buffer_offsets_and_sizes = page_buffers
.into_iter()
.map(|b| {
let (offset, size) = self.write_buffer(b);
trace!("Encoded buffer offset={} size={}", offset, size);
(offset, size)
})
.collect::<Vec<_>>();
let page_info = PageInfo {
num_rows: encoded_page.num_rows,
encoding: page_encoding,
buffer_offsets_and_sizes: Arc::from(buffer_offsets_and_sizes),
priority: encoded_page.row_number,
};
let col_idx = encoded_page.column_idx as usize;
self.page_infos[col_idx].push(page_info);
}
fn new_external_buffers(&self) -> OutOfLineBuffers {
OutOfLineBuffers::new(self.encoded_data.len() as u64, MIN_PAGE_BUFFER_ALIGNMENT)
}
}
async fn check_round_trip_encoding_inner(
mut encoder: Box<dyn FieldEncoder>,
field: &Field,
data: Vec<Arc<dyn Array>>,
expected_override: Option<Arc<dyn Array>>,
test_cases: &TestCases,
file_version: LanceFileVersion,
) {
let mut writer = SimulatedWriter::new(encoder.num_columns());
let log_page = |encoded_page: &EncodedPage| {
debug!(
"Encoded page on column {} with {} rows and start row {} and buffer sizes [{}]",
encoded_page.column_idx,
encoded_page.num_rows,
encoded_page.row_number,
encoded_page
.data
.iter()
.map(|buf| buf.len().to_string())
.collect::<Vec<_>>()
.join(", ")
);
};
let mut row_number = 0;
for arr in &data {
let mut external_buffers = writer.new_external_buffers();
let repdef = RepDefBuilder::default();
let num_rows = arr.len() as u64;
let encode_tasks = encoder
.maybe_encode(
arr.clone(),
&mut external_buffers,
repdef,
row_number,
num_rows,
)
.unwrap();
for buffer in external_buffers.take_buffers() {
writer.write_lance_buffer(buffer);
}
for encode_task in encode_tasks {
let encoded_page = encode_task.await.unwrap();
log_page(&encoded_page);
if file_version >= LanceFileVersion::V2_1
&& let Some(ref expected) = test_cases.expected_encoding
{
verify_page_encoding(&encoded_page, expected, encoded_page.column_idx as usize)
.unwrap();
}
writer.write_page(encoded_page);
}
row_number += arr.len() as u64;
}
let mut external_buffers = writer.new_external_buffers();
let encode_tasks = encoder.flush(&mut external_buffers).unwrap();
for buffer in external_buffers.take_buffers() {
writer.write_lance_buffer(buffer);
}
for task in encode_tasks {
let encoded_page = task.await.unwrap();
log_page(&encoded_page);
if file_version >= LanceFileVersion::V2_1
&& let Some(ref expected) = test_cases.expected_encoding
{
verify_page_encoding(&encoded_page, expected, encoded_page.column_idx as usize)
.unwrap();
}
writer.write_page(encoded_page);
}
let mut external_buffers = writer.new_external_buffers();
let encoded_columns = encoder.finish(&mut external_buffers).await.unwrap();
test_cases.verify_encoding(&encoded_columns, &file_version);
for buffer in external_buffers.take_buffers() {
writer.write_lance_buffer(buffer);
}
let mut column_infos = Vec::new();
for (col_idx, encoded_column) in encoded_columns.into_iter().enumerate() {
for page in encoded_column.final_pages {
writer.write_page(page);
}
let col_buffer_off_and_size = encoded_column
.column_buffers
.into_iter()
.map(|b| writer.write_buffer(b))
.collect::<Vec<_>>();
let column_info = ColumnInfo::new(
col_idx as u32,
Arc::from(std::mem::take(&mut writer.page_infos[col_idx])),
col_buffer_off_and_size,
encoded_column.encoding,
);
column_infos.push(Arc::new(column_info));
}
let encoded_data = writer.encoded_data.freeze();
let scheduler = Arc::new(SimulatedScheduler::new(encoded_data)) as Arc<dyn EncodingsIo>;
let num_rows = data.iter().map(|arr| arr.len() as u64).sum::<u64>();
let concat_data = if test_cases.skip_validation {
None
} else if let Some(DataType::Struct(_)) = data.first().map(|datum| datum.data_type()) {
let capacities = Capacities::Array(num_rows as usize);
let array_data: Vec<_> = data.iter().map(|a| a.to_data()).collect::<Vec<_>>();
let array_data = array_data.iter().collect();
let mut mutable = MutableArrayData::with_capacities(array_data, false, capacities);
for (i, a) in data.iter().enumerate() {
mutable.extend(i, 0, a.len())
}
Some(make_array(mutable.freeze()))
} else {
Some(concat(&data.iter().map(|arr| arr.as_ref()).collect::<Vec<_>>()).unwrap())
};
let expected_data = expected_override.clone().or_else(|| concat_data.clone());
let is_structural_encoding = file_version >= LanceFileVersion::V2_1;
let decode_field = if is_structural_encoding {
let mut lance_field = lance_core::datatypes::Field::try_from(field).unwrap();
if lance_field.is_blob() && matches!(lance_field.data_type(), DataType::Struct(_)) {
lance_field.unloaded_mut();
let mut arrow_field = ArrowField::from(&lance_field);
let mut metadata = arrow_field.metadata().clone();
metadata.insert("lance-encoding:packed".to_string(), "true".to_string());
arrow_field = arrow_field.with_metadata(metadata);
arrow_field
} else {
field.clone()
}
} else {
field.clone()
};
let schema = Schema::new(vec![decode_field]);
debug!("Testing full decode");
let scheduler_copy = scheduler.clone();
test_decode(
num_rows,
test_cases.batch_size,
&schema,
&column_infos,
expected_data.clone(),
scheduler_copy.clone(),
is_structural_encoding,
|mut decode_scheduler, tx| {
async move {
decode_scheduler.schedule_range(
0..num_rows,
&FilterExpression::no_filter(),
tx,
scheduler_copy,
)
}
.boxed()
},
)
.await;
for range in &test_cases.ranges {
debug!("Testing decode of range {:?}", range);
let num_rows = range.end - range.start;
let expected = expected_data
.as_ref()
.map(|arr| arr.slice(range.start as usize, num_rows as usize));
let scheduler = scheduler.clone();
let range = range.clone();
test_decode(
num_rows,
test_cases.batch_size,
&schema,
&column_infos,
expected,
scheduler.clone(),
is_structural_encoding,
|mut decode_scheduler, tx| {
async move {
decode_scheduler.schedule_range(
range,
&FilterExpression::no_filter(),
tx,
scheduler,
)
}
.boxed()
},
)
.await;
}
for indices in &test_cases.indices {
if indices.len() == 1 {
debug!("Testing decode of index {}", indices[0]);
} else {
debug!(
"Testing decode of {} indices spread across range [{}..{}]",
indices.len(),
indices[0],
indices[indices.len() - 1]
);
}
let num_rows = indices.len() as u64;
let indices_arr = UInt64Array::from(indices.clone());
let is_empty_struct = if let DataType::Struct(fields) = field.data_type() {
fields.is_empty()
} else {
false
};
let expected = if is_empty_struct {
Some(Arc::new(StructArray::new_empty_fields(indices_arr.len(), None)) as Arc<dyn Array>)
} else {
concat_data.as_ref().map(|concat_data| {
arrow_select::take::take(&concat_data, &indices_arr, None).unwrap()
})
};
let scheduler = scheduler.clone();
let indices = indices.clone();
test_decode(
num_rows,
test_cases.batch_size,
&schema,
&column_infos,
expected,
scheduler.clone(),
is_structural_encoding,
|mut decode_scheduler, tx| {
async move {
decode_scheduler.schedule_take(
&indices,
&FilterExpression::no_filter(),
tx,
scheduler,
)
}
.boxed()
},
)
.await;
}
}
const NUM_RANDOM_ROWS: u32 = 10000;
async fn check_round_trip_random(
encoder_factory: impl Fn(LanceFileVersion) -> Box<dyn FieldEncoder>,
field: Field,
array_generator_provider: Box<dyn ArrayGeneratorProvider>,
test_cases: &TestCases,
) {
for null_rate in [None, Some(0.5), Some(1.0)] {
for use_slicing in [false, true] {
for file_version in test_cases.get_versions() {
if null_rate != Some(1.0) && matches!(field.data_type(), DataType::Null) {
continue;
}
let field = if null_rate.is_some() {
if !supports_nulls(field.data_type(), file_version) {
continue;
}
field.clone().with_nullable(true)
} else {
field.clone().with_nullable(false)
};
for num_ingest_batches in [1, 5, 10] {
let rows_per_batch = NUM_RANDOM_ROWS / num_ingest_batches;
let mut data = Vec::new();
if use_slicing {
let mut generator =
gen_batch().anon_col(array_generator_provider.provide());
if let Some(null_rate) = null_rate {
if !matches!(field.data_type(), DataType::Null) {
generator.with_random_nulls(null_rate);
}
}
let all_data = generator
.into_batch_rows(RowCount::from(10000))
.unwrap()
.column(0)
.clone();
let mut offset = 0;
for _ in 0..num_ingest_batches {
data.push(all_data.slice(offset, rows_per_batch as usize));
offset += rows_per_batch as usize;
}
} else {
for i in 0..num_ingest_batches {
let mut generator = gen_batch()
.with_seed(Seed::from(i as u64))
.anon_col(array_generator_provider.provide());
if let Some(null_rate) = null_rate {
if !matches!(field.data_type(), DataType::Null) {
generator.with_random_nulls(null_rate);
}
}
let arr = generator
.into_batch_rows(RowCount::from(rows_per_batch as u64))
.unwrap()
.column(0)
.clone();
data.push(arr);
}
}
info!(
"Testing version {} with {} rows divided across {} batches for {} rows per batch with null_rate={:?} and use_slicing={}",
file_version,
NUM_RANDOM_ROWS,
num_ingest_batches,
rows_per_batch,
null_rate,
use_slicing
);
check_round_trip_encoding_inner(
encoder_factory(file_version),
&field,
data,
None,
test_cases,
file_version,
)
.await
}
}
}
}
}