1#[cfg(any(feature = "arrow", feature = "polars"))]
8use ferray_core::DType;
9#[cfg(any(feature = "arrow", feature = "polars"))]
10use ferray_core::FerrayError;
11
12#[cfg(feature = "arrow")]
23pub fn dtype_to_arrow(dt: DType) -> Result<arrow::datatypes::DataType, FerrayError> {
24 use arrow::datatypes::{DataType as AD, TimeUnit as ATU};
25 use ferray_core::dtype::TimeUnit;
26
27 fn to_arrow_time_unit(u: TimeUnit) -> Result<ATU, FerrayError> {
28 match u {
29 TimeUnit::Ns => Ok(ATU::Nanosecond),
30 TimeUnit::Us => Ok(ATU::Microsecond),
31 TimeUnit::Ms => Ok(ATU::Millisecond),
32 TimeUnit::S => Ok(ATU::Second),
33 other => Err(FerrayError::invalid_dtype(format!(
36 "Arrow has no time unit equivalent for ferray TimeUnit::{other:?}"
37 ))),
38 }
39 }
40
41 match dt {
42 DType::Bool => Ok(AD::Boolean),
43 DType::U8 => Ok(AD::UInt8),
44 DType::U16 => Ok(AD::UInt16),
45 DType::U32 => Ok(AD::UInt32),
46 DType::U64 => Ok(AD::UInt64),
47 DType::I8 => Ok(AD::Int8),
48 DType::I16 => Ok(AD::Int16),
49 DType::I32 => Ok(AD::Int32),
50 DType::I64 => Ok(AD::Int64),
51 DType::F32 => Ok(AD::Float32),
52 DType::F64 => Ok(AD::Float64),
53 #[cfg(feature = "f16")]
58 DType::F16 => Ok(AD::Float16),
59 #[cfg(feature = "bf16")]
60 DType::BF16 => Err(FerrayError::invalid_dtype(
61 "Arrow has no native bfloat16 type — pass through as f32 or use a struct(real, imag)-style workaround",
62 )),
63 DType::DateTime64(u) => Ok(AD::Timestamp(to_arrow_time_unit(u)?, None)),
67 DType::Timedelta64(u) => Ok(AD::Duration(to_arrow_time_unit(u)?)),
69 other => Err(FerrayError::invalid_dtype(format!(
70 "ferray dtype {other} has no Arrow equivalent"
71 ))),
72 }
73}
74
75#[cfg(feature = "arrow")]
82pub fn arrow_to_dtype(ad: &arrow::datatypes::DataType) -> Result<DType, FerrayError> {
83 use arrow::datatypes::{DataType as AD, TimeUnit as ATU};
84 use ferray_core::dtype::TimeUnit;
85
86 fn from_arrow_time_unit(u: &ATU) -> TimeUnit {
87 match u {
88 ATU::Nanosecond => TimeUnit::Ns,
89 ATU::Microsecond => TimeUnit::Us,
90 ATU::Millisecond => TimeUnit::Ms,
91 ATU::Second => TimeUnit::S,
92 }
93 }
94
95 match ad {
96 AD::Boolean => Ok(DType::Bool),
97 AD::UInt8 => Ok(DType::U8),
98 AD::UInt16 => Ok(DType::U16),
99 AD::UInt32 => Ok(DType::U32),
100 AD::UInt64 => Ok(DType::U64),
101 AD::Int8 => Ok(DType::I8),
102 AD::Int16 => Ok(DType::I16),
103 AD::Int32 => Ok(DType::I32),
104 AD::Int64 => Ok(DType::I64),
105 AD::Float32 => Ok(DType::F32),
106 AD::Float64 => Ok(DType::F64),
107 #[cfg(feature = "f16")]
108 AD::Float16 => Ok(DType::F16),
109 AD::Timestamp(u, _tz) => Ok(DType::DateTime64(from_arrow_time_unit(u))),
113 AD::Duration(u) => Ok(DType::Timedelta64(from_arrow_time_unit(u))),
114 other => Err(FerrayError::invalid_dtype(format!(
115 "Arrow DataType {other:?} has no ferray equivalent"
116 ))),
117 }
118}
119
120#[cfg(feature = "polars")]
131pub fn dtype_to_polars(dt: DType) -> Result<polars::prelude::DataType, FerrayError> {
132 use polars::prelude::DataType as PD;
133 match dt {
134 DType::Bool => Ok(PD::Boolean),
135 DType::U8 => Ok(PD::UInt8),
136 DType::U16 => Ok(PD::UInt16),
137 DType::U32 => Ok(PD::UInt32),
138 DType::U64 => Ok(PD::UInt64),
139 DType::I8 => Ok(PD::Int8),
140 DType::I16 => Ok(PD::Int16),
141 DType::I32 => Ok(PD::Int32),
142 DType::I64 => Ok(PD::Int64),
143 DType::F32 => Ok(PD::Float32),
144 DType::F64 => Ok(PD::Float64),
145 other => Err(FerrayError::invalid_dtype(format!(
146 "ferray dtype {other} has no Polars equivalent"
147 ))),
148 }
149}
150
151#[cfg(feature = "polars")]
158pub fn polars_to_dtype(pd: &polars::prelude::DataType) -> Result<DType, FerrayError> {
159 use polars::prelude::DataType as PD;
160 match pd {
161 PD::Boolean => Ok(DType::Bool),
162 PD::UInt8 => Ok(DType::U8),
163 PD::UInt16 => Ok(DType::U16),
164 PD::UInt32 => Ok(DType::U32),
165 PD::UInt64 => Ok(DType::U64),
166 PD::Int8 => Ok(DType::I8),
167 PD::Int16 => Ok(DType::I16),
168 PD::Int32 => Ok(DType::I32),
169 PD::Int64 => Ok(DType::I64),
170 PD::Float32 => Ok(DType::F32),
171 PD::Float64 => Ok(DType::F64),
172 other => Err(FerrayError::invalid_dtype(format!(
173 "Polars DataType {other:?} has no ferray equivalent"
174 ))),
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 #[cfg(feature = "arrow")]
181 mod arrow_tests {
182 use crate::dtype_map::{arrow_to_dtype, dtype_to_arrow};
183 use arrow::datatypes::DataType as AD;
184 use ferray_core::DType;
185
186 #[test]
187 fn roundtrip_all_supported_dtypes() {
188 let dtypes = [
189 (DType::Bool, AD::Boolean),
190 (DType::U8, AD::UInt8),
191 (DType::U16, AD::UInt16),
192 (DType::U32, AD::UInt32),
193 (DType::U64, AD::UInt64),
194 (DType::I8, AD::Int8),
195 (DType::I16, AD::Int16),
196 (DType::I32, AD::Int32),
197 (DType::I64, AD::Int64),
198 (DType::F32, AD::Float32),
199 (DType::F64, AD::Float64),
200 ];
201
202 for (ferray_dt, arrow_dt) in &dtypes {
203 let converted = dtype_to_arrow(*ferray_dt).unwrap();
204 assert_eq!(&converted, arrow_dt);
205 let back = arrow_to_dtype(&converted).unwrap();
206 assert_eq!(back, *ferray_dt);
207 }
208 }
209
210 #[test]
211 fn complex_has_no_arrow_equiv() {
212 assert!(dtype_to_arrow(DType::Complex32).is_err());
213 assert!(dtype_to_arrow(DType::Complex64).is_err());
214 }
215
216 #[test]
217 fn unsupported_arrow_type() {
218 assert!(arrow_to_dtype(&AD::Utf8).is_err());
219 }
220
221 #[test]
222 fn datetime64_to_arrow_timestamp() {
223 use arrow::datatypes::TimeUnit as ATU;
224 use ferray_core::dtype::TimeUnit;
225 assert_eq!(
226 dtype_to_arrow(DType::DateTime64(TimeUnit::Ns)).unwrap(),
227 AD::Timestamp(ATU::Nanosecond, None)
228 );
229 assert_eq!(
230 dtype_to_arrow(DType::DateTime64(TimeUnit::Ms)).unwrap(),
231 AD::Timestamp(ATU::Millisecond, None)
232 );
233 }
234
235 #[test]
236 fn timedelta64_to_arrow_duration() {
237 use arrow::datatypes::TimeUnit as ATU;
238 use ferray_core::dtype::TimeUnit;
239 assert_eq!(
240 dtype_to_arrow(DType::Timedelta64(TimeUnit::Us)).unwrap(),
241 AD::Duration(ATU::Microsecond)
242 );
243 }
244
245 #[test]
246 fn arrow_timestamp_to_datetime64() {
247 use arrow::datatypes::TimeUnit as ATU;
248 use ferray_core::dtype::TimeUnit;
249 let arrow_dt = AD::Timestamp(ATU::Nanosecond, None);
250 assert_eq!(
251 arrow_to_dtype(&arrow_dt).unwrap(),
252 DType::DateTime64(TimeUnit::Ns)
253 );
254 let arrow_tz = AD::Timestamp(ATU::Microsecond, Some("UTC".into()));
256 assert_eq!(
257 arrow_to_dtype(&arrow_tz).unwrap(),
258 DType::DateTime64(TimeUnit::Us)
259 );
260 }
261
262 #[test]
263 fn datetime64_minute_unit_arrow_unsupported() {
264 use ferray_core::dtype::TimeUnit;
265 assert!(dtype_to_arrow(DType::DateTime64(TimeUnit::M)).is_err());
267 assert!(dtype_to_arrow(DType::DateTime64(TimeUnit::H)).is_err());
268 assert!(dtype_to_arrow(DType::DateTime64(TimeUnit::D)).is_err());
269 }
270 }
271
272 #[cfg(feature = "polars")]
273 mod polars_tests {
274 use crate::dtype_map::{dtype_to_polars, polars_to_dtype};
275 use ferray_core::DType;
276 use polars::prelude::DataType as PD;
277
278 #[test]
279 fn roundtrip_all_supported_dtypes() {
280 let dtypes = [
281 (DType::Bool, PD::Boolean),
282 (DType::U8, PD::UInt8),
283 (DType::U16, PD::UInt16),
284 (DType::U32, PD::UInt32),
285 (DType::U64, PD::UInt64),
286 (DType::I8, PD::Int8),
287 (DType::I16, PD::Int16),
288 (DType::I32, PD::Int32),
289 (DType::I64, PD::Int64),
290 (DType::F32, PD::Float32),
291 (DType::F64, PD::Float64),
292 ];
293
294 for (ferray_dt, polars_dt) in &dtypes {
295 let converted = dtype_to_polars(*ferray_dt).unwrap();
296 assert_eq!(&converted, polars_dt);
297 let back = polars_to_dtype(&converted).unwrap();
298 assert_eq!(back, *ferray_dt);
299 }
300 }
301
302 #[test]
303 fn complex_has_no_polars_equiv() {
304 assert!(dtype_to_polars(DType::Complex32).is_err());
305 assert!(dtype_to_polars(DType::Complex64).is_err());
306 }
307
308 #[test]
309 fn unsupported_polars_type() {
310 assert!(polars_to_dtype(&PD::String).is_err());
311 }
312 }
313}