1extern crate proc_macro;
9
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::{Data, DeriveInput, Fields, parse_macro_input};
13
14#[proc_macro_derive(FerrayRecord)]
29pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
30 let input = parse_macro_input!(input as DeriveInput);
31 match impl_ferray_record(&input) {
32 Ok(ts) => ts.into(),
33 Err(e) => e.to_compile_error().into(),
34 }
35}
36
37fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
38 let name = &input.ident;
39
40 let has_repr_c = input.attrs.iter().any(|attr| {
42 if !attr.path().is_ident("repr") {
43 return false;
44 }
45 let mut found = false;
46 let _ = attr.parse_nested_meta(|meta| {
47 if meta.path.is_ident("C") {
48 found = true;
49 }
50 Ok(())
51 });
52 found
53 });
54
55 if !has_repr_c {
56 return Err(syn::Error::new_spanned(
57 &input.ident,
58 "FerrayRecord requires #[repr(C)] on the struct",
59 ));
60 }
61
62 let fields = match &input.data {
64 Data::Struct(data_struct) => match &data_struct.fields {
65 Fields::Named(named) => &named.named,
66 _ => {
67 return Err(syn::Error::new_spanned(
68 &input.ident,
69 "FerrayRecord only supports structs with named fields",
70 ));
71 }
72 },
73 _ => {
74 return Err(syn::Error::new_spanned(
75 &input.ident,
76 "FerrayRecord can only be derived for structs",
77 ));
78 }
79 };
80
81 let field_count = fields.len();
82 let mut field_descriptors = Vec::with_capacity(field_count);
83
84 for field in fields.iter() {
85 let field_name = field.ident.as_ref().unwrap();
86 let field_name_str = field_name.to_string();
87 let field_ty = &field.ty;
88
89 field_descriptors.push(quote! {
90 ferray_core::record::FieldDescriptor {
91 name: #field_name_str,
92 dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
93 offset: std::mem::offset_of!(#name, #field_name),
94 size: std::mem::size_of::<#field_ty>(),
95 }
96 });
97 }
98
99 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
100
101 let expanded = quote! {
102 unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
103 fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
104 static FIELDS: std::sync::LazyLock<Vec<ferray_core::record::FieldDescriptor>> =
105 std::sync::LazyLock::new(|| {
106 vec![
107 #(#field_descriptors),*
108 ]
109 });
110 &FIELDS
111 }
112
113 fn record_size() -> usize {
114 std::mem::size_of::<#name>()
115 }
116 }
117 };
118
119 Ok(expanded)
120}
121
122#[proc_macro]
144pub fn s(input: TokenStream) -> TokenStream {
145 let input2: proc_macro2::TokenStream = input.into();
146 let expanded = impl_s_macro(input2);
147 match expanded {
148 Ok(ts) => ts.into(),
149 Err(e) => e.to_compile_error().into(),
150 }
151}
152
153fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
154 let input_str = input.to_string();
169
170 if input_str.trim().is_empty() {
172 return Ok(quote! {
173 ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
174 });
175 }
176
177 let components = split_top_level_commas(&input_str);
179 let mut elems = Vec::new();
180
181 for component in &components {
182 let trimmed = component.trim();
183 if trimmed.is_empty() {
184 continue;
185 }
186 elems.push(parse_slice_component(trimmed)?);
187 }
188
189 Ok(quote! {
190 vec![#(#elems),*]
191 })
192}
193
194fn split_top_level_commas(s: &str) -> Vec<String> {
195 let mut result = Vec::new();
196 let mut current = String::new();
197 let mut depth = 0i32;
198
199 for ch in s.chars() {
200 match ch {
201 '(' | '[' | '{' => {
202 depth += 1;
203 current.push(ch);
204 }
205 ')' | ']' | '}' => {
206 depth -= 1;
207 current.push(ch);
208 }
209 ',' if depth == 0 => {
210 result.push(current.clone());
211 current.clear();
212 }
213 _ => {
214 current.push(ch);
215 }
216 }
217 }
218 if !current.is_empty() {
219 result.push(current);
220 }
221 result
222}
223
224fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
226 let mut depth = 0i32;
227 let mut last_idx = None;
228 for (i, ch) in s.char_indices() {
229 match ch {
230 '(' | '[' | '{' => depth += 1,
231 ')' | ']' | '}' => depth -= 1,
232 ';' if depth == 0 => last_idx = Some(i),
233 _ => {}
234 }
235 }
236 last_idx
237}
238
239fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
240 let trimmed = s.trim();
241
242 let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
244 let (rp, sp) = trimmed.split_at(idx);
245 (rp.trim(), Some(sp[1..].trim()))
246 } else {
247 (trimmed, None)
248 };
249
250 let step_expr = if let Some(step_str) = step_part {
251 let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
252 syn::Error::new(
253 proc_macro2::Span::call_site(),
254 format!("invalid step expression: {step_str}"),
255 )
256 })?;
257 quote! { #step_tokens }
258 } else {
259 quote! { 1isize }
260 };
261
262 if range_part == ".." {
264 return Ok(quote! {
266 ferray_core::dtype::SliceInfoElem::Slice {
267 start: 0,
268 end: ::core::option::Option::None,
269 step: #step_expr,
270 }
271 });
272 }
273
274 if let Some(rest) = range_part.strip_prefix("..") {
275 let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
277 syn::Error::new(
278 proc_macro2::Span::call_site(),
279 format!("invalid end expression: {rest}"),
280 )
281 })?;
282 return Ok(quote! {
283 ferray_core::dtype::SliceInfoElem::Slice {
284 start: 0,
285 end: ::core::option::Option::Some(#end_tokens),
286 step: #step_expr,
287 }
288 });
289 }
290
291 if let Some(idx) = range_part.find("..") {
292 let start_str = range_part[..idx].trim();
293 let end_str = range_part[idx + 2..].trim();
294
295 let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
296 syn::Error::new(
297 proc_macro2::Span::call_site(),
298 format!("invalid start expression: {start_str}"),
299 )
300 })?;
301
302 if end_str.is_empty() {
303 return Ok(quote! {
305 ferray_core::dtype::SliceInfoElem::Slice {
306 start: #start_tokens,
307 end: ::core::option::Option::None,
308 step: #step_expr,
309 }
310 });
311 }
312
313 let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
314 syn::Error::new(
315 proc_macro2::Span::call_site(),
316 format!("invalid end expression: {end_str}"),
317 )
318 })?;
319
320 return Ok(quote! {
321 ferray_core::dtype::SliceInfoElem::Slice {
322 start: #start_tokens,
323 end: ::core::option::Option::Some(#end_tokens),
324 step: #step_expr,
325 }
326 });
327 }
328
329 if step_part.is_some() {
331 return Err(syn::Error::new(
332 proc_macro2::Span::call_site(),
333 format!("step ';' is not valid for integer indices: {trimmed}"),
334 ));
335 }
336
337 let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
338 syn::Error::new(
339 proc_macro2::Span::call_site(),
340 format!("invalid index expression: {range_part}"),
341 )
342 })?;
343
344 Ok(quote! {
345 ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
346 })
347}
348
349#[proc_macro]
365pub fn promoted_type(input: TokenStream) -> TokenStream {
366 let input2: proc_macro2::TokenStream = input.into();
367 match impl_promoted_type(input2) {
368 Ok(ts) => ts.into(),
369 Err(e) => e.to_compile_error().into(),
370 }
371}
372
373fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
374 let input_str = input.to_string();
375 let parts: Vec<&str> = input_str.split(',').map(|s| s.trim()).collect();
376
377 if parts.len() != 2 {
378 return Err(syn::Error::new(
379 proc_macro2::Span::call_site(),
380 "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
381 ));
382 }
383
384 let t1 = normalize_type(parts[0]);
385 let t2 = normalize_type(parts[1]);
386
387 let result = promote_types_static(&t1, &t2).ok_or_else(|| {
388 syn::Error::new(
389 proc_macro2::Span::call_site(),
390 format!("cannot promote types: {t1} and {t2}"),
391 )
392 })?;
393
394 let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
395 syn::Error::new(
396 proc_macro2::Span::call_site(),
397 format!("internal error: could not parse result type: {result}"),
398 )
399 })?;
400
401 Ok(result_tokens)
402}
403
404fn normalize_type(s: &str) -> String {
405 s.trim().replace(' ', "")
407}
408
409fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
413 let ra = type_rank(a)?;
427 let rb = type_rank(b)?;
428
429 Some(promote_ranks(ra, rb))
430}
431
432#[derive(Clone, Copy, PartialEq, Eq)]
433enum TypeKind {
434 Bool,
435 Unsigned,
436 Signed,
437 Float,
438 Complex,
439}
440
441#[derive(Clone, Copy)]
442struct TypeRank {
443 kind: TypeKind,
444 bits: u32,
446}
447
448fn type_rank(s: &str) -> Option<TypeRank> {
449 let result = match s {
450 "bool" => TypeRank {
451 kind: TypeKind::Bool,
452 bits: 1,
453 },
454 "u8" => TypeRank {
455 kind: TypeKind::Unsigned,
456 bits: 8,
457 },
458 "u16" => TypeRank {
459 kind: TypeKind::Unsigned,
460 bits: 16,
461 },
462 "u32" => TypeRank {
463 kind: TypeKind::Unsigned,
464 bits: 32,
465 },
466 "u64" => TypeRank {
467 kind: TypeKind::Unsigned,
468 bits: 64,
469 },
470 "u128" => TypeRank {
471 kind: TypeKind::Unsigned,
472 bits: 128,
473 },
474 "i8" => TypeRank {
475 kind: TypeKind::Signed,
476 bits: 8,
477 },
478 "i16" => TypeRank {
479 kind: TypeKind::Signed,
480 bits: 16,
481 },
482 "i32" => TypeRank {
483 kind: TypeKind::Signed,
484 bits: 32,
485 },
486 "i64" => TypeRank {
487 kind: TypeKind::Signed,
488 bits: 64,
489 },
490 "i128" => TypeRank {
491 kind: TypeKind::Signed,
492 bits: 128,
493 },
494 "f32" => TypeRank {
495 kind: TypeKind::Float,
496 bits: 32,
497 },
498 "f64" => TypeRank {
499 kind: TypeKind::Float,
500 bits: 64,
501 },
502 "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
503 kind: TypeKind::Complex,
504 bits: 32,
505 },
506 "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
507 kind: TypeKind::Complex,
508 bits: 64,
509 },
510 "f16" | "half::f16" => TypeRank {
511 kind: TypeKind::Float,
512 bits: 16,
513 },
514 "bf16" | "half::bf16" => TypeRank {
515 kind: TypeKind::Float,
516 bits: 16,
517 },
518 _ => return None,
519 };
520 Some(result)
521}
522
523fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
524 use TypeKind::*;
525
526 if a.kind == b.kind && a.bits == b.bits {
528 return rank_to_type(a);
529 }
530
531 if a.kind == Bool {
533 return rank_to_type(b);
534 }
535 if b.kind == Bool {
536 return rank_to_type(a);
537 }
538
539 if a.kind == Complex || b.kind == Complex {
541 let float_bits_a = to_float_bits(a);
542 let float_bits_b = to_float_bits(b);
543 let bits = float_bits_a.max(float_bits_b);
544 return if bits <= 32 {
545 "num_complex::Complex<f32>"
546 } else {
547 "num_complex::Complex<f64>"
548 };
549 }
550
551 if a.kind == Float || b.kind == Float {
553 let float_bits_a = to_float_bits(a);
554 let float_bits_b = to_float_bits(b);
555 let bits = float_bits_a.max(float_bits_b);
556 return if bits <= 32 { "f32" } else { "f64" };
557 }
558
559 match (a.kind, b.kind) {
561 (Unsigned, Unsigned) => {
562 let bits = a.bits.max(b.bits);
563 uint_type(bits)
564 }
565 (Signed, Signed) => {
566 let bits = a.bits.max(b.bits);
567 int_type(bits)
568 }
569 (Unsigned, Signed) | (Signed, Unsigned) => {
570 let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
571 if u.bits < s.bits {
574 int_type(s.bits)
576 } else {
577 let needed = u.bits.max(s.bits) * 2;
579 if needed <= 128 {
580 int_type(needed)
581 } else {
582 "f64"
584 }
585 }
586 }
587 _ => "f64", }
589}
590
591fn to_float_bits(r: TypeRank) -> u32 {
593 match r.kind {
594 TypeKind::Bool => 32,
595 TypeKind::Unsigned | TypeKind::Signed => {
596 if r.bits <= 16 { 32 } else { 64 }
599 }
600 TypeKind::Float => r.bits,
601 TypeKind::Complex => r.bits,
602 }
603}
604
605fn uint_type(bits: u32) -> &'static str {
606 match bits {
607 8 => "u8",
608 16 => "u16",
609 32 => "u32",
610 64 => "u64",
611 128 => "u128",
612 _ => "u64",
613 }
614}
615
616fn int_type(bits: u32) -> &'static str {
617 match bits {
618 8 => "i8",
619 16 => "i16",
620 32 => "i32",
621 64 => "i64",
622 128 => "i128",
623 _ => "i64",
624 }
625}
626
627fn rank_to_type(r: TypeRank) -> &'static str {
628 match r.kind {
629 TypeKind::Bool => "bool",
630 TypeKind::Unsigned => uint_type(r.bits),
631 TypeKind::Signed => int_type(r.bits),
632 TypeKind::Float => {
633 if r.bits <= 16 {
634 "half::f16"
635 } else if r.bits <= 32 {
636 "f32"
637 } else {
638 "f64"
639 }
640 }
641 TypeKind::Complex => {
642 if r.bits <= 32 {
643 "num_complex::Complex<f32>"
644 } else {
645 "num_complex::Complex<f64>"
646 }
647 }
648 }
649}