1#![allow(
13 clippy::match_same_arms,
14 clippy::needless_pass_by_value,
15 clippy::option_if_let_else
16)]
17
18extern crate proc_macro;
19
20use proc_macro::TokenStream;
21use quote::quote;
22use syn::{Data, DeriveInput, Fields, parse_macro_input};
23
24#[proc_macro_derive(FerrayRecord)]
39pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
40 let input = parse_macro_input!(input as DeriveInput);
41 match impl_ferray_record(&input) {
42 Ok(ts) => ts.into(),
43 Err(e) => e.to_compile_error().into(),
44 }
45}
46
47fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
48 let name = &input.ident;
49
50 let has_repr_c = input.attrs.iter().any(|attr| {
52 if !attr.path().is_ident("repr") {
53 return false;
54 }
55 let mut found = false;
56 let _ = attr.parse_nested_meta(|meta| {
57 if meta.path.is_ident("C") {
58 found = true;
59 }
60 Ok(())
61 });
62 found
63 });
64
65 if !has_repr_c {
66 return Err(syn::Error::new_spanned(
67 &input.ident,
68 "FerrayRecord requires #[repr(C)] on the struct",
69 ));
70 }
71
72 let fields = match &input.data {
74 Data::Struct(data_struct) => match &data_struct.fields {
75 Fields::Named(named) => &named.named,
76 _ => {
77 return Err(syn::Error::new_spanned(
78 &input.ident,
79 "FerrayRecord only supports structs with named fields",
80 ));
81 }
82 },
83 _ => {
84 return Err(syn::Error::new_spanned(
85 &input.ident,
86 "FerrayRecord can only be derived for structs",
87 ));
88 }
89 };
90
91 let field_count = fields.len();
92 let mut field_descriptors = Vec::with_capacity(field_count);
93
94 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
95
96 let ty_generics_turbofish = input.generics.split_for_impl().1.as_turbofish();
102 let name_with_generics = quote! { #name #ty_generics_turbofish };
103
104 for field in fields {
105 let field_name = field.ident.as_ref().unwrap();
106 let field_name_str = field_name.to_string();
107 let field_ty = &field.ty;
108
109 field_descriptors.push(quote! {
110 ferray_core::record::FieldDescriptor {
111 name: #field_name_str,
112 dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
113 offset: std::mem::offset_of!(#name_with_generics, #field_name),
114 size: std::mem::size_of::<#field_ty>(),
115 }
116 });
117 }
118
119 let expanded = quote! {
120 unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
121 fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
122 use std::any::TypeId;
141 use std::collections::HashMap;
142 use std::sync::{OnceLock, Mutex};
143 static CACHE: OnceLock<
144 Mutex<HashMap<TypeId, &'static [ferray_core::record::FieldDescriptor]>>,
145 > = OnceLock::new();
146 let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
147 let mut guard = cache.lock().unwrap();
148 *guard
149 .entry(TypeId::of::<#name_with_generics>())
150 .or_insert_with(|| {
151 let v: Vec<ferray_core::record::FieldDescriptor> = vec![
152 #(#field_descriptors),*
153 ];
154 Box::leak(v.into_boxed_slice())
155 })
156 }
157
158 fn record_size() -> usize {
159 std::mem::size_of::<#name_with_generics>()
160 }
161 }
162 };
163
164 Ok(expanded)
165}
166
167#[proc_macro]
189pub fn s(input: TokenStream) -> TokenStream {
190 let input2: proc_macro2::TokenStream = input.into();
191 let expanded = impl_s_macro(input2);
192 match expanded {
193 Ok(ts) => ts.into(),
194 Err(e) => e.to_compile_error().into(),
195 }
196}
197
198fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
199 let input_str = input.to_string();
226
227 if input_str.trim().is_empty() {
229 return Ok(quote! {
230 ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
231 });
232 }
233
234 let components = split_top_level_commas(&input_str);
236 let mut elems = Vec::new();
237
238 for component in &components {
239 let trimmed = component.trim();
240 if trimmed.is_empty() {
241 continue;
242 }
243 elems.push(parse_slice_component(trimmed)?);
244 }
245
246 Ok(quote! {
247 vec![#(#elems),*]
248 })
249}
250
251fn split_top_level_commas(s: &str) -> Vec<String> {
252 let mut result = Vec::new();
253 let mut current = String::new();
254 let mut depth = 0i32;
255
256 for ch in s.chars() {
257 match ch {
258 '(' | '[' | '{' => {
259 depth += 1;
260 current.push(ch);
261 }
262 ')' | ']' | '}' => {
263 depth -= 1;
264 current.push(ch);
265 }
266 ',' if depth == 0 => {
267 result.push(current.clone());
268 current.clear();
269 }
270 _ => {
271 current.push(ch);
272 }
273 }
274 }
275 if !current.is_empty() {
276 result.push(current);
277 }
278 result
279}
280
281fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
283 let mut depth = 0i32;
284 let mut last_idx = None;
285 for (i, ch) in s.char_indices() {
286 match ch {
287 '(' | '[' | '{' => depth += 1,
288 ')' | ']' | '}' => depth -= 1,
289 ';' if depth == 0 => last_idx = Some(i),
290 _ => {}
291 }
292 }
293 last_idx
294}
295
296fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
297 let trimmed = s.trim();
298
299 let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
301 let (rp, sp) = trimmed.split_at(idx);
302 (rp.trim(), Some(sp[1..].trim()))
303 } else {
304 (trimmed, None)
305 };
306
307 let step_expr = if let Some(step_str) = step_part {
308 let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
309 syn::Error::new(
310 proc_macro2::Span::call_site(),
311 format!("invalid step expression: {step_str}"),
312 )
313 })?;
314 quote! { #step_tokens }
315 } else {
316 quote! { 1isize }
317 };
318
319 if range_part == ".." {
321 return Ok(quote! {
323 ferray_core::dtype::SliceInfoElem::Slice {
324 start: 0,
325 end: ::core::option::Option::None,
326 step: #step_expr,
327 }
328 });
329 }
330
331 if let Some(rest) = range_part.strip_prefix("..") {
332 let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
334 syn::Error::new(
335 proc_macro2::Span::call_site(),
336 format!("invalid end expression: {rest}"),
337 )
338 })?;
339 return Ok(quote! {
340 ferray_core::dtype::SliceInfoElem::Slice {
341 start: 0,
342 end: ::core::option::Option::Some(#end_tokens),
343 step: #step_expr,
344 }
345 });
346 }
347
348 if let Some(idx) = range_part.find("..") {
349 let start_str = range_part[..idx].trim();
350 let end_str = range_part[idx + 2..].trim();
351
352 let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
353 syn::Error::new(
354 proc_macro2::Span::call_site(),
355 format!("invalid start expression: {start_str}"),
356 )
357 })?;
358
359 if end_str.is_empty() {
360 return Ok(quote! {
362 ferray_core::dtype::SliceInfoElem::Slice {
363 start: #start_tokens,
364 end: ::core::option::Option::None,
365 step: #step_expr,
366 }
367 });
368 }
369
370 let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
371 syn::Error::new(
372 proc_macro2::Span::call_site(),
373 format!("invalid end expression: {end_str}"),
374 )
375 })?;
376
377 return Ok(quote! {
378 ferray_core::dtype::SliceInfoElem::Slice {
379 start: #start_tokens,
380 end: ::core::option::Option::Some(#end_tokens),
381 step: #step_expr,
382 }
383 });
384 }
385
386 if step_part.is_some() {
388 return Err(syn::Error::new(
389 proc_macro2::Span::call_site(),
390 format!("step ';' is not valid for integer indices: {trimmed}"),
391 ));
392 }
393
394 let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
395 syn::Error::new(
396 proc_macro2::Span::call_site(),
397 format!("invalid index expression: {range_part}"),
398 )
399 })?;
400
401 Ok(quote! {
402 ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
403 })
404}
405
406#[proc_macro]
422pub fn promoted_type(input: TokenStream) -> TokenStream {
423 let input2: proc_macro2::TokenStream = input.into();
424 match impl_promoted_type(input2) {
425 Ok(ts) => ts.into(),
426 Err(e) => e.to_compile_error().into(),
427 }
428}
429
430fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
431 let input_str = input.to_string();
432 let parts: Vec<&str> = input_str.split(',').map(str::trim).collect();
433
434 if parts.len() != 2 {
435 return Err(syn::Error::new(
436 proc_macro2::Span::call_site(),
437 "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
438 ));
439 }
440
441 let t1 = normalize_type(parts[0]);
442 let t2 = normalize_type(parts[1]);
443
444 let result = promote_types_static(&t1, &t2).ok_or_else(|| {
445 syn::Error::new(
446 proc_macro2::Span::call_site(),
447 format!("cannot promote types: {t1} and {t2}"),
448 )
449 })?;
450
451 let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
452 syn::Error::new(
453 proc_macro2::Span::call_site(),
454 format!("internal error: could not parse result type: {result}"),
455 )
456 })?;
457
458 Ok(result_tokens)
459}
460
461fn normalize_type(s: &str) -> String {
462 s.trim().replace(' ', "")
464}
465
466fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
470 let ra = type_rank(a)?;
484 let rb = type_rank(b)?;
485
486 match promote_ranks(ra, rb) {
490 "" => None,
491 other => Some(other),
492 }
493}
494
495#[derive(Clone, Copy, PartialEq, Eq)]
496enum TypeKind {
497 Bool,
498 Unsigned,
499 Signed,
500 Float,
501 Complex,
502}
503
504#[derive(Clone, Copy)]
505struct TypeRank {
506 kind: TypeKind,
507 bits: u32,
509}
510
511fn type_rank(s: &str) -> Option<TypeRank> {
512 let result = match s {
513 "bool" => TypeRank {
514 kind: TypeKind::Bool,
515 bits: 1,
516 },
517 "u8" => TypeRank {
518 kind: TypeKind::Unsigned,
519 bits: 8,
520 },
521 "u16" => TypeRank {
522 kind: TypeKind::Unsigned,
523 bits: 16,
524 },
525 "u32" => TypeRank {
526 kind: TypeKind::Unsigned,
527 bits: 32,
528 },
529 "u64" => TypeRank {
530 kind: TypeKind::Unsigned,
531 bits: 64,
532 },
533 "u128" => TypeRank {
534 kind: TypeKind::Unsigned,
535 bits: 128,
536 },
537 "i8" => TypeRank {
538 kind: TypeKind::Signed,
539 bits: 8,
540 },
541 "i16" => TypeRank {
542 kind: TypeKind::Signed,
543 bits: 16,
544 },
545 "i32" => TypeRank {
546 kind: TypeKind::Signed,
547 bits: 32,
548 },
549 "i64" => TypeRank {
550 kind: TypeKind::Signed,
551 bits: 64,
552 },
553 "i128" => TypeRank {
554 kind: TypeKind::Signed,
555 bits: 128,
556 },
557 "f32" => TypeRank {
558 kind: TypeKind::Float,
559 bits: 32,
560 },
561 "f64" => TypeRank {
562 kind: TypeKind::Float,
563 bits: 64,
564 },
565 "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
566 kind: TypeKind::Complex,
567 bits: 32,
568 },
569 "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
570 kind: TypeKind::Complex,
571 bits: 64,
572 },
573 "f16" | "half::f16" => TypeRank {
574 kind: TypeKind::Float,
575 bits: 16,
576 },
577 "bf16" | "half::bf16" => TypeRank {
578 kind: TypeKind::Float,
579 bits: 16,
580 },
581 _ => return None,
582 };
583 Some(result)
584}
585
586fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
587 use TypeKind::{Bool, Complex, Float, Signed, Unsigned};
588
589 if a.kind == b.kind && a.bits == b.bits {
591 return rank_to_type(a);
592 }
593
594 if a.kind == Bool {
596 return rank_to_type(b);
597 }
598 if b.kind == Bool {
599 return rank_to_type(a);
600 }
601
602 if a.kind == Complex || b.kind == Complex {
604 let float_bits_a = to_float_bits(a);
605 let float_bits_b = to_float_bits(b);
606 let bits = float_bits_a.max(float_bits_b);
607 return if bits <= 32 {
608 "num_complex::Complex<f32>"
609 } else {
610 "num_complex::Complex<f64>"
611 };
612 }
613
614 if a.kind == Float || b.kind == Float {
616 let float_bits_a = to_float_bits(a);
617 let float_bits_b = to_float_bits(b);
618 let bits = float_bits_a.max(float_bits_b);
619 return if bits <= 32 { "f32" } else { "f64" };
620 }
621
622 match (a.kind, b.kind) {
624 (Unsigned, Unsigned) => {
625 let bits = a.bits.max(b.bits);
626 uint_type(bits)
627 }
628 (Signed, Signed) => {
629 let bits = a.bits.max(b.bits);
630 int_type(bits)
631 }
632 (Unsigned, Signed) | (Signed, Unsigned) => {
633 let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
634 if u.bits < s.bits {
637 int_type(s.bits)
639 } else {
640 let needed = u.bits.max(s.bits) * 2;
642 if needed <= 128 {
643 int_type(needed)
644 } else {
645 ""
652 }
653 }
654 }
655 _ => "f64", }
657}
658
659const fn to_float_bits(r: TypeRank) -> u32 {
661 match r.kind {
662 TypeKind::Bool => 32,
663 TypeKind::Unsigned | TypeKind::Signed => {
664 if r.bits <= 16 { 32 } else { 64 }
667 }
668 TypeKind::Float => r.bits,
669 TypeKind::Complex => r.bits,
670 }
671}
672
673const fn uint_type(bits: u32) -> &'static str {
674 match bits {
675 8 => "u8",
676 16 => "u16",
677 32 => "u32",
678 64 => "u64",
679 128 => "u128",
680 _ => "u64",
681 }
682}
683
684const fn int_type(bits: u32) -> &'static str {
685 match bits {
686 8 => "i8",
687 16 => "i16",
688 32 => "i32",
689 64 => "i64",
690 128 => "i128",
691 _ => "i64",
692 }
693}
694
695const fn rank_to_type(r: TypeRank) -> &'static str {
696 match r.kind {
697 TypeKind::Bool => "bool",
698 TypeKind::Unsigned => uint_type(r.bits),
699 TypeKind::Signed => int_type(r.bits),
700 TypeKind::Float => {
701 if r.bits <= 16 {
702 "half::f16"
703 } else if r.bits <= 32 {
704 "f32"
705 } else {
706 "f64"
707 }
708 }
709 TypeKind::Complex => {
710 if r.bits <= 32 {
711 "num_complex::Complex<f32>"
712 } else {
713 "num_complex::Complex<f64>"
714 }
715 }
716 }
717}