#![allow(clippy::redundant_closure_call)]
use rust_decimal::Decimal;
use thiserror::Error;
use std::convert::TryInto;
use std::iter::Map;
use std::num::TryFromIntError;
use std::slice::IterMut;
use std::str::Utf8Error;
use errors::OrcError;
use kind::Kind;
use vector::{BorrowedColumnVectorBatch, ColumnVectorBatch, DecimalVectorBatch, StructVectorBatch};
#[derive(Debug, Error, PartialEq)]
pub enum DeserializationError {
#[error("Mismatched ORC column type: {0}")]
MismatchedColumnKind(OrcError),
#[error("Field {0} is missing from ORC file")]
MissingField(String),
#[error("Number of items exceeds maximum buffer capacity on this platform: {0}")]
UsizeOverflow(TryFromIntError),
#[error("Failed to decode ORC byte string as UTF-8: {0}")]
Utf8Error(Utf8Error),
#[error("Unexpected null value in ORC file: {0}")]
UnexpectedNull(String),
#[error("Tried to deserialize {src}-long buffer into {dst}-long buffer")]
MismatchedLength { src: u64, dst: u64 },
}
fn check_kind_equals(
got_kind: &Kind,
expected_kinds: &[Kind],
type_name: &str,
) -> Result<(), String> {
if expected_kinds.contains(got_kind) {
Ok(())
} else {
Err(format!(
"{} must be decoded from ORC {}, not ORC {:?}",
type_name,
expected_kinds
.iter()
.map(|k| format!("{:?}", k))
.collect::<Vec<_>>()
.join("/"),
got_kind
))
}
}
pub trait CheckableKind {
fn check_kind(kind: &Kind) -> Result<(), String>;
}
impl<T: CheckableKind> CheckableKind for Option<T> {
fn check_kind(kind: &Kind) -> Result<(), String> {
T::check_kind(kind)
}
}
pub trait OrcStruct {
fn columns() -> Vec<String> {
Self::columns_with_prefix("")
}
fn columns_with_prefix(prefix: &str) -> Vec<String>;
}
impl<T: OrcStruct> OrcStruct for Option<T> {
fn columns_with_prefix(prefix: &str) -> Vec<String> {
T::columns_with_prefix(prefix)
}
}
pub trait OrcDeserialize: Sized + Default + CheckableKind {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
Self: 'a,
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b;
fn from_vector_batch(
vector_batch: &BorrowedColumnVectorBatch,
) -> Result<Vec<Self>, DeserializationError> {
let num_elements = vector_batch.num_elements();
let num_elements = num_elements
.try_into()
.map_err(DeserializationError::UsizeOverflow)?;
let mut values = Vec::with_capacity(num_elements);
values.resize_with(num_elements, Default::default);
Self::read_from_vector_batch(vector_batch, &mut values)?;
Ok(values)
}
}
macro_rules! impl_scalar {
($ty:ty, $kind:expr, $method:ident) => {
impl_scalar!($ty, $kind, $method, |s| Ok(s as $ty));
};
($ty:ty, $kind:expr, $method:ident, $cast:expr) => {
impl OrcStruct for $ty {
fn columns_with_prefix(prefix: &str) -> Vec<String> {
vec![prefix.to_string()]
}
}
impl CheckableKind for $ty {
fn check_kind(kind: &Kind) -> Result<(), String> {
check_kind_equals(kind, &$kind, stringify!($ty))
}
}
impl OrcDeserialize for $ty {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b,
{
let src = src
.$method()
.map_err(DeserializationError::MismatchedColumnKind)?;
match src.try_iter_not_null() {
None => Err(DeserializationError::UnexpectedNull(format!(
"{} column contains nulls",
stringify!($ty)
))),
Some(it) => {
for (s, d) in it.zip(dst.iter_mut()) {
*d = ($cast)(s)?
}
Ok(src.num_elements().try_into().unwrap())
}
}
}
}
impl OrcDeserialize for Option<$ty> {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b,
{
let src = src
.$method()
.map_err(DeserializationError::MismatchedColumnKind)?;
for (s, d) in src.iter().zip(dst.iter_mut()) {
match s {
None => *d = None,
Some(s) => *d = Some(($cast)(s)?),
}
}
Ok(src.num_elements().try_into().unwrap())
}
}
};
}
impl_scalar!(bool, [Kind::Boolean], try_into_longs, |s| Ok(s != 0));
impl_scalar!(i8, [Kind::Byte], try_into_longs);
impl_scalar!(i16, [Kind::Short], try_into_longs);
impl_scalar!(i32, [Kind::Int], try_into_longs);
impl_scalar!(i64, [Kind::Long], try_into_longs);
impl_scalar!(f32, [Kind::Float], try_into_doubles);
impl_scalar!(f64, [Kind::Double], try_into_doubles);
impl_scalar!(String, [Kind::String], try_into_strings, |s| {
std::str::from_utf8(s)
.map_err(DeserializationError::Utf8Error)
.map(|s| s.to_string())
});
impl_scalar!(Vec<u8>, [Kind::Binary], try_into_strings, |s: &[u8]| Ok(
s.to_vec()
));
impl_scalar!(
crate::Timestamp,
[Kind::Timestamp],
try_into_timestamps,
|s: (i64, i64)| Ok(crate::Timestamp {
seconds: s.0,
nanoseconds: s.1
})
);
impl OrcStruct for Decimal {
fn columns_with_prefix(prefix: &str) -> Vec<String> {
vec![prefix.to_string()]
}
}
impl CheckableKind for Decimal {
fn check_kind(kind: &Kind) -> Result<(), String> {
match kind {
Kind::Decimal { .. } => Ok(()),
_ => Err(format!(
"Decimal must be decoded from ORC Decimal, not ORC {:?}",
kind
)),
}
}
}
impl OrcDeserialize for Decimal {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b,
{
match src.try_into_decimals64() {
Ok(src) => match src.try_iter_not_null() {
None => {
return Err(DeserializationError::UnexpectedNull(
"Decimal column contains nulls".to_string(),
))
}
Some(it) => {
for (s, d) in it.zip(dst.iter_mut()) {
*d = s;
}
}
},
Err(_) => {
let src = src
.try_into_decimals128()
.map_err(DeserializationError::MismatchedColumnKind)?;
match src.try_iter_not_null() {
None => {
return Err(DeserializationError::UnexpectedNull(
"Decimal column contains nulls".to_string(),
))
}
Some(it) => {
for (s, d) in it.zip(dst.iter_mut()) {
*d = s;
}
}
}
}
}
Ok(src.num_elements().try_into().unwrap())
}
}
impl OrcDeserialize for Option<Decimal> {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b,
{
match src.try_into_decimals64() {
Ok(src) => {
for (s, d) in src.iter().zip(dst.iter_mut()) {
match s {
None => *d = None,
Some(s) => *d = Some(s),
}
}
}
Err(_) => {
let src = src
.try_into_decimals128()
.map_err(DeserializationError::MismatchedColumnKind)?;
for (s, d) in src.iter().zip(dst.iter_mut()) {
match s {
None => *d = None,
Some(s) => *d = Some(s),
}
}
}
}
Ok(src.num_elements().try_into().unwrap())
}
}
impl<T: OrcStruct> OrcStruct for Vec<T> {
fn columns_with_prefix(prefix: &str) -> Vec<String> {
T::columns_with_prefix(prefix)
}
}
impl<T: CheckableKind> CheckableKind for Vec<T> {
fn check_kind(kind: &Kind) -> Result<(), String> {
match kind {
Kind::List(inner) => T::check_kind(inner),
_ => Err(format!("Must be a List, not {:?}", kind)),
}
}
}
macro_rules! init_list_read {
($src:expr, $dst: expr) => {{
let src = $src
.try_into_lists()
.map_err(DeserializationError::MismatchedColumnKind)?;
let num_lists: usize = src
.num_elements()
.try_into()
.map_err(DeserializationError::UsizeOverflow)?;
let num_elements: usize = src
.elements()
.num_elements()
.try_into()
.map_err(DeserializationError::UsizeOverflow)?;
if num_lists > $dst.len() {
return Err(DeserializationError::MismatchedLength {
src: num_lists as u64,
dst: $dst.len() as u64,
});
}
let mut elements = Vec::new();
elements.resize_with(num_elements, Default::default);
OrcDeserialize::read_from_vector_batch::<Vec<I>>(&src.elements(), &mut elements)?;
let elements = elements.into_iter();
(src, elements)
}};
}
macro_rules! build_list_item {
($range:expr, $last_offset:expr, $elements:expr) => {{
let range = $range;
assert_eq!(
range.start, $last_offset,
"Non-continuous list (jumped from offset {} to {}",
$last_offset, range.start
);
let mut array: Vec<I> = Vec::with_capacity((range.end - range.start) as usize);
for _ in range.clone() {
match $elements.next() {
Some(item) => {
array.push(item);
}
None => panic!(
"List too short (expected {} elements, got {})",
range.end - range.start,
array.len()
),
}
}
$last_offset = range.end;
array
}};
}
impl<I> OrcDeserializeOption for Vec<I>
where
I: Default + OrcDeserialize,
{
fn read_options_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Option<Self>> + 'b,
{
let (src, mut elements) = init_list_read!(src, dst);
let offsets = src.iter_offsets();
let mut dst = dst.iter_mut();
let mut last_offset = 0;
for offset in offsets {
let dst_item: &mut Option<Vec<I>> = unsafe { dst.next().unwrap_unchecked() };
match offset {
None => *dst_item = None,
Some(range) => {
*dst_item = Some(build_list_item!(range, last_offset, elements));
}
}
}
if elements.next().is_some() {
panic!("List too long");
}
Ok(src.num_elements().try_into().unwrap())
}
}
impl<I> OrcDeserialize for Vec<I>
where
I: OrcDeserialize,
{
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b,
{
let (src, mut elements) = init_list_read!(src, dst);
match src.try_iter_offsets_not_null() {
None => Err(DeserializationError::UnexpectedNull(format!(
"{} column contains nulls",
stringify!($ty)
))),
Some(offsets) => {
let mut dst = dst.iter_mut();
let mut last_offset = 0;
for range in offsets {
let dst_item: &mut Vec<I> = unsafe { dst.next().unwrap_unchecked() };
*dst_item = build_list_item!(range, last_offset, elements);
}
if elements.next().is_some() {
panic!("List too long");
}
Ok(src.num_elements().try_into().unwrap())
}
}
}
}
pub unsafe trait DeserializationTarget<'a> {
type Item: 'a;
type IterMut<'b>: Iterator<Item = &'b mut Self::Item>
where
Self: 'b,
'a: 'b;
fn len(&self) -> usize;
fn iter_mut(&mut self) -> Self::IterMut<'_>;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn map<B, F>(&mut self, f: F) -> MultiMap<Self, F>
where
Self: Sized,
F: FnMut(&mut Self::Item) -> &mut B,
{
MultiMap { iter: self, f }
}
}
unsafe impl<'a, V: Sized + 'a> DeserializationTarget<'a> for &mut Vec<V> {
type Item = V;
type IterMut<'b> = IterMut<'b, V> where V: 'b, 'a: 'b, Self: 'b;
fn len(&self) -> usize {
(self as &Vec<_>).len()
}
fn iter_mut(&mut self) -> IterMut<'_, V> {
<[_]>::iter_mut(self)
}
}
pub struct MultiMap<'c, T: Sized, F> {
iter: &'c mut T,
f: F,
}
unsafe impl<'a, 'c, V: Sized + 'a, V2: Sized + 'a, T, F> DeserializationTarget<'a>
for &mut MultiMap<'c, T, F>
where
F: Copy + for<'b> FnMut(&'b mut V) -> &'b mut V2,
T: DeserializationTarget<'a, Item = V>,
{
type Item = V2;
type IterMut<'b> = Map<T::IterMut<'b>, F> where T: 'b, 'a: 'b, F: 'b, Self: 'b;
fn len(&self) -> usize {
self.iter.len()
}
fn iter_mut(&mut self) -> Map<T::IterMut<'_>, F> {
self.iter.iter_mut().map(self.f)
}
}
pub fn default_option_vec<T: Default>(vector_batch: &StructVectorBatch) -> Vec<Option<T>> {
match vector_batch.not_null() {
None => (0..vector_batch.num_elements())
.map(|_| Some(Default::default()))
.collect(),
Some(not_null) => not_null
.iter()
.map(|&b| {
if b == 0 {
None
} else {
Some(Default::default())
}
})
.collect(),
}
}
pub trait OrcDeserializeOption: Sized + CheckableKind {
fn read_options_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
Self: 'a,
&'b mut T: DeserializationTarget<'a, Item = Option<Self>> + 'b;
}
impl<I: OrcDeserializeOption> OrcDeserialize for Option<I> {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self> + 'b,
I: 'a,
{
I::read_options_from_vector_batch(src, dst)
}
}
#[cfg(test)]
mod tests {
use super::*;
use kind::Kind;
use vector::BorrowedColumnVectorBatch;
#[test]
fn test_map_struct() {
#[derive(Default)]
struct Test {
field1: Option<i64>,
}
impl CheckableKind for Test {
fn check_kind(kind: &Kind) -> Result<(), String> {
check_kind_equals(
kind,
&[Kind::Struct(vec![("field1".to_owned(), Kind::Long)])],
"Vec<u8>",
)
}
}
impl OrcDeserialize for Option<Test> {
fn read_from_vector_batch<'a, 'b, T>(
src: &BorrowedColumnVectorBatch,
mut dst: &'b mut T,
) -> Result<usize, DeserializationError>
where
&'b mut T: DeserializationTarget<'a, Item = Self>,
{
let src = src
.try_into_structs()
.map_err(DeserializationError::MismatchedColumnKind)?;
let columns = src.fields();
let column: BorrowedColumnVectorBatch = columns.into_iter().next().unwrap();
OrcDeserialize::read_from_vector_batch::<MultiMap<&mut T, _>>(
&column,
&mut dst.map(|struct_| &mut struct_.as_mut().unwrap().field1),
)?;
Ok(src.num_elements().try_into().unwrap())
}
}
}
#[test]
fn test_check_kind() {
assert_eq!(i64::check_kind(&Kind::Long), Ok(()));
assert_eq!(crate::Timestamp::check_kind(&Kind::Timestamp), Ok(()));
assert_eq!(String::check_kind(&Kind::String), Ok(()));
assert_eq!(Vec::<u8>::check_kind(&Kind::Binary), Ok(()));
}
#[test]
fn test_check_kind_fail() {
assert_eq!(
i64::check_kind(&Kind::String),
Err("i64 must be decoded from ORC Long, not ORC String".to_string())
);
assert_eq!(
i64::check_kind(&Kind::Int),
Err("i64 must be decoded from ORC Long, not ORC Int".to_string())
);
assert_eq!(
String::check_kind(&Kind::Int),
Err("String must be decoded from ORC String, not ORC Int".to_string())
);
assert_eq!(
String::check_kind(&Kind::Binary),
Err("String must be decoded from ORC String, not ORC Binary".to_string())
);
assert_eq!(
Vec::<u8>::check_kind(&Kind::Int),
Err("Vec<u8> must be decoded from ORC Binary, not ORC Int".to_string())
);
assert_eq!(
Vec::<u8>::check_kind(&Kind::String),
Err("Vec<u8> must be decoded from ORC Binary, not ORC String".to_string())
);
}
}