Skip to main content

ferray_core_macros/
lib.rs

1// ferray-core-macros: Procedural macros for ferray-core
2//
3// Implements:
4// - #[derive(FerrayRecord)] — generates FerrayRecord trait impl for #[repr(C)] structs
5// - s![] — NumPy-style slice indexing macro
6// - promoted_type!() — compile-time type promotion macro
7
8extern crate proc_macro;
9
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::{Data, DeriveInput, Fields, parse_macro_input};
13
14// ---------------------------------------------------------------------------
15// #[derive(FerrayRecord)]
16// ---------------------------------------------------------------------------
17
18/// Derive macro that generates an `unsafe impl FerrayRecord` for a `#[repr(C)]` struct.
19///
20/// # Requirements
21/// - The struct must have `#[repr(C)]`.
22/// - All fields must implement `ferray_core::dtype::Element`.
23///
24/// # Generated code
25/// - `field_descriptors()` returns a static slice of `FieldDescriptor` with correct
26///   name, dtype, offset, and size for each field.
27/// - `record_size()` returns `std::mem::size_of::<Self>()`.
28#[proc_macro_derive(FerrayRecord)]
29pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
30    let input = parse_macro_input!(input as DeriveInput);
31    match impl_ferray_record(&input) {
32        Ok(ts) => ts.into(),
33        Err(e) => e.to_compile_error().into(),
34    }
35}
36
37fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
38    let name = &input.ident;
39
40    // Check for #[repr(C)]
41    let has_repr_c = input.attrs.iter().any(|attr| {
42        if !attr.path().is_ident("repr") {
43            return false;
44        }
45        let mut found = false;
46        let _ = attr.parse_nested_meta(|meta| {
47            if meta.path.is_ident("C") {
48                found = true;
49            }
50            Ok(())
51        });
52        found
53    });
54
55    if !has_repr_c {
56        return Err(syn::Error::new_spanned(
57            &input.ident,
58            "FerrayRecord requires #[repr(C)] on the struct",
59        ));
60    }
61
62    // Only works on structs with named fields
63    let fields = match &input.data {
64        Data::Struct(data_struct) => match &data_struct.fields {
65            Fields::Named(named) => &named.named,
66            _ => {
67                return Err(syn::Error::new_spanned(
68                    &input.ident,
69                    "FerrayRecord only supports structs with named fields",
70                ));
71            }
72        },
73        _ => {
74            return Err(syn::Error::new_spanned(
75                &input.ident,
76                "FerrayRecord can only be derived for structs",
77            ));
78        }
79    };
80
81    let field_count = fields.len();
82    let mut field_descriptors = Vec::with_capacity(field_count);
83
84    for field in fields.iter() {
85        let field_name = field.ident.as_ref().unwrap();
86        let field_name_str = field_name.to_string();
87        let field_ty = &field.ty;
88
89        field_descriptors.push(quote! {
90            ferray_core::record::FieldDescriptor {
91                name: #field_name_str,
92                dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
93                offset: std::mem::offset_of!(#name, #field_name),
94                size: std::mem::size_of::<#field_ty>(),
95            }
96        });
97    }
98
99    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
100
101    let expanded = quote! {
102        unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
103            fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
104                static FIELDS: std::sync::LazyLock<Vec<ferray_core::record::FieldDescriptor>> =
105                    std::sync::LazyLock::new(|| {
106                        vec![
107                            #(#field_descriptors),*
108                        ]
109                    });
110                &FIELDS
111            }
112
113            fn record_size() -> usize {
114                std::mem::size_of::<#name>()
115            }
116        }
117    };
118
119    Ok(expanded)
120}
121
122// ---------------------------------------------------------------------------
123// s![] macro — NumPy-style slice indexing
124// ---------------------------------------------------------------------------
125
126/// NumPy-style slice indexing macro.
127///
128/// Produces a `Vec<ferray_core::dtype::SliceInfoElem>` that can be passed
129/// to array slicing methods.
130///
131/// # Syntax
132/// - `s![0..3, 2]` — rows 0..3, column 2
133/// - `s![.., 0..;2]` — all rows, every-other column starting from 0
134/// - `s![1..5;2, ..]` — rows 1..5 step 2, all columns
135/// - `s![3]` — single integer index
136/// - `s![..]` — all elements along this axis
137/// - `s![2..]` — from index 2 to end
138/// - `s![..5]` — from start to index 5
139/// - `s![1..5]` — from index 1 to 5
140/// - `s![1..5;2]` — from index 1 to 5, step 2
141///
142/// Each component in the comma-separated list becomes one `SliceInfoElem`.
143#[proc_macro]
144pub fn s(input: TokenStream) -> TokenStream {
145    let input2: proc_macro2::TokenStream = input.into();
146    let expanded = impl_s_macro(input2);
147    match expanded {
148        Ok(ts) => ts.into(),
149        Err(e) => e.to_compile_error().into(),
150    }
151}
152
153fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
154    // We parse the input as a sequence of comma-separated slice expressions.
155    // Each expression can be:
156    //   - An integer literal or expression: `2` -> Index(2)
157    //   - A full range `..` -> Slice { start: 0, end: None, step: 1 }
158    //   - A range `a..b` -> Slice { start: a, end: Some(b), step: 1 }
159    //   - A range from `a..` -> Slice { start: a, end: None, step: 1 }
160    //   - A range to `..b` -> Slice { start: 0, end: Some(b), step: 1 }
161    //   - Any of the above with `;step` suffix
162    //
163    // We'll output code that constructs a Vec<SliceInfoElem>.
164    //
165    // Since proc macros can't easily parse arbitrary Rust expressions with range syntax
166    // mixed with custom `;step` syntax, we'll use a simpler token-based approach.
167
168    let input_str = input.to_string();
169
170    // Handle empty input
171    if input_str.trim().is_empty() {
172        return Ok(quote! {
173            ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
174        });
175    }
176
177    // Split by commas (respecting parentheses/brackets nesting)
178    let components = split_top_level_commas(&input_str);
179    let mut elems = Vec::new();
180
181    for component in &components {
182        let trimmed = component.trim();
183        if trimmed.is_empty() {
184            continue;
185        }
186        elems.push(parse_slice_component(trimmed)?);
187    }
188
189    Ok(quote! {
190        vec![#(#elems),*]
191    })
192}
193
194fn split_top_level_commas(s: &str) -> Vec<String> {
195    let mut result = Vec::new();
196    let mut current = String::new();
197    let mut depth = 0i32;
198
199    for ch in s.chars() {
200        match ch {
201            '(' | '[' | '{' => {
202                depth += 1;
203                current.push(ch);
204            }
205            ')' | ']' | '}' => {
206                depth -= 1;
207                current.push(ch);
208            }
209            ',' if depth == 0 => {
210                result.push(current.clone());
211                current.clear();
212            }
213            _ => {
214                current.push(ch);
215            }
216        }
217    }
218    if !current.is_empty() {
219        result.push(current);
220    }
221    result
222}
223
224/// Find the last top-level semicolon (not inside brackets/braces/parens).
225fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
226    let mut depth = 0i32;
227    let mut last_idx = None;
228    for (i, ch) in s.char_indices() {
229        match ch {
230            '(' | '[' | '{' => depth += 1,
231            ')' | ']' | '}' => depth -= 1,
232            ';' if depth == 0 => last_idx = Some(i),
233            _ => {}
234        }
235    }
236    last_idx
237}
238
239fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
240    let trimmed = s.trim();
241
242    // Check for step suffix: `expr;step` (depth-aware to handle block expressions)
243    let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
244        let (rp, sp) = trimmed.split_at(idx);
245        (rp.trim(), Some(sp[1..].trim()))
246    } else {
247        (trimmed, None)
248    };
249
250    let step_expr = if let Some(step_str) = step_part {
251        let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
252            syn::Error::new(
253                proc_macro2::Span::call_site(),
254                format!("invalid step expression: {step_str}"),
255            )
256        })?;
257        quote! { #step_tokens }
258    } else {
259        quote! { 1isize }
260    };
261
262    // Now parse range_part
263    if range_part == ".." {
264        // Full range: all elements
265        return Ok(quote! {
266            ferray_core::dtype::SliceInfoElem::Slice {
267                start: 0,
268                end: ::core::option::Option::None,
269                step: #step_expr,
270            }
271        });
272    }
273
274    if let Some(rest) = range_part.strip_prefix("..") {
275        // RangeTo: ..end
276        let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
277            syn::Error::new(
278                proc_macro2::Span::call_site(),
279                format!("invalid end expression: {rest}"),
280            )
281        })?;
282        return Ok(quote! {
283            ferray_core::dtype::SliceInfoElem::Slice {
284                start: 0,
285                end: ::core::option::Option::Some(#end_tokens),
286                step: #step_expr,
287            }
288        });
289    }
290
291    if let Some(idx) = range_part.find("..") {
292        let start_str = range_part[..idx].trim();
293        let end_str = range_part[idx + 2..].trim();
294
295        let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
296            syn::Error::new(
297                proc_macro2::Span::call_site(),
298                format!("invalid start expression: {start_str}"),
299            )
300        })?;
301
302        if end_str.is_empty() {
303            // RangeFrom: start..
304            return Ok(quote! {
305                ferray_core::dtype::SliceInfoElem::Slice {
306                    start: #start_tokens,
307                    end: ::core::option::Option::None,
308                    step: #step_expr,
309                }
310            });
311        }
312
313        let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
314            syn::Error::new(
315                proc_macro2::Span::call_site(),
316                format!("invalid end expression: {end_str}"),
317            )
318        })?;
319
320        return Ok(quote! {
321            ferray_core::dtype::SliceInfoElem::Slice {
322                start: #start_tokens,
323                end: ::core::option::Option::Some(#end_tokens),
324                step: #step_expr,
325            }
326        });
327    }
328
329    // No `..` found — this is a single index (integer expression)
330    if step_part.is_some() {
331        return Err(syn::Error::new(
332            proc_macro2::Span::call_site(),
333            format!("step ';' is not valid for integer indices: {trimmed}"),
334        ));
335    }
336
337    let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
338        syn::Error::new(
339            proc_macro2::Span::call_site(),
340            format!("invalid index expression: {range_part}"),
341        )
342    })?;
343
344    Ok(quote! {
345        ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
346    })
347}
348
349// ---------------------------------------------------------------------------
350// promoted_type!() — compile-time type promotion
351// ---------------------------------------------------------------------------
352
353/// Compile-time type promotion macro.
354///
355/// Given two numeric types, resolves to the smallest type that can represent
356/// both without precision loss, following NumPy's promotion rules.
357///
358/// # Examples
359/// ```ignore
360/// type R = promoted_type!(f32, f64); // R = f64
361/// type R = promoted_type!(i32, f32); // R = f64
362/// type R = promoted_type!(u8, i8);   // R = i16
363/// ```
364#[proc_macro]
365pub fn promoted_type(input: TokenStream) -> TokenStream {
366    let input2: proc_macro2::TokenStream = input.into();
367    match impl_promoted_type(input2) {
368        Ok(ts) => ts.into(),
369        Err(e) => e.to_compile_error().into(),
370    }
371}
372
373fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
374    let input_str = input.to_string();
375    let parts: Vec<&str> = input_str.split(',').map(|s| s.trim()).collect();
376
377    if parts.len() != 2 {
378        return Err(syn::Error::new(
379            proc_macro2::Span::call_site(),
380            "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
381        ));
382    }
383
384    let t1 = normalize_type(parts[0]);
385    let t2 = normalize_type(parts[1]);
386
387    let result = promote_types_static(&t1, &t2).ok_or_else(|| {
388        syn::Error::new(
389            proc_macro2::Span::call_site(),
390            format!("cannot promote types: {t1} and {t2}"),
391        )
392    })?;
393
394    let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
395        syn::Error::new(
396            proc_macro2::Span::call_site(),
397            format!("internal error: could not parse result type: {result}"),
398        )
399    })?;
400
401    Ok(result_tokens)
402}
403
404fn normalize_type(s: &str) -> String {
405    // Normalize Complex<f32> / Complex<f64> / num_complex::Complex<f32> etc.
406    s.trim().replace(' ', "")
407}
408
409/// Static type promotion following NumPy rules.
410///
411/// Returns the promoted type as a string, or None if unknown.
412fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
413    // Assign a numeric "kind + rank" to each type, then pick the larger.
414    //
415    // NumPy promotion hierarchy (simplified):
416    //   bool < u8 < u16 < u32 < u64 < u128
417    //   bool < i8 < i16 < i32 < i64 < i128
418    //   f32 < f64
419    //   Complex<f32> < Complex<f64>
420    //
421    // Cross-kind rules:
422    //   unsigned + signed -> next-size signed (e.g. u8 + i8 -> i16)
423    //   any int + float -> float (ensure enough precision)
424    //   any real + complex -> complex with appropriate float size
425
426    let ra = type_rank(a)?;
427    let rb = type_rank(b)?;
428
429    Some(promote_ranks(ra, rb))
430}
431
432#[derive(Clone, Copy, PartialEq, Eq)]
433enum TypeKind {
434    Bool,
435    Unsigned,
436    Signed,
437    Float,
438    Complex,
439}
440
441#[derive(Clone, Copy)]
442struct TypeRank {
443    kind: TypeKind,
444    /// Bit width within the kind (e.g., 8 for u8, 32 for f32, etc.)
445    bits: u32,
446}
447
448fn type_rank(s: &str) -> Option<TypeRank> {
449    let result = match s {
450        "bool" => TypeRank {
451            kind: TypeKind::Bool,
452            bits: 1,
453        },
454        "u8" => TypeRank {
455            kind: TypeKind::Unsigned,
456            bits: 8,
457        },
458        "u16" => TypeRank {
459            kind: TypeKind::Unsigned,
460            bits: 16,
461        },
462        "u32" => TypeRank {
463            kind: TypeKind::Unsigned,
464            bits: 32,
465        },
466        "u64" => TypeRank {
467            kind: TypeKind::Unsigned,
468            bits: 64,
469        },
470        "u128" => TypeRank {
471            kind: TypeKind::Unsigned,
472            bits: 128,
473        },
474        "i8" => TypeRank {
475            kind: TypeKind::Signed,
476            bits: 8,
477        },
478        "i16" => TypeRank {
479            kind: TypeKind::Signed,
480            bits: 16,
481        },
482        "i32" => TypeRank {
483            kind: TypeKind::Signed,
484            bits: 32,
485        },
486        "i64" => TypeRank {
487            kind: TypeKind::Signed,
488            bits: 64,
489        },
490        "i128" => TypeRank {
491            kind: TypeKind::Signed,
492            bits: 128,
493        },
494        "f32" => TypeRank {
495            kind: TypeKind::Float,
496            bits: 32,
497        },
498        "f64" => TypeRank {
499            kind: TypeKind::Float,
500            bits: 64,
501        },
502        "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
503            kind: TypeKind::Complex,
504            bits: 32,
505        },
506        "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
507            kind: TypeKind::Complex,
508            bits: 64,
509        },
510        "f16" | "half::f16" => TypeRank {
511            kind: TypeKind::Float,
512            bits: 16,
513        },
514        "bf16" | "half::bf16" => TypeRank {
515            kind: TypeKind::Float,
516            bits: 16,
517        },
518        _ => return None,
519    };
520    Some(result)
521}
522
523fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
524    use TypeKind::*;
525
526    // Same type
527    if a.kind == b.kind && a.bits == b.bits {
528        return rank_to_type(a);
529    }
530
531    // Handle Bool: bool promotes to anything
532    if a.kind == Bool {
533        return rank_to_type(b);
534    }
535    if b.kind == Bool {
536        return rank_to_type(a);
537    }
538
539    // Complex + anything -> Complex with max float precision
540    if a.kind == Complex || b.kind == Complex {
541        let float_bits_a = to_float_bits(a);
542        let float_bits_b = to_float_bits(b);
543        let bits = float_bits_a.max(float_bits_b);
544        return if bits <= 32 {
545            "num_complex::Complex<f32>"
546        } else {
547            "num_complex::Complex<f64>"
548        };
549    }
550
551    // Float + anything -> Float with enough precision
552    if a.kind == Float || b.kind == Float {
553        let float_bits_a = to_float_bits(a);
554        let float_bits_b = to_float_bits(b);
555        let bits = float_bits_a.max(float_bits_b);
556        return if bits <= 32 { "f32" } else { "f64" };
557    }
558
559    // Now both are integer types (Unsigned or Signed)
560    match (a.kind, b.kind) {
561        (Unsigned, Unsigned) => {
562            let bits = a.bits.max(b.bits);
563            uint_type(bits)
564        }
565        (Signed, Signed) => {
566            let bits = a.bits.max(b.bits);
567            int_type(bits)
568        }
569        (Unsigned, Signed) | (Signed, Unsigned) => {
570            let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
571            // unsigned + signed: need a signed type that holds both ranges
572            // u8 + i8 -> i16, u16 + i16 -> i32, etc.
573            if u.bits < s.bits {
574                // Signed type is strictly larger, it can hold the unsigned range
575                int_type(s.bits)
576            } else {
577                // Need the next larger signed type
578                let needed = u.bits.max(s.bits) * 2;
579                if needed <= 128 {
580                    int_type(needed)
581                } else {
582                    // Fall back to f64 when we exceed i128
583                    "f64"
584                }
585            }
586        }
587        _ => "f64", // fallback
588    }
589}
590
591/// Convert any type rank to the float bit width it requires.
592fn to_float_bits(r: TypeRank) -> u32 {
593    match r.kind {
594        TypeKind::Bool => 32,
595        TypeKind::Unsigned | TypeKind::Signed => {
596            // Integers up to 24-bit mantissa fit in f32 (i.e., i8, i16, u8, u16).
597            // Larger integers need f64 (53-bit mantissa).
598            if r.bits <= 16 { 32 } else { 64 }
599        }
600        TypeKind::Float => r.bits,
601        TypeKind::Complex => r.bits,
602    }
603}
604
605fn uint_type(bits: u32) -> &'static str {
606    match bits {
607        8 => "u8",
608        16 => "u16",
609        32 => "u32",
610        64 => "u64",
611        128 => "u128",
612        _ => "u64",
613    }
614}
615
616fn int_type(bits: u32) -> &'static str {
617    match bits {
618        8 => "i8",
619        16 => "i16",
620        32 => "i32",
621        64 => "i64",
622        128 => "i128",
623        _ => "i64",
624    }
625}
626
627fn rank_to_type(r: TypeRank) -> &'static str {
628    match r.kind {
629        TypeKind::Bool => "bool",
630        TypeKind::Unsigned => uint_type(r.bits),
631        TypeKind::Signed => int_type(r.bits),
632        TypeKind::Float => {
633            if r.bits <= 16 {
634                "half::f16"
635            } else if r.bits <= 32 {
636                "f32"
637            } else {
638                "f64"
639            }
640        }
641        TypeKind::Complex => {
642            if r.bits <= 32 {
643                "num_complex::Complex<f32>"
644            } else {
645                "num_complex::Complex<f64>"
646            }
647        }
648    }
649}