use std::{
collections::{BinaryHeap, VecDeque},
ops::Range,
sync::Arc,
};
use arrow_array::{cast::AsArray, ArrayRef, StructArray};
use arrow_schema::{DataType, Fields};
use futures::{future::BoxFuture, FutureExt};
use log::trace;
use snafu::{location, Location};
use crate::{
decoder::{
DecodeArrayTask, DecoderReady, FieldScheduler, FilterExpression, LogicalPageDecoder,
NextDecodeTask, ScheduledScanLine, SchedulerContext, SchedulingJob,
},
encoder::{EncodeTask, EncodedArray, EncodedColumn, EncodedPage, FieldEncoder},
format::pb,
};
use lance_core::{Error, Result};
#[derive(Debug)]
struct SchedulingJobWithStatus<'a> {
col_idx: u32,
col_name: &'a str,
job: Box<dyn SchedulingJob + 'a>,
rows_scheduled: u64,
rows_remaining: u64,
}
impl<'a> PartialEq for SchedulingJobWithStatus<'a> {
fn eq(&self, other: &Self) -> bool {
self.col_idx == other.col_idx
}
}
impl<'a> Eq for SchedulingJobWithStatus<'a> {}
impl<'a> PartialOrd for SchedulingJobWithStatus<'a> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Ord for SchedulingJobWithStatus<'a> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.rows_scheduled.cmp(&self.rows_scheduled)
}
}
#[derive(Debug)]
struct SimpleStructSchedulerJob<'a> {
scheduler: &'a SimpleStructScheduler,
children: BinaryHeap<SchedulingJobWithStatus<'a>>,
rows_scheduled: u64,
num_rows: u64,
initialized: bool,
}
impl<'a> SimpleStructSchedulerJob<'a> {
fn new(
scheduler: &'a SimpleStructScheduler,
children: Vec<Box<dyn SchedulingJob + 'a>>,
num_rows: u64,
) -> Self {
let children = children
.into_iter()
.enumerate()
.map(|(idx, job)| SchedulingJobWithStatus {
col_idx: idx as u32,
col_name: scheduler.child_fields[idx].name(),
job,
rows_scheduled: 0,
rows_remaining: num_rows,
})
.collect::<BinaryHeap<_>>();
Self {
scheduler,
children,
rows_scheduled: 0,
num_rows,
initialized: false,
}
}
}
impl<'a> SchedulingJob for SimpleStructSchedulerJob<'a> {
fn schedule_next(
&mut self,
mut context: &mut SchedulerContext,
top_level_row: u64,
) -> Result<ScheduledScanLine> {
let mut decoders = Vec::new();
if !self.initialized {
let struct_decoder = Box::new(SimpleStructDecoder::new(
self.scheduler.child_fields.clone(),
self.num_rows,
));
let struct_decoder = context.locate_decoder(struct_decoder);
decoders.push(struct_decoder);
self.initialized = true;
}
let old_rows_scheduled = self.rows_scheduled;
while old_rows_scheduled == self.rows_scheduled {
let mut next_child = self.children.pop().unwrap();
trace!("Scheduling more rows for child {}", next_child.col_idx);
let scoped = context.push(next_child.col_name, next_child.col_idx);
let child_scan = next_child
.job
.schedule_next(scoped.context, top_level_row)?;
trace!(
"Scheduled {} rows for child {}",
child_scan.rows_scheduled,
next_child.col_idx
);
next_child.rows_scheduled += child_scan.rows_scheduled;
next_child.rows_remaining -= child_scan.rows_scheduled;
decoders.extend(child_scan.decoders);
self.children.push(next_child);
self.rows_scheduled = self.children.peek().unwrap().rows_scheduled;
context = scoped.pop();
}
let struct_rows_scheduled = self.rows_scheduled - old_rows_scheduled;
Ok(ScheduledScanLine {
decoders,
rows_scheduled: struct_rows_scheduled,
})
}
fn num_rows(&self) -> u64 {
self.num_rows
}
}
#[derive(Debug)]
pub struct SimpleStructScheduler {
children: Vec<Arc<dyn FieldScheduler>>,
child_fields: Fields,
num_rows: u64,
}
impl SimpleStructScheduler {
pub fn new(children: Vec<Arc<dyn FieldScheduler>>, child_fields: Fields) -> Self {
debug_assert!(!children.is_empty());
let num_rows = children[0].num_rows();
debug_assert!(children.iter().all(|child| child.num_rows() == num_rows));
Self {
children,
child_fields,
num_rows,
}
}
}
impl FieldScheduler for SimpleStructScheduler {
fn schedule_ranges<'a>(
&'a self,
ranges: &[Range<u64>],
filter: &FilterExpression,
) -> Result<Box<dyn SchedulingJob + 'a>> {
let child_schedulers = self
.children
.iter()
.map(|child| child.schedule_ranges(ranges, filter))
.collect::<Result<Vec<_>>>()?;
let num_rows = child_schedulers[0].num_rows();
Ok(Box::new(SimpleStructSchedulerJob::new(
self,
child_schedulers,
num_rows,
)))
}
fn num_rows(&self) -> u64 {
self.num_rows
}
}
#[derive(Debug)]
struct ChildState {
scheduled: VecDeque<Box<dyn LogicalPageDecoder>>,
rows_unawaited: u64,
rows_available: u64,
field_index: u32,
}
struct CompositeDecodeTask {
tasks: Vec<Box<dyn DecodeArrayTask>>,
num_rows: u64,
has_more: bool,
}
impl CompositeDecodeTask {
fn decode(self) -> Result<ArrayRef> {
let arrays = self
.tasks
.into_iter()
.map(|task| task.decode())
.collect::<Result<Vec<_>>>()?;
let array_refs = arrays.iter().map(|arr| arr.as_ref()).collect::<Vec<_>>();
Ok(arrow_select::concat::concat(&array_refs)?)
}
}
impl ChildState {
fn new(num_rows: u64, field_index: u32) -> Self {
Self {
scheduled: VecDeque::new(),
rows_unawaited: num_rows,
rows_available: 0,
field_index,
}
}
async fn wait(&mut self, num_rows: u64) -> Result<()> {
trace!(
"Struct child {} waiting for {} rows and {} are available already",
self.field_index,
num_rows,
self.rows_available
);
let mut remaining = num_rows.saturating_sub(self.rows_available);
for next_decoder in &mut self.scheduled {
if next_decoder.unawaited() > 0 {
let rows_to_wait = remaining.min(next_decoder.unawaited());
trace!(
"Struct await an additional {} rows from the current page",
rows_to_wait
);
let previously_avail = next_decoder.avail();
next_decoder.wait(rows_to_wait).await?;
let newly_avail = next_decoder.avail() - previously_avail;
trace!("The await loaded {} rows", newly_avail);
self.rows_available += newly_avail;
self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail);
remaining -= rows_to_wait;
if remaining == 0 {
break;
}
}
}
if remaining > 0 {
Err(Error::Internal { message: format!("The struct field at index {} is still waiting for {} rows but ran out of scheduled pages", self.field_index, remaining), location: location!() })
} else {
Ok(())
}
}
fn drain(&mut self, num_rows: u64) -> Result<CompositeDecodeTask> {
trace!("Struct draining {} rows", num_rows);
debug_assert!(self.rows_available >= num_rows);
self.rows_available -= num_rows;
let mut remaining = num_rows;
let mut composite = CompositeDecodeTask {
tasks: Vec::new(),
num_rows: 0,
has_more: true,
};
while remaining > 0 {
let next = self.scheduled.front_mut().unwrap();
let rows_to_take = remaining.min(next.avail());
let next_task = next.drain(rows_to_take)?;
if next.avail() == 0 && next.unawaited() == 0 {
trace!("Completely drained page");
self.scheduled.pop_front();
}
remaining -= rows_to_take;
composite.tasks.push(next_task.task);
composite.num_rows += next_task.num_rows;
}
composite.has_more = self.rows_available != 0 || self.rows_unawaited != 0;
Ok(composite)
}
}
#[derive(Debug)]
pub struct SimpleStructDecoder {
children: Vec<ChildState>,
child_fields: Fields,
data_type: DataType,
}
impl SimpleStructDecoder {
pub fn new(child_fields: Fields, num_rows: u64) -> Self {
let data_type = DataType::Struct(child_fields.clone());
Self {
children: child_fields
.iter()
.enumerate()
.map(|(idx, _)| ChildState::new(num_rows, idx as u32))
.collect(),
child_fields,
data_type,
}
}
}
impl LogicalPageDecoder for SimpleStructDecoder {
fn accept_child(&mut self, mut child: DecoderReady) -> Result<()> {
let child_idx = child.path.pop_front().unwrap();
if child.path.is_empty() {
self.children[child_idx as usize]
.scheduled
.push_back(child.decoder);
} else {
let intended = self.children[child_idx as usize].scheduled.back_mut().ok_or_else(|| Error::Internal { message: format!("Decoder scheduled for child at index {} but we don't have any child at that index yet", child_idx), location: location!() })?;
intended.accept_child(child)?;
}
Ok(())
}
fn wait(&mut self, num_rows: u64) -> BoxFuture<Result<()>> {
async move {
for child in self.children.iter_mut() {
child.wait(num_rows).await?;
}
Ok(())
}
.boxed()
}
fn drain(&mut self, num_rows: u64) -> Result<NextDecodeTask> {
let child_tasks = self
.children
.iter_mut()
.map(|child| child.drain(num_rows))
.collect::<Result<Vec<_>>>()?;
let num_rows = child_tasks[0].num_rows;
let has_more = child_tasks[0].has_more;
debug_assert!(child_tasks.iter().all(|task| task.num_rows == num_rows));
debug_assert!(child_tasks.iter().all(|task| task.has_more == has_more));
Ok(NextDecodeTask {
task: Box::new(SimpleStructDecodeTask {
children: child_tasks,
child_fields: self.child_fields.clone(),
}),
num_rows,
has_more,
})
}
fn avail(&self) -> u64 {
self.children
.iter()
.map(|c| c.rows_available)
.min()
.unwrap()
}
fn unawaited(&self) -> u64 {
self.children
.iter()
.map(|c| c.rows_unawaited)
.max()
.unwrap()
}
fn data_type(&self) -> &DataType {
&self.data_type
}
}
struct SimpleStructDecodeTask {
children: Vec<CompositeDecodeTask>,
child_fields: Fields,
}
impl DecodeArrayTask for SimpleStructDecodeTask {
fn decode(self: Box<Self>) -> Result<ArrayRef> {
let child_arrays = self
.children
.into_iter()
.map(|child| child.decode())
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(StructArray::try_new(
self.child_fields,
child_arrays,
None,
)?))
}
}
pub struct StructFieldEncoder {
children: Vec<Box<dyn FieldEncoder>>,
column_index: u32,
num_rows_seen: u64,
}
impl StructFieldEncoder {
#[allow(dead_code)]
pub fn new(children: Vec<Box<dyn FieldEncoder>>, column_index: u32) -> Self {
Self {
children,
column_index,
num_rows_seen: 0,
}
}
}
impl FieldEncoder for StructFieldEncoder {
fn maybe_encode(&mut self, array: ArrayRef) -> Result<Vec<EncodeTask>> {
self.num_rows_seen += array.len() as u64;
let struct_array = array.as_struct();
let child_tasks = self
.children
.iter_mut()
.zip(struct_array.columns().iter())
.map(|(encoder, arr)| encoder.maybe_encode(arr.clone()))
.collect::<Result<Vec<_>>>()?;
Ok(child_tasks.into_iter().flatten().collect::<Vec<_>>())
}
fn flush(&mut self) -> Result<Vec<EncodeTask>> {
let child_tasks = self
.children
.iter_mut()
.map(|encoder| encoder.flush())
.collect::<Result<Vec<_>>>()?;
let mut child_tasks = child_tasks.into_iter().flatten().collect::<Vec<_>>();
let num_rows_seen = self.num_rows_seen;
let column_index = self.column_index;
child_tasks.push(
std::future::ready(Ok(EncodedPage {
array: EncodedArray {
buffers: vec![],
encoding: pb::ArrayEncoding {
array_encoding: Some(pb::array_encoding::ArrayEncoding::Struct(
pb::SimpleStruct {},
)),
},
},
num_rows: num_rows_seen,
column_idx: column_index,
}))
.boxed(),
);
Ok(child_tasks)
}
fn num_columns(&self) -> u32 {
self.children
.iter()
.map(|child| child.num_columns())
.sum::<u32>()
+ 1
}
fn finish(&mut self) -> BoxFuture<'_, Result<Vec<crate::encoder::EncodedColumn>>> {
async move {
let mut columns = Vec::new();
columns.push(EncodedColumn::default());
for child in self.children.iter_mut() {
columns.extend(child.finish().await?);
}
Ok(columns)
}
.boxed()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{
builder::{Int32Builder, ListBuilder},
Array, ArrayRef, Int32Array, StructArray,
};
use arrow_schema::{DataType, Field, Fields};
use crate::testing::{
check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases,
};
#[test_log::test(tokio::test)]
async fn test_simple_struct() {
let data_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
let field = Field::new("", data_type, false);
check_round_trip_encoding_random(field).await;
}
#[test_log::test(tokio::test)]
async fn test_struct_list() {
let data_type = DataType::Struct(Fields::from(vec![
Field::new(
"inner_list",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
Field::new("outer_int", DataType::Int32, true),
]));
let field = Field::new("row", data_type, false);
check_round_trip_encoding_random(field).await;
}
#[test_log::test(tokio::test)]
async fn test_complicated_struct() {
let data_type = DataType::Struct(Fields::from(vec![
Field::new("int", DataType::Int32, true),
Field::new(
"inner",
DataType::Struct(Fields::from(vec![
Field::new("inner_int", DataType::Int32, true),
Field::new(
"inner_list",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
])),
true,
),
Field::new("outer_binary", DataType::Binary, true),
]));
let field = Field::new("row", data_type, false);
check_round_trip_encoding_random(field).await;
}
#[test_log::test(tokio::test)]
async fn test_ragged_scheduling() {
let items_builder = Int32Builder::new();
let mut list_builder = ListBuilder::new(items_builder);
for _ in 0..10000 {
list_builder.append_null();
}
let list_array = Arc::new(list_builder.finish());
let int_array = Arc::new(Int32Array::from_iter_values(0..10000));
let fields = vec![
Field::new("", list_array.data_type().clone(), true),
Field::new("", int_array.data_type().clone(), true),
];
let struct_array = Arc::new(StructArray::new(
Fields::from(fields),
vec![list_array, int_array],
None,
)) as ArrayRef;
let struct_arrays = (0..10000)
.step_by(437)
.map(|offset| struct_array.slice(offset, 437.min(10000 - offset)))
.collect::<Vec<_>>();
check_round_trip_encoding_of_data(struct_arrays, &TestCases::default()).await;
}
}