ella_tensor/
arrow.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use std::collections::HashMap;

use arrow::datatypes::{DataType, Field};

use crate::Dyn;

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FixedShapeTensor {
    #[serde(rename = "shape")]
    pub row_shape: Dyn,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub dim_names: Option<Vec<String>>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub permutation: Option<Vec<u32>>,
}

impl FixedShapeTensor {
    const EXT_NAME: &str = "arrow.fixed_shape_tensor";

    fn new(row_shape: Dyn) -> Self {
        Self {
            row_shape,
            dim_names: None,
            permutation: None,
        }
    }
}

#[derive(Debug, Clone)]
pub enum ExtensionType {
    FixedShapeTensor(FixedShapeTensor),
}

impl ExtensionType {
    const EXTENSION_NAME_KEY: &str = "ARROW:extension:name";
    const EXTENSION_META_KEY: &str = "ARROW:extension:metadata";

    pub fn tensor(row_shape: Dyn) -> Self {
        Self::FixedShapeTensor(FixedShapeTensor::new(row_shape))
    }

    pub fn encode(&self) -> HashMap<String, String> {
        let mut meta = HashMap::with_capacity(2);
        let (name, value) = match self {
            ExtensionType::FixedShapeTensor(tensor) => (
                FixedShapeTensor::EXT_NAME,
                serde_json::to_string(tensor).unwrap(),
            ),
        };
        meta.insert(Self::EXTENSION_NAME_KEY.to_string(), name.to_string());
        meta.insert(Self::EXTENSION_META_KEY.to_string(), value);
        meta
    }

    pub fn decode(meta: &HashMap<String, String>) -> crate::Result<Option<Self>> {
        if let Some(name) = meta.get(Self::EXTENSION_NAME_KEY) {
            let value = meta
                .get(Self::EXTENSION_META_KEY)
                .ok_or_else(|| crate::Error::MissingMetadata(name.to_owned()))?;
            match name.as_str() {
                FixedShapeTensor::EXT_NAME => {
                    Ok(Some(Self::FixedShapeTensor(serde_json::from_str(value)?)))
                }
                _ => Err(crate::Error::UnknownExtension(name.to_owned())),
            }
        } else {
            Ok(None)
        }
    }
}

pub fn row_shape(f: &Field) -> crate::Result<Dyn> {
    match f.data_type() {
        DataType::FixedSizeList(_, row_size) => {
            if let Some(ExtensionType::FixedShapeTensor(tensor)) =
                ExtensionType::decode(f.metadata())?
            {
                if tensor.permutation.is_some() {
                    unimplemented!();
                }
                Ok(tensor.row_shape)
            } else {
                Ok(Dyn::from([*row_size as usize]))
            }
        }
        _ => Ok(Dyn::from([])),
    }
}