parquet 58.1.0

Apache Parquet implementation in Rust
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! Utilities to traverse against various parquet type.

use crate::basic::{ConvertedType, Repetition};
use crate::errors::ParquetError::General;
use crate::errors::Result;
use crate::schema::types::{Type, TypePtr};

/// A utility trait to help user to traverse against parquet type.
pub trait TypeVisitor<R, C> {
    /// Called when a primitive type hit.
    fn visit_primitive(&mut self, primitive_type: TypePtr, context: C) -> Result<R>;

    /// Default implementation when visiting a list.
    ///
    /// It checks list type definition and calls [`Self::visit_list_with_item`] with extracted
    /// item type.
    ///
    /// To fully understand this algorithm, please refer to
    /// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md).
    ///
    /// For example, a standard list type looks like:
    ///
    /// ```text
    /// required/optional group my_list (LIST) {
    //    repeated group list {
    //      required/optional binary element (UTF8);
    //    }
    //  }
    /// ```
    ///
    /// In such a case, [`Self::visit_list_with_item`] will be called with `my_list` as the list
    /// type, and `element` as the `item_type`
    ///
    fn visit_list(&mut self, list_type: TypePtr, context: C) -> Result<R> {
        match list_type.as_ref() {
            Type::PrimitiveType { .. } => {
                panic!("{list_type:?} is a list type and must be a group type")
            }
            Type::GroupType {
                basic_info: _,
                fields,
            } if fields.len() == 1 => {
                let list_item = fields.first().unwrap();

                match list_item.as_ref() {
                    Type::PrimitiveType { .. } => {
                        if list_item.get_basic_info().repetition() == Repetition::REPEATED {
                            self.visit_list_with_item(list_type.clone(), list_item.clone(), context)
                        } else {
                            Err(General(
                                "Primitive element type of list must be repeated.".to_string(),
                            ))
                        }
                    }
                    Type::GroupType {
                        basic_info: _,
                        fields,
                    } => {
                        if fields.len() == 1
                            && list_item.name() != "array"
                            && list_item.name() != format!("{}_tuple", list_type.name())
                        {
                            self.visit_list_with_item(
                                list_type.clone(),
                                fields.first().unwrap().clone(),
                                context,
                            )
                        } else {
                            self.visit_list_with_item(list_type.clone(), list_item.clone(), context)
                        }
                    }
                }
            }
            _ => Err(General(
                "Group element type of list can only contain one field.".to_string(),
            )),
        }
    }

    /// Called when a struct type hit.
    fn visit_struct(&mut self, struct_type: TypePtr, context: C) -> Result<R>;

    /// Called when a map type hit.
    fn visit_map(&mut self, map_type: TypePtr, context: C) -> Result<R>;

    /// A utility method which detects input type and calls corresponding method.
    fn dispatch(&mut self, cur_type: TypePtr, context: C) -> Result<R> {
        if cur_type.is_primitive() {
            self.visit_primitive(cur_type, context)
        } else {
            match cur_type.get_basic_info().converted_type() {
                ConvertedType::LIST => self.visit_list(cur_type, context),
                ConvertedType::MAP | ConvertedType::MAP_KEY_VALUE => {
                    self.visit_map(cur_type, context)
                }
                _ => self.visit_struct(cur_type, context),
            }
        }
    }

    /// Called by `visit_list`.
    fn visit_list_with_item(
        &mut self,
        list_type: TypePtr,
        item_type: TypePtr,
        context: C,
    ) -> Result<R>;
}

#[cfg(test)]
mod tests {
    use super::TypeVisitor;
    use crate::basic::Type as PhysicalType;
    use crate::errors::Result;
    use crate::schema::parser::parse_message_type;
    use crate::schema::types::TypePtr;
    use std::sync::Arc;

    struct TestVisitorContext {}
    struct TestVisitor {
        primitive_visited: bool,
        struct_visited: bool,
        list_visited: bool,
        root_type: TypePtr,
    }

    impl TypeVisitor<bool, TestVisitorContext> for TestVisitor {
        fn visit_primitive(
            &mut self,
            primitive_type: TypePtr,
            _context: TestVisitorContext,
        ) -> Result<bool> {
            assert_eq!(
                self.get_field_by_name(primitive_type.name()).as_ref(),
                primitive_type.as_ref()
            );
            self.primitive_visited = true;
            Ok(true)
        }

        fn visit_struct(
            &mut self,
            struct_type: TypePtr,
            _context: TestVisitorContext,
        ) -> Result<bool> {
            assert_eq!(
                self.get_field_by_name(struct_type.name()).as_ref(),
                struct_type.as_ref()
            );
            self.struct_visited = true;
            Ok(true)
        }

        fn visit_map(&mut self, _map_type: TypePtr, _context: TestVisitorContext) -> Result<bool> {
            unimplemented!()
        }

        fn visit_list_with_item(
            &mut self,
            list_type: TypePtr,
            item_type: TypePtr,
            _context: TestVisitorContext,
        ) -> Result<bool> {
            assert_eq!(
                self.get_field_by_name(list_type.name()).as_ref(),
                list_type.as_ref()
            );
            assert_eq!("element", item_type.name());
            assert_eq!(PhysicalType::INT32, item_type.get_physical_type());
            self.list_visited = true;
            Ok(true)
        }
    }

    impl TestVisitor {
        fn new(root: TypePtr) -> Self {
            Self {
                primitive_visited: false,
                struct_visited: false,
                list_visited: false,
                root_type: root,
            }
        }

        fn get_field_by_name(&self, name: &str) -> TypePtr {
            self.root_type
                .get_fields()
                .iter()
                .find(|t| t.name() == name)
                .cloned()
                .unwrap()
        }
    }

    #[test]
    fn test_visitor() {
        let message_type = "
          message spark_schema {
            REQUIRED INT32 a;
            OPTIONAL group inner_schema {
              REQUIRED INT32 b;
              REQUIRED DOUBLE c;
            }

            OPTIONAL group e (LIST) {
              REPEATED group list {
                REQUIRED INT32 element;
              }
            }
        ";

        let parquet_type = Arc::new(parse_message_type(message_type).unwrap());

        let mut visitor = TestVisitor::new(parquet_type.clone());
        for f in parquet_type.get_fields() {
            let c = TestVisitorContext {};
            assert!(visitor.dispatch(f.clone(), c).unwrap());
        }

        assert!(visitor.struct_visited);
        assert!(visitor.primitive_visited);
        assert!(visitor.list_visited);
    }
}