1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10#[proc_macro_derive(H5Type)]
24pub fn derive_h5type(input: TokenStream) -> TokenStream {
25 let input = parse_macro_input!(input as DeriveInput);
26 match impl_h5type(&input) {
27 Ok(ts) => ts.into(),
28 Err(e) => e.to_compile_error().into(),
29 }
30}
31
32fn impl_h5type(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
33 let name = &input.ident;
34
35 let fields = match &input.data {
36 Data::Struct(data) => match &data.fields {
37 Fields::Named(named) => &named.named,
38 _ => {
39 return Err(syn::Error::new_spanned(
40 name,
41 "H5Type can only be derived for structs with named fields",
42 ));
43 }
44 },
45 _ => {
46 return Err(syn::Error::new_spanned(
47 name,
48 "H5Type can only be derived for structs",
49 ));
50 }
51 };
52
53 let mut datatype_member_stmts = Vec::new();
54 let mut serialize_stmts = Vec::new();
55 let mut deserialize_stmts = Vec::new();
56 let mut field_names = Vec::new();
57 let mut size_increments = Vec::new();
58
59 for field in fields.iter() {
60 let field_name = field.ident.as_ref().unwrap();
61 let field_name_str = field_name.to_string();
62 let ty = &field.ty;
63
64 let (dt_expr, ser_expr, deser_expr, size_expr) = type_mapping(ty, field_name)?;
65
66 datatype_member_stmts.push(quote! {
67 _members.push(rustyhdf5_format::datatype::CompoundMember {
68 name: #field_name_str.into(),
69 byte_offset: _offset,
70 datatype: #dt_expr,
71 });
72 _offset += #size_expr as u64;
73 });
74
75 size_increments.push(quote! { + (#size_expr as usize) });
76 serialize_stmts.push(ser_expr);
77 deserialize_stmts.push(deser_expr);
78 field_names.push(field_name.clone());
79 }
80
81 let expanded = quote! {
82 impl #name {
83 pub fn hdf5_datatype() -> rustyhdf5_format::datatype::Datatype {
85 let mut _offset: u64 = 0;
86 let mut _members = Vec::new();
87 #(#datatype_member_stmts)*
88 rustyhdf5_format::datatype::Datatype::Compound {
89 size: _offset as u32,
90 members: _members,
91 }
92 }
93
94 pub fn to_bytes(&self) -> Vec<u8> {
96 let mut _buf = Vec::with_capacity(Self::_h5_compound_size());
97 #(#serialize_stmts)*
98 _buf
99 }
100
101 pub fn from_bytes(_data: &[u8]) -> Self {
103 let mut _pos = 0usize;
104 #(#deserialize_stmts)*
105 Self {
106 #(#field_names),*
107 }
108 }
109
110 fn _h5_compound_size() -> usize {
111 0usize #(#size_increments)*
112 }
113 }
114 };
115
116 Ok(expanded)
117}
118
119fn type_mapping(
120 ty: &Type,
121 field_name: &syn::Ident,
122) -> syn::Result<(
123 proc_macro2::TokenStream, proc_macro2::TokenStream, proc_macro2::TokenStream, proc_macro2::TokenStream, )> {
128 match ty {
129 Type::Path(type_path) => {
130 let seg = type_path.path.segments.last().unwrap();
131 let type_name = seg.ident.to_string();
132 match type_name.as_str() {
133 "f64" => Ok(float_mapping(field_name, 8, 64, 52, 11, 52, 1023)),
134 "f32" => Ok(float_mapping(field_name, 4, 32, 23, 8, 23, 127)),
135 "i8" => Ok(int_mapping(field_name, 1, true)),
136 "i16" => Ok(int_mapping(field_name, 2, true)),
137 "i32" => Ok(int_mapping(field_name, 4, true)),
138 "i64" => Ok(int_mapping(field_name, 8, true)),
139 "u8" => Ok(int_mapping(field_name, 1, false)),
140 "u16" => Ok(int_mapping(field_name, 2, false)),
141 "u32" => Ok(int_mapping(field_name, 4, false)),
142 "u64" => Ok(int_mapping(field_name, 8, false)),
143 "bool" => Ok(bool_mapping(field_name)),
144 _ => Err(syn::Error::new_spanned(
145 ty,
146 format!("unsupported type `{type_name}` for H5Type derive"),
147 )),
148 }
149 }
150 Type::Array(arr) => {
151 let elem_ty = &*arr.elem;
152 let len_expr = &arr.len;
153 array_mapping(field_name, elem_ty, len_expr)
154 }
155 _ => Err(syn::Error::new_spanned(
156 ty,
157 "unsupported type for H5Type derive",
158 )),
159 }
160}
161
162fn float_mapping(
163 field_name: &syn::Ident,
164 size: u32,
165 precision: u16,
166 mant_loc: u8,
167 exp_size: u8,
168 mant_size: u8,
169 exp_bias: u32,
170) -> (
171 proc_macro2::TokenStream,
172 proc_macro2::TokenStream,
173 proc_macro2::TokenStream,
174 proc_macro2::TokenStream,
175) {
176 let size_lit = size;
177 let precision_lit = precision;
178 let exp_size_lit = exp_size;
179 let mant_size_lit = mant_size;
180 let exp_bias_lit = exp_bias;
181 let exp_loc: u8 = mant_loc;
182
183 let dt = quote! {
184 rustyhdf5_format::datatype::Datatype::FloatingPoint {
185 size: #size_lit,
186 byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
187 bit_offset: 0,
188 bit_precision: #precision_lit,
189 exponent_location: #exp_loc,
190 exponent_size: #exp_size_lit,
191 mantissa_location: 0,
192 mantissa_size: #mant_size_lit,
193 exponent_bias: #exp_bias_lit,
194 }
195 };
196
197 let ser = quote! {
198 _buf.extend_from_slice(&self.#field_name.to_le_bytes());
199 };
200
201 let deser = if size == 8 {
202 quote! {
203 let #field_name = f64::from_le_bytes(
204 _data[_pos.._pos + 8].try_into().unwrap()
205 );
206 _pos += 8;
207 }
208 } else {
209 quote! {
210 let #field_name = f32::from_le_bytes(
211 _data[_pos.._pos + 4].try_into().unwrap()
212 );
213 _pos += 4;
214 }
215 };
216
217 let sz = size as usize;
218 let size_expr = quote! { #sz };
219 (dt, ser, deser, size_expr)
220}
221
222fn int_mapping(
223 field_name: &syn::Ident,
224 size: u32,
225 signed: bool,
226) -> (
227 proc_macro2::TokenStream,
228 proc_macro2::TokenStream,
229 proc_macro2::TokenStream,
230 proc_macro2::TokenStream,
231) {
232 let precision = (size * 8) as u16;
233
234 let dt = quote! {
235 rustyhdf5_format::datatype::Datatype::FixedPoint {
236 size: #size,
237 byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
238 signed: #signed,
239 bit_offset: 0,
240 bit_precision: #precision,
241 }
242 };
243
244 let ser = quote! {
245 _buf.extend_from_slice(&self.#field_name.to_le_bytes());
246 };
247
248 let sz = size as usize;
249 let deser = match (size, signed) {
250 (1, true) => quote! {
251 let #field_name = _data[_pos] as i8;
252 _pos += 1;
253 },
254 (1, false) => quote! {
255 let #field_name = _data[_pos];
256 _pos += 1;
257 },
258 (2, true) => quote! {
259 let #field_name = i16::from_le_bytes(
260 _data[_pos.._pos + 2].try_into().unwrap()
261 );
262 _pos += 2;
263 },
264 (2, false) => quote! {
265 let #field_name = u16::from_le_bytes(
266 _data[_pos.._pos + 2].try_into().unwrap()
267 );
268 _pos += 2;
269 },
270 (4, true) => quote! {
271 let #field_name = i32::from_le_bytes(
272 _data[_pos.._pos + 4].try_into().unwrap()
273 );
274 _pos += 4;
275 },
276 (4, false) => quote! {
277 let #field_name = u32::from_le_bytes(
278 _data[_pos.._pos + 4].try_into().unwrap()
279 );
280 _pos += 4;
281 },
282 (8, true) => quote! {
283 let #field_name = i64::from_le_bytes(
284 _data[_pos.._pos + 8].try_into().unwrap()
285 );
286 _pos += 8;
287 },
288 (8, false) => quote! {
289 let #field_name = u64::from_le_bytes(
290 _data[_pos.._pos + 8].try_into().unwrap()
291 );
292 _pos += 8;
293 },
294 _ => quote! {
295 let mut _tmp = [0u8; #sz];
296 _tmp.copy_from_slice(&_data[_pos.._pos + #sz]);
297 let #field_name = _tmp;
298 _pos += #sz;
299 },
300 };
301
302 let sz = size as usize;
303 let size_expr = quote! { #sz };
304 (dt, ser, deser, size_expr)
305}
306
307fn bool_mapping(
308 field_name: &syn::Ident,
309) -> (
310 proc_macro2::TokenStream,
311 proc_macro2::TokenStream,
312 proc_macro2::TokenStream,
313 proc_macro2::TokenStream,
314) {
315 let dt = quote! {
316 rustyhdf5_format::datatype::Datatype::FixedPoint {
317 size: 1,
318 byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
319 signed: false,
320 bit_offset: 0,
321 bit_precision: 8,
322 }
323 };
324
325 let ser = quote! {
326 _buf.push(if self.#field_name { 1u8 } else { 0u8 });
327 };
328
329 let deser = quote! {
330 let #field_name = _data[_pos] != 0;
331 _pos += 1;
332 };
333
334 let size_expr = quote! { 1usize };
335 (dt, ser, deser, size_expr)
336}
337
338fn array_mapping(
339 field_name: &syn::Ident,
340 elem_ty: &Type,
341 len_expr: &syn::Expr,
342) -> syn::Result<(
343 proc_macro2::TokenStream,
344 proc_macro2::TokenStream,
345 proc_macro2::TokenStream,
346 proc_macro2::TokenStream,
347)> {
348 let Type::Path(type_path) = elem_ty else {
349 return Err(syn::Error::new_spanned(
350 elem_ty,
351 "array element must be a primitive type for H5Type derive",
352 ));
353 };
354 let elem_name = type_path.path.segments.last().unwrap().ident.to_string();
355
356 let (base_dt, elem_size, deser_one) = match elem_name.as_str() {
357 "f64" => (
358 quote! {
359 rustyhdf5_format::datatype::Datatype::FloatingPoint {
360 size: 8,
361 byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
362 bit_offset: 0, bit_precision: 64,
363 exponent_location: 52, exponent_size: 11,
364 mantissa_location: 0, mantissa_size: 52,
365 exponent_bias: 1023,
366 }
367 },
368 8usize,
369 quote! { f64::from_le_bytes(_data[_pos.._pos + 8].try_into().unwrap()) },
370 ),
371 "f32" => (
372 quote! {
373 rustyhdf5_format::datatype::Datatype::FloatingPoint {
374 size: 4,
375 byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
376 bit_offset: 0, bit_precision: 32,
377 exponent_location: 23, exponent_size: 8,
378 mantissa_location: 0, mantissa_size: 23,
379 exponent_bias: 127,
380 }
381 },
382 4usize,
383 quote! { f32::from_le_bytes(_data[_pos.._pos + 4].try_into().unwrap()) },
384 ),
385 "i8" => (
386 int_dt_quote(1, true),
387 1usize,
388 quote! { _data[_pos] as i8 },
389 ),
390 "i16" => (
391 int_dt_quote(2, true),
392 2usize,
393 quote! { i16::from_le_bytes(_data[_pos.._pos + 2].try_into().unwrap()) },
394 ),
395 "i32" => (
396 int_dt_quote(4, true),
397 4usize,
398 quote! { i32::from_le_bytes(_data[_pos.._pos + 4].try_into().unwrap()) },
399 ),
400 "i64" => (
401 int_dt_quote(8, true),
402 8usize,
403 quote! { i64::from_le_bytes(_data[_pos.._pos + 8].try_into().unwrap()) },
404 ),
405 "u8" => (
406 int_dt_quote(1, false),
407 1usize,
408 quote! { _data[_pos] },
409 ),
410 "u16" => (
411 int_dt_quote(2, false),
412 2usize,
413 quote! { u16::from_le_bytes(_data[_pos.._pos + 2].try_into().unwrap()) },
414 ),
415 "u32" => (
416 int_dt_quote(4, false),
417 4usize,
418 quote! { u32::from_le_bytes(_data[_pos.._pos + 4].try_into().unwrap()) },
419 ),
420 "u64" => (
421 int_dt_quote(8, false),
422 8usize,
423 quote! { u64::from_le_bytes(_data[_pos.._pos + 8].try_into().unwrap()) },
424 ),
425 _ => {
426 return Err(syn::Error::new_spanned(
427 elem_ty,
428 format!("unsupported array element type `{elem_name}` for H5Type derive"),
429 ));
430 }
431 };
432
433 let dt = quote! {
434 rustyhdf5_format::datatype::Datatype::Array {
435 base_type: Box::new(#base_dt),
436 dimensions: vec![#len_expr as u32],
437 }
438 };
439
440 let ser = quote! {
441 for _elem in &self.#field_name {
442 _buf.extend_from_slice(&_elem.to_le_bytes());
443 }
444 };
445
446 let deser = quote! {
447 let #field_name = {
448 let mut _arr = [Default::default(); #len_expr];
449 for _i in 0..#len_expr {
450 _arr[_i] = #deser_one;
451 _pos += #elem_size;
452 }
453 _arr
454 };
455 };
456
457 let size_expr = quote! { (#len_expr * #elem_size) };
458 Ok((dt, ser, deser, size_expr))
459}
460
461fn int_dt_quote(size: u32, signed: bool) -> proc_macro2::TokenStream {
462 let precision = (size * 8) as u16;
463 quote! {
464 rustyhdf5_format::datatype::Datatype::FixedPoint {
465 size: #size,
466 byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
467 signed: #signed,
468 bit_offset: 0,
469 bit_precision: #precision,
470 }
471 }
472}