Skip to main content

ferray_numpy_interop/
dtype_map.rs

1//! Mapping between ferray [`DType`], Arrow [`DataType`], and `NumPy` dtype codes.
2//!
3//! This module provides bidirectional conversion functions so that every
4//! interop path (`NumPy`, Arrow, Polars) shares a single source of truth for
5//! type correspondence.
6
7#[cfg(any(feature = "arrow", feature = "polars"))]
8use ferray_core::DType;
9#[cfg(any(feature = "arrow", feature = "polars"))]
10use ferray_core::FerrayError;
11
12// ---------------------------------------------------------------------------
13// Arrow DataType <-> DType
14// ---------------------------------------------------------------------------
15
16/// Convert a ferray [`DType`] to the corresponding Arrow [`DataType`].
17///
18/// # Errors
19///
20/// Returns [`FerrayError::InvalidDtype`] if the ferray dtype has no Arrow
21/// equivalent (e.g. `Complex32`, `Complex64`, `U128`, `I128`).
22#[cfg(feature = "arrow")]
23pub fn dtype_to_arrow(dt: DType) -> Result<arrow::datatypes::DataType, FerrayError> {
24    use arrow::datatypes::{DataType as AD, TimeUnit as ATU};
25    use ferray_core::dtype::TimeUnit;
26
27    fn to_arrow_time_unit(u: TimeUnit) -> Result<ATU, FerrayError> {
28        match u {
29            TimeUnit::Ns => Ok(ATU::Nanosecond),
30            TimeUnit::Us => Ok(ATU::Microsecond),
31            TimeUnit::Ms => Ok(ATU::Millisecond),
32            TimeUnit::S => Ok(ATU::Second),
33            // Arrow's TimeUnit only has Ns/Us/Ms/S — minute/hour/day
34            // datetime64 units have no direct Arrow correspondence.
35            other => Err(FerrayError::invalid_dtype(format!(
36                "Arrow has no time unit equivalent for ferray TimeUnit::{other:?}"
37            ))),
38        }
39    }
40
41    match dt {
42        DType::Bool => Ok(AD::Boolean),
43        DType::U8 => Ok(AD::UInt8),
44        DType::U16 => Ok(AD::UInt16),
45        DType::U32 => Ok(AD::UInt32),
46        DType::U64 => Ok(AD::UInt64),
47        DType::I8 => Ok(AD::Int8),
48        DType::I16 => Ok(AD::Int16),
49        DType::I32 => Ok(AD::Int32),
50        DType::I64 => Ok(AD::Int64),
51        DType::F32 => Ok(AD::Float32),
52        DType::F64 => Ok(AD::Float64),
53        // Arrow's Float16 covers IEEE 754 binary16 — same encoding as
54        // ferray's `DType::F16`. bfloat16 has no Arrow primitive (Arrow
55        // doesn't define bf16 in the canonical schema), so we surface
56        // it as InvalidDtype with a hint.
57        #[cfg(feature = "f16")]
58        DType::F16 => Ok(AD::Float16),
59        #[cfg(feature = "bf16")]
60        DType::BF16 => Err(FerrayError::invalid_dtype(
61            "Arrow has no native bfloat16 type — pass through as f32 or use a struct(real, imag)-style workaround",
62        )),
63        // datetime64 → Arrow Timestamp(unit, None). Arrow Timestamp's
64        // Option<Tz> is set to None since ferray doesn't track timezones
65        // (NumPy datetime64 is also TZ-naive).
66        DType::DateTime64(u) => Ok(AD::Timestamp(to_arrow_time_unit(u)?, None)),
67        // timedelta64 → Arrow Duration(unit).
68        DType::Timedelta64(u) => Ok(AD::Duration(to_arrow_time_unit(u)?)),
69        other => Err(FerrayError::invalid_dtype(format!(
70            "ferray dtype {other} has no Arrow equivalent"
71        ))),
72    }
73}
74
75/// Convert an Arrow [`DataType`] to the corresponding ferray [`DType`].
76///
77/// # Errors
78///
79/// Returns [`FerrayError::InvalidDtype`] for Arrow types that ferray does
80/// not support (e.g. `Utf8`, `Timestamp`, `Struct`, etc.).
81#[cfg(feature = "arrow")]
82pub fn arrow_to_dtype(ad: &arrow::datatypes::DataType) -> Result<DType, FerrayError> {
83    use arrow::datatypes::{DataType as AD, TimeUnit as ATU};
84    use ferray_core::dtype::TimeUnit;
85
86    fn from_arrow_time_unit(u: &ATU) -> TimeUnit {
87        match u {
88            ATU::Nanosecond => TimeUnit::Ns,
89            ATU::Microsecond => TimeUnit::Us,
90            ATU::Millisecond => TimeUnit::Ms,
91            ATU::Second => TimeUnit::S,
92        }
93    }
94
95    match ad {
96        AD::Boolean => Ok(DType::Bool),
97        AD::UInt8 => Ok(DType::U8),
98        AD::UInt16 => Ok(DType::U16),
99        AD::UInt32 => Ok(DType::U32),
100        AD::UInt64 => Ok(DType::U64),
101        AD::Int8 => Ok(DType::I8),
102        AD::Int16 => Ok(DType::I16),
103        AD::Int32 => Ok(DType::I32),
104        AD::Int64 => Ok(DType::I64),
105        AD::Float32 => Ok(DType::F32),
106        AD::Float64 => Ok(DType::F64),
107        #[cfg(feature = "f16")]
108        AD::Float16 => Ok(DType::F16),
109        // Arrow Timestamp -> datetime64. Timezone-tagged timestamps are
110        // mapped to the same TZ-naive ferray dtype (the TZ is dropped);
111        // round-tripping a TZ-tagged timestamp through ferray loses TZ.
112        AD::Timestamp(u, _tz) => Ok(DType::DateTime64(from_arrow_time_unit(u))),
113        AD::Duration(u) => Ok(DType::Timedelta64(from_arrow_time_unit(u))),
114        other => Err(FerrayError::invalid_dtype(format!(
115            "Arrow DataType {other:?} has no ferray equivalent"
116        ))),
117    }
118}
119
120// ---------------------------------------------------------------------------
121// Polars DataType <-> DType
122// ---------------------------------------------------------------------------
123
124/// Convert a ferray [`DType`] to the corresponding Polars [`DataType`].
125///
126/// # Errors
127///
128/// Returns [`FerrayError::InvalidDtype`] if the ferray dtype has no Polars
129/// equivalent (e.g. `Complex32`, `Complex64`, `U128`, `I128`, `Bool`-as-bitfield).
130#[cfg(feature = "polars")]
131pub fn dtype_to_polars(dt: DType) -> Result<polars::prelude::DataType, FerrayError> {
132    use polars::prelude::DataType as PD;
133    match dt {
134        DType::Bool => Ok(PD::Boolean),
135        DType::U8 => Ok(PD::UInt8),
136        DType::U16 => Ok(PD::UInt16),
137        DType::U32 => Ok(PD::UInt32),
138        DType::U64 => Ok(PD::UInt64),
139        DType::I8 => Ok(PD::Int8),
140        DType::I16 => Ok(PD::Int16),
141        DType::I32 => Ok(PD::Int32),
142        DType::I64 => Ok(PD::Int64),
143        DType::F32 => Ok(PD::Float32),
144        DType::F64 => Ok(PD::Float64),
145        other => Err(FerrayError::invalid_dtype(format!(
146            "ferray dtype {other} has no Polars equivalent"
147        ))),
148    }
149}
150
151/// Convert a Polars [`DataType`] to the corresponding ferray [`DType`].
152///
153/// # Errors
154///
155/// Returns [`FerrayError::InvalidDtype`] for Polars types that ferray does
156/// not support (e.g. `String`, `Date`, `Datetime`, etc.).
157#[cfg(feature = "polars")]
158pub fn polars_to_dtype(pd: &polars::prelude::DataType) -> Result<DType, FerrayError> {
159    use polars::prelude::DataType as PD;
160    match pd {
161        PD::Boolean => Ok(DType::Bool),
162        PD::UInt8 => Ok(DType::U8),
163        PD::UInt16 => Ok(DType::U16),
164        PD::UInt32 => Ok(DType::U32),
165        PD::UInt64 => Ok(DType::U64),
166        PD::Int8 => Ok(DType::I8),
167        PD::Int16 => Ok(DType::I16),
168        PD::Int32 => Ok(DType::I32),
169        PD::Int64 => Ok(DType::I64),
170        PD::Float32 => Ok(DType::F32),
171        PD::Float64 => Ok(DType::F64),
172        other => Err(FerrayError::invalid_dtype(format!(
173            "Polars DataType {other:?} has no ferray equivalent"
174        ))),
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    #[cfg(feature = "arrow")]
181    mod arrow_tests {
182        use crate::dtype_map::{arrow_to_dtype, dtype_to_arrow};
183        use arrow::datatypes::DataType as AD;
184        use ferray_core::DType;
185
186        #[test]
187        fn roundtrip_all_supported_dtypes() {
188            let dtypes = [
189                (DType::Bool, AD::Boolean),
190                (DType::U8, AD::UInt8),
191                (DType::U16, AD::UInt16),
192                (DType::U32, AD::UInt32),
193                (DType::U64, AD::UInt64),
194                (DType::I8, AD::Int8),
195                (DType::I16, AD::Int16),
196                (DType::I32, AD::Int32),
197                (DType::I64, AD::Int64),
198                (DType::F32, AD::Float32),
199                (DType::F64, AD::Float64),
200            ];
201
202            for (ferray_dt, arrow_dt) in &dtypes {
203                let converted = dtype_to_arrow(*ferray_dt).unwrap();
204                assert_eq!(&converted, arrow_dt);
205                let back = arrow_to_dtype(&converted).unwrap();
206                assert_eq!(back, *ferray_dt);
207            }
208        }
209
210        #[test]
211        fn complex_has_no_arrow_equiv() {
212            assert!(dtype_to_arrow(DType::Complex32).is_err());
213            assert!(dtype_to_arrow(DType::Complex64).is_err());
214        }
215
216        #[test]
217        fn unsupported_arrow_type() {
218            assert!(arrow_to_dtype(&AD::Utf8).is_err());
219        }
220
221        #[test]
222        fn datetime64_to_arrow_timestamp() {
223            use arrow::datatypes::TimeUnit as ATU;
224            use ferray_core::dtype::TimeUnit;
225            assert_eq!(
226                dtype_to_arrow(DType::DateTime64(TimeUnit::Ns)).unwrap(),
227                AD::Timestamp(ATU::Nanosecond, None)
228            );
229            assert_eq!(
230                dtype_to_arrow(DType::DateTime64(TimeUnit::Ms)).unwrap(),
231                AD::Timestamp(ATU::Millisecond, None)
232            );
233        }
234
235        #[test]
236        fn timedelta64_to_arrow_duration() {
237            use arrow::datatypes::TimeUnit as ATU;
238            use ferray_core::dtype::TimeUnit;
239            assert_eq!(
240                dtype_to_arrow(DType::Timedelta64(TimeUnit::Us)).unwrap(),
241                AD::Duration(ATU::Microsecond)
242            );
243        }
244
245        #[test]
246        fn arrow_timestamp_to_datetime64() {
247            use arrow::datatypes::TimeUnit as ATU;
248            use ferray_core::dtype::TimeUnit;
249            let arrow_dt = AD::Timestamp(ATU::Nanosecond, None);
250            assert_eq!(
251                arrow_to_dtype(&arrow_dt).unwrap(),
252                DType::DateTime64(TimeUnit::Ns)
253            );
254            // Timezone is dropped on the way back.
255            let arrow_tz = AD::Timestamp(ATU::Microsecond, Some("UTC".into()));
256            assert_eq!(
257                arrow_to_dtype(&arrow_tz).unwrap(),
258                DType::DateTime64(TimeUnit::Us)
259            );
260        }
261
262        #[test]
263        fn datetime64_minute_unit_arrow_unsupported() {
264            use ferray_core::dtype::TimeUnit;
265            // Arrow has no minute/hour/day TimeUnit — surface InvalidDtype.
266            assert!(dtype_to_arrow(DType::DateTime64(TimeUnit::M)).is_err());
267            assert!(dtype_to_arrow(DType::DateTime64(TimeUnit::H)).is_err());
268            assert!(dtype_to_arrow(DType::DateTime64(TimeUnit::D)).is_err());
269        }
270    }
271
272    #[cfg(feature = "polars")]
273    mod polars_tests {
274        use crate::dtype_map::{dtype_to_polars, polars_to_dtype};
275        use ferray_core::DType;
276        use polars::prelude::DataType as PD;
277
278        #[test]
279        fn roundtrip_all_supported_dtypes() {
280            let dtypes = [
281                (DType::Bool, PD::Boolean),
282                (DType::U8, PD::UInt8),
283                (DType::U16, PD::UInt16),
284                (DType::U32, PD::UInt32),
285                (DType::U64, PD::UInt64),
286                (DType::I8, PD::Int8),
287                (DType::I16, PD::Int16),
288                (DType::I32, PD::Int32),
289                (DType::I64, PD::Int64),
290                (DType::F32, PD::Float32),
291                (DType::F64, PD::Float64),
292            ];
293
294            for (ferray_dt, polars_dt) in &dtypes {
295                let converted = dtype_to_polars(*ferray_dt).unwrap();
296                assert_eq!(&converted, polars_dt);
297                let back = polars_to_dtype(&converted).unwrap();
298                assert_eq!(back, *ferray_dt);
299            }
300        }
301
302        #[test]
303        fn complex_has_no_polars_equiv() {
304            assert!(dtype_to_polars(DType::Complex32).is_err());
305            assert!(dtype_to_polars(DType::Complex64).is_err());
306        }
307
308        #[test]
309        fn unsupported_polars_type() {
310            assert!(polars_to_dtype(&PD::String).is_err());
311        }
312    }
313}