arrow2 0.7.1

Unofficial implementation of Apache Arrow spec in safe Rust
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines take kernel for [`Array`]

use crate::{
    array::{new_empty_array, Array, NullArray, PrimitiveArray},
    datatypes::DataType,
    error::Result,
    types::Index,
};

mod binary;
mod boolean;
mod dict;
mod generic_binary;
mod list;
mod primitive;
mod structure;
mod utf8;

/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls.
/// The returned array has a length equal to `indices.len()`.
pub fn take<O: Index>(values: &dyn Array, indices: &PrimitiveArray<O>) -> Result<Box<dyn Array>> {
    if indices.len() == 0 {
        return Ok(new_empty_array(values.data_type().clone()));
    }

    use crate::datatypes::PhysicalType::*;
    match values.data_type().to_physical_type() {
        Null => Ok(Box::new(NullArray::from_data(
            values.data_type().clone(),
            indices.len(),
        ))),
        Boolean => {
            let values = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(boolean::take::<O>(values, indices)))
        }
        Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
            let values = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(primitive::take::<$T, _>(&values, indices)))
        }),
        Utf8 => {
            let values = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(utf8::take::<i32, _>(values, indices)))
        }
        LargeUtf8 => {
            let values = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(utf8::take::<i64, _>(values, indices)))
        }
        Binary => {
            let values = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(binary::take::<i32, _>(values, indices)))
        }
        LargeBinary => {
            let values = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(binary::take::<i64, _>(values, indices)))
        }
        Dictionary(key_type) => {
            with_match_physical_dictionary_key_type!(key_type, |$T| {
                let values = values.as_any().downcast_ref().unwrap();
                Ok(Box::new(dict::take::<$T, _>(&values, indices)))
            })
        }
        Struct => {
            let array = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(structure::take::<_>(array, indices)?))
        }
        List => {
            let array = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(list::take::<i32, O>(array, indices)))
        }
        LargeList => {
            let array = values.as_any().downcast_ref().unwrap();
            Ok(Box::new(list::take::<i64, O>(array, indices)))
        }
        t => unimplemented!("Take not supported for data type {:?}", t),
    }
}

/// Checks if an array of type `datatype` can perform take operation
///
/// # Examples
/// ```
/// use arrow2::compute::take::can_take;
/// use arrow2::datatypes::{DataType};
///
/// let data_type = DataType::Int8;
/// assert_eq!(can_take(&data_type), true);
/// ```
pub fn can_take(data_type: &DataType) -> bool {
    match data_type {
        DataType::Null
        | DataType::Boolean
        | DataType::Int8
        | DataType::Int16
        | DataType::Int32
        | DataType::Date32
        | DataType::Time32(_)
        | DataType::Interval(_)
        | DataType::Int64
        | DataType::Date64
        | DataType::Time64(_)
        | DataType::Duration(_)
        | DataType::Timestamp(_, _)
        | DataType::UInt8
        | DataType::UInt16
        | DataType::UInt32
        | DataType::UInt64
        | DataType::Float16
        | DataType::Float32
        | DataType::Float64
        | DataType::Decimal(_, _)
        | DataType::Utf8
        | DataType::LargeUtf8
        | DataType::Binary
        | DataType::LargeBinary
        | DataType::Struct(_)
        | DataType::List(_)
        | DataType::LargeList(_) => true,
        DataType::Dictionary(key_type, _) => matches!(
            key_type.as_ref(),
            DataType::Int8
                | DataType::Int16
                | DataType::Int32
                | DataType::Int64
                | DataType::UInt8
                | DataType::UInt16
                | DataType::UInt32
                | DataType::UInt64
        ),
        _ => false,
    }
}