Skip to main content

ferray_core_macros/
lib.rs

1// ferray-core-macros: Procedural macros for ferray-core
2//
3// Implements:
4// - #[derive(FerrayRecord)] — generates FerrayRecord trait impl for #[repr(C)] structs
5// - s![] — NumPy-style slice indexing macro
6// - promoted_type!() — compile-time type promotion macro
7
8extern crate proc_macro;
9
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::{Data, DeriveInput, Fields, parse_macro_input};
13
14// ---------------------------------------------------------------------------
15// #[derive(FerrayRecord)]
16// ---------------------------------------------------------------------------
17
18/// Derive macro that generates an `unsafe impl FerrayRecord` for a `#[repr(C)]` struct.
19///
20/// # Requirements
21/// - The struct must have `#[repr(C)]`.
22/// - All fields must implement `ferray_core::dtype::Element`.
23///
24/// # Generated code
25/// - `field_descriptors()` returns a static slice of `FieldDescriptor` with correct
26///   name, dtype, offset, and size for each field.
27/// - `record_size()` returns `std::mem::size_of::<Self>()`.
28#[proc_macro_derive(FerrayRecord)]
29pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
30    let input = parse_macro_input!(input as DeriveInput);
31    match impl_ferray_record(&input) {
32        Ok(ts) => ts.into(),
33        Err(e) => e.to_compile_error().into(),
34    }
35}
36
37fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
38    let name = &input.ident;
39
40    // Check for #[repr(C)]
41    let has_repr_c = input.attrs.iter().any(|attr| {
42        if !attr.path().is_ident("repr") {
43            return false;
44        }
45        let mut found = false;
46        let _ = attr.parse_nested_meta(|meta| {
47            if meta.path.is_ident("C") {
48                found = true;
49            }
50            Ok(())
51        });
52        found
53    });
54
55    if !has_repr_c {
56        return Err(syn::Error::new_spanned(
57            &input.ident,
58            "FerrayRecord requires #[repr(C)] on the struct",
59        ));
60    }
61
62    // Only works on structs with named fields
63    let fields = match &input.data {
64        Data::Struct(data_struct) => match &data_struct.fields {
65            Fields::Named(named) => &named.named,
66            _ => {
67                return Err(syn::Error::new_spanned(
68                    &input.ident,
69                    "FerrayRecord only supports structs with named fields",
70                ));
71            }
72        },
73        _ => {
74            return Err(syn::Error::new_spanned(
75                &input.ident,
76                "FerrayRecord can only be derived for structs",
77            ));
78        }
79    };
80
81    let field_count = fields.len();
82    let mut field_descriptors = Vec::with_capacity(field_count);
83
84    for field in fields.iter() {
85        let field_name = field.ident.as_ref().unwrap();
86        let field_name_str = field_name.to_string();
87        let field_ty = &field.ty;
88
89        field_descriptors.push(quote! {
90            ferray_core::record::FieldDescriptor {
91                name: #field_name_str,
92                dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
93                offset: std::mem::offset_of!(#name, #field_name),
94                size: std::mem::size_of::<#field_ty>(),
95            }
96        });
97    }
98
99    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
100
101    let expanded = quote! {
102        unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
103            fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
104                // Issue #323 asked whether this could be a `const` static
105                // array. `offset_of!` and `size_of::<T>()` are both const,
106                // but `<T as Element>::dtype()` is a trait method call and
107                // trait methods are not `const` on stable Rust yet — so
108                // the LazyLock is unavoidable without breaking
109                // `Element::dtype()` for every user. Re-evaluate once
110                // `const_trait_impl` stabilises.
111                static FIELDS: std::sync::LazyLock<Vec<ferray_core::record::FieldDescriptor>> =
112                    std::sync::LazyLock::new(|| {
113                        vec![
114                            #(#field_descriptors),*
115                        ]
116                    });
117                &FIELDS
118            }
119
120            fn record_size() -> usize {
121                std::mem::size_of::<#name>()
122            }
123        }
124    };
125
126    Ok(expanded)
127}
128
129// ---------------------------------------------------------------------------
130// s![] macro — NumPy-style slice indexing
131// ---------------------------------------------------------------------------
132
133/// NumPy-style slice indexing macro.
134///
135/// Produces a `Vec<ferray_core::dtype::SliceInfoElem>` that can be passed
136/// to array slicing methods.
137///
138/// # Syntax
139/// - `s![0..3, 2]` — rows 0..3, column 2
140/// - `s![.., 0..;2]` — all rows, every-other column starting from 0
141/// - `s![1..5;2, ..]` — rows 1..5 step 2, all columns
142/// - `s![3]` — single integer index
143/// - `s![..]` — all elements along this axis
144/// - `s![2..]` — from index 2 to end
145/// - `s![..5]` — from start to index 5
146/// - `s![1..5]` — from index 1 to 5
147/// - `s![1..5;2]` — from index 1 to 5, step 2
148///
149/// Each component in the comma-separated list becomes one `SliceInfoElem`.
150#[proc_macro]
151pub fn s(input: TokenStream) -> TokenStream {
152    let input2: proc_macro2::TokenStream = input.into();
153    let expanded = impl_s_macro(input2);
154    match expanded {
155        Ok(ts) => ts.into(),
156        Err(e) => e.to_compile_error().into(),
157    }
158}
159
160fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
161    // We parse the input as a sequence of comma-separated slice expressions.
162    // Each expression can be:
163    //   - An integer literal or expression: `2` -> Index(2)
164    //   - A full range `..` -> Slice { start: 0, end: None, step: 1 }
165    //   - A range `a..b` -> Slice { start: a, end: Some(b), step: 1 }
166    //   - A range from `a..` -> Slice { start: a, end: None, step: 1 }
167    //   - A range to `..b` -> Slice { start: 0, end: Some(b), step: 1 }
168    //   - Any of the above with `;step` suffix
169    //
170    // We'll output code that constructs a Vec<SliceInfoElem>.
171    //
172    // Since proc macros can't easily parse arbitrary Rust expressions with range syntax
173    // mixed with custom `;step` syntax, we'll use a simpler token-based approach.
174
175    let input_str = input.to_string();
176
177    // Handle empty input
178    if input_str.trim().is_empty() {
179        return Ok(quote! {
180            ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
181        });
182    }
183
184    // Split by commas (respecting parentheses/brackets nesting)
185    let components = split_top_level_commas(&input_str);
186    let mut elems = Vec::new();
187
188    for component in &components {
189        let trimmed = component.trim();
190        if trimmed.is_empty() {
191            continue;
192        }
193        elems.push(parse_slice_component(trimmed)?);
194    }
195
196    Ok(quote! {
197        vec![#(#elems),*]
198    })
199}
200
201fn split_top_level_commas(s: &str) -> Vec<String> {
202    let mut result = Vec::new();
203    let mut current = String::new();
204    let mut depth = 0i32;
205
206    for ch in s.chars() {
207        match ch {
208            '(' | '[' | '{' => {
209                depth += 1;
210                current.push(ch);
211            }
212            ')' | ']' | '}' => {
213                depth -= 1;
214                current.push(ch);
215            }
216            ',' if depth == 0 => {
217                result.push(current.clone());
218                current.clear();
219            }
220            _ => {
221                current.push(ch);
222            }
223        }
224    }
225    if !current.is_empty() {
226        result.push(current);
227    }
228    result
229}
230
231/// Find the last top-level semicolon (not inside brackets/braces/parens).
232fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
233    let mut depth = 0i32;
234    let mut last_idx = None;
235    for (i, ch) in s.char_indices() {
236        match ch {
237            '(' | '[' | '{' => depth += 1,
238            ')' | ']' | '}' => depth -= 1,
239            ';' if depth == 0 => last_idx = Some(i),
240            _ => {}
241        }
242    }
243    last_idx
244}
245
246fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
247    let trimmed = s.trim();
248
249    // Check for step suffix: `expr;step` (depth-aware to handle block expressions)
250    let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
251        let (rp, sp) = trimmed.split_at(idx);
252        (rp.trim(), Some(sp[1..].trim()))
253    } else {
254        (trimmed, None)
255    };
256
257    let step_expr = if let Some(step_str) = step_part {
258        let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
259            syn::Error::new(
260                proc_macro2::Span::call_site(),
261                format!("invalid step expression: {step_str}"),
262            )
263        })?;
264        quote! { #step_tokens }
265    } else {
266        quote! { 1isize }
267    };
268
269    // Now parse range_part
270    if range_part == ".." {
271        // Full range: all elements
272        return Ok(quote! {
273            ferray_core::dtype::SliceInfoElem::Slice {
274                start: 0,
275                end: ::core::option::Option::None,
276                step: #step_expr,
277            }
278        });
279    }
280
281    if let Some(rest) = range_part.strip_prefix("..") {
282        // RangeTo: ..end
283        let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
284            syn::Error::new(
285                proc_macro2::Span::call_site(),
286                format!("invalid end expression: {rest}"),
287            )
288        })?;
289        return Ok(quote! {
290            ferray_core::dtype::SliceInfoElem::Slice {
291                start: 0,
292                end: ::core::option::Option::Some(#end_tokens),
293                step: #step_expr,
294            }
295        });
296    }
297
298    if let Some(idx) = range_part.find("..") {
299        let start_str = range_part[..idx].trim();
300        let end_str = range_part[idx + 2..].trim();
301
302        let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
303            syn::Error::new(
304                proc_macro2::Span::call_site(),
305                format!("invalid start expression: {start_str}"),
306            )
307        })?;
308
309        if end_str.is_empty() {
310            // RangeFrom: start..
311            return Ok(quote! {
312                ferray_core::dtype::SliceInfoElem::Slice {
313                    start: #start_tokens,
314                    end: ::core::option::Option::None,
315                    step: #step_expr,
316                }
317            });
318        }
319
320        let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
321            syn::Error::new(
322                proc_macro2::Span::call_site(),
323                format!("invalid end expression: {end_str}"),
324            )
325        })?;
326
327        return Ok(quote! {
328            ferray_core::dtype::SliceInfoElem::Slice {
329                start: #start_tokens,
330                end: ::core::option::Option::Some(#end_tokens),
331                step: #step_expr,
332            }
333        });
334    }
335
336    // No `..` found — this is a single index (integer expression)
337    if step_part.is_some() {
338        return Err(syn::Error::new(
339            proc_macro2::Span::call_site(),
340            format!("step ';' is not valid for integer indices: {trimmed}"),
341        ));
342    }
343
344    let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
345        syn::Error::new(
346            proc_macro2::Span::call_site(),
347            format!("invalid index expression: {range_part}"),
348        )
349    })?;
350
351    Ok(quote! {
352        ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
353    })
354}
355
356// ---------------------------------------------------------------------------
357// promoted_type!() — compile-time type promotion
358// ---------------------------------------------------------------------------
359
360/// Compile-time type promotion macro.
361///
362/// Given two numeric types, resolves to the smallest type that can represent
363/// both without precision loss, following NumPy's promotion rules.
364///
365/// # Examples
366/// ```ignore
367/// type R = promoted_type!(f32, f64); // R = f64
368/// type R = promoted_type!(i32, f32); // R = f64
369/// type R = promoted_type!(u8, i8);   // R = i16
370/// ```
371#[proc_macro]
372pub fn promoted_type(input: TokenStream) -> TokenStream {
373    let input2: proc_macro2::TokenStream = input.into();
374    match impl_promoted_type(input2) {
375        Ok(ts) => ts.into(),
376        Err(e) => e.to_compile_error().into(),
377    }
378}
379
380fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
381    let input_str = input.to_string();
382    let parts: Vec<&str> = input_str.split(',').map(|s| s.trim()).collect();
383
384    if parts.len() != 2 {
385        return Err(syn::Error::new(
386            proc_macro2::Span::call_site(),
387            "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
388        ));
389    }
390
391    let t1 = normalize_type(parts[0]);
392    let t2 = normalize_type(parts[1]);
393
394    let result = promote_types_static(&t1, &t2).ok_or_else(|| {
395        syn::Error::new(
396            proc_macro2::Span::call_site(),
397            format!("cannot promote types: {t1} and {t2}"),
398        )
399    })?;
400
401    let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
402        syn::Error::new(
403            proc_macro2::Span::call_site(),
404            format!("internal error: could not parse result type: {result}"),
405        )
406    })?;
407
408    Ok(result_tokens)
409}
410
411fn normalize_type(s: &str) -> String {
412    // Normalize Complex<f32> / Complex<f64> / num_complex::Complex<f32> etc.
413    s.trim().replace(' ', "")
414}
415
416/// Static type promotion following NumPy rules.
417///
418/// Returns the promoted type as a string, or None if unknown.
419fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
420    // Assign a numeric "kind + rank" to each type, then pick the larger.
421    //
422    // NumPy promotion hierarchy (simplified):
423    //   bool < u8 < u16 < u32 < u64 < u128
424    //   bool < i8 < i16 < i32 < i64 < i128
425    //   f32 < f64
426    //   Complex<f32> < Complex<f64>
427    //
428    // Cross-kind rules:
429    //   unsigned + signed -> next-size signed (e.g. u8 + i8 -> i16)
430    //   any int + float -> float (ensure enough precision)
431    //   any real + complex -> complex with appropriate float size
432
433    let ra = type_rank(a)?;
434    let rb = type_rank(b)?;
435
436    // `promote_ranks` uses an empty string to signal "no lossless
437    // common type exists" (currently the u128 + signed-int case); map
438    // that to `None` so the outer proc-macro emits a compile error.
439    match promote_ranks(ra, rb) {
440        "" => None,
441        other => Some(other),
442    }
443}
444
445#[derive(Clone, Copy, PartialEq, Eq)]
446enum TypeKind {
447    Bool,
448    Unsigned,
449    Signed,
450    Float,
451    Complex,
452}
453
454#[derive(Clone, Copy)]
455struct TypeRank {
456    kind: TypeKind,
457    /// Bit width within the kind (e.g., 8 for u8, 32 for f32, etc.)
458    bits: u32,
459}
460
461fn type_rank(s: &str) -> Option<TypeRank> {
462    let result = match s {
463        "bool" => TypeRank {
464            kind: TypeKind::Bool,
465            bits: 1,
466        },
467        "u8" => TypeRank {
468            kind: TypeKind::Unsigned,
469            bits: 8,
470        },
471        "u16" => TypeRank {
472            kind: TypeKind::Unsigned,
473            bits: 16,
474        },
475        "u32" => TypeRank {
476            kind: TypeKind::Unsigned,
477            bits: 32,
478        },
479        "u64" => TypeRank {
480            kind: TypeKind::Unsigned,
481            bits: 64,
482        },
483        "u128" => TypeRank {
484            kind: TypeKind::Unsigned,
485            bits: 128,
486        },
487        "i8" => TypeRank {
488            kind: TypeKind::Signed,
489            bits: 8,
490        },
491        "i16" => TypeRank {
492            kind: TypeKind::Signed,
493            bits: 16,
494        },
495        "i32" => TypeRank {
496            kind: TypeKind::Signed,
497            bits: 32,
498        },
499        "i64" => TypeRank {
500            kind: TypeKind::Signed,
501            bits: 64,
502        },
503        "i128" => TypeRank {
504            kind: TypeKind::Signed,
505            bits: 128,
506        },
507        "f32" => TypeRank {
508            kind: TypeKind::Float,
509            bits: 32,
510        },
511        "f64" => TypeRank {
512            kind: TypeKind::Float,
513            bits: 64,
514        },
515        "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
516            kind: TypeKind::Complex,
517            bits: 32,
518        },
519        "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
520            kind: TypeKind::Complex,
521            bits: 64,
522        },
523        "f16" | "half::f16" => TypeRank {
524            kind: TypeKind::Float,
525            bits: 16,
526        },
527        "bf16" | "half::bf16" => TypeRank {
528            kind: TypeKind::Float,
529            bits: 16,
530        },
531        _ => return None,
532    };
533    Some(result)
534}
535
536fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
537    use TypeKind::*;
538
539    // Same type
540    if a.kind == b.kind && a.bits == b.bits {
541        return rank_to_type(a);
542    }
543
544    // Handle Bool: bool promotes to anything
545    if a.kind == Bool {
546        return rank_to_type(b);
547    }
548    if b.kind == Bool {
549        return rank_to_type(a);
550    }
551
552    // Complex + anything -> Complex with max float precision
553    if a.kind == Complex || b.kind == Complex {
554        let float_bits_a = to_float_bits(a);
555        let float_bits_b = to_float_bits(b);
556        let bits = float_bits_a.max(float_bits_b);
557        return if bits <= 32 {
558            "num_complex::Complex<f32>"
559        } else {
560            "num_complex::Complex<f64>"
561        };
562    }
563
564    // Float + anything -> Float with enough precision
565    if a.kind == Float || b.kind == Float {
566        let float_bits_a = to_float_bits(a);
567        let float_bits_b = to_float_bits(b);
568        let bits = float_bits_a.max(float_bits_b);
569        return if bits <= 32 { "f32" } else { "f64" };
570    }
571
572    // Now both are integer types (Unsigned or Signed)
573    match (a.kind, b.kind) {
574        (Unsigned, Unsigned) => {
575            let bits = a.bits.max(b.bits);
576            uint_type(bits)
577        }
578        (Signed, Signed) => {
579            let bits = a.bits.max(b.bits);
580            int_type(bits)
581        }
582        (Unsigned, Signed) | (Signed, Unsigned) => {
583            let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
584            // unsigned + signed: need a signed type that holds both ranges
585            // u8 + i8 -> i16, u16 + i16 -> i32, etc.
586            if u.bits < s.bits {
587                // Signed type is strictly larger, it can hold the unsigned range
588                int_type(s.bits)
589            } else {
590                // Need the next larger signed type
591                let needed = u.bits.max(s.bits) * 2;
592                if needed <= 128 {
593                    int_type(needed)
594                } else {
595                    // u128 + any signed int: no lossless common type
596                    // on stable Rust (see ferray-core/src/dtype/promotion.rs
597                    // and issue #375). Return the empty string as a
598                    // sentinel; `promote_types_static` maps that to
599                    // `None`, which the outer proc-macro converts
600                    // into a compile error with a clear message.
601                    ""
602                }
603            }
604        }
605        _ => "f64", // fallback
606    }
607}
608
609/// Convert any type rank to the float bit width it requires.
610fn to_float_bits(r: TypeRank) -> u32 {
611    match r.kind {
612        TypeKind::Bool => 32,
613        TypeKind::Unsigned | TypeKind::Signed => {
614            // Integers up to 24-bit mantissa fit in f32 (i.e., i8, i16, u8, u16).
615            // Larger integers need f64 (53-bit mantissa).
616            if r.bits <= 16 { 32 } else { 64 }
617        }
618        TypeKind::Float => r.bits,
619        TypeKind::Complex => r.bits,
620    }
621}
622
623fn uint_type(bits: u32) -> &'static str {
624    match bits {
625        8 => "u8",
626        16 => "u16",
627        32 => "u32",
628        64 => "u64",
629        128 => "u128",
630        _ => "u64",
631    }
632}
633
634fn int_type(bits: u32) -> &'static str {
635    match bits {
636        8 => "i8",
637        16 => "i16",
638        32 => "i32",
639        64 => "i64",
640        128 => "i128",
641        _ => "i64",
642    }
643}
644
645fn rank_to_type(r: TypeRank) -> &'static str {
646    match r.kind {
647        TypeKind::Bool => "bool",
648        TypeKind::Unsigned => uint_type(r.bits),
649        TypeKind::Signed => int_type(r.bits),
650        TypeKind::Float => {
651            if r.bits <= 16 {
652                "half::f16"
653            } else if r.bits <= 32 {
654                "f32"
655            } else {
656                "f64"
657            }
658        }
659        TypeKind::Complex => {
660            if r.bits <= 32 {
661                "num_complex::Complex<f32>"
662            } else {
663                "num_complex::Complex<f64>"
664            }
665        }
666    }
667}