ella_common/
tensor_type.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use std::sync::Arc;

use datafusion::arrow::datatypes::{DataType, TimeUnit};

#[derive(
    Debug,
    Clone,
    PartialEq,
    Eq,
    Hash,
    serde::Serialize,
    serde::Deserialize,
    strum::Display,
    strum::EnumString,
    strum::FromRepr,
)]
#[repr(u8)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum TensorType {
    Bool,
    Int8,
    Int16,
    Int32,
    Int64,
    UInt8,
    UInt16,
    UInt32,
    UInt64,
    Float32,
    Float64,
    Timestamp,
    Duration,
    String,
}

impl TensorType {
    pub fn to_arrow(&self) -> DataType {
        use TensorType::*;

        match self {
            Bool => DataType::Boolean,
            Int8 => DataType::Int8,
            Int16 => DataType::Int16,
            Int32 => DataType::Int32,
            Int64 => DataType::Int64,
            UInt8 => DataType::UInt8,
            UInt16 => DataType::UInt16,
            UInt32 => DataType::UInt32,
            UInt64 => DataType::UInt64,
            Float32 => DataType::Float32,
            Float64 => DataType::Float64,
            Duration => DataType::Duration(TimeUnit::Nanosecond),
            Timestamp => DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+00:00"))),
            String => DataType::Utf8,
        }
    }

    pub fn from_arrow(dtype: &DataType) -> crate::Result<Self> {
        use DataType::*;

        Ok(match dtype {
            Boolean => Self::Bool,
            Int8 => Self::Int8,
            Int16 => Self::Int16,
            Int32 => Self::Int32,
            Int64 => Self::Int64,
            UInt8 => Self::UInt8,
            UInt16 => Self::UInt16,
            UInt32 => Self::UInt32,
            UInt64 => Self::UInt64,
            Float32 => Self::Float32,
            Float64 => Self::Float64,
            Duration(TimeUnit::Nanosecond) => Self::Duration,
            Timestamp(TimeUnit::Nanosecond, _) => Self::Timestamp,
            Utf8 => Self::String,
            _ => return Err(crate::Error::DataType(dtype.clone())),
        })
    }
}