1#![allow(
13 clippy::match_same_arms,
14 clippy::needless_pass_by_value,
15 clippy::option_if_let_else
16)]
17
18extern crate proc_macro;
19
20use proc_macro::TokenStream;
21use quote::quote;
22use syn::{Data, DeriveInput, Fields, parse_macro_input};
23
24#[proc_macro_derive(FerrayRecord)]
39pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
40 let input = parse_macro_input!(input as DeriveInput);
41 match impl_ferray_record(&input) {
42 Ok(ts) => ts.into(),
43 Err(e) => e.to_compile_error().into(),
44 }
45}
46
47fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
48 let name = &input.ident;
49
50 let has_repr_c = input.attrs.iter().any(|attr| {
52 if !attr.path().is_ident("repr") {
53 return false;
54 }
55 let mut found = false;
56 let _ = attr.parse_nested_meta(|meta| {
57 if meta.path.is_ident("C") {
58 found = true;
59 }
60 Ok(())
61 });
62 found
63 });
64
65 if !has_repr_c {
66 return Err(syn::Error::new_spanned(
67 &input.ident,
68 "FerrayRecord requires #[repr(C)] on the struct",
69 ));
70 }
71
72 let fields = match &input.data {
74 Data::Struct(data_struct) => match &data_struct.fields {
75 Fields::Named(named) => &named.named,
76 _ => {
77 return Err(syn::Error::new_spanned(
78 &input.ident,
79 "FerrayRecord only supports structs with named fields",
80 ));
81 }
82 },
83 _ => {
84 return Err(syn::Error::new_spanned(
85 &input.ident,
86 "FerrayRecord can only be derived for structs",
87 ));
88 }
89 };
90
91 let field_count = fields.len();
92 let mut field_descriptors = Vec::with_capacity(field_count);
93
94 for field in fields {
95 let field_name = field.ident.as_ref().unwrap();
96 let field_name_str = field_name.to_string();
97 let field_ty = &field.ty;
98
99 field_descriptors.push(quote! {
100 ferray_core::record::FieldDescriptor {
101 name: #field_name_str,
102 dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
103 offset: std::mem::offset_of!(#name, #field_name),
104 size: std::mem::size_of::<#field_ty>(),
105 }
106 });
107 }
108
109 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
110
111 let expanded = quote! {
112 unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
113 fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
114 static FIELDS: std::sync::LazyLock<Vec<ferray_core::record::FieldDescriptor>> =
122 std::sync::LazyLock::new(|| {
123 vec![
124 #(#field_descriptors),*
125 ]
126 });
127 &FIELDS
128 }
129
130 fn record_size() -> usize {
131 std::mem::size_of::<#name>()
132 }
133 }
134 };
135
136 Ok(expanded)
137}
138
139#[proc_macro]
161pub fn s(input: TokenStream) -> TokenStream {
162 let input2: proc_macro2::TokenStream = input.into();
163 let expanded = impl_s_macro(input2);
164 match expanded {
165 Ok(ts) => ts.into(),
166 Err(e) => e.to_compile_error().into(),
167 }
168}
169
170fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
171 let input_str = input.to_string();
186
187 if input_str.trim().is_empty() {
189 return Ok(quote! {
190 ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
191 });
192 }
193
194 let components = split_top_level_commas(&input_str);
196 let mut elems = Vec::new();
197
198 for component in &components {
199 let trimmed = component.trim();
200 if trimmed.is_empty() {
201 continue;
202 }
203 elems.push(parse_slice_component(trimmed)?);
204 }
205
206 Ok(quote! {
207 vec![#(#elems),*]
208 })
209}
210
211fn split_top_level_commas(s: &str) -> Vec<String> {
212 let mut result = Vec::new();
213 let mut current = String::new();
214 let mut depth = 0i32;
215
216 for ch in s.chars() {
217 match ch {
218 '(' | '[' | '{' => {
219 depth += 1;
220 current.push(ch);
221 }
222 ')' | ']' | '}' => {
223 depth -= 1;
224 current.push(ch);
225 }
226 ',' if depth == 0 => {
227 result.push(current.clone());
228 current.clear();
229 }
230 _ => {
231 current.push(ch);
232 }
233 }
234 }
235 if !current.is_empty() {
236 result.push(current);
237 }
238 result
239}
240
241fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
243 let mut depth = 0i32;
244 let mut last_idx = None;
245 for (i, ch) in s.char_indices() {
246 match ch {
247 '(' | '[' | '{' => depth += 1,
248 ')' | ']' | '}' => depth -= 1,
249 ';' if depth == 0 => last_idx = Some(i),
250 _ => {}
251 }
252 }
253 last_idx
254}
255
256fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
257 let trimmed = s.trim();
258
259 let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
261 let (rp, sp) = trimmed.split_at(idx);
262 (rp.trim(), Some(sp[1..].trim()))
263 } else {
264 (trimmed, None)
265 };
266
267 let step_expr = if let Some(step_str) = step_part {
268 let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
269 syn::Error::new(
270 proc_macro2::Span::call_site(),
271 format!("invalid step expression: {step_str}"),
272 )
273 })?;
274 quote! { #step_tokens }
275 } else {
276 quote! { 1isize }
277 };
278
279 if range_part == ".." {
281 return Ok(quote! {
283 ferray_core::dtype::SliceInfoElem::Slice {
284 start: 0,
285 end: ::core::option::Option::None,
286 step: #step_expr,
287 }
288 });
289 }
290
291 if let Some(rest) = range_part.strip_prefix("..") {
292 let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
294 syn::Error::new(
295 proc_macro2::Span::call_site(),
296 format!("invalid end expression: {rest}"),
297 )
298 })?;
299 return Ok(quote! {
300 ferray_core::dtype::SliceInfoElem::Slice {
301 start: 0,
302 end: ::core::option::Option::Some(#end_tokens),
303 step: #step_expr,
304 }
305 });
306 }
307
308 if let Some(idx) = range_part.find("..") {
309 let start_str = range_part[..idx].trim();
310 let end_str = range_part[idx + 2..].trim();
311
312 let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
313 syn::Error::new(
314 proc_macro2::Span::call_site(),
315 format!("invalid start expression: {start_str}"),
316 )
317 })?;
318
319 if end_str.is_empty() {
320 return Ok(quote! {
322 ferray_core::dtype::SliceInfoElem::Slice {
323 start: #start_tokens,
324 end: ::core::option::Option::None,
325 step: #step_expr,
326 }
327 });
328 }
329
330 let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
331 syn::Error::new(
332 proc_macro2::Span::call_site(),
333 format!("invalid end expression: {end_str}"),
334 )
335 })?;
336
337 return Ok(quote! {
338 ferray_core::dtype::SliceInfoElem::Slice {
339 start: #start_tokens,
340 end: ::core::option::Option::Some(#end_tokens),
341 step: #step_expr,
342 }
343 });
344 }
345
346 if step_part.is_some() {
348 return Err(syn::Error::new(
349 proc_macro2::Span::call_site(),
350 format!("step ';' is not valid for integer indices: {trimmed}"),
351 ));
352 }
353
354 let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
355 syn::Error::new(
356 proc_macro2::Span::call_site(),
357 format!("invalid index expression: {range_part}"),
358 )
359 })?;
360
361 Ok(quote! {
362 ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
363 })
364}
365
366#[proc_macro]
382pub fn promoted_type(input: TokenStream) -> TokenStream {
383 let input2: proc_macro2::TokenStream = input.into();
384 match impl_promoted_type(input2) {
385 Ok(ts) => ts.into(),
386 Err(e) => e.to_compile_error().into(),
387 }
388}
389
390fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
391 let input_str = input.to_string();
392 let parts: Vec<&str> = input_str.split(',').map(str::trim).collect();
393
394 if parts.len() != 2 {
395 return Err(syn::Error::new(
396 proc_macro2::Span::call_site(),
397 "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
398 ));
399 }
400
401 let t1 = normalize_type(parts[0]);
402 let t2 = normalize_type(parts[1]);
403
404 let result = promote_types_static(&t1, &t2).ok_or_else(|| {
405 syn::Error::new(
406 proc_macro2::Span::call_site(),
407 format!("cannot promote types: {t1} and {t2}"),
408 )
409 })?;
410
411 let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
412 syn::Error::new(
413 proc_macro2::Span::call_site(),
414 format!("internal error: could not parse result type: {result}"),
415 )
416 })?;
417
418 Ok(result_tokens)
419}
420
421fn normalize_type(s: &str) -> String {
422 s.trim().replace(' ', "")
424}
425
426fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
430 let ra = type_rank(a)?;
444 let rb = type_rank(b)?;
445
446 match promote_ranks(ra, rb) {
450 "" => None,
451 other => Some(other),
452 }
453}
454
455#[derive(Clone, Copy, PartialEq, Eq)]
456enum TypeKind {
457 Bool,
458 Unsigned,
459 Signed,
460 Float,
461 Complex,
462}
463
464#[derive(Clone, Copy)]
465struct TypeRank {
466 kind: TypeKind,
467 bits: u32,
469}
470
471fn type_rank(s: &str) -> Option<TypeRank> {
472 let result = match s {
473 "bool" => TypeRank {
474 kind: TypeKind::Bool,
475 bits: 1,
476 },
477 "u8" => TypeRank {
478 kind: TypeKind::Unsigned,
479 bits: 8,
480 },
481 "u16" => TypeRank {
482 kind: TypeKind::Unsigned,
483 bits: 16,
484 },
485 "u32" => TypeRank {
486 kind: TypeKind::Unsigned,
487 bits: 32,
488 },
489 "u64" => TypeRank {
490 kind: TypeKind::Unsigned,
491 bits: 64,
492 },
493 "u128" => TypeRank {
494 kind: TypeKind::Unsigned,
495 bits: 128,
496 },
497 "i8" => TypeRank {
498 kind: TypeKind::Signed,
499 bits: 8,
500 },
501 "i16" => TypeRank {
502 kind: TypeKind::Signed,
503 bits: 16,
504 },
505 "i32" => TypeRank {
506 kind: TypeKind::Signed,
507 bits: 32,
508 },
509 "i64" => TypeRank {
510 kind: TypeKind::Signed,
511 bits: 64,
512 },
513 "i128" => TypeRank {
514 kind: TypeKind::Signed,
515 bits: 128,
516 },
517 "f32" => TypeRank {
518 kind: TypeKind::Float,
519 bits: 32,
520 },
521 "f64" => TypeRank {
522 kind: TypeKind::Float,
523 bits: 64,
524 },
525 "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
526 kind: TypeKind::Complex,
527 bits: 32,
528 },
529 "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
530 kind: TypeKind::Complex,
531 bits: 64,
532 },
533 "f16" | "half::f16" => TypeRank {
534 kind: TypeKind::Float,
535 bits: 16,
536 },
537 "bf16" | "half::bf16" => TypeRank {
538 kind: TypeKind::Float,
539 bits: 16,
540 },
541 _ => return None,
542 };
543 Some(result)
544}
545
546fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
547 use TypeKind::{Bool, Complex, Float, Signed, Unsigned};
548
549 if a.kind == b.kind && a.bits == b.bits {
551 return rank_to_type(a);
552 }
553
554 if a.kind == Bool {
556 return rank_to_type(b);
557 }
558 if b.kind == Bool {
559 return rank_to_type(a);
560 }
561
562 if a.kind == Complex || b.kind == Complex {
564 let float_bits_a = to_float_bits(a);
565 let float_bits_b = to_float_bits(b);
566 let bits = float_bits_a.max(float_bits_b);
567 return if bits <= 32 {
568 "num_complex::Complex<f32>"
569 } else {
570 "num_complex::Complex<f64>"
571 };
572 }
573
574 if a.kind == Float || b.kind == Float {
576 let float_bits_a = to_float_bits(a);
577 let float_bits_b = to_float_bits(b);
578 let bits = float_bits_a.max(float_bits_b);
579 return if bits <= 32 { "f32" } else { "f64" };
580 }
581
582 match (a.kind, b.kind) {
584 (Unsigned, Unsigned) => {
585 let bits = a.bits.max(b.bits);
586 uint_type(bits)
587 }
588 (Signed, Signed) => {
589 let bits = a.bits.max(b.bits);
590 int_type(bits)
591 }
592 (Unsigned, Signed) | (Signed, Unsigned) => {
593 let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
594 if u.bits < s.bits {
597 int_type(s.bits)
599 } else {
600 let needed = u.bits.max(s.bits) * 2;
602 if needed <= 128 {
603 int_type(needed)
604 } else {
605 ""
612 }
613 }
614 }
615 _ => "f64", }
617}
618
619const fn to_float_bits(r: TypeRank) -> u32 {
621 match r.kind {
622 TypeKind::Bool => 32,
623 TypeKind::Unsigned | TypeKind::Signed => {
624 if r.bits <= 16 { 32 } else { 64 }
627 }
628 TypeKind::Float => r.bits,
629 TypeKind::Complex => r.bits,
630 }
631}
632
633const fn uint_type(bits: u32) -> &'static str {
634 match bits {
635 8 => "u8",
636 16 => "u16",
637 32 => "u32",
638 64 => "u64",
639 128 => "u128",
640 _ => "u64",
641 }
642}
643
644const fn int_type(bits: u32) -> &'static str {
645 match bits {
646 8 => "i8",
647 16 => "i16",
648 32 => "i32",
649 64 => "i64",
650 128 => "i128",
651 _ => "i64",
652 }
653}
654
655const fn rank_to_type(r: TypeRank) -> &'static str {
656 match r.kind {
657 TypeKind::Bool => "bool",
658 TypeKind::Unsigned => uint_type(r.bits),
659 TypeKind::Signed => int_type(r.bits),
660 TypeKind::Float => {
661 if r.bits <= 16 {
662 "half::f16"
663 } else if r.bits <= 32 {
664 "f32"
665 } else {
666 "f64"
667 }
668 }
669 TypeKind::Complex => {
670 if r.bits <= 32 {
671 "num_complex::Complex<f32>"
672 } else {
673 "num_complex::Complex<f64>"
674 }
675 }
676 }
677}