use std::sync::Arc;
use arrow::array::{ArrayRef, BooleanBufferBuilder, UnionArray};
use arrow::buffer::{Buffer, NullBuffer};
use arrow::datatypes::UnionFields;
use snafu::ResultExt;
use crate::column::Column;
use crate::encoding::byte::ByteRleDecoder;
use crate::encoding::PrimitiveValueDecoder;
use crate::error::ArrowSnafu;
use crate::error::Result;
use crate::proto::stream::Kind;
use crate::stripe::Stripe;
use super::{array_decoder_factory, derive_present_vec, ArrayBatchDecoder, PresentDecoder};
pub struct UnionArrayDecoder {
fields: UnionFields,
variants: Vec<Box<dyn ArrayBatchDecoder>>,
tags: Box<dyn PrimitiveValueDecoder<i8> + Send>,
present: Option<PresentDecoder>,
}
impl UnionArrayDecoder {
pub fn new(column: &Column, fields: UnionFields, stripe: &Stripe) -> Result<Self> {
let present = PresentDecoder::from_stripe(stripe, column);
let tags = stripe.stream_map().get(column, Kind::Data);
let tags = Box::new(ByteRleDecoder::new(tags));
let variants = column
.children()
.iter()
.zip(fields.iter())
.map(|(child, (_, field))| array_decoder_factory(child, field.data_type(), stripe))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
fields,
variants,
tags,
present,
})
}
}
impl ArrayBatchDecoder for UnionArrayDecoder {
fn next_batch(
&mut self,
batch_size: usize,
parent_present: Option<&NullBuffer>,
) -> Result<ArrayRef> {
let present =
derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?;
let mut tags = vec![0; batch_size];
match &present {
Some(present) => {
self.tags.decode_spaced(&mut tags, present)?;
}
None => {
self.tags.decode(&mut tags)?;
}
}
let mut children_nullability = (0..self.variants.len())
.map(|index| {
let mut child_present = BooleanBufferBuilder::new(batch_size);
child_present.append_n(batch_size, false);
for idx in tags
.iter()
.enumerate()
.filter_map(|(idx, &tag)| (tag as usize == index).then_some(idx))
{
child_present.set_bit(idx, true);
}
child_present
})
.collect::<Vec<_>>();
if let Some(present) = &present {
let first_child = &mut children_nullability[0];
for idx in present
.iter()
.enumerate()
.filter_map(|(idx, parent_present)| (!parent_present).then_some(idx))
{
first_child.set_bit(idx, false);
}
}
let child_arrays = self
.variants
.iter_mut()
.zip(children_nullability)
.map(|(decoder, mut present)| {
let present = NullBuffer::from(present.finish());
decoder.next_batch(batch_size, Some(&present))
})
.collect::<Result<Vec<_>>>()?;
let type_ids = Buffer::from_vec(tags.clone()).into();
let array = UnionArray::try_new(self.fields.clone(), type_ids, None, child_arrays)
.context(ArrowSnafu)?;
let array = Arc::new(array);
Ok(array)
}
fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> {
use super::derive_present_vec;
let present = derive_present_vec(&mut self.present, parent_present, n).transpose()?;
let non_null_count = if let Some(present) = &present {
present.len() - present.null_count()
} else {
n
};
self.tags.skip(non_null_count)?;
for decoder in &mut self.variants {
decoder.skip_values(n, present.as_ref())?;
}
Ok(())
}
}