Skip to main content

ferray_core_macros/
lib.rs

1// ferray-core-macros: Procedural macros for ferray-core
2//
3// Implements:
4// - #[derive(FerrayRecord)] — generates FerrayRecord trait impl for #[repr(C)] structs
5// - s![] — NumPy-style slice indexing macro
6// - promoted_type!() — compile-time type promotion macro
7
8// Macro implementations branch on parsed AST shapes; identical match arms
9// reflect "different syntactic forms produce the same expansion" and are
10// kept distinct on purpose for future divergence. `needless_pass_by_value`
11// is a syn-API ergonomics tradeoff (the parser owns each TokenStream once).
12#![allow(
13    clippy::match_same_arms,
14    clippy::needless_pass_by_value,
15    clippy::option_if_let_else
16)]
17
18extern crate proc_macro;
19
20use proc_macro::TokenStream;
21use quote::quote;
22use syn::{Data, DeriveInput, Fields, parse_macro_input};
23
24// ---------------------------------------------------------------------------
25// #[derive(FerrayRecord)]
26// ---------------------------------------------------------------------------
27
28/// Derive macro that generates an `unsafe impl FerrayRecord` for a `#[repr(C)]` struct.
29///
30/// # Requirements
31/// - The struct must have `#[repr(C)]`.
32/// - All fields must implement `ferray_core::dtype::Element`.
33///
34/// # Generated code
35/// - `field_descriptors()` returns a static slice of `FieldDescriptor` with correct
36///   name, dtype, offset, and size for each field.
37/// - `record_size()` returns `std::mem::size_of::<Self>()`.
38#[proc_macro_derive(FerrayRecord)]
39pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(input as DeriveInput);
41    match impl_ferray_record(&input) {
42        Ok(ts) => ts.into(),
43        Err(e) => e.to_compile_error().into(),
44    }
45}
46
47fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
48    let name = &input.ident;
49
50    // Check for #[repr(C)]
51    let has_repr_c = input.attrs.iter().any(|attr| {
52        if !attr.path().is_ident("repr") {
53            return false;
54        }
55        let mut found = false;
56        let _ = attr.parse_nested_meta(|meta| {
57            if meta.path.is_ident("C") {
58                found = true;
59            }
60            Ok(())
61        });
62        found
63    });
64
65    if !has_repr_c {
66        return Err(syn::Error::new_spanned(
67            &input.ident,
68            "FerrayRecord requires #[repr(C)] on the struct",
69        ));
70    }
71
72    // Only works on structs with named fields
73    let fields = match &input.data {
74        Data::Struct(data_struct) => match &data_struct.fields {
75            Fields::Named(named) => &named.named,
76            _ => {
77                return Err(syn::Error::new_spanned(
78                    &input.ident,
79                    "FerrayRecord only supports structs with named fields",
80                ));
81            }
82        },
83        _ => {
84            return Err(syn::Error::new_spanned(
85                &input.ident,
86                "FerrayRecord can only be derived for structs",
87            ));
88        }
89    };
90
91    let field_count = fields.len();
92    let mut field_descriptors = Vec::with_capacity(field_count);
93
94    for field in fields {
95        let field_name = field.ident.as_ref().unwrap();
96        let field_name_str = field_name.to_string();
97        let field_ty = &field.ty;
98
99        field_descriptors.push(quote! {
100            ferray_core::record::FieldDescriptor {
101                name: #field_name_str,
102                dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
103                offset: std::mem::offset_of!(#name, #field_name),
104                size: std::mem::size_of::<#field_ty>(),
105            }
106        });
107    }
108
109    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
110
111    let expanded = quote! {
112        unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
113            fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
114                // Issue #323 asked whether this could be a `const` static
115                // array. `offset_of!` and `size_of::<T>()` are both const,
116                // but `<T as Element>::dtype()` is a trait method call and
117                // trait methods are not `const` on stable Rust yet — so
118                // the LazyLock is unavoidable without breaking
119                // `Element::dtype()` for every user. Re-evaluate once
120                // `const_trait_impl` stabilises.
121                static FIELDS: std::sync::LazyLock<Vec<ferray_core::record::FieldDescriptor>> =
122                    std::sync::LazyLock::new(|| {
123                        vec![
124                            #(#field_descriptors),*
125                        ]
126                    });
127                &FIELDS
128            }
129
130            fn record_size() -> usize {
131                std::mem::size_of::<#name>()
132            }
133        }
134    };
135
136    Ok(expanded)
137}
138
139// ---------------------------------------------------------------------------
140// s![] macro — NumPy-style slice indexing
141// ---------------------------------------------------------------------------
142
143/// NumPy-style slice indexing macro.
144///
145/// Produces a `Vec<ferray_core::dtype::SliceInfoElem>` that can be passed
146/// to array slicing methods.
147///
148/// # Syntax
149/// - `s![0..3, 2]` — rows 0..3, column 2
150/// - `s![.., 0..;2]` — all rows, every-other column starting from 0
151/// - `s![1..5;2, ..]` — rows 1..5 step 2, all columns
152/// - `s![3]` — single integer index
153/// - `s![..]` — all elements along this axis
154/// - `s![2..]` — from index 2 to end
155/// - `s![..5]` — from start to index 5
156/// - `s![1..5]` — from index 1 to 5
157/// - `s![1..5;2]` — from index 1 to 5, step 2
158///
159/// Each component in the comma-separated list becomes one `SliceInfoElem`.
160#[proc_macro]
161pub fn s(input: TokenStream) -> TokenStream {
162    let input2: proc_macro2::TokenStream = input.into();
163    let expanded = impl_s_macro(input2);
164    match expanded {
165        Ok(ts) => ts.into(),
166        Err(e) => e.to_compile_error().into(),
167    }
168}
169
170fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
171    // We parse the input as a sequence of comma-separated slice expressions.
172    // Each expression can be:
173    //   - An integer literal or expression: `2` -> Index(2)
174    //   - A full range `..` -> Slice { start: 0, end: None, step: 1 }
175    //   - A range `a..b` -> Slice { start: a, end: Some(b), step: 1 }
176    //   - A range from `a..` -> Slice { start: a, end: None, step: 1 }
177    //   - A range to `..b` -> Slice { start: 0, end: Some(b), step: 1 }
178    //   - Any of the above with `;step` suffix
179    //
180    // We'll output code that constructs a Vec<SliceInfoElem>.
181    //
182    // Since proc macros can't easily parse arbitrary Rust expressions with range syntax
183    // mixed with custom `;step` syntax, we'll use a simpler token-based approach.
184
185    let input_str = input.to_string();
186
187    // Handle empty input
188    if input_str.trim().is_empty() {
189        return Ok(quote! {
190            ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
191        });
192    }
193
194    // Split by commas (respecting parentheses/brackets nesting)
195    let components = split_top_level_commas(&input_str);
196    let mut elems = Vec::new();
197
198    for component in &components {
199        let trimmed = component.trim();
200        if trimmed.is_empty() {
201            continue;
202        }
203        elems.push(parse_slice_component(trimmed)?);
204    }
205
206    Ok(quote! {
207        vec![#(#elems),*]
208    })
209}
210
211fn split_top_level_commas(s: &str) -> Vec<String> {
212    let mut result = Vec::new();
213    let mut current = String::new();
214    let mut depth = 0i32;
215
216    for ch in s.chars() {
217        match ch {
218            '(' | '[' | '{' => {
219                depth += 1;
220                current.push(ch);
221            }
222            ')' | ']' | '}' => {
223                depth -= 1;
224                current.push(ch);
225            }
226            ',' if depth == 0 => {
227                result.push(current.clone());
228                current.clear();
229            }
230            _ => {
231                current.push(ch);
232            }
233        }
234    }
235    if !current.is_empty() {
236        result.push(current);
237    }
238    result
239}
240
241/// Find the last top-level semicolon (not inside brackets/braces/parens).
242fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
243    let mut depth = 0i32;
244    let mut last_idx = None;
245    for (i, ch) in s.char_indices() {
246        match ch {
247            '(' | '[' | '{' => depth += 1,
248            ')' | ']' | '}' => depth -= 1,
249            ';' if depth == 0 => last_idx = Some(i),
250            _ => {}
251        }
252    }
253    last_idx
254}
255
256fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
257    let trimmed = s.trim();
258
259    // Check for step suffix: `expr;step` (depth-aware to handle block expressions)
260    let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
261        let (rp, sp) = trimmed.split_at(idx);
262        (rp.trim(), Some(sp[1..].trim()))
263    } else {
264        (trimmed, None)
265    };
266
267    let step_expr = if let Some(step_str) = step_part {
268        let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
269            syn::Error::new(
270                proc_macro2::Span::call_site(),
271                format!("invalid step expression: {step_str}"),
272            )
273        })?;
274        quote! { #step_tokens }
275    } else {
276        quote! { 1isize }
277    };
278
279    // Now parse range_part
280    if range_part == ".." {
281        // Full range: all elements
282        return Ok(quote! {
283            ferray_core::dtype::SliceInfoElem::Slice {
284                start: 0,
285                end: ::core::option::Option::None,
286                step: #step_expr,
287            }
288        });
289    }
290
291    if let Some(rest) = range_part.strip_prefix("..") {
292        // RangeTo: ..end
293        let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
294            syn::Error::new(
295                proc_macro2::Span::call_site(),
296                format!("invalid end expression: {rest}"),
297            )
298        })?;
299        return Ok(quote! {
300            ferray_core::dtype::SliceInfoElem::Slice {
301                start: 0,
302                end: ::core::option::Option::Some(#end_tokens),
303                step: #step_expr,
304            }
305        });
306    }
307
308    if let Some(idx) = range_part.find("..") {
309        let start_str = range_part[..idx].trim();
310        let end_str = range_part[idx + 2..].trim();
311
312        let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
313            syn::Error::new(
314                proc_macro2::Span::call_site(),
315                format!("invalid start expression: {start_str}"),
316            )
317        })?;
318
319        if end_str.is_empty() {
320            // RangeFrom: start..
321            return Ok(quote! {
322                ferray_core::dtype::SliceInfoElem::Slice {
323                    start: #start_tokens,
324                    end: ::core::option::Option::None,
325                    step: #step_expr,
326                }
327            });
328        }
329
330        let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
331            syn::Error::new(
332                proc_macro2::Span::call_site(),
333                format!("invalid end expression: {end_str}"),
334            )
335        })?;
336
337        return Ok(quote! {
338            ferray_core::dtype::SliceInfoElem::Slice {
339                start: #start_tokens,
340                end: ::core::option::Option::Some(#end_tokens),
341                step: #step_expr,
342            }
343        });
344    }
345
346    // No `..` found — this is a single index (integer expression)
347    if step_part.is_some() {
348        return Err(syn::Error::new(
349            proc_macro2::Span::call_site(),
350            format!("step ';' is not valid for integer indices: {trimmed}"),
351        ));
352    }
353
354    let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
355        syn::Error::new(
356            proc_macro2::Span::call_site(),
357            format!("invalid index expression: {range_part}"),
358        )
359    })?;
360
361    Ok(quote! {
362        ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
363    })
364}
365
366// ---------------------------------------------------------------------------
367// promoted_type!() — compile-time type promotion
368// ---------------------------------------------------------------------------
369
370/// Compile-time type promotion macro.
371///
372/// Given two numeric types, resolves to the smallest type that can represent
373/// both without precision loss, following `NumPy`'s promotion rules.
374///
375/// # Examples
376/// ```ignore
377/// type R = promoted_type!(f32, f64); // R = f64
378/// type R = promoted_type!(i32, f32); // R = f64
379/// type R = promoted_type!(u8, i8);   // R = i16
380/// ```
381#[proc_macro]
382pub fn promoted_type(input: TokenStream) -> TokenStream {
383    let input2: proc_macro2::TokenStream = input.into();
384    match impl_promoted_type(input2) {
385        Ok(ts) => ts.into(),
386        Err(e) => e.to_compile_error().into(),
387    }
388}
389
390fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
391    let input_str = input.to_string();
392    let parts: Vec<&str> = input_str.split(',').map(str::trim).collect();
393
394    if parts.len() != 2 {
395        return Err(syn::Error::new(
396            proc_macro2::Span::call_site(),
397            "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
398        ));
399    }
400
401    let t1 = normalize_type(parts[0]);
402    let t2 = normalize_type(parts[1]);
403
404    let result = promote_types_static(&t1, &t2).ok_or_else(|| {
405        syn::Error::new(
406            proc_macro2::Span::call_site(),
407            format!("cannot promote types: {t1} and {t2}"),
408        )
409    })?;
410
411    let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
412        syn::Error::new(
413            proc_macro2::Span::call_site(),
414            format!("internal error: could not parse result type: {result}"),
415        )
416    })?;
417
418    Ok(result_tokens)
419}
420
421fn normalize_type(s: &str) -> String {
422    // Normalize Complex<f32> / Complex<f64> / num_complex::Complex<f32> etc.
423    s.trim().replace(' ', "")
424}
425
426/// Static type promotion following `NumPy` rules.
427///
428/// Returns the promoted type as a string, or None if unknown.
429fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
430    // Assign a numeric "kind + rank" to each type, then pick the larger.
431    //
432    // NumPy promotion hierarchy (simplified):
433    //   bool < u8 < u16 < u32 < u64 < u128
434    //   bool < i8 < i16 < i32 < i64 < i128
435    //   f32 < f64
436    //   Complex<f32> < Complex<f64>
437    //
438    // Cross-kind rules:
439    //   unsigned + signed -> next-size signed (e.g. u8 + i8 -> i16)
440    //   any int + float -> float (ensure enough precision)
441    //   any real + complex -> complex with appropriate float size
442
443    let ra = type_rank(a)?;
444    let rb = type_rank(b)?;
445
446    // `promote_ranks` uses an empty string to signal "no lossless
447    // common type exists" (currently the u128 + signed-int case); map
448    // that to `None` so the outer proc-macro emits a compile error.
449    match promote_ranks(ra, rb) {
450        "" => None,
451        other => Some(other),
452    }
453}
454
455#[derive(Clone, Copy, PartialEq, Eq)]
456enum TypeKind {
457    Bool,
458    Unsigned,
459    Signed,
460    Float,
461    Complex,
462}
463
464#[derive(Clone, Copy)]
465struct TypeRank {
466    kind: TypeKind,
467    /// Bit width within the kind (e.g., 8 for u8, 32 for f32, etc.)
468    bits: u32,
469}
470
471fn type_rank(s: &str) -> Option<TypeRank> {
472    let result = match s {
473        "bool" => TypeRank {
474            kind: TypeKind::Bool,
475            bits: 1,
476        },
477        "u8" => TypeRank {
478            kind: TypeKind::Unsigned,
479            bits: 8,
480        },
481        "u16" => TypeRank {
482            kind: TypeKind::Unsigned,
483            bits: 16,
484        },
485        "u32" => TypeRank {
486            kind: TypeKind::Unsigned,
487            bits: 32,
488        },
489        "u64" => TypeRank {
490            kind: TypeKind::Unsigned,
491            bits: 64,
492        },
493        "u128" => TypeRank {
494            kind: TypeKind::Unsigned,
495            bits: 128,
496        },
497        "i8" => TypeRank {
498            kind: TypeKind::Signed,
499            bits: 8,
500        },
501        "i16" => TypeRank {
502            kind: TypeKind::Signed,
503            bits: 16,
504        },
505        "i32" => TypeRank {
506            kind: TypeKind::Signed,
507            bits: 32,
508        },
509        "i64" => TypeRank {
510            kind: TypeKind::Signed,
511            bits: 64,
512        },
513        "i128" => TypeRank {
514            kind: TypeKind::Signed,
515            bits: 128,
516        },
517        "f32" => TypeRank {
518            kind: TypeKind::Float,
519            bits: 32,
520        },
521        "f64" => TypeRank {
522            kind: TypeKind::Float,
523            bits: 64,
524        },
525        "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
526            kind: TypeKind::Complex,
527            bits: 32,
528        },
529        "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
530            kind: TypeKind::Complex,
531            bits: 64,
532        },
533        "f16" | "half::f16" => TypeRank {
534            kind: TypeKind::Float,
535            bits: 16,
536        },
537        "bf16" | "half::bf16" => TypeRank {
538            kind: TypeKind::Float,
539            bits: 16,
540        },
541        _ => return None,
542    };
543    Some(result)
544}
545
546fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
547    use TypeKind::{Bool, Complex, Float, Signed, Unsigned};
548
549    // Same type
550    if a.kind == b.kind && a.bits == b.bits {
551        return rank_to_type(a);
552    }
553
554    // Handle Bool: bool promotes to anything
555    if a.kind == Bool {
556        return rank_to_type(b);
557    }
558    if b.kind == Bool {
559        return rank_to_type(a);
560    }
561
562    // Complex + anything -> Complex with max float precision
563    if a.kind == Complex || b.kind == Complex {
564        let float_bits_a = to_float_bits(a);
565        let float_bits_b = to_float_bits(b);
566        let bits = float_bits_a.max(float_bits_b);
567        return if bits <= 32 {
568            "num_complex::Complex<f32>"
569        } else {
570            "num_complex::Complex<f64>"
571        };
572    }
573
574    // Float + anything -> Float with enough precision
575    if a.kind == Float || b.kind == Float {
576        let float_bits_a = to_float_bits(a);
577        let float_bits_b = to_float_bits(b);
578        let bits = float_bits_a.max(float_bits_b);
579        return if bits <= 32 { "f32" } else { "f64" };
580    }
581
582    // Now both are integer types (Unsigned or Signed)
583    match (a.kind, b.kind) {
584        (Unsigned, Unsigned) => {
585            let bits = a.bits.max(b.bits);
586            uint_type(bits)
587        }
588        (Signed, Signed) => {
589            let bits = a.bits.max(b.bits);
590            int_type(bits)
591        }
592        (Unsigned, Signed) | (Signed, Unsigned) => {
593            let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
594            // unsigned + signed: need a signed type that holds both ranges
595            // u8 + i8 -> i16, u16 + i16 -> i32, etc.
596            if u.bits < s.bits {
597                // Signed type is strictly larger, it can hold the unsigned range
598                int_type(s.bits)
599            } else {
600                // Need the next larger signed type
601                let needed = u.bits.max(s.bits) * 2;
602                if needed <= 128 {
603                    int_type(needed)
604                } else {
605                    // u128 + any signed int: no lossless common type
606                    // on stable Rust (see ferray-core/src/dtype/promotion.rs
607                    // and issue #375). Return the empty string as a
608                    // sentinel; `promote_types_static` maps that to
609                    // `None`, which the outer proc-macro converts
610                    // into a compile error with a clear message.
611                    ""
612                }
613            }
614        }
615        _ => "f64", // fallback
616    }
617}
618
619/// Convert any type rank to the float bit width it requires.
620const fn to_float_bits(r: TypeRank) -> u32 {
621    match r.kind {
622        TypeKind::Bool => 32,
623        TypeKind::Unsigned | TypeKind::Signed => {
624            // Integers up to 24-bit mantissa fit in f32 (i.e., i8, i16, u8, u16).
625            // Larger integers need f64 (53-bit mantissa).
626            if r.bits <= 16 { 32 } else { 64 }
627        }
628        TypeKind::Float => r.bits,
629        TypeKind::Complex => r.bits,
630    }
631}
632
633const fn uint_type(bits: u32) -> &'static str {
634    match bits {
635        8 => "u8",
636        16 => "u16",
637        32 => "u32",
638        64 => "u64",
639        128 => "u128",
640        _ => "u64",
641    }
642}
643
644const fn int_type(bits: u32) -> &'static str {
645    match bits {
646        8 => "i8",
647        16 => "i16",
648        32 => "i32",
649        64 => "i64",
650        128 => "i128",
651        _ => "i64",
652    }
653}
654
655const fn rank_to_type(r: TypeRank) -> &'static str {
656    match r.kind {
657        TypeKind::Bool => "bool",
658        TypeKind::Unsigned => uint_type(r.bits),
659        TypeKind::Signed => int_type(r.bits),
660        TypeKind::Float => {
661            if r.bits <= 16 {
662                "half::f16"
663            } else if r.bits <= 32 {
664                "f32"
665            } else {
666                "f64"
667            }
668        }
669        TypeKind::Complex => {
670            if r.bits <= 32 {
671                "num_complex::Complex<f32>"
672            } else {
673                "num_complex::Complex<f64>"
674            }
675        }
676    }
677}