use std::{collections::HashMap, sync::Arc};
use arrow_array::{Array, ArrayRef, RecordBatch};
use arrow_schema::DataType;
use bytes::{Bytes, BytesMut};
use futures::future::BoxFuture;
use lance_core::datatypes::{Field, Schema};
use lance_core::error::LanceOptionExt;
use lance_core::utils::bit::{is_pwr_two, pad_bytes_to};
use lance_core::{Error, Result};
use crate::buffer::LanceBuffer;
use crate::compression::{CompressionStrategy, DefaultCompressionStrategy};
use crate::compression_config::CompressionParams;
use crate::decoder::PageEncoding;
use crate::encodings::logical::blob::{BlobStructuralEncoder, BlobV2StructuralEncoder};
use crate::encodings::logical::fixed_size_list::FixedSizeListStructuralEncoder;
use crate::encodings::logical::list::ListStructuralEncoder;
use crate::encodings::logical::map::MapStructuralEncoder;
use crate::encodings::logical::primitive::PrimitiveStructuralEncoder;
use crate::encodings::logical::r#struct::StructStructuralEncoder;
use crate::repdef::RepDefBuilder;
use crate::version::LanceFileVersion;
use crate::{
decoder::{ColumnInfo, PageInfo},
format::pb,
};
pub const MIN_PAGE_BUFFER_ALIGNMENT: u64 = 8;
#[derive(Debug)]
pub struct EncodedPage {
pub data: Vec<LanceBuffer>,
pub description: PageEncoding,
pub num_rows: u64,
pub row_number: u64,
pub column_idx: u32,
}
pub struct EncodedColumn {
pub column_buffers: Vec<LanceBuffer>,
pub encoding: pb::ColumnEncoding,
pub final_pages: Vec<EncodedPage>,
}
impl Default for EncodedColumn {
fn default() -> Self {
Self {
column_buffers: Default::default(),
encoding: pb::ColumnEncoding {
column_encoding: Some(pb::column_encoding::ColumnEncoding::Values(())),
},
final_pages: Default::default(),
}
}
}
pub struct OutOfLineBuffers {
position: u64,
buffer_alignment: u64,
buffers: Vec<LanceBuffer>,
}
impl OutOfLineBuffers {
pub fn new(base_position: u64, buffer_alignment: u64) -> Self {
Self {
position: base_position,
buffer_alignment,
buffers: Vec::new(),
}
}
pub fn add_buffer(&mut self, buffer: LanceBuffer) -> u64 {
let position = self.position;
self.position += buffer.len() as u64;
self.position += pad_bytes_to(buffer.len(), self.buffer_alignment as usize) as u64;
self.buffers.push(buffer);
position
}
pub fn take_buffers(self) -> Vec<LanceBuffer> {
self.buffers
}
pub fn reset_position(&mut self, position: u64) {
self.position = position;
}
}
pub type EncodeTask = BoxFuture<'static, Result<EncodedPage>>;
pub trait FieldEncoder: Send {
fn maybe_encode(
&mut self,
array: ArrayRef,
external_buffers: &mut OutOfLineBuffers,
repdef: RepDefBuilder,
row_number: u64,
num_rows: u64,
) -> Result<Vec<EncodeTask>>;
fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result<Vec<EncodeTask>>;
fn finish(
&mut self,
external_buffers: &mut OutOfLineBuffers,
) -> BoxFuture<'_, Result<Vec<EncodedColumn>>>;
fn num_columns(&self) -> u32;
}
#[derive(Debug, Default)]
pub struct ColumnIndexSequence {
current_index: u32,
mapping: Vec<(u32, u32)>,
}
impl ColumnIndexSequence {
pub fn next_column_index(&mut self, field_id: u32) -> u32 {
let idx = self.current_index;
self.current_index += 1;
self.mapping.push((field_id, idx));
idx
}
pub fn skip(&mut self) {
self.current_index += 1;
}
}
pub struct EncodingOptions {
pub cache_bytes_per_column: u64,
pub max_page_bytes: u64,
pub keep_original_array: bool,
pub buffer_alignment: u64,
pub version: LanceFileVersion,
}
impl Default for EncodingOptions {
fn default() -> Self {
Self {
cache_bytes_per_column: 8 * 1024 * 1024,
max_page_bytes: 32 * 1024 * 1024,
keep_original_array: true,
buffer_alignment: 64,
version: LanceFileVersion::default(),
}
}
}
impl EncodingOptions {
pub fn support_large_chunk(&self) -> bool {
self.version >= LanceFileVersion::V2_2
}
}
pub trait FieldEncodingStrategy: Send + Sync + std::fmt::Debug {
fn create_field_encoder(
&self,
encoding_strategy_root: &dyn FieldEncodingStrategy,
field: &Field,
column_index: &mut ColumnIndexSequence,
options: &EncodingOptions,
) -> Result<Box<dyn FieldEncoder>>;
}
pub fn default_encoding_strategy(version: LanceFileVersion) -> Box<dyn FieldEncodingStrategy> {
match version.resolve() {
LanceFileVersion::Legacy => panic!(),
LanceFileVersion::V2_0 => Box::new(
crate::previous::encoder::CoreFieldEncodingStrategy::new(version),
),
_ => Box::new(StructuralEncodingStrategy::with_version(version)),
}
}
pub fn default_encoding_strategy_with_params(
version: LanceFileVersion,
params: CompressionParams,
) -> Result<Box<dyn FieldEncodingStrategy>> {
match version.resolve() {
LanceFileVersion::Legacy | LanceFileVersion::V2_0 => Err(Error::invalid_input(
"Compression parameters are only supported in Lance file version 2.1 and later",
)),
_ => {
let compression_strategy =
Arc::new(DefaultCompressionStrategy::with_params(params).with_version(version));
Ok(Box::new(StructuralEncodingStrategy {
compression_strategy,
version,
}))
}
}
}
#[derive(Debug)]
pub struct StructuralEncodingStrategy {
pub compression_strategy: Arc<dyn CompressionStrategy>,
pub version: LanceFileVersion,
}
#[allow(clippy::derivable_impls)]
impl Default for StructuralEncodingStrategy {
fn default() -> Self {
Self {
compression_strategy: Arc::new(DefaultCompressionStrategy::new()),
version: LanceFileVersion::default(),
}
}
}
impl StructuralEncodingStrategy {
pub fn with_version(version: LanceFileVersion) -> Self {
Self {
compression_strategy: Arc::new(DefaultCompressionStrategy::new().with_version(version)),
version,
}
}
fn is_primitive_type(data_type: &DataType) -> bool {
match data_type {
DataType::FixedSizeList(inner, _) => Self::is_primitive_type(inner.data_type()),
_ => matches!(
data_type,
DataType::Boolean
| DataType::Date32
| DataType::Date64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::Duration(_)
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Int8
| DataType::Interval(_)
| DataType::Null
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::UInt8
| DataType::FixedSizeBinary(_)
| DataType::Binary
| DataType::LargeBinary
| DataType::Utf8
| DataType::LargeUtf8,
),
}
}
fn do_create_field_encoder(
&self,
_encoding_strategy_root: &dyn FieldEncodingStrategy,
field: &Field,
column_index: &mut ColumnIndexSequence,
options: &EncodingOptions,
root_field_metadata: &HashMap<String, String>,
) -> Result<Box<dyn FieldEncoder>> {
let data_type = field.data_type();
if field.is_blob() {
match data_type {
DataType::Binary | DataType::LargeBinary => {
return Ok(Box::new(BlobStructuralEncoder::new(
field,
column_index.next_column_index(field.id as u32),
options,
self.compression_strategy.clone(),
)?));
}
DataType::Struct(_) if self.version >= LanceFileVersion::V2_2 => {
return Ok(Box::new(BlobV2StructuralEncoder::new(
field,
column_index.next_column_index(field.id as u32),
options,
self.compression_strategy.clone(),
)?));
}
DataType::Struct(_) => {
return Err(Error::invalid_input_source(
"Blob v2 struct input requires file version >= 2.2".into(),
));
}
_ => {
return Err(Error::invalid_input_source(
format!(
"Blob encoding only supports Binary/LargeBinary or v2 Struct, got {}",
data_type
)
.into(),
));
}
}
}
if Self::is_primitive_type(&data_type) {
Ok(Box::new(PrimitiveStructuralEncoder::try_new(
options,
self.compression_strategy.clone(),
column_index.next_column_index(field.id as u32),
field.clone(),
Arc::new(root_field_metadata.clone()),
)?))
} else {
match data_type {
DataType::List(_) | DataType::LargeList(_) => {
let child = field.children.first().expect_ok()?;
let child_encoder = self.do_create_field_encoder(
_encoding_strategy_root,
child,
column_index,
options,
root_field_metadata,
)?;
Ok(Box::new(ListStructuralEncoder::new(
options.keep_original_array,
child_encoder,
)))
}
DataType::FixedSizeList(inner, _)
if matches!(inner.data_type(), DataType::Struct(_)) =>
{
if self.version < LanceFileVersion::V2_2 {
return Err(Error::not_supported_source(format!(
"FixedSizeList<Struct> is only supported in Lance file format 2.2+, current version: {}",
self.version
)
.into()));
}
let child = field.children.first().expect_ok()?;
let child_encoder = self.do_create_field_encoder(
_encoding_strategy_root,
child,
column_index,
options,
root_field_metadata,
)?;
Ok(Box::new(FixedSizeListStructuralEncoder::new(
options.keep_original_array,
child_encoder,
)))
}
DataType::Map(_, keys_sorted) => {
if keys_sorted {
return Err(Error::not_supported_source(format!("Map data type is not supported with keys_sorted=true now, current value is {}", keys_sorted).into()));
}
if self.version < LanceFileVersion::V2_2 {
return Err(Error::not_supported_source(format!(
"Map data type is only supported in Lance file format 2.2+, current version: {}",
self.version
)
.into()));
}
let entries_child = field.children.first().ok_or_else(|| {
Error::schema("Map should have an entries child".to_string())
})?;
let DataType::Struct(struct_fields) = entries_child.data_type() else {
return Err(Error::schema(
"Map entries field must be a Struct<key, value>".to_string(),
));
};
if struct_fields.len() < 2 {
return Err(Error::schema(
"Map entries struct must contain both key and value fields".to_string(),
));
}
let key_field = &struct_fields[0];
if key_field.is_nullable() {
return Err(Error::schema(format!(
"Map key field '{}' must be non-nullable according to Arrow Map specification",
key_field.name()
)));
}
let child_encoder = self.do_create_field_encoder(
_encoding_strategy_root,
entries_child,
column_index,
options,
root_field_metadata,
)?;
Ok(Box::new(MapStructuralEncoder::new(
options.keep_original_array,
child_encoder,
)))
}
DataType::Struct(fields) => {
if field.is_packed_struct() || fields.is_empty() {
Ok(Box::new(PrimitiveStructuralEncoder::try_new(
options,
self.compression_strategy.clone(),
column_index.next_column_index(field.id as u32),
field.clone(),
Arc::new(root_field_metadata.clone()),
)?))
} else {
let children_encoders = field
.children
.iter()
.map(|field| {
self.do_create_field_encoder(
_encoding_strategy_root,
field,
column_index,
options,
root_field_metadata,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(StructStructuralEncoder::new(
options.keep_original_array,
children_encoders,
)))
}
}
DataType::Dictionary(_, value_type) => {
if Self::is_primitive_type(&value_type) {
Ok(Box::new(PrimitiveStructuralEncoder::try_new(
options,
self.compression_strategy.clone(),
column_index.next_column_index(field.id as u32),
field.clone(),
Arc::new(root_field_metadata.clone()),
)?))
} else {
Err(Error::not_supported_source(format!("cannot encode a dictionary column whose value type is a logical type ({})", value_type).into()))
}
}
_ => todo!("Implement encoding for field {}", field),
}
}
}
}
impl FieldEncodingStrategy for StructuralEncodingStrategy {
fn create_field_encoder(
&self,
encoding_strategy_root: &dyn FieldEncodingStrategy,
field: &Field,
column_index: &mut ColumnIndexSequence,
options: &EncodingOptions,
) -> Result<Box<dyn FieldEncoder>> {
self.do_create_field_encoder(
encoding_strategy_root,
field,
column_index,
options,
&field.metadata,
)
}
}
pub struct BatchEncoder {
pub field_encoders: Vec<Box<dyn FieldEncoder>>,
pub field_id_to_column_index: Vec<(u32, u32)>,
}
impl BatchEncoder {
pub fn try_new(
schema: &Schema,
strategy: &dyn FieldEncodingStrategy,
options: &EncodingOptions,
) -> Result<Self> {
let mut col_idx = 0;
let mut col_idx_sequence = ColumnIndexSequence::default();
let field_encoders = schema
.fields
.iter()
.map(|field| {
let encoder = strategy.create_field_encoder(
strategy,
field,
&mut col_idx_sequence,
options,
)?;
col_idx += encoder.as_ref().num_columns();
Ok(encoder)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
field_encoders,
field_id_to_column_index: col_idx_sequence.mapping,
})
}
pub fn num_columns(&self) -> u32 {
self.field_encoders
.iter()
.map(|field_encoder| field_encoder.num_columns())
.sum::<u32>()
}
}
#[derive(Debug)]
pub struct EncodedBatch {
pub data: Bytes,
pub page_table: Vec<Arc<ColumnInfo>>,
pub schema: Arc<Schema>,
pub top_level_columns: Vec<u32>,
pub num_rows: u64,
}
fn write_page_to_data_buffer(page: EncodedPage, data_buffer: &mut BytesMut) -> PageInfo {
let buffers = page.data;
let mut buffer_offsets_and_sizes = Vec::with_capacity(buffers.len());
for buffer in buffers {
let buffer_offset = data_buffer.len() as u64;
data_buffer.extend_from_slice(&buffer);
let size = data_buffer.len() as u64 - buffer_offset;
buffer_offsets_and_sizes.push((buffer_offset, size));
}
PageInfo {
buffer_offsets_and_sizes: Arc::from(buffer_offsets_and_sizes),
encoding: page.description,
num_rows: page.num_rows,
priority: page.row_number,
}
}
pub async fn encode_batch(
batch: &RecordBatch,
schema: Arc<Schema>,
encoding_strategy: &dyn FieldEncodingStrategy,
options: &EncodingOptions,
) -> Result<EncodedBatch> {
if !is_pwr_two(options.buffer_alignment) || options.buffer_alignment < MIN_PAGE_BUFFER_ALIGNMENT
{
return Err(Error::invalid_input_source(
format!(
"buffer_alignment must be a power of two and at least {}",
MIN_PAGE_BUFFER_ALIGNMENT
)
.into(),
));
}
let mut data_buffer = BytesMut::new();
let lance_schema = Schema::try_from(batch.schema().as_ref())?;
let options = EncodingOptions {
keep_original_array: true,
..*options
};
let batch_encoder = BatchEncoder::try_new(&lance_schema, encoding_strategy, &options)?;
let mut page_table = Vec::new();
let mut col_idx_offset = 0;
for (arr, mut encoder) in batch.columns().iter().zip(batch_encoder.field_encoders) {
let mut external_buffers =
OutOfLineBuffers::new(data_buffer.len() as u64, options.buffer_alignment);
let repdef = RepDefBuilder::default();
let encoder = encoder.as_mut();
let num_rows = arr.len() as u64;
let mut tasks =
encoder.maybe_encode(arr.clone(), &mut external_buffers, repdef, 0, num_rows)?;
tasks.extend(encoder.flush(&mut external_buffers)?);
for buffer in external_buffers.take_buffers() {
data_buffer.extend_from_slice(&buffer);
}
let mut pages = HashMap::<u32, Vec<PageInfo>>::new();
for task in tasks {
let encoded_page = task.await?;
pages
.entry(encoded_page.column_idx)
.or_default()
.push(write_page_to_data_buffer(encoded_page, &mut data_buffer));
}
let mut external_buffers =
OutOfLineBuffers::new(data_buffer.len() as u64, options.buffer_alignment);
let encoded_columns = encoder.finish(&mut external_buffers).await?;
for buffer in external_buffers.take_buffers() {
data_buffer.extend_from_slice(&buffer);
}
let num_columns = encoded_columns.len();
for (col_idx, encoded_column) in encoded_columns.into_iter().enumerate() {
let col_idx = col_idx + col_idx_offset;
let mut col_buffer_offsets_and_sizes = Vec::new();
for buffer in encoded_column.column_buffers {
let buffer_offset = data_buffer.len() as u64;
data_buffer.extend_from_slice(&buffer);
let size = data_buffer.len() as u64 - buffer_offset;
col_buffer_offsets_and_sizes.push((buffer_offset, size));
}
for page in encoded_column.final_pages {
pages
.entry(page.column_idx)
.or_default()
.push(write_page_to_data_buffer(page, &mut data_buffer));
}
let col_pages = std::mem::take(pages.entry(col_idx as u32).or_default());
page_table.push(Arc::new(ColumnInfo {
index: col_idx as u32,
buffer_offsets_and_sizes: Arc::from(
col_buffer_offsets_and_sizes.into_boxed_slice(),
),
page_infos: Arc::from(col_pages.into_boxed_slice()),
encoding: encoded_column.encoding,
}))
}
col_idx_offset += num_columns;
}
let top_level_columns = batch_encoder
.field_id_to_column_index
.iter()
.map(|(_, idx)| *idx)
.collect();
Ok(EncodedBatch {
data: data_buffer.freeze(),
top_level_columns,
page_table,
schema,
num_rows: batch.num_rows() as u64,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compression_config::{CompressionFieldParams, CompressionParams};
use arrow_schema::{DataType as ArrowDataType, Field as ArrowField, Fields as ArrowFields};
#[test]
fn test_configured_encoding_strategy() {
let mut params = CompressionParams::new();
params.columns.insert(
"*_id".to_string(),
CompressionFieldParams {
rle_threshold: Some(0.5),
compression: Some("lz4".to_string()),
compression_level: None,
bss: None,
minichunk_size: None,
},
);
let strategy =
default_encoding_strategy_with_params(LanceFileVersion::V2_1, params.clone())
.expect("Should succeed for V2.1");
assert!(format!("{:?}", strategy).contains("StructuralEncodingStrategy"));
assert!(format!("{:?}", strategy).contains("DefaultCompressionStrategy"));
let err = default_encoding_strategy_with_params(LanceFileVersion::V2_0, params.clone())
.expect_err("Should fail for V2.0");
assert!(
err.to_string()
.contains("only supported in Lance file version 2.1")
);
let err = default_encoding_strategy_with_params(LanceFileVersion::Legacy, params)
.expect_err("Should fail for Legacy");
assert!(
err.to_string()
.contains("only supported in Lance file version 2.1")
);
}
#[test]
fn test_fixed_size_list_struct_requires_v2_2() {
let list_item = ArrowField::new(
"item",
ArrowDataType::Struct(ArrowFields::from(vec![ArrowField::new(
"x",
ArrowDataType::Int32,
true,
)])),
true,
);
let arrow_field = ArrowField::new(
"list_struct",
ArrowDataType::FixedSizeList(Arc::new(list_item), 2),
true,
);
let field = Field::try_from(&arrow_field).unwrap();
let strategy = StructuralEncodingStrategy::with_version(LanceFileVersion::V2_1);
let mut column_index = ColumnIndexSequence::default();
let options = EncodingOptions::default();
let result = strategy.create_field_encoder(&strategy, &field, &mut column_index, &options);
assert!(
result.is_err(),
"FixedSizeList<Struct> should be rejected for file version 2.1"
);
let err = result.err().unwrap();
assert!(
err.to_string()
.contains("FixedSizeList<Struct> is only supported in Lance file format 2.2+")
);
}
}