Skip to main content

omnigraph_compiler/
types.rs

1use arrow_schema::DataType;
2use serde::{Deserialize, Serialize};
3
4const MAX_VECTOR_DIM: u32 = i32::MAX as u32;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum ScalarType {
8    String,
9    Bool,
10    I32,
11    I64,
12    U32,
13    U64,
14    F32,
15    F64,
16    Date,
17    DateTime,
18    Vector(u32),
19    Blob,
20}
21
22impl ScalarType {
23    pub fn from_str_name(s: &str) -> Option<Self> {
24        if let Some(inner) = s.strip_prefix("Vector(").and_then(|t| t.strip_suffix(')')) {
25            let dim = inner.parse::<u32>().ok()?;
26            if dim == 0 || dim > MAX_VECTOR_DIM {
27                return None;
28            }
29            return Some(Self::Vector(dim));
30        }
31
32        match s {
33            "String" => Some(Self::String),
34            "Bool" => Some(Self::Bool),
35            "I32" => Some(Self::I32),
36            "I64" => Some(Self::I64),
37            "U32" => Some(Self::U32),
38            "U64" => Some(Self::U64),
39            "F32" => Some(Self::F32),
40            "F64" => Some(Self::F64),
41            "Date" => Some(Self::Date),
42            "DateTime" => Some(Self::DateTime),
43            "Blob" => Some(Self::Blob),
44            _ => None,
45        }
46    }
47
48    pub fn to_arrow(&self) -> DataType {
49        match self {
50            Self::String => DataType::Utf8,
51            Self::Bool => DataType::Boolean,
52            Self::I32 => DataType::Int32,
53            Self::I64 => DataType::Int64,
54            Self::U32 => DataType::UInt32,
55            Self::U64 => DataType::UInt64,
56            Self::F32 => DataType::Float32,
57            Self::F64 => DataType::Float64,
58            Self::Date => DataType::Date32,
59            Self::DateTime => DataType::Date64,
60            Self::Blob => DataType::LargeBinary,
61            Self::Vector(dim) => {
62                let dim = i32::try_from(*dim)
63                    .expect("vector dimension exceeds Arrow FixedSizeList i32 bound");
64                DataType::FixedSizeList(
65                    std::sync::Arc::new(arrow_schema::Field::new("item", DataType::Float32, true)),
66                    dim,
67                )
68            }
69        }
70    }
71
72    pub fn is_numeric(&self) -> bool {
73        matches!(
74            self,
75            Self::I32 | Self::I64 | Self::U32 | Self::U64 | Self::F32 | Self::F64
76        )
77    }
78}
79
80impl std::fmt::Display for ScalarType {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        let s = match self {
83            Self::String => "String",
84            Self::Bool => "Bool",
85            Self::I32 => "I32",
86            Self::I64 => "I64",
87            Self::U32 => "U32",
88            Self::U64 => "U64",
89            Self::F32 => "F32",
90            Self::F64 => "F64",
91            Self::Date => "Date",
92            Self::DateTime => "DateTime",
93            Self::Blob => "Blob",
94            Self::Vector(dim) => return write!(f, "Vector({})", dim),
95        };
96        write!(f, "{}", s)
97    }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
101pub struct PropType {
102    pub scalar: ScalarType,
103    pub nullable: bool,
104    pub list: bool,
105    pub enum_values: Option<Vec<String>>,
106}
107
108impl PropType {
109    pub fn from_param_type_name(s: &str, nullable: bool) -> Option<Self> {
110        if let Some(inner) = s
111            .strip_prefix('[')
112            .and_then(|value| value.strip_suffix(']'))
113        {
114            let scalar = ScalarType::from_str_name(inner)?;
115            return Some(Self::list_of(scalar, nullable));
116        }
117
118        let scalar = ScalarType::from_str_name(s)?;
119        Some(Self::scalar(scalar, nullable))
120    }
121
122    pub fn scalar(scalar: ScalarType, nullable: bool) -> Self {
123        Self {
124            scalar,
125            nullable,
126            list: false,
127            enum_values: None,
128        }
129    }
130
131    pub fn list_of(scalar: ScalarType, nullable: bool) -> Self {
132        Self {
133            scalar,
134            nullable,
135            list: true,
136            enum_values: None,
137        }
138    }
139
140    pub fn enum_type(mut values: Vec<String>, nullable: bool) -> Self {
141        values.sort();
142        values.dedup();
143        Self {
144            scalar: ScalarType::String,
145            nullable,
146            list: false,
147            enum_values: Some(values),
148        }
149    }
150
151    pub fn is_enum(&self) -> bool {
152        self.enum_values.is_some()
153    }
154
155    pub fn to_arrow(&self) -> DataType {
156        let scalar_dt = self.scalar.to_arrow();
157        if self.list {
158            DataType::List(std::sync::Arc::new(arrow_schema::Field::new(
159                "item", scalar_dt, true,
160            )))
161        } else {
162            scalar_dt
163        }
164    }
165
166    pub fn display_name(&self) -> String {
167        let base = if let Some(values) = &self.enum_values {
168            format!("enum({})", values.join(", "))
169        } else {
170            self.scalar.to_string()
171        };
172        let wrapped = if self.list {
173            format!("[{}]", base)
174        } else {
175            base
176        };
177        if self.nullable {
178            format!("{}?", wrapped)
179        } else {
180            wrapped
181        }
182    }
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq)]
186pub enum Direction {
187    Out,
188    In,
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use arrow_schema::{DataType, Field};
195    use std::sync::Arc;
196
197    #[test]
198    fn vector_to_arrow_uses_nullable_float32_child() {
199        let dt = ScalarType::Vector(4).to_arrow();
200        assert_eq!(
201            dt,
202            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4)
203        );
204    }
205
206    #[test]
207    fn scalar_type_from_str_name_rejects_vector_dimensions_outside_arrow_bounds() {
208        let too_large = format!("Vector({})", (i32::MAX as u64) + 1);
209        assert!(ScalarType::from_str_name(&too_large).is_none());
210        assert_eq!(
211            ScalarType::from_str_name("Vector(2147483647)"),
212            Some(ScalarType::Vector(2147483647))
213        );
214    }
215
216    #[test]
217    fn prop_type_from_param_type_name_supports_lists_and_nullable_scalars() {
218        assert_eq!(
219            PropType::from_param_type_name("[DateTime]", false),
220            Some(PropType::list_of(ScalarType::DateTime, false))
221        );
222        assert_eq!(
223            PropType::from_param_type_name("DateTime", true),
224            Some(PropType::scalar(ScalarType::DateTime, true))
225        );
226    }
227}