Skip to main content

hdf5_reader/
datatype_api.rs

1use crate::error::{ByteOrder, Error, Result};
2use crate::messages::datatype::Datatype;
3
4// Re-export types from the datatype message module so users don't need to
5// reach into messages::datatype.
6pub use crate::messages::datatype::{
7    CompoundField, EnumMember, ReferenceType, StringEncoding, StringPadding, StringSize,
8};
9
10/// Trait for types that can be read from HDF5 datasets.
11///
12/// Implemented for primitive numeric types. Users can implement this
13/// for custom types (e.g., compound types).
14pub trait H5Type: Sized + Send + Clone {
15    /// The HDF5 datatype that this Rust type corresponds to.
16    fn hdf5_type() -> Datatype;
17
18    /// Decode a single value from raw bytes with the given datatype.
19    fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self>;
20
21    /// Size of a single element in bytes.
22    fn element_size(dtype: &Datatype) -> usize;
23
24    /// Decode many values at once when the datatype has an efficient bulk path.
25    ///
26    /// Returning `None` falls back to per-element decoding.
27    fn decode_vec(_raw: &[u8], _dtype: &Datatype, _count: usize) -> Option<Result<Vec<Self>>> {
28        None
29    }
30
31    /// Whether raw bytes for this datatype can be copied directly into a `Vec<Self>`
32    /// without any further decoding or byte swapping.
33    fn native_copy_compatible(_dtype: &Datatype) -> bool {
34        false
35    }
36}
37
38/// Read a numeric value from bytes, handling byte-order conversion.
39fn read_numeric<const N: usize>(bytes: &[u8], byte_order: ByteOrder) -> Result<[u8; N]> {
40    if bytes.len() < N {
41        return Err(Error::InvalidData(format!(
42            "expected {} bytes, got {}",
43            N,
44            bytes.len()
45        )));
46    }
47    let mut arr = [0u8; N];
48    arr.copy_from_slice(&bytes[..N]);
49
50    // Swap bytes if the source endianness doesn't match native
51    #[cfg(target_endian = "little")]
52    if byte_order == ByteOrder::BigEndian {
53        arr.reverse();
54    }
55    #[cfg(target_endian = "big")]
56    if byte_order == ByteOrder::LittleEndian {
57        arr.reverse();
58    }
59
60    Ok(arr)
61}
62
63fn byte_order_is_native(byte_order: ByteOrder) -> bool {
64    #[cfg(target_endian = "little")]
65    {
66        byte_order == ByteOrder::LittleEndian
67    }
68    #[cfg(target_endian = "big")]
69    {
70        byte_order == ByteOrder::BigEndian
71    }
72}
73
74macro_rules! impl_h5type_int {
75    ($ty:ty, $size:expr, $signed:expr) => {
76        impl H5Type for $ty {
77            fn hdf5_type() -> Datatype {
78                Datatype::FixedPoint {
79                    size: $size,
80                    signed: $signed,
81                    byte_order: if cfg!(target_endian = "little") {
82                        ByteOrder::LittleEndian
83                    } else {
84                        ByteOrder::BigEndian
85                    },
86                }
87            }
88
89            fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self> {
90                match dtype {
91                    Datatype::FixedPoint {
92                        size, byte_order, ..
93                    } => {
94                        if *size as usize != std::mem::size_of::<$ty>() {
95                            return Err(Error::TypeMismatch {
96                                expected: stringify!($ty).into(),
97                                actual: format!("FixedPoint(size={})", size),
98                            });
99                        }
100                        let arr = read_numeric::<$size>(bytes, *byte_order)?;
101                        Ok(<$ty>::from_ne_bytes(arr))
102                    }
103                    _ => Err(Error::TypeMismatch {
104                        expected: stringify!($ty).into(),
105                        actual: format!("{:?}", dtype),
106                    }),
107                }
108            }
109
110            fn element_size(_dtype: &Datatype) -> usize {
111                $size
112            }
113
114            fn decode_vec(raw: &[u8], dtype: &Datatype, count: usize) -> Option<Result<Vec<Self>>> {
115                match dtype {
116                    Datatype::FixedPoint {
117                        size, byte_order, ..
118                    } if *size as usize == $size => {
119                        let total_bytes = count.checked_mul($size)?;
120                        if raw.len() < total_bytes {
121                            return None;
122                        }
123
124                        let bytes = &raw[..total_bytes];
125                        if byte_order_is_native(*byte_order) {
126                            let mut values = Vec::<$ty>::with_capacity(count);
127                            unsafe {
128                                std::ptr::copy_nonoverlapping(
129                                    bytes.as_ptr(),
130                                    values.as_mut_ptr() as *mut u8,
131                                    total_bytes,
132                                );
133                                values.set_len(count);
134                            }
135                            Some(Ok(values))
136                        } else {
137                            Some(Ok(bytes
138                                .chunks_exact($size)
139                                .map(|chunk| {
140                                    let mut arr = [0u8; $size];
141                                    arr.copy_from_slice(chunk);
142                                    arr.reverse();
143                                    <$ty>::from_ne_bytes(arr)
144                                })
145                                .collect()))
146                        }
147                    }
148                    _ => None,
149                }
150            }
151
152            fn native_copy_compatible(dtype: &Datatype) -> bool {
153                matches!(
154                    dtype,
155                    Datatype::FixedPoint {
156                        size,
157                        byte_order,
158                        ..
159                    } if *size as usize == $size && byte_order_is_native(*byte_order)
160                )
161            }
162        }
163    };
164}
165
166impl_h5type_int!(i8, 1, true);
167impl_h5type_int!(u8, 1, false);
168impl_h5type_int!(i16, 2, true);
169impl_h5type_int!(u16, 2, false);
170impl_h5type_int!(i32, 4, true);
171impl_h5type_int!(u32, 4, false);
172impl_h5type_int!(i64, 8, true);
173impl_h5type_int!(u64, 8, false);
174
175impl H5Type for f32 {
176    fn hdf5_type() -> Datatype {
177        Datatype::FloatingPoint {
178            size: 4,
179            byte_order: if cfg!(target_endian = "little") {
180                ByteOrder::LittleEndian
181            } else {
182                ByteOrder::BigEndian
183            },
184        }
185    }
186
187    fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self> {
188        match dtype {
189            Datatype::FloatingPoint { size, byte_order } => {
190                if *size != 4 {
191                    return Err(Error::TypeMismatch {
192                        expected: "f32".into(),
193                        actual: format!("FloatingPoint(size={})", size),
194                    });
195                }
196                let arr = read_numeric::<4>(bytes, *byte_order)?;
197                Ok(f32::from_ne_bytes(arr))
198            }
199            _ => Err(Error::TypeMismatch {
200                expected: "f32".into(),
201                actual: format!("{:?}", dtype),
202            }),
203        }
204    }
205
206    fn element_size(_dtype: &Datatype) -> usize {
207        4
208    }
209
210    fn decode_vec(raw: &[u8], dtype: &Datatype, count: usize) -> Option<Result<Vec<Self>>> {
211        match dtype {
212            Datatype::FloatingPoint { size, byte_order } if *size == 4 => {
213                let total_bytes = count.checked_mul(4)?;
214                if raw.len() < total_bytes {
215                    return None;
216                }
217
218                let bytes = &raw[..total_bytes];
219                if byte_order_is_native(*byte_order) {
220                    let mut values = Vec::<f32>::with_capacity(count);
221                    unsafe {
222                        std::ptr::copy_nonoverlapping(
223                            bytes.as_ptr(),
224                            values.as_mut_ptr() as *mut u8,
225                            total_bytes,
226                        );
227                        values.set_len(count);
228                    }
229                    Some(Ok(values))
230                } else {
231                    Some(Ok(bytes
232                        .chunks_exact(4)
233                        .map(|chunk| {
234                            let mut arr = [0u8; 4];
235                            arr.copy_from_slice(chunk);
236                            arr.reverse();
237                            f32::from_ne_bytes(arr)
238                        })
239                        .collect()))
240                }
241            }
242            _ => None,
243        }
244    }
245
246    fn native_copy_compatible(dtype: &Datatype) -> bool {
247        matches!(
248            dtype,
249            Datatype::FloatingPoint { size, byte_order }
250                if *size == 4 && byte_order_is_native(*byte_order)
251        )
252    }
253}
254
255impl H5Type for f64 {
256    fn hdf5_type() -> Datatype {
257        Datatype::FloatingPoint {
258            size: 8,
259            byte_order: if cfg!(target_endian = "little") {
260                ByteOrder::LittleEndian
261            } else {
262                ByteOrder::BigEndian
263            },
264        }
265    }
266
267    fn from_bytes(bytes: &[u8], dtype: &Datatype) -> Result<Self> {
268        match dtype {
269            Datatype::FloatingPoint { size, byte_order } => {
270                if *size != 8 {
271                    return Err(Error::TypeMismatch {
272                        expected: "f64".into(),
273                        actual: format!("FloatingPoint(size={})", size),
274                    });
275                }
276                let arr = read_numeric::<8>(bytes, *byte_order)?;
277                Ok(f64::from_ne_bytes(arr))
278            }
279            _ => Err(Error::TypeMismatch {
280                expected: "f64".into(),
281                actual: format!("{:?}", dtype),
282            }),
283        }
284    }
285
286    fn element_size(_dtype: &Datatype) -> usize {
287        8
288    }
289
290    fn decode_vec(raw: &[u8], dtype: &Datatype, count: usize) -> Option<Result<Vec<Self>>> {
291        match dtype {
292            Datatype::FloatingPoint { size, byte_order } if *size == 8 => {
293                let total_bytes = count.checked_mul(8)?;
294                if raw.len() < total_bytes {
295                    return None;
296                }
297
298                let bytes = &raw[..total_bytes];
299                if byte_order_is_native(*byte_order) {
300                    let mut values = Vec::<f64>::with_capacity(count);
301                    unsafe {
302                        std::ptr::copy_nonoverlapping(
303                            bytes.as_ptr(),
304                            values.as_mut_ptr() as *mut u8,
305                            total_bytes,
306                        );
307                        values.set_len(count);
308                    }
309                    Some(Ok(values))
310                } else {
311                    Some(Ok(bytes
312                        .chunks_exact(8)
313                        .map(|chunk| {
314                            let mut arr = [0u8; 8];
315                            arr.copy_from_slice(chunk);
316                            arr.reverse();
317                            f64::from_ne_bytes(arr)
318                        })
319                        .collect()))
320                }
321            }
322            _ => None,
323        }
324    }
325
326    fn native_copy_compatible(dtype: &Datatype) -> bool {
327        matches!(
328            dtype,
329            Datatype::FloatingPoint { size, byte_order }
330                if *size == 8 && byte_order_is_native(*byte_order)
331        )
332    }
333}
334
335/// Get the element size from a datatype.
336pub fn dtype_element_size(dtype: &Datatype) -> usize {
337    match dtype {
338        Datatype::FixedPoint { size, .. } => *size as usize,
339        Datatype::FloatingPoint { size, .. } => *size as usize,
340        Datatype::String {
341            size: StringSize::Fixed(n),
342            ..
343        } => *n as usize,
344        Datatype::String {
345            size: StringSize::Variable,
346            ..
347        } => 16,
348        Datatype::Compound { size, .. } => *size as usize,
349        Datatype::Array { base, dims } => {
350            let base_size = dtype_element_size(base);
351            let count: u64 = dims.iter().product();
352            base_size * count as usize
353        }
354        Datatype::Enum { base, .. } => dtype_element_size(base),
355        Datatype::VarLen { .. } => 16,
356        Datatype::Opaque { size, .. } => *size as usize,
357        Datatype::Reference { size, .. } => *size as usize,
358        Datatype::Bitfield { size, .. } => *size as usize,
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_f32_bulk_decode_native_endian() {
368        let dtype = <f32 as H5Type>::hdf5_type();
369        let raw = [0.5f32.to_ne_bytes(), 1.25f32.to_ne_bytes()].concat();
370        let values = <f32 as H5Type>::decode_vec(&raw, &dtype, 2)
371            .unwrap()
372            .unwrap();
373        assert_eq!(values, vec![0.5, 1.25]);
374    }
375
376    #[test]
377    fn test_u32_bulk_decode_big_endian() {
378        let dtype = Datatype::FixedPoint {
379            size: 4,
380            signed: false,
381            byte_order: ByteOrder::BigEndian,
382        };
383        let raw = [1u32.to_be_bytes(), 7u32.to_be_bytes()].concat();
384        let values = <u32 as H5Type>::decode_vec(&raw, &dtype, 2)
385            .unwrap()
386            .unwrap();
387        assert_eq!(values, vec![1, 7]);
388    }
389}