use std::collections::HashMap;
use std::sync::Arc;
use crate::arrow::schema::extension::try_add_extension_type;
use crate::arrow::schema::primitive::convert_primitive;
use crate::arrow::schema::virtual_type::{RowGroupIndex, RowNumber};
use crate::arrow::{PARQUET_FIELD_ID_META_KEY, ProjectionMask};
use crate::basic::{ConvertedType, Repetition};
use crate::errors::ParquetError;
use crate::errors::Result;
use crate::schema::types::{SchemaDescriptor, Type, TypePtr};
use arrow_schema::{DataType, Field, Fields, SchemaBuilder, extension::ExtensionType};
fn get_repetition(t: &Type) -> Repetition {
let info = t.get_basic_info();
match info.has_repetition() {
true => info.repetition(),
false => Repetition::REQUIRED,
}
}
#[derive(Debug, Clone)]
pub struct ParquetField {
pub rep_level: i16,
pub def_level: i16,
pub nullable: bool,
pub arrow_type: DataType,
pub field_type: ParquetFieldType,
}
impl ParquetField {
fn into_list(self, name: &str) -> Self {
ParquetField {
rep_level: self.rep_level,
def_level: self.def_level,
nullable: false,
arrow_type: DataType::List(Arc::new(Field::new(name, self.arrow_type.clone(), false))),
field_type: ParquetFieldType::Group {
children: vec![self],
},
}
}
pub fn children(&self) -> Option<&[Self]> {
match &self.field_type {
ParquetFieldType::Primitive { .. } => None,
ParquetFieldType::Group { children } => Some(children),
ParquetFieldType::Virtual(_) => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum VirtualColumnType {
RowNumber,
RowGroupIndex,
}
#[derive(Debug, Clone)]
pub enum ParquetFieldType {
Primitive {
col_idx: usize,
primitive_type: TypePtr,
},
Group {
children: Vec<ParquetField>,
},
Virtual(VirtualColumnType),
}
struct VisitorContext {
rep_level: i16,
def_level: i16,
data_type: Option<DataType>,
}
impl VisitorContext {
fn levels(&self, repetition: Repetition) -> (i16, i16, bool) {
match repetition {
Repetition::OPTIONAL => (self.def_level + 1, self.rep_level, true),
Repetition::REQUIRED => (self.def_level, self.rep_level, false),
Repetition::REPEATED => (self.def_level + 1, self.rep_level + 1, false),
}
}
}
struct Visitor {
next_col_idx: usize,
mask: ProjectionMask,
}
impl Visitor {
fn visit_primitive(
&mut self,
primitive_type: &TypePtr,
context: VisitorContext,
) -> Result<Option<ParquetField>> {
let col_idx = self.next_col_idx;
self.next_col_idx += 1;
if !self.mask.leaf_included(col_idx) {
return Ok(None);
}
let repetition = get_repetition(primitive_type);
let (def_level, rep_level, nullable) = context.levels(repetition);
let arrow_type = convert_primitive(primitive_type, context.data_type)?;
let primitive_field = ParquetField {
rep_level,
def_level,
nullable,
arrow_type,
field_type: ParquetFieldType::Primitive {
primitive_type: primitive_type.clone(),
col_idx,
},
};
Ok(Some(match repetition {
Repetition::REPEATED => primitive_field.into_list(primitive_type.name()),
_ => primitive_field,
}))
}
fn visit_struct(
&mut self,
struct_type: &TypePtr,
context: VisitorContext,
) -> Result<Option<ParquetField>> {
let repetition = get_repetition(struct_type);
let (def_level, rep_level, nullable) = context.levels(repetition);
let parquet_fields = struct_type.get_fields();
let arrow_fields = match &context.data_type {
Some(DataType::Struct(fields)) => {
if fields.len() != parquet_fields.len() {
return Err(arrow_err!(
"incompatible arrow schema, expected {} struct fields got {}",
parquet_fields.len(),
fields.len()
));
}
Some(fields)
}
Some(d) => {
return Err(arrow_err!(
"incompatible arrow schema, expected struct got {}",
d
));
}
None => None,
};
let mut child_fields = SchemaBuilder::with_capacity(parquet_fields.len());
let mut children = Vec::with_capacity(parquet_fields.len());
for (idx, parquet_field) in parquet_fields.iter().enumerate() {
let data_type = match arrow_fields {
Some(fields) => {
let field = &fields[idx];
if field.name() != parquet_field.name() {
return Err(arrow_err!(
"incompatible arrow schema, expected field named {} got {}",
parquet_field.name(),
field.name()
));
}
Some(field.data_type().clone())
}
None => None,
};
let arrow_field = arrow_fields.map(|x| &*x[idx]);
let child_ctx = VisitorContext {
rep_level,
def_level,
data_type,
};
if let Some(mut child) = self.dispatch(parquet_field, child_ctx)? {
child_fields.push(convert_field(parquet_field, &mut child, arrow_field)?);
children.push(child);
}
}
if children.is_empty() {
return Ok(None);
}
let struct_field = ParquetField {
rep_level,
def_level,
nullable,
arrow_type: DataType::Struct(child_fields.finish().fields),
field_type: ParquetFieldType::Group { children },
};
Ok(Some(match repetition {
Repetition::REPEATED => struct_field.into_list(struct_type.name()),
_ => struct_field,
}))
}
fn visit_map(
&mut self,
map_type: &TypePtr,
context: VisitorContext,
) -> Result<Option<ParquetField>> {
let rep_level = context.rep_level + 1;
let (def_level, nullable) = match get_repetition(map_type) {
Repetition::REQUIRED => (context.def_level + 1, false),
Repetition::OPTIONAL => (context.def_level + 2, true),
Repetition::REPEATED => return Err(arrow_err!("Map cannot be repeated")),
};
if map_type.get_fields().len() != 1 {
return Err(arrow_err!(
"Map field must have exactly one key_value child, found {}",
map_type.get_fields().len()
));
}
let map_key_value = &map_type.get_fields()[0];
if map_key_value.get_basic_info().repetition() != Repetition::REPEATED {
return Err(arrow_err!("Child of map field must be repeated"));
}
if map_key_value.get_fields().len() == 1 {
return self.visit_list(map_type, context);
}
if map_key_value.get_fields().len() != 2 {
return Err(arrow_err!(
"Child of map field must have two children, found {}",
map_key_value.get_fields().len()
));
}
let map_key = &map_key_value.get_fields()[0];
let map_value = &map_key_value.get_fields()[1];
match map_key.get_basic_info().repetition() {
Repetition::REPEATED => {
return Err(arrow_err!("Map keys cannot be repeated"));
}
Repetition::REQUIRED | Repetition::OPTIONAL => {
}
}
if map_value.get_basic_info().repetition() == Repetition::REPEATED {
return Err(arrow_err!("Map values cannot be repeated"));
}
let (arrow_map, arrow_key, arrow_value, sorted) = match &context.data_type {
Some(DataType::Map(field, sorted)) => match field.data_type() {
DataType::Struct(fields) => {
if fields.len() != 2 {
return Err(arrow_err!(
"Map data type should contain struct with two children, got {}",
fields.len()
));
}
(Some(field), Some(&*fields[0]), Some(&*fields[1]), *sorted)
}
d => {
return Err(arrow_err!("Map data type should contain struct got {}", d));
}
},
Some(d) => {
return Err(arrow_err!(
"incompatible arrow schema, expected map got {}",
d
));
}
None => (None, None, None, false),
};
let maybe_key = {
let context = VisitorContext {
rep_level,
def_level,
data_type: arrow_key.map(|x| x.data_type().clone()),
};
self.dispatch(map_key, context)?
};
let maybe_value = {
let context = VisitorContext {
rep_level,
def_level,
data_type: arrow_value.map(|x| x.data_type().clone()),
};
self.dispatch(map_value, context)?
};
match (maybe_key, maybe_value) {
(Some(mut key), Some(mut value)) => {
let key_field = Arc::new(
convert_field(map_key, &mut key, arrow_key)?
.with_nullable(false),
);
let value_field = Arc::new(convert_field(map_value, &mut value, arrow_value)?);
let field_metadata = match arrow_map {
Some(field) => field.metadata().clone(),
_ => HashMap::default(),
};
let map_field = Field::new_struct(
map_key_value.name(),
[key_field, value_field],
false, )
.with_metadata(field_metadata);
Ok(Some(ParquetField {
rep_level,
def_level,
nullable,
arrow_type: DataType::Map(Arc::new(map_field), sorted),
field_type: ParquetFieldType::Group {
children: vec![key, value],
},
}))
}
_ => Ok(None),
}
}
fn visit_list(
&mut self,
list_type: &TypePtr,
context: VisitorContext,
) -> Result<Option<ParquetField>> {
if list_type.is_primitive() {
return Err(arrow_err!(
"{:?} is a list type and can't be processed as primitive.",
list_type
));
}
let fields = list_type.get_fields();
if fields.len() != 1 {
return Err(arrow_err!(
"list type must have a single child, found {}",
fields.len()
));
}
let repeated_field = &fields[0];
if get_repetition(repeated_field) != Repetition::REPEATED {
return Err(arrow_err!("List child must be repeated"));
}
let (def_level, nullable) = match list_type.get_basic_info().repetition() {
Repetition::REQUIRED => (context.def_level, false),
Repetition::OPTIONAL => (context.def_level + 1, true),
Repetition::REPEATED => return Err(arrow_err!("List type cannot be repeated")),
};
let arrow_field = match &context.data_type {
Some(DataType::List(f)) => Some(f.as_ref()),
Some(DataType::LargeList(f)) => Some(f.as_ref()),
Some(DataType::FixedSizeList(f, _)) => Some(f.as_ref()),
Some(DataType::ListView(f)) => Some(f.as_ref()),
Some(DataType::LargeListView(f)) => Some(f.as_ref()),
Some(d) => {
return Err(arrow_err!(
"incompatible arrow schema, expected list got {}",
d
));
}
None => None,
};
if repeated_field.is_primitive() {
let context = VisitorContext {
rep_level: context.rep_level,
def_level,
data_type: arrow_field.map(|f| f.data_type().clone()),
};
return match self.visit_primitive(repeated_field, context) {
Ok(Some(mut field)) => {
field.nullable = nullable;
Ok(Some(field))
}
r => r,
};
}
let items = repeated_field.get_fields();
if items.len() != 1
|| (!repeated_field.is_list()
&& !repeated_field.has_single_repeated_child()
&& (repeated_field.name() == "array"
|| repeated_field.name() == format!("{}_tuple", list_type.name())))
{
let context = VisitorContext {
rep_level: context.rep_level,
def_level,
data_type: arrow_field.map(|f| f.data_type().clone()),
};
return match self.visit_struct(repeated_field, context) {
Ok(Some(mut field)) => {
field.nullable = nullable;
Ok(Some(field))
}
r => r,
};
}
let item_type = &items[0];
let rep_level = context.rep_level + 1;
let def_level = def_level + 1;
let new_context = VisitorContext {
def_level,
rep_level,
data_type: arrow_field.map(|f| f.data_type().clone()),
};
match self.dispatch(item_type, new_context) {
Ok(Some(mut item)) => {
let item_field = Arc::new(convert_field(item_type, &mut item, arrow_field)?);
let arrow_type = match context.data_type {
Some(DataType::LargeList(_)) => DataType::LargeList(item_field),
Some(DataType::FixedSizeList(_, len)) => {
DataType::FixedSizeList(item_field, len)
}
Some(DataType::ListView(_)) => DataType::ListView(item_field),
Some(DataType::LargeListView(_)) => DataType::LargeListView(item_field),
_ => DataType::List(item_field),
};
Ok(Some(ParquetField {
rep_level,
def_level,
nullable,
arrow_type,
field_type: ParquetFieldType::Group {
children: vec![item],
},
}))
}
r => r,
}
}
fn dispatch(
&mut self,
cur_type: &TypePtr,
context: VisitorContext,
) -> Result<Option<ParquetField>> {
if cur_type.is_primitive() {
self.visit_primitive(cur_type, context)
} else {
match cur_type.get_basic_info().converted_type() {
ConvertedType::LIST => self.visit_list(cur_type, context),
ConvertedType::MAP | ConvertedType::MAP_KEY_VALUE => {
self.visit_map(cur_type, context)
}
_ => self.visit_struct(cur_type, context),
}
}
}
}
pub(super) fn convert_virtual_field(
arrow_field: &Field,
parent_rep_level: i16,
parent_def_level: i16,
) -> Result<ParquetField> {
let nullable = arrow_field.is_nullable();
let def_level = if nullable {
parent_def_level + 1
} else {
parent_def_level
};
let extension_name = arrow_field.extension_type_name().ok_or_else(|| {
ParquetError::ArrowError(format!(
"virtual column field '{}' must have an extension type",
arrow_field.name()
))
})?;
let virtual_type = match extension_name {
RowNumber::NAME => VirtualColumnType::RowNumber,
RowGroupIndex::NAME => VirtualColumnType::RowGroupIndex,
_ => {
return Err(ParquetError::ArrowError(format!(
"unsupported virtual column type '{}' for field '{}'",
extension_name,
arrow_field.name()
)));
}
};
Ok(ParquetField {
rep_level: parent_rep_level,
def_level,
nullable,
arrow_type: arrow_field.data_type().clone(),
field_type: ParquetFieldType::Virtual(virtual_type),
})
}
fn convert_field(
parquet_type: &Type,
field: &mut ParquetField,
arrow_hint: Option<&Field>,
) -> Result<Field, ParquetError> {
let name = parquet_type.name();
let data_type = field.arrow_type.clone();
let nullable = field.nullable;
match arrow_hint {
Some(hint) => {
#[allow(deprecated)]
let field = match (&data_type, hint.dict_id(), hint.dict_is_ordered()) {
(DataType::Dictionary(_, _), Some(id), Some(ordered)) =>
{
#[allow(deprecated)]
Field::new_dict(name, data_type, nullable, id, ordered)
}
_ => Field::new(name, data_type, nullable),
};
Ok(field.with_metadata(hint.metadata().clone()))
}
None => {
let mut ret = Field::new(name, data_type, nullable);
let basic_info = parquet_type.get_basic_info();
if basic_info.has_id() {
let mut meta = HashMap::with_capacity(1);
meta.insert(
PARQUET_FIELD_ID_META_KEY.to_string(),
basic_info.id().to_string(),
);
ret.set_metadata(meta);
}
try_add_extension_type(ret, parquet_type)
}
}
}
pub fn convert_schema(
schema: &SchemaDescriptor,
mask: ProjectionMask,
embedded_arrow_schema: Option<&Fields>,
) -> Result<Option<ParquetField>> {
let mut visitor = Visitor {
next_col_idx: 0,
mask,
};
let context = VisitorContext {
rep_level: 0,
def_level: 0,
data_type: embedded_arrow_schema.map(|fields| DataType::Struct(fields.clone())),
};
visitor.dispatch(&schema.root_schema_ptr(), context)
}
pub fn convert_type(parquet_type: &TypePtr) -> Result<ParquetField> {
let mut visitor = Visitor {
next_col_idx: 0,
mask: ProjectionMask::all(),
};
let context = VisitorContext {
rep_level: 0,
def_level: 0,
data_type: None,
};
Ok(visitor.dispatch(parquet_type, context)?.unwrap())
}