Skip to main content

hdf5_reader/
datatype_api.rs

1use crate::error::{ByteOrder, Error, Result};
2use crate::messages::datatype::Datatype;
3
4// Re-export types from the datatype message module so users don't need to
5// reach into messages::datatype.
6pub use crate::messages::datatype::{
7    CompoundField, EnumMember, ReferenceType, StringEncoding, StringPadding, StringSize, VarLenKind,
8};
9
10/// Trait for types that can be read from HDF5 datasets.
11///
12/// Implemented for primitive numeric types. Users can implement this
13/// for custom types (e.g., compound types).
14pub trait H5Type: Sized + Send + Clone {
15    /// The HDF5 datatype that this Rust type corresponds to.
16    fn hdf5_type() -> Datatype;
17
18    /// Decode a single value from raw bytes with the given datatype.
19    fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self>;
20
21    /// Size of a single element in bytes.
22    fn element_size(dtype: &Datatype) -> usize;
23
24    /// Decode many values at once when the datatype has an efficient bulk path.
25    ///
26    /// Returning `None` falls back to per-element decoding.
27    fn decode_vec(_raw: &[u8], _dtype: &Datatype, _count: usize) -> Option<Result<Vec<Self>>> {
28        None
29    }
30
31    /// Whether raw bytes for this datatype can be copied directly into a `Vec<Self>`
32    /// without any further decoding or byte swapping.
33    fn native_copy_compatible(_dtype: &Datatype) -> bool {
34        false
35    }
36}
37
38/// Read a numeric value from bytes, handling byte-order conversion.
39fn read_numeric<const N: usize>(bytes: &[u8], byte_order: ByteOrder) -> Result<[u8; N]> {
40    if bytes.len() < N {
41        return Err(Error::InvalidData(format!(
42            "expected {} bytes, got {}",
43            N,
44            bytes.len()
45        )));
46    }
47    let mut arr = [0u8; N];
48    arr.copy_from_slice(&bytes[..N]);
49
50    // Swap bytes if the source endianness doesn't match native
51    #[cfg(target_endian = "little")]
52    if byte_order == ByteOrder::BigEndian {
53        arr.reverse();
54    }
55    #[cfg(target_endian = "big")]
56    if byte_order == ByteOrder::LittleEndian {
57        arr.reverse();
58    }
59
60    Ok(arr)
61}
62
63fn byte_order_is_native(byte_order: ByteOrder) -> bool {
64    #[cfg(target_endian = "little")]
65    {
66        byte_order == ByteOrder::LittleEndian
67    }
68    #[cfg(target_endian = "big")]
69    {
70        byte_order == ByteOrder::BigEndian
71    }
72}
73
74macro_rules! impl_h5type_int {
75    ($ty:ty, $size:expr, $signed:expr) => {
76        impl H5Type for $ty {
77            fn hdf5_type() -> Datatype {
78                Datatype::FixedPoint {
79                    size: $size,
80                    signed: $signed,
81                    byte_order: if cfg!(target_endian = "little") {
82                        ByteOrder::LittleEndian
83                    } else {
84                        ByteOrder::BigEndian
85                    },
86                }
87            }
88
89            fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self> {
90                match dtype {
91                    Datatype::FixedPoint {
92                        size,
93                        signed,
94                        byte_order,
95                    } => {
96                        if *size as usize != std::mem::size_of::<$ty>() || *signed != $signed {
97                            return Err(Error::TypeMismatch {
98                                expected: stringify!($ty).into(),
99                                actual: format!("FixedPoint(size={}, signed={})", size, signed),
100                            });
101                        }
102                        let arr = read_numeric::<$size>(bytes, *byte_order)?;
103                        Ok(<$ty>::from_ne_bytes(arr))
104                    }
105                    _ => Err(Error::TypeMismatch {
106                        expected: stringify!($ty).into(),
107                        actual: format!("{:?}", dtype),
108                    }),
109                }
110            }
111
112            fn element_size(_dtype: &Datatype) -> usize {
113                $size
114            }
115
116            fn decode_vec(raw: &[u8], dtype: &Datatype, count: usize) -> Option<Result<Vec<Self>>> {
117                match dtype {
118                    Datatype::FixedPoint {
119                        size,
120                        signed,
121                        byte_order,
122                    } if *size as usize == $size && *signed == $signed => {
123                        let total_bytes = count.checked_mul($size)?;
124                        if raw.len() < total_bytes {
125                            return None;
126                        }
127
128                        let bytes = &raw[..total_bytes];
129                        if byte_order_is_native(*byte_order) {
130                            let mut values = Vec::<$ty>::with_capacity(count);
131                            unsafe {
132                                std::ptr::copy_nonoverlapping(
133                                    bytes.as_ptr(),
134                                    values.as_mut_ptr() as *mut u8,
135                                    total_bytes,
136                                );
137                                values.set_len(count);
138                            }
139                            Some(Ok(values))
140                        } else {
141                            Some(Ok(bytes
142                                .chunks_exact($size)
143                                .map(|chunk| {
144                                    let mut arr = [0u8; $size];
145                                    arr.copy_from_slice(chunk);
146                                    arr.reverse();
147                                    <$ty>::from_ne_bytes(arr)
148                                })
149                                .collect()))
150                        }
151                    }
152                    _ => None,
153                }
154            }
155
156            fn native_copy_compatible(dtype: &Datatype) -> bool {
157                matches!(
158                    dtype,
159                    Datatype::FixedPoint {
160                        size,
161                        signed,
162                        byte_order,
163                    } if *size as usize == $size
164                        && *signed == $signed
165                        && byte_order_is_native(*byte_order)
166                )
167            }
168        }
169    };
170}
171
172impl_h5type_int!(i8, 1, true);
173impl_h5type_int!(u8, 1, false);
174impl_h5type_int!(i16, 2, true);
175impl_h5type_int!(u16, 2, false);
176impl_h5type_int!(i32, 4, true);
177impl_h5type_int!(u32, 4, false);
178impl_h5type_int!(i64, 8, true);
179impl_h5type_int!(u64, 8, false);
180
181impl H5Type for f32 {
182    fn hdf5_type() -> Datatype {
183        Datatype::FloatingPoint {
184            size: 4,
185            byte_order: if cfg!(target_endian = "little") {
186                ByteOrder::LittleEndian
187            } else {
188                ByteOrder::BigEndian
189            },
190        }
191    }
192
193    fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self> {
194        match dtype {
195            Datatype::FloatingPoint { size, byte_order } => {
196                if *size != 4 {
197                    return Err(Error::TypeMismatch {
198                        expected: "f32".into(),
199                        actual: format!("FloatingPoint(size={})", size),
200                    });
201                }
202                let arr = read_numeric::<4>(bytes, *byte_order)?;
203                Ok(f32::from_ne_bytes(arr))
204            }
205            _ => Err(Error::TypeMismatch {
206                expected: "f32".into(),
207                actual: format!("{:?}", dtype),
208            }),
209        }
210    }
211
212    fn element_size(_dtype: &Datatype) -> usize {
213        4
214    }
215
216    fn decode_vec(raw: &[u8], dtype: &Datatype, count: usize) -> Option<Result<Vec<Self>>> {
217        match dtype {
218            Datatype::FloatingPoint { size, byte_order } if *size == 4 => {
219                let total_bytes = count.checked_mul(4)?;
220                if raw.len() < total_bytes {
221                    return None;
222                }
223
224                let bytes = &raw[..total_bytes];
225                if byte_order_is_native(*byte_order) {
226                    let mut values = Vec::<f32>::with_capacity(count);
227                    unsafe {
228                        std::ptr::copy_nonoverlapping(
229                            bytes.as_ptr(),
230                            values.as_mut_ptr() as *mut u8,
231                            total_bytes,
232                        );
233                        values.set_len(count);
234                    }
235                    Some(Ok(values))
236                } else {
237                    Some(Ok(bytes
238                        .chunks_exact(4)
239                        .map(|chunk| {
240                            let mut arr = [0u8; 4];
241                            arr.copy_from_slice(chunk);
242                            arr.reverse();
243                            f32::from_ne_bytes(arr)
244                        })
245                        .collect()))
246                }
247            }
248            _ => None,
249        }
250    }
251
252    fn native_copy_compatible(dtype: &Datatype) -> bool {
253        matches!(
254            dtype,
255            Datatype::FloatingPoint { size, byte_order }
256                if *size == 4 && byte_order_is_native(*byte_order)
257        )
258    }
259}
260
261impl H5Type for f64 {
262    fn hdf5_type() -> Datatype {
263        Datatype::FloatingPoint {
264            size: 8,
265            byte_order: if cfg!(target_endian = "little") {
266                ByteOrder::LittleEndian
267            } else {
268                ByteOrder::BigEndian
269            },
270        }
271    }
272
273    fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self> {
274        match dtype {
275            Datatype::FloatingPoint { size, byte_order } => {
276                if *size != 8 {
277                    return Err(Error::TypeMismatch {
278                        expected: "f64".into(),
279                        actual: format!("FloatingPoint(size={})", size),
280                    });
281                }
282                let arr = read_numeric::<8>(bytes, *byte_order)?;
283                Ok(f64::from_ne_bytes(arr))
284            }
285            _ => Err(Error::TypeMismatch {
286                expected: "f64".into(),
287                actual: format!("{:?}", dtype),
288            }),
289        }
290    }
291
292    fn element_size(_dtype: &Datatype) -> usize {
293        8
294    }
295
296    fn decode_vec(raw: &[u8], dtype: &Datatype, count: usize) -> Option<Result<Vec<Self>>> {
297        match dtype {
298            Datatype::FloatingPoint { size, byte_order } if *size == 8 => {
299                let total_bytes = count.checked_mul(8)?;
300                if raw.len() < total_bytes {
301                    return None;
302                }
303
304                let bytes = &raw[..total_bytes];
305                if byte_order_is_native(*byte_order) {
306                    let mut values = Vec::<f64>::with_capacity(count);
307                    unsafe {
308                        std::ptr::copy_nonoverlapping(
309                            bytes.as_ptr(),
310                            values.as_mut_ptr() as *mut u8,
311                            total_bytes,
312                        );
313                        values.set_len(count);
314                    }
315                    Some(Ok(values))
316                } else {
317                    Some(Ok(bytes
318                        .chunks_exact(8)
319                        .map(|chunk| {
320                            let mut arr = [0u8; 8];
321                            arr.copy_from_slice(chunk);
322                            arr.reverse();
323                            f64::from_ne_bytes(arr)
324                        })
325                        .collect()))
326                }
327            }
328            _ => None,
329        }
330    }
331
332    fn native_copy_compatible(dtype: &Datatype) -> bool {
333        matches!(
334            dtype,
335            Datatype::FloatingPoint { size, byte_order }
336                if *size == 8 && byte_order_is_native(*byte_order)
337        )
338    }
339}
340
341/// Get the element size from a datatype.
342pub fn dtype_element_size(dtype: &Datatype) -> Result<usize> {
343    match dtype {
344        Datatype::FixedPoint { size, .. } => Ok(*size as usize),
345        Datatype::FloatingPoint { size, .. } => Ok(*size as usize),
346        Datatype::String {
347            size: StringSize::Fixed(n),
348            ..
349        } => Ok(*n as usize),
350        Datatype::String {
351            size: StringSize::Variable,
352            ..
353        } => Ok(16),
354        Datatype::Compound { size, .. } => Ok(*size as usize),
355        Datatype::Array { base, dims } => {
356            let base_size = dtype_element_size(base)?;
357            let count = dims.iter().try_fold(1usize, |acc, &dim| {
358                let dim = usize::try_from(dim).map_err(|_| {
359                    Error::InvalidData(
360                        "array datatype dimension exceeds platform usize capacity".to_string(),
361                    )
362                })?;
363                acc.checked_mul(dim).ok_or_else(|| {
364                    Error::InvalidData(
365                        "array datatype element count exceeds platform usize capacity".to_string(),
366                    )
367                })
368            })?;
369            base_size.checked_mul(count).ok_or_else(|| {
370                Error::InvalidData(
371                    "array datatype byte size exceeds platform usize capacity".to_string(),
372                )
373            })
374        }
375        Datatype::Enum { base, .. } => dtype_element_size(base),
376        Datatype::VarLen { .. } => Ok(16),
377        Datatype::Opaque { size, .. } => Ok(*size as usize),
378        Datatype::Reference { size, .. } => Ok(*size as usize),
379        Datatype::Bitfield { size, .. } => Ok(*size as usize),
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn f32_bulk_decode_native_endian() {
389        let dtype = <f32 as H5Type>::hdf5_type();
390        let raw = [0.5f32.to_ne_bytes(), 1.25f32.to_ne_bytes()].concat();
391        let values = <f32 as H5Type>::decode_vec(&raw, &dtype, 2)
392            .unwrap()
393            .unwrap();
394        assert_eq!(values, vec![0.5, 1.25]);
395    }
396
397    #[test]
398    fn u32_bulk_decode_big_endian() {
399        let dtype = Datatype::FixedPoint {
400            size: 4,
401            signed: false,
402            byte_order: ByteOrder::BigEndian,
403        };
404        let raw = [1u32.to_be_bytes(), 7u32.to_be_bytes()].concat();
405        let values = <u32 as H5Type>::decode_vec(&raw, &dtype, 2)
406            .unwrap()
407            .unwrap();
408        assert_eq!(values, vec![1, 7]);
409    }
410
411    #[test]
412    fn integer_from_bytes_rejects_signedness_mismatch() {
413        let dtype = Datatype::FixedPoint {
414            size: 2,
415            signed: false,
416            byte_order: ByteOrder::LittleEndian,
417        };
418
419        let err = <i16 as H5Type>::from_bytes(&u16::MAX.to_le_bytes(), &dtype).unwrap_err();
420        assert!(matches!(
421            err,
422            Error::TypeMismatch {
423                expected,
424                actual
425            } if expected == "i16" && actual.contains("signed=false")
426        ));
427    }
428
429    #[test]
430    fn integer_bulk_decode_rejects_signedness_mismatch() {
431        let unsigned_dtype = Datatype::FixedPoint {
432            size: 2,
433            signed: false,
434            byte_order: ByteOrder::LittleEndian,
435        };
436        let signed_dtype = Datatype::FixedPoint {
437            size: 2,
438            signed: true,
439            byte_order: ByteOrder::LittleEndian,
440        };
441
442        assert!(<i16 as H5Type>::decode_vec(&[0, 0], &unsigned_dtype, 1).is_none());
443        assert!(<u16 as H5Type>::decode_vec(&[0, 0], &signed_dtype, 1).is_none());
444    }
445
446    #[test]
447    fn integer_native_copy_compatible_rejects_signedness_mismatch() {
448        let unsigned_dtype = Datatype::FixedPoint {
449            size: 2,
450            signed: false,
451            byte_order: if cfg!(target_endian = "little") {
452                ByteOrder::LittleEndian
453            } else {
454                ByteOrder::BigEndian
455            },
456        };
457        let signed_dtype = Datatype::FixedPoint {
458            size: 2,
459            signed: true,
460            byte_order: if cfg!(target_endian = "little") {
461                ByteOrder::LittleEndian
462            } else {
463                ByteOrder::BigEndian
464            },
465        };
466
467        assert!(!<i16 as H5Type>::native_copy_compatible(&unsigned_dtype));
468        assert!(!<u16 as H5Type>::native_copy_compatible(&signed_dtype));
469        assert!(<u16 as H5Type>::native_copy_compatible(&unsigned_dtype));
470        assert!(<i16 as H5Type>::native_copy_compatible(&signed_dtype));
471    }
472
473    #[test]
474    fn dtype_element_size_rejects_array_overflow() {
475        let dtype = Datatype::Array {
476            base: Box::new(Datatype::FixedPoint {
477                size: 8,
478                signed: false,
479                byte_order: ByteOrder::LittleEndian,
480            }),
481            dims: vec![u64::MAX, 2],
482        };
483
484        let err = dtype_element_size(&dtype).unwrap_err();
485        assert!(err.to_string().contains("array datatype"));
486    }
487}