arrow 6.1.0

Rust implementation of Apache Arrow
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! Arrow Tensor Type, defined in
//! [`format/Tensor.fbs`](https://github.com/apache/arrow/blob/master/format/Tensor.fbs).

use std::marker::PhantomData;
use std::mem;

use crate::buffer::Buffer;
use crate::datatypes::*;

use crate::error::{ArrowError, Result};

/// Computes the strides required assuming a row major memory layout
fn compute_row_major_strides<T: ArrowPrimitiveType>(
    shape: &[usize],
) -> Result<Vec<usize>> {
    let mut remaining_bytes = mem::size_of::<T::Native>();

    for i in shape {
        if let Some(val) = remaining_bytes.checked_mul(*i) {
            remaining_bytes = val;
        } else {
            return Err(ArrowError::ComputeError(
                "overflow occurred when computing row major strides.".to_string(),
            ));
        }
    }

    let mut strides = Vec::<usize>::new();
    for i in shape {
        remaining_bytes /= *i;
        strides.push(remaining_bytes);
    }

    Ok(strides)
}

/// Computes the strides required assuming a column major memory layout
fn compute_column_major_strides<T: ArrowPrimitiveType>(
    shape: &[usize],
) -> Result<Vec<usize>> {
    let mut remaining_bytes = mem::size_of::<T::Native>();
    let mut strides = Vec::<usize>::new();

    for i in shape {
        strides.push(remaining_bytes);

        if let Some(val) = remaining_bytes.checked_mul(*i) {
            remaining_bytes = val;
        } else {
            return Err(ArrowError::ComputeError(
                "overflow occurred when computing column major strides.".to_string(),
            ));
        }
    }

    Ok(strides)
}

/// Tensor of primitive types
#[derive(Debug)]
pub struct Tensor<'a, T: ArrowPrimitiveType> {
    data_type: DataType,
    buffer: Buffer,
    shape: Option<Vec<usize>>,
    strides: Option<Vec<usize>>,
    names: Option<Vec<&'a str>>,
    _marker: PhantomData<T>,
}

pub type BooleanTensor<'a> = Tensor<'a, BooleanType>;
pub type Int8Tensor<'a> = Tensor<'a, Int8Type>;
pub type Int16Tensor<'a> = Tensor<'a, Int16Type>;
pub type Int32Tensor<'a> = Tensor<'a, Int32Type>;
pub type Int64Tensor<'a> = Tensor<'a, Int64Type>;
pub type UInt8Tensor<'a> = Tensor<'a, UInt8Type>;
pub type UInt16Tensor<'a> = Tensor<'a, UInt16Type>;
pub type UInt32Tensor<'a> = Tensor<'a, UInt32Type>;
pub type UInt64Tensor<'a> = Tensor<'a, UInt64Type>;
pub type Float32Tensor<'a> = Tensor<'a, Float32Type>;
pub type Float64Tensor<'a> = Tensor<'a, Float64Type>;

impl<'a, T: ArrowPrimitiveType> Tensor<'a, T> {
    /// Creates a new `Tensor`
    pub fn try_new(
        buffer: Buffer,
        shape: Option<Vec<usize>>,
        strides: Option<Vec<usize>>,
        names: Option<Vec<&'a str>>,
    ) -> Result<Self> {
        match shape {
            None => {
                if buffer.len() != mem::size_of::<T::Native>() {
                    return Err(ArrowError::InvalidArgumentError(
                        "underlying buffer should only contain a single tensor element"
                            .to_string(),
                    ));
                }

                if strides != None {
                    return Err(ArrowError::InvalidArgumentError(
                        "expected None strides for tensor with no shape".to_string(),
                    ));
                }

                if names != None {
                    return Err(ArrowError::InvalidArgumentError(
                        "expected None names for tensor with no shape".to_string(),
                    ));
                }
            }

            Some(ref s) => {
                if let Some(ref st) = strides {
                    if st.len() != s.len() {
                        return Err(ArrowError::InvalidArgumentError(
                            "shape and stride dimensions differ".to_string(),
                        ));
                    }
                }

                if let Some(ref n) = names {
                    if n.len() != s.len() {
                        return Err(ArrowError::InvalidArgumentError(
                            "number of dimensions and number of dimension names differ"
                                .to_string(),
                        ));
                    }
                }

                let total_elements: usize = s.iter().product();
                if total_elements != (buffer.len() / mem::size_of::<T::Native>()) {
                    return Err(ArrowError::InvalidArgumentError(
                        "number of elements in buffer does not match dimensions"
                            .to_string(),
                    ));
                }
            }
        };

        // Checking that the tensor strides used for construction are correct
        // otherwise a row major stride is calculated and used as value for the tensor
        let tensor_strides = {
            if let Some(st) = strides {
                if let Some(ref s) = shape {
                    if compute_row_major_strides::<T>(s)? == st
                        || compute_column_major_strides::<T>(s)? == st
                    {
                        Some(st)
                    } else {
                        return Err(ArrowError::InvalidArgumentError(
                            "the input stride does not match the selected shape"
                                .to_string(),
                        ));
                    }
                } else {
                    Some(st)
                }
            } else if let Some(ref s) = shape {
                Some(compute_row_major_strides::<T>(s)?)
            } else {
                None
            }
        };

        Ok(Self {
            data_type: T::DATA_TYPE,
            buffer,
            shape,
            strides: tensor_strides,
            names,
            _marker: PhantomData,
        })
    }

    /// Creates a new Tensor using row major memory layout
    pub fn new_row_major(
        buffer: Buffer,
        shape: Option<Vec<usize>>,
        names: Option<Vec<&'a str>>,
    ) -> Result<Self> {
        if let Some(ref s) = shape {
            let strides = Some(compute_row_major_strides::<T>(s)?);

            Self::try_new(buffer, shape, strides, names)
        } else {
            Err(ArrowError::InvalidArgumentError(
                "shape required to create row major tensor".to_string(),
            ))
        }
    }

    /// Creates a new Tensor using column major memory layout
    pub fn new_column_major(
        buffer: Buffer,
        shape: Option<Vec<usize>>,
        names: Option<Vec<&'a str>>,
    ) -> Result<Self> {
        if let Some(ref s) = shape {
            let strides = Some(compute_column_major_strides::<T>(s)?);

            Self::try_new(buffer, shape, strides, names)
        } else {
            Err(ArrowError::InvalidArgumentError(
                "shape required to create column major tensor".to_string(),
            ))
        }
    }

    /// The data type of the `Tensor`
    pub fn data_type(&self) -> &DataType {
        &self.data_type
    }

    /// The sizes of the dimensions
    pub fn shape(&self) -> Option<&Vec<usize>> {
        self.shape.as_ref()
    }

    /// Returns a reference to the underlying `Buffer`
    pub fn data(&self) -> &Buffer {
        &self.buffer
    }

    /// The number of bytes between elements in each dimension
    pub fn strides(&self) -> Option<&Vec<usize>> {
        self.strides.as_ref()
    }

    /// The names of the dimensions
    pub fn names(&self) -> Option<&Vec<&'a str>> {
        self.names.as_ref()
    }

    /// The number of dimensions
    pub fn ndim(&self) -> usize {
        match &self.shape {
            None => 0,
            Some(v) => v.len(),
        }
    }

    /// The name of dimension i
    pub fn dim_name(&self, i: usize) -> Option<&'a str> {
        self.names.as_ref().map(|names| names[i])
    }

    /// The total number of elements in the `Tensor`
    pub fn size(&self) -> usize {
        match self.shape {
            None => 0,
            Some(ref s) => s.iter().product(),
        }
    }

    /// Indicates if the data is laid out contiguously in memory
    pub fn is_contiguous(&self) -> Result<bool> {
        Ok(self.is_row_major()? || self.is_column_major()?)
    }

    /// Indicates if the memory layout row major
    pub fn is_row_major(&self) -> Result<bool> {
        match self.shape {
            None => Ok(false),
            Some(ref s) => Ok(Some(compute_row_major_strides::<T>(s)?) == self.strides),
        }
    }

    /// Indicates if the memory layout column major
    pub fn is_column_major(&self) -> Result<bool> {
        match self.shape {
            None => Ok(false),
            Some(ref s) => {
                Ok(Some(compute_column_major_strides::<T>(s)?) == self.strides)
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::array::*;
    use crate::buffer::Buffer;

    #[test]
    fn test_compute_row_major_strides() {
        assert_eq!(
            vec![48_usize, 8],
            compute_row_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
        );
        assert_eq!(
            vec![24_usize, 4],
            compute_row_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
        );
        assert_eq!(
            vec![6_usize, 1],
            compute_row_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
        );
    }

    #[test]
    fn test_compute_column_major_strides() {
        assert_eq!(
            vec![8_usize, 32],
            compute_column_major_strides::<Int64Type>(&[4_usize, 6]).unwrap()
        );
        assert_eq!(
            vec![4_usize, 16],
            compute_column_major_strides::<Int32Type>(&[4_usize, 6]).unwrap()
        );
        assert_eq!(
            vec![1_usize, 4],
            compute_column_major_strides::<Int8Type>(&[4_usize, 6]).unwrap()
        );
    }

    #[test]
    fn test_zero_dim() {
        let buf = Buffer::from(&[1]);
        let tensor = UInt8Tensor::try_new(buf, None, None, None).unwrap();
        assert_eq!(0, tensor.size());
        assert_eq!(None, tensor.shape());
        assert_eq!(None, tensor.names());
        assert_eq!(0, tensor.ndim());
        assert!(!tensor.is_row_major().unwrap());
        assert!(!tensor.is_column_major().unwrap());
        assert!(!tensor.is_contiguous().unwrap());

        let buf = Buffer::from(&[1, 2, 2, 2]);
        let tensor = Int32Tensor::try_new(buf, None, None, None).unwrap();
        assert_eq!(0, tensor.size());
        assert_eq!(None, tensor.shape());
        assert_eq!(None, tensor.names());
        assert_eq!(0, tensor.ndim());
        assert!(!tensor.is_row_major().unwrap());
        assert!(!tensor.is_column_major().unwrap());
        assert!(!tensor.is_contiguous().unwrap());
    }

    #[test]
    fn test_tensor() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();
        let tensor = Int32Tensor::try_new(buf, Some(vec![2, 8]), None, None).unwrap();
        assert_eq!(16, tensor.size());
        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
        assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
        assert_eq!(2, tensor.ndim());
        assert_eq!(None, tensor.names());
    }

    #[test]
    fn test_new_row_major() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();
        let tensor = Int32Tensor::new_row_major(buf, Some(vec![2, 8]), None).unwrap();
        assert_eq!(16, tensor.size());
        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
        assert_eq!(Some(vec![32_usize, 4]).as_ref(), tensor.strides());
        assert_eq!(None, tensor.names());
        assert_eq!(2, tensor.ndim());
        assert!(tensor.is_row_major().unwrap());
        assert!(!tensor.is_column_major().unwrap());
        assert!(tensor.is_contiguous().unwrap());
    }

    #[test]
    fn test_new_column_major() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();
        let tensor = Int32Tensor::new_column_major(buf, Some(vec![2, 8]), None).unwrap();
        assert_eq!(16, tensor.size());
        assert_eq!(Some(vec![2_usize, 8]).as_ref(), tensor.shape());
        assert_eq!(Some(vec![4_usize, 8]).as_ref(), tensor.strides());
        assert_eq!(None, tensor.names());
        assert_eq!(2, tensor.ndim());
        assert!(!tensor.is_row_major().unwrap());
        assert!(tensor.is_column_major().unwrap());
        assert!(tensor.is_contiguous().unwrap());
    }

    #[test]
    fn test_with_names() {
        let mut builder = Int64BufferBuilder::new(8);
        for i in 0..8 {
            builder.append(i);
        }
        let buf = builder.finish();
        let names = vec!["Dim 1", "Dim 2"];
        let tensor =
            Int64Tensor::new_column_major(buf, Some(vec![2, 4]), Some(names)).unwrap();
        assert_eq!(8, tensor.size());
        assert_eq!(Some(vec![2_usize, 4]).as_ref(), tensor.shape());
        assert_eq!(Some(vec![8_usize, 16]).as_ref(), tensor.strides());
        assert_eq!("Dim 1", tensor.dim_name(0).unwrap());
        assert_eq!("Dim 2", tensor.dim_name(1).unwrap());
        assert_eq!(2, tensor.ndim());
        assert!(!tensor.is_row_major().unwrap());
        assert!(tensor.is_column_major().unwrap());
        assert!(tensor.is_contiguous().unwrap());
    }

    #[test]
    fn test_inconsistent_strides() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();

        let result =
            Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![2, 8, 1]), None);

        if result.is_ok() {
            panic!("shape and stride dimensions are different")
        }
    }

    #[test]
    fn test_inconsistent_names() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();

        let result = Int32Tensor::try_new(
            buf,
            Some(vec![2, 8]),
            Some(vec![4, 8]),
            Some(vec!["1", "2", "3"]),
        );

        if result.is_ok() {
            panic!("dimensions and names have different shape")
        }
    }

    #[test]
    fn test_incorrect_shape() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();

        let result = Int32Tensor::try_new(buf, Some(vec![2, 6]), None, None);

        if result.is_ok() {
            panic!("number of elements does not match for the shape")
        }
    }

    #[test]
    fn test_incorrect_stride() {
        let mut builder = Int32BufferBuilder::new(16);
        for i in 0..16 {
            builder.append(i);
        }
        let buf = builder.finish();

        let result = Int32Tensor::try_new(buf, Some(vec![2, 8]), Some(vec![30, 4]), None);

        if result.is_ok() {
            panic!("the input stride does not match the selected shape")
        }
    }
}