Skip to main content

ferray_numpy_interop/
dtype_map.rs

1//! Mapping between ferray [`DType`], Arrow [`DataType`], and NumPy dtype codes.
2//!
3//! This module provides bidirectional conversion functions so that every
4//! interop path (NumPy, Arrow, Polars) shares a single source of truth for
5//! type correspondence.
6
7#[cfg(any(feature = "arrow", feature = "polars"))]
8use ferray_core::DType;
9#[cfg(any(feature = "arrow", feature = "polars"))]
10use ferray_core::FerrayError;
11
12// ---------------------------------------------------------------------------
13// Arrow DataType <-> DType
14// ---------------------------------------------------------------------------
15
16/// Convert a ferray [`DType`] to the corresponding Arrow [`DataType`].
17///
18/// # Errors
19///
20/// Returns [`FerrayError::InvalidDtype`] if the ferray dtype has no Arrow
21/// equivalent (e.g. `Complex32`, `Complex64`, `U128`, `I128`).
22#[cfg(feature = "arrow")]
23pub fn dtype_to_arrow(dt: DType) -> Result<arrow::datatypes::DataType, FerrayError> {
24    use arrow::datatypes::DataType as AD;
25    match dt {
26        DType::Bool => Ok(AD::Boolean),
27        DType::U8 => Ok(AD::UInt8),
28        DType::U16 => Ok(AD::UInt16),
29        DType::U32 => Ok(AD::UInt32),
30        DType::U64 => Ok(AD::UInt64),
31        DType::I8 => Ok(AD::Int8),
32        DType::I16 => Ok(AD::Int16),
33        DType::I32 => Ok(AD::Int32),
34        DType::I64 => Ok(AD::Int64),
35        DType::F32 => Ok(AD::Float32),
36        DType::F64 => Ok(AD::Float64),
37        other => Err(FerrayError::invalid_dtype(format!(
38            "ferray dtype {other} has no Arrow equivalent"
39        ))),
40    }
41}
42
43/// Convert an Arrow [`DataType`] to the corresponding ferray [`DType`].
44///
45/// # Errors
46///
47/// Returns [`FerrayError::InvalidDtype`] for Arrow types that ferray does
48/// not support (e.g. `Utf8`, `Timestamp`, `Struct`, etc.).
49#[cfg(feature = "arrow")]
50pub fn arrow_to_dtype(ad: &arrow::datatypes::DataType) -> Result<DType, FerrayError> {
51    use arrow::datatypes::DataType as AD;
52    match ad {
53        AD::Boolean => Ok(DType::Bool),
54        AD::UInt8 => Ok(DType::U8),
55        AD::UInt16 => Ok(DType::U16),
56        AD::UInt32 => Ok(DType::U32),
57        AD::UInt64 => Ok(DType::U64),
58        AD::Int8 => Ok(DType::I8),
59        AD::Int16 => Ok(DType::I16),
60        AD::Int32 => Ok(DType::I32),
61        AD::Int64 => Ok(DType::I64),
62        AD::Float32 => Ok(DType::F32),
63        AD::Float64 => Ok(DType::F64),
64        other => Err(FerrayError::invalid_dtype(format!(
65            "Arrow DataType {other:?} has no ferray equivalent"
66        ))),
67    }
68}
69
70// ---------------------------------------------------------------------------
71// Polars DataType <-> DType
72// ---------------------------------------------------------------------------
73
74/// Convert a ferray [`DType`] to the corresponding Polars [`DataType`].
75///
76/// # Errors
77///
78/// Returns [`FerrayError::InvalidDtype`] if the ferray dtype has no Polars
79/// equivalent (e.g. `Complex32`, `Complex64`, `U128`, `I128`, `Bool`-as-bitfield).
80#[cfg(feature = "polars")]
81pub fn dtype_to_polars(dt: DType) -> Result<polars::prelude::DataType, FerrayError> {
82    use polars::prelude::DataType as PD;
83    match dt {
84        DType::Bool => Ok(PD::Boolean),
85        DType::U8 => Ok(PD::UInt8),
86        DType::U16 => Ok(PD::UInt16),
87        DType::U32 => Ok(PD::UInt32),
88        DType::U64 => Ok(PD::UInt64),
89        DType::I8 => Ok(PD::Int8),
90        DType::I16 => Ok(PD::Int16),
91        DType::I32 => Ok(PD::Int32),
92        DType::I64 => Ok(PD::Int64),
93        DType::F32 => Ok(PD::Float32),
94        DType::F64 => Ok(PD::Float64),
95        other => Err(FerrayError::invalid_dtype(format!(
96            "ferray dtype {other} has no Polars equivalent"
97        ))),
98    }
99}
100
101/// Convert a Polars [`DataType`] to the corresponding ferray [`DType`].
102///
103/// # Errors
104///
105/// Returns [`FerrayError::InvalidDtype`] for Polars types that ferray does
106/// not support (e.g. `String`, `Date`, `Datetime`, etc.).
107#[cfg(feature = "polars")]
108pub fn polars_to_dtype(pd: &polars::prelude::DataType) -> Result<DType, FerrayError> {
109    use polars::prelude::DataType as PD;
110    match pd {
111        PD::Boolean => Ok(DType::Bool),
112        PD::UInt8 => Ok(DType::U8),
113        PD::UInt16 => Ok(DType::U16),
114        PD::UInt32 => Ok(DType::U32),
115        PD::UInt64 => Ok(DType::U64),
116        PD::Int8 => Ok(DType::I8),
117        PD::Int16 => Ok(DType::I16),
118        PD::Int32 => Ok(DType::I32),
119        PD::Int64 => Ok(DType::I64),
120        PD::Float32 => Ok(DType::F32),
121        PD::Float64 => Ok(DType::F64),
122        other => Err(FerrayError::invalid_dtype(format!(
123            "Polars DataType {other:?} has no ferray equivalent"
124        ))),
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    #[cfg(feature = "arrow")]
131    mod arrow_tests {
132        use crate::dtype_map::{arrow_to_dtype, dtype_to_arrow};
133        use arrow::datatypes::DataType as AD;
134        use ferray_core::DType;
135
136        #[test]
137        fn roundtrip_all_supported_dtypes() {
138            let dtypes = [
139                (DType::Bool, AD::Boolean),
140                (DType::U8, AD::UInt8),
141                (DType::U16, AD::UInt16),
142                (DType::U32, AD::UInt32),
143                (DType::U64, AD::UInt64),
144                (DType::I8, AD::Int8),
145                (DType::I16, AD::Int16),
146                (DType::I32, AD::Int32),
147                (DType::I64, AD::Int64),
148                (DType::F32, AD::Float32),
149                (DType::F64, AD::Float64),
150            ];
151
152            for (ferray_dt, arrow_dt) in &dtypes {
153                let converted = dtype_to_arrow(*ferray_dt).unwrap();
154                assert_eq!(&converted, arrow_dt);
155                let back = arrow_to_dtype(&converted).unwrap();
156                assert_eq!(back, *ferray_dt);
157            }
158        }
159
160        #[test]
161        fn complex_has_no_arrow_equiv() {
162            assert!(dtype_to_arrow(DType::Complex32).is_err());
163            assert!(dtype_to_arrow(DType::Complex64).is_err());
164        }
165
166        #[test]
167        fn unsupported_arrow_type() {
168            assert!(arrow_to_dtype(&AD::Utf8).is_err());
169        }
170    }
171
172    #[cfg(feature = "polars")]
173    mod polars_tests {
174        use crate::dtype_map::{dtype_to_polars, polars_to_dtype};
175        use ferray_core::DType;
176        use polars::prelude::DataType as PD;
177
178        #[test]
179        fn roundtrip_all_supported_dtypes() {
180            let dtypes = [
181                (DType::Bool, PD::Boolean),
182                (DType::U8, PD::UInt8),
183                (DType::U16, PD::UInt16),
184                (DType::U32, PD::UInt32),
185                (DType::U64, PD::UInt64),
186                (DType::I8, PD::Int8),
187                (DType::I16, PD::Int16),
188                (DType::I32, PD::Int32),
189                (DType::I64, PD::Int64),
190                (DType::F32, PD::Float32),
191                (DType::F64, PD::Float64),
192            ];
193
194            for (ferray_dt, polars_dt) in &dtypes {
195                let converted = dtype_to_polars(*ferray_dt).unwrap();
196                assert_eq!(&converted, polars_dt);
197                let back = polars_to_dtype(&converted).unwrap();
198                assert_eq!(back, *ferray_dt);
199            }
200        }
201
202        #[test]
203        fn complex_has_no_polars_equiv() {
204            assert!(dtype_to_polars(DType::Complex32).is_err());
205            assert!(dtype_to_polars(DType::Complex64).is_err());
206        }
207
208        #[test]
209        fn unsupported_polars_type() {
210            assert!(polars_to_dtype(&PD::String).is_err());
211        }
212    }
213}