Skip to main content

ferray_core_macros/
lib.rs

1// ferray-core-macros: Procedural macros for ferray-core
2//
3// Implements:
4// - #[derive(FerrayRecord)] — generates FerrayRecord trait impl for #[repr(C)] structs
5// - s![] — NumPy-style slice indexing macro
6// - promoted_type!() — compile-time type promotion macro
7
8// Macro implementations branch on parsed AST shapes; identical match arms
9// reflect "different syntactic forms produce the same expansion" and are
10// kept distinct on purpose for future divergence. `needless_pass_by_value`
11// is a syn-API ergonomics tradeoff (the parser owns each TokenStream once).
12#![allow(
13    clippy::match_same_arms,
14    clippy::needless_pass_by_value,
15    clippy::option_if_let_else
16)]
17
18extern crate proc_macro;
19
20use proc_macro::TokenStream;
21use quote::quote;
22use syn::{Data, DeriveInput, Fields, parse_macro_input};
23
24// ---------------------------------------------------------------------------
25// #[derive(FerrayRecord)]
26// ---------------------------------------------------------------------------
27
28/// Derive macro that generates an `unsafe impl FerrayRecord` for a `#[repr(C)]` struct.
29///
30/// # Requirements
31/// - The struct must have `#[repr(C)]`.
32/// - All fields must implement `ferray_core::dtype::Element`.
33///
34/// # Generated code
35/// - `field_descriptors()` returns a static slice of `FieldDescriptor` with correct
36///   name, dtype, offset, and size for each field.
37/// - `record_size()` returns `std::mem::size_of::<Self>()`.
38#[proc_macro_derive(FerrayRecord)]
39pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(input as DeriveInput);
41    match impl_ferray_record(&input) {
42        Ok(ts) => ts.into(),
43        Err(e) => e.to_compile_error().into(),
44    }
45}
46
47fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
48    let name = &input.ident;
49
50    // Check for #[repr(C)]
51    let has_repr_c = input.attrs.iter().any(|attr| {
52        if !attr.path().is_ident("repr") {
53            return false;
54        }
55        let mut found = false;
56        let _ = attr.parse_nested_meta(|meta| {
57            if meta.path.is_ident("C") {
58                found = true;
59            }
60            Ok(())
61        });
62        found
63    });
64
65    if !has_repr_c {
66        return Err(syn::Error::new_spanned(
67            &input.ident,
68            "FerrayRecord requires #[repr(C)] on the struct",
69        ));
70    }
71
72    // Only works on structs with named fields
73    let fields = match &input.data {
74        Data::Struct(data_struct) => match &data_struct.fields {
75            Fields::Named(named) => &named.named,
76            _ => {
77                return Err(syn::Error::new_spanned(
78                    &input.ident,
79                    "FerrayRecord only supports structs with named fields",
80                ));
81            }
82        },
83        _ => {
84            return Err(syn::Error::new_spanned(
85                &input.ident,
86                "FerrayRecord can only be derived for structs",
87            ));
88        }
89    };
90
91    let field_count = fields.len();
92    let mut field_descriptors = Vec::with_capacity(field_count);
93
94    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
95
96    // Type-with-generics name used in field-offset / size-of expressions.
97    // `#name` alone is the bare ident (`Mixed`), which won't compile in
98    // `offset_of!` or `size_of::<>` for generic structs (#326). We
99    // reconstruct `Mixed::<A, B>` by combining `#name` with the
100    // turbofished form of `#ty_generics`.
101    let ty_generics_turbofish = input.generics.split_for_impl().1.as_turbofish();
102    let name_with_generics = quote! { #name #ty_generics_turbofish };
103
104    for field in fields {
105        let field_name = field.ident.as_ref().unwrap();
106        let field_name_str = field_name.to_string();
107        let field_ty = &field.ty;
108
109        field_descriptors.push(quote! {
110            ferray_core::record::FieldDescriptor {
111                name: #field_name_str,
112                dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
113                offset: std::mem::offset_of!(#name_with_generics, #field_name),
114                size: std::mem::size_of::<#field_ty>(),
115            }
116        });
117    }
118
119    let expanded = quote! {
120        unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
121            fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
122                // #326: a plain `static FIELDS: LazyLock<Vec<…>>` here
123                // can't reference the outer generic parameters
124                // (`error[E0401]: use of generic parameter from outer
125                // item` — statics are independent items). Use a
126                // TypeId-keyed cache that lazily inserts the descriptor
127                // vector on first call per monomorphization, then
128                // returns the leaked slice. For non-generic structs
129                // this lookups the same entry every call; for generic
130                // structs each instantiation gets its own slot keyed
131                // by `TypeId::of::<Self>()`.
132                //
133                // Issue #323 asked whether this could be a `const`
134                // static array. `offset_of!` and `size_of::<T>()` are
135                // both const, but `<T as Element>::dtype()` is a trait
136                // method call and trait methods are not `const` on
137                // stable Rust yet — so the runtime cache is
138                // unavoidable. Re-evaluate once `const_trait_impl`
139                // stabilises.
140                use std::any::TypeId;
141                use std::collections::HashMap;
142                use std::sync::{OnceLock, Mutex};
143                static CACHE: OnceLock<
144                    Mutex<HashMap<TypeId, &'static [ferray_core::record::FieldDescriptor]>>,
145                > = OnceLock::new();
146                let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new()));
147                let mut guard = cache.lock().unwrap();
148                *guard
149                    .entry(TypeId::of::<#name_with_generics>())
150                    .or_insert_with(|| {
151                        let v: Vec<ferray_core::record::FieldDescriptor> = vec![
152                            #(#field_descriptors),*
153                        ];
154                        Box::leak(v.into_boxed_slice())
155                    })
156            }
157
158            fn record_size() -> usize {
159                std::mem::size_of::<#name_with_generics>()
160            }
161        }
162    };
163
164    Ok(expanded)
165}
166
167// ---------------------------------------------------------------------------
168// s![] macro — NumPy-style slice indexing
169// ---------------------------------------------------------------------------
170
171/// NumPy-style slice indexing macro.
172///
173/// Produces a `Vec<ferray_core::dtype::SliceInfoElem>` that can be passed
174/// to array slicing methods.
175///
176/// # Syntax
177/// - `s![0..3, 2]` — rows 0..3, column 2
178/// - `s![.., 0..;2]` — all rows, every-other column starting from 0
179/// - `s![1..5;2, ..]` — rows 1..5 step 2, all columns
180/// - `s![3]` — single integer index
181/// - `s![..]` — all elements along this axis
182/// - `s![2..]` — from index 2 to end
183/// - `s![..5]` — from start to index 5
184/// - `s![1..5]` — from index 1 to 5
185/// - `s![1..5;2]` — from index 1 to 5, step 2
186///
187/// Each component in the comma-separated list becomes one `SliceInfoElem`.
188#[proc_macro]
189pub fn s(input: TokenStream) -> TokenStream {
190    let input2: proc_macro2::TokenStream = input.into();
191    let expanded = impl_s_macro(input2);
192    match expanded {
193        Ok(ts) => ts.into(),
194        Err(e) => e.to_compile_error().into(),
195    }
196}
197
198fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
199    // We parse the input as a sequence of comma-separated slice expressions.
200    // Each expression can be:
201    //   - An integer literal or expression: `2` -> Index(2)
202    //   - A full range `..` -> Slice { start: 0, end: None, step: 1 }
203    //   - A range `a..b` -> Slice { start: a, end: Some(b), step: 1 }
204    //   - A range from `a..` -> Slice { start: a, end: None, step: 1 }
205    //   - A range to `..b` -> Slice { start: 0, end: Some(b), step: 1 }
206    //   - Any of the above with `;step` suffix
207    //
208    // We'll output code that constructs a Vec<SliceInfoElem>.
209    //
210    // Since proc macros can't easily parse arbitrary Rust expressions with range syntax
211    // mixed with custom `;step` syntax, we'll use a simpler token-based approach.
212    //
213    // FRAGILITY NOTE (#322): the implementation routes through
214    // `TokenStream::to_string()` and string-parses the result. Rust's
215    // pretty-printer formatting between tokens (whitespace placement
216    // around `..`, `;`, parens) is technically allowed to drift across
217    // compiler versions. The integration tests in
218    // ferray-core/tests/macro_tests.rs (s_macro_*) pin a representative
219    // set of inputs — variable bounds, arithmetic expressions, function
220    // calls, typed literals — to catch any drift. A future cleanup
221    // could walk the token tree directly via `into_iter()` and avoid
222    // `to_string()` entirely; that requires writing a small recursive-
223    // descent parser for `expr (.. expr)? (; expr)?` over `TokenTree`.
224
225    let input_str = input.to_string();
226
227    // Handle empty input
228    if input_str.trim().is_empty() {
229        return Ok(quote! {
230            ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
231        });
232    }
233
234    // Split by commas (respecting parentheses/brackets nesting)
235    let components = split_top_level_commas(&input_str);
236    let mut elems = Vec::new();
237
238    for component in &components {
239        let trimmed = component.trim();
240        if trimmed.is_empty() {
241            continue;
242        }
243        elems.push(parse_slice_component(trimmed)?);
244    }
245
246    Ok(quote! {
247        vec![#(#elems),*]
248    })
249}
250
251fn split_top_level_commas(s: &str) -> Vec<String> {
252    let mut result = Vec::new();
253    let mut current = String::new();
254    let mut depth = 0i32;
255
256    for ch in s.chars() {
257        match ch {
258            '(' | '[' | '{' => {
259                depth += 1;
260                current.push(ch);
261            }
262            ')' | ']' | '}' => {
263                depth -= 1;
264                current.push(ch);
265            }
266            ',' if depth == 0 => {
267                result.push(current.clone());
268                current.clear();
269            }
270            _ => {
271                current.push(ch);
272            }
273        }
274    }
275    if !current.is_empty() {
276        result.push(current);
277    }
278    result
279}
280
281/// Find the last top-level semicolon (not inside brackets/braces/parens).
282fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
283    let mut depth = 0i32;
284    let mut last_idx = None;
285    for (i, ch) in s.char_indices() {
286        match ch {
287            '(' | '[' | '{' => depth += 1,
288            ')' | ']' | '}' => depth -= 1,
289            ';' if depth == 0 => last_idx = Some(i),
290            _ => {}
291        }
292    }
293    last_idx
294}
295
296fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
297    let trimmed = s.trim();
298
299    // Check for step suffix: `expr;step` (depth-aware to handle block expressions)
300    let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
301        let (rp, sp) = trimmed.split_at(idx);
302        (rp.trim(), Some(sp[1..].trim()))
303    } else {
304        (trimmed, None)
305    };
306
307    let step_expr = if let Some(step_str) = step_part {
308        let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
309            syn::Error::new(
310                proc_macro2::Span::call_site(),
311                format!("invalid step expression: {step_str}"),
312            )
313        })?;
314        quote! { #step_tokens }
315    } else {
316        quote! { 1isize }
317    };
318
319    // Now parse range_part
320    if range_part == ".." {
321        // Full range: all elements
322        return Ok(quote! {
323            ferray_core::dtype::SliceInfoElem::Slice {
324                start: 0,
325                end: ::core::option::Option::None,
326                step: #step_expr,
327            }
328        });
329    }
330
331    if let Some(rest) = range_part.strip_prefix("..") {
332        // RangeTo: ..end
333        let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
334            syn::Error::new(
335                proc_macro2::Span::call_site(),
336                format!("invalid end expression: {rest}"),
337            )
338        })?;
339        return Ok(quote! {
340            ferray_core::dtype::SliceInfoElem::Slice {
341                start: 0,
342                end: ::core::option::Option::Some(#end_tokens),
343                step: #step_expr,
344            }
345        });
346    }
347
348    if let Some(idx) = range_part.find("..") {
349        let start_str = range_part[..idx].trim();
350        let end_str = range_part[idx + 2..].trim();
351
352        let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
353            syn::Error::new(
354                proc_macro2::Span::call_site(),
355                format!("invalid start expression: {start_str}"),
356            )
357        })?;
358
359        if end_str.is_empty() {
360            // RangeFrom: start..
361            return Ok(quote! {
362                ferray_core::dtype::SliceInfoElem::Slice {
363                    start: #start_tokens,
364                    end: ::core::option::Option::None,
365                    step: #step_expr,
366                }
367            });
368        }
369
370        let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
371            syn::Error::new(
372                proc_macro2::Span::call_site(),
373                format!("invalid end expression: {end_str}"),
374            )
375        })?;
376
377        return Ok(quote! {
378            ferray_core::dtype::SliceInfoElem::Slice {
379                start: #start_tokens,
380                end: ::core::option::Option::Some(#end_tokens),
381                step: #step_expr,
382            }
383        });
384    }
385
386    // No `..` found — this is a single index (integer expression)
387    if step_part.is_some() {
388        return Err(syn::Error::new(
389            proc_macro2::Span::call_site(),
390            format!("step ';' is not valid for integer indices: {trimmed}"),
391        ));
392    }
393
394    let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
395        syn::Error::new(
396            proc_macro2::Span::call_site(),
397            format!("invalid index expression: {range_part}"),
398        )
399    })?;
400
401    Ok(quote! {
402        ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
403    })
404}
405
406// ---------------------------------------------------------------------------
407// promoted_type!() — compile-time type promotion
408// ---------------------------------------------------------------------------
409
410/// Compile-time type promotion macro.
411///
412/// Given two numeric types, resolves to the smallest type that can represent
413/// both without precision loss, following `NumPy`'s promotion rules.
414///
415/// # Examples
416/// ```ignore
417/// type R = promoted_type!(f32, f64); // R = f64
418/// type R = promoted_type!(i32, f32); // R = f64
419/// type R = promoted_type!(u8, i8);   // R = i16
420/// ```
421#[proc_macro]
422pub fn promoted_type(input: TokenStream) -> TokenStream {
423    let input2: proc_macro2::TokenStream = input.into();
424    match impl_promoted_type(input2) {
425        Ok(ts) => ts.into(),
426        Err(e) => e.to_compile_error().into(),
427    }
428}
429
430fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
431    let input_str = input.to_string();
432    let parts: Vec<&str> = input_str.split(',').map(str::trim).collect();
433
434    if parts.len() != 2 {
435        return Err(syn::Error::new(
436            proc_macro2::Span::call_site(),
437            "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
438        ));
439    }
440
441    let t1 = normalize_type(parts[0]);
442    let t2 = normalize_type(parts[1]);
443
444    let result = promote_types_static(&t1, &t2).ok_or_else(|| {
445        syn::Error::new(
446            proc_macro2::Span::call_site(),
447            format!("cannot promote types: {t1} and {t2}"),
448        )
449    })?;
450
451    let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
452        syn::Error::new(
453            proc_macro2::Span::call_site(),
454            format!("internal error: could not parse result type: {result}"),
455        )
456    })?;
457
458    Ok(result_tokens)
459}
460
461fn normalize_type(s: &str) -> String {
462    // Normalize Complex<f32> / Complex<f64> / num_complex::Complex<f32> etc.
463    s.trim().replace(' ', "")
464}
465
466/// Static type promotion following `NumPy` rules.
467///
468/// Returns the promoted type as a string, or None if unknown.
469fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
470    // Assign a numeric "kind + rank" to each type, then pick the larger.
471    //
472    // NumPy promotion hierarchy (simplified):
473    //   bool < u8 < u16 < u32 < u64 < u128
474    //   bool < i8 < i16 < i32 < i64 < i128
475    //   f32 < f64
476    //   Complex<f32> < Complex<f64>
477    //
478    // Cross-kind rules:
479    //   unsigned + signed -> next-size signed (e.g. u8 + i8 -> i16)
480    //   any int + float -> float (ensure enough precision)
481    //   any real + complex -> complex with appropriate float size
482
483    let ra = type_rank(a)?;
484    let rb = type_rank(b)?;
485
486    // `promote_ranks` uses an empty string to signal "no lossless
487    // common type exists" (currently the u128 + signed-int case); map
488    // that to `None` so the outer proc-macro emits a compile error.
489    match promote_ranks(ra, rb) {
490        "" => None,
491        other => Some(other),
492    }
493}
494
495#[derive(Clone, Copy, PartialEq, Eq)]
496enum TypeKind {
497    Bool,
498    Unsigned,
499    Signed,
500    Float,
501    Complex,
502}
503
504#[derive(Clone, Copy)]
505struct TypeRank {
506    kind: TypeKind,
507    /// Bit width within the kind (e.g., 8 for u8, 32 for f32, etc.)
508    bits: u32,
509}
510
511fn type_rank(s: &str) -> Option<TypeRank> {
512    let result = match s {
513        "bool" => TypeRank {
514            kind: TypeKind::Bool,
515            bits: 1,
516        },
517        "u8" => TypeRank {
518            kind: TypeKind::Unsigned,
519            bits: 8,
520        },
521        "u16" => TypeRank {
522            kind: TypeKind::Unsigned,
523            bits: 16,
524        },
525        "u32" => TypeRank {
526            kind: TypeKind::Unsigned,
527            bits: 32,
528        },
529        "u64" => TypeRank {
530            kind: TypeKind::Unsigned,
531            bits: 64,
532        },
533        "u128" => TypeRank {
534            kind: TypeKind::Unsigned,
535            bits: 128,
536        },
537        "i8" => TypeRank {
538            kind: TypeKind::Signed,
539            bits: 8,
540        },
541        "i16" => TypeRank {
542            kind: TypeKind::Signed,
543            bits: 16,
544        },
545        "i32" => TypeRank {
546            kind: TypeKind::Signed,
547            bits: 32,
548        },
549        "i64" => TypeRank {
550            kind: TypeKind::Signed,
551            bits: 64,
552        },
553        "i128" => TypeRank {
554            kind: TypeKind::Signed,
555            bits: 128,
556        },
557        "f32" => TypeRank {
558            kind: TypeKind::Float,
559            bits: 32,
560        },
561        "f64" => TypeRank {
562            kind: TypeKind::Float,
563            bits: 64,
564        },
565        "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
566            kind: TypeKind::Complex,
567            bits: 32,
568        },
569        "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
570            kind: TypeKind::Complex,
571            bits: 64,
572        },
573        "f16" | "half::f16" => TypeRank {
574            kind: TypeKind::Float,
575            bits: 16,
576        },
577        "bf16" | "half::bf16" => TypeRank {
578            kind: TypeKind::Float,
579            bits: 16,
580        },
581        _ => return None,
582    };
583    Some(result)
584}
585
586fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
587    use TypeKind::{Bool, Complex, Float, Signed, Unsigned};
588
589    // Same type
590    if a.kind == b.kind && a.bits == b.bits {
591        return rank_to_type(a);
592    }
593
594    // Handle Bool: bool promotes to anything
595    if a.kind == Bool {
596        return rank_to_type(b);
597    }
598    if b.kind == Bool {
599        return rank_to_type(a);
600    }
601
602    // Complex + anything -> Complex with max float precision
603    if a.kind == Complex || b.kind == Complex {
604        let float_bits_a = to_float_bits(a);
605        let float_bits_b = to_float_bits(b);
606        let bits = float_bits_a.max(float_bits_b);
607        return if bits <= 32 {
608            "num_complex::Complex<f32>"
609        } else {
610            "num_complex::Complex<f64>"
611        };
612    }
613
614    // Float + anything -> Float with enough precision
615    if a.kind == Float || b.kind == Float {
616        let float_bits_a = to_float_bits(a);
617        let float_bits_b = to_float_bits(b);
618        let bits = float_bits_a.max(float_bits_b);
619        return if bits <= 32 { "f32" } else { "f64" };
620    }
621
622    // Now both are integer types (Unsigned or Signed)
623    match (a.kind, b.kind) {
624        (Unsigned, Unsigned) => {
625            let bits = a.bits.max(b.bits);
626            uint_type(bits)
627        }
628        (Signed, Signed) => {
629            let bits = a.bits.max(b.bits);
630            int_type(bits)
631        }
632        (Unsigned, Signed) | (Signed, Unsigned) => {
633            let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
634            // unsigned + signed: need a signed type that holds both ranges
635            // u8 + i8 -> i16, u16 + i16 -> i32, etc.
636            if u.bits < s.bits {
637                // Signed type is strictly larger, it can hold the unsigned range
638                int_type(s.bits)
639            } else {
640                // Need the next larger signed type
641                let needed = u.bits.max(s.bits) * 2;
642                if needed <= 128 {
643                    int_type(needed)
644                } else {
645                    // u128 + any signed int: no lossless common type
646                    // on stable Rust (see ferray-core/src/dtype/promotion.rs
647                    // and issue #375). Return the empty string as a
648                    // sentinel; `promote_types_static` maps that to
649                    // `None`, which the outer proc-macro converts
650                    // into a compile error with a clear message.
651                    ""
652                }
653            }
654        }
655        _ => "f64", // fallback
656    }
657}
658
659/// Convert any type rank to the float bit width it requires.
660const fn to_float_bits(r: TypeRank) -> u32 {
661    match r.kind {
662        TypeKind::Bool => 32,
663        TypeKind::Unsigned | TypeKind::Signed => {
664            // Integers up to 24-bit mantissa fit in f32 (i.e., i8, i16, u8, u16).
665            // Larger integers need f64 (53-bit mantissa).
666            if r.bits <= 16 { 32 } else { 64 }
667        }
668        TypeKind::Float => r.bits,
669        TypeKind::Complex => r.bits,
670    }
671}
672
673const fn uint_type(bits: u32) -> &'static str {
674    match bits {
675        8 => "u8",
676        16 => "u16",
677        32 => "u32",
678        64 => "u64",
679        128 => "u128",
680        _ => "u64",
681    }
682}
683
684const fn int_type(bits: u32) -> &'static str {
685    match bits {
686        8 => "i8",
687        16 => "i16",
688        32 => "i32",
689        64 => "i64",
690        128 => "i128",
691        _ => "i64",
692    }
693}
694
695const fn rank_to_type(r: TypeRank) -> &'static str {
696    match r.kind {
697        TypeKind::Bool => "bool",
698        TypeKind::Unsigned => uint_type(r.bits),
699        TypeKind::Signed => int_type(r.bits),
700        TypeKind::Float => {
701            if r.bits <= 16 {
702                "half::f16"
703            } else if r.bits <= 32 {
704                "f32"
705            } else {
706                "f64"
707            }
708        }
709        TypeKind::Complex => {
710            if r.bits <= 32 {
711                "num_complex::Complex<f32>"
712            } else {
713                "num_complex::Complex<f64>"
714            }
715        }
716    }
717}