1extern crate proc_macro;
9
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::{Data, DeriveInput, Fields, parse_macro_input};
13
14#[proc_macro_derive(FerrayRecord)]
29pub fn derive_ferray_record(input: TokenStream) -> TokenStream {
30 let input = parse_macro_input!(input as DeriveInput);
31 match impl_ferray_record(&input) {
32 Ok(ts) => ts.into(),
33 Err(e) => e.to_compile_error().into(),
34 }
35}
36
37fn impl_ferray_record(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
38 let name = &input.ident;
39
40 let has_repr_c = input.attrs.iter().any(|attr| {
42 if !attr.path().is_ident("repr") {
43 return false;
44 }
45 let mut found = false;
46 let _ = attr.parse_nested_meta(|meta| {
47 if meta.path.is_ident("C") {
48 found = true;
49 }
50 Ok(())
51 });
52 found
53 });
54
55 if !has_repr_c {
56 return Err(syn::Error::new_spanned(
57 &input.ident,
58 "FerrayRecord requires #[repr(C)] on the struct",
59 ));
60 }
61
62 let fields = match &input.data {
64 Data::Struct(data_struct) => match &data_struct.fields {
65 Fields::Named(named) => &named.named,
66 _ => {
67 return Err(syn::Error::new_spanned(
68 &input.ident,
69 "FerrayRecord only supports structs with named fields",
70 ));
71 }
72 },
73 _ => {
74 return Err(syn::Error::new_spanned(
75 &input.ident,
76 "FerrayRecord can only be derived for structs",
77 ));
78 }
79 };
80
81 let field_count = fields.len();
82 let mut field_descriptors = Vec::with_capacity(field_count);
83
84 for field in fields.iter() {
85 let field_name = field.ident.as_ref().unwrap();
86 let field_name_str = field_name.to_string();
87 let field_ty = &field.ty;
88
89 field_descriptors.push(quote! {
90 ferray_core::record::FieldDescriptor {
91 name: #field_name_str,
92 dtype: <#field_ty as ferray_core::dtype::Element>::dtype(),
93 offset: std::mem::offset_of!(#name, #field_name),
94 size: std::mem::size_of::<#field_ty>(),
95 }
96 });
97 }
98
99 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
100
101 let expanded = quote! {
102 unsafe impl #impl_generics ferray_core::record::FerrayRecord for #name #ty_generics #where_clause {
103 fn field_descriptors() -> &'static [ferray_core::record::FieldDescriptor] {
104 static FIELDS: std::sync::LazyLock<Vec<ferray_core::record::FieldDescriptor>> =
112 std::sync::LazyLock::new(|| {
113 vec![
114 #(#field_descriptors),*
115 ]
116 });
117 &FIELDS
118 }
119
120 fn record_size() -> usize {
121 std::mem::size_of::<#name>()
122 }
123 }
124 };
125
126 Ok(expanded)
127}
128
129#[proc_macro]
151pub fn s(input: TokenStream) -> TokenStream {
152 let input2: proc_macro2::TokenStream = input.into();
153 let expanded = impl_s_macro(input2);
154 match expanded {
155 Ok(ts) => ts.into(),
156 Err(e) => e.to_compile_error().into(),
157 }
158}
159
160fn impl_s_macro(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
161 let input_str = input.to_string();
176
177 if input_str.trim().is_empty() {
179 return Ok(quote! {
180 ::std::vec::Vec::<ferray_core::dtype::SliceInfoElem>::new()
181 });
182 }
183
184 let components = split_top_level_commas(&input_str);
186 let mut elems = Vec::new();
187
188 for component in &components {
189 let trimmed = component.trim();
190 if trimmed.is_empty() {
191 continue;
192 }
193 elems.push(parse_slice_component(trimmed)?);
194 }
195
196 Ok(quote! {
197 vec![#(#elems),*]
198 })
199}
200
201fn split_top_level_commas(s: &str) -> Vec<String> {
202 let mut result = Vec::new();
203 let mut current = String::new();
204 let mut depth = 0i32;
205
206 for ch in s.chars() {
207 match ch {
208 '(' | '[' | '{' => {
209 depth += 1;
210 current.push(ch);
211 }
212 ')' | ']' | '}' => {
213 depth -= 1;
214 current.push(ch);
215 }
216 ',' if depth == 0 => {
217 result.push(current.clone());
218 current.clear();
219 }
220 _ => {
221 current.push(ch);
222 }
223 }
224 }
225 if !current.is_empty() {
226 result.push(current);
227 }
228 result
229}
230
231fn rfind_top_level_semicolon(s: &str) -> Option<usize> {
233 let mut depth = 0i32;
234 let mut last_idx = None;
235 for (i, ch) in s.char_indices() {
236 match ch {
237 '(' | '[' | '{' => depth += 1,
238 ')' | ']' | '}' => depth -= 1,
239 ';' if depth == 0 => last_idx = Some(i),
240 _ => {}
241 }
242 }
243 last_idx
244}
245
246fn parse_slice_component(s: &str) -> syn::Result<proc_macro2::TokenStream> {
247 let trimmed = s.trim();
248
249 let (range_part, step_part) = if let Some(idx) = rfind_top_level_semicolon(trimmed) {
251 let (rp, sp) = trimmed.split_at(idx);
252 (rp.trim(), Some(sp[1..].trim()))
253 } else {
254 (trimmed, None)
255 };
256
257 let step_expr = if let Some(step_str) = step_part {
258 let step_tokens: proc_macro2::TokenStream = step_str.parse().map_err(|_| {
259 syn::Error::new(
260 proc_macro2::Span::call_site(),
261 format!("invalid step expression: {step_str}"),
262 )
263 })?;
264 quote! { #step_tokens }
265 } else {
266 quote! { 1isize }
267 };
268
269 if range_part == ".." {
271 return Ok(quote! {
273 ferray_core::dtype::SliceInfoElem::Slice {
274 start: 0,
275 end: ::core::option::Option::None,
276 step: #step_expr,
277 }
278 });
279 }
280
281 if let Some(rest) = range_part.strip_prefix("..") {
282 let end_tokens: proc_macro2::TokenStream = rest.parse().map_err(|_| {
284 syn::Error::new(
285 proc_macro2::Span::call_site(),
286 format!("invalid end expression: {rest}"),
287 )
288 })?;
289 return Ok(quote! {
290 ferray_core::dtype::SliceInfoElem::Slice {
291 start: 0,
292 end: ::core::option::Option::Some(#end_tokens),
293 step: #step_expr,
294 }
295 });
296 }
297
298 if let Some(idx) = range_part.find("..") {
299 let start_str = range_part[..idx].trim();
300 let end_str = range_part[idx + 2..].trim();
301
302 let start_tokens: proc_macro2::TokenStream = start_str.parse().map_err(|_| {
303 syn::Error::new(
304 proc_macro2::Span::call_site(),
305 format!("invalid start expression: {start_str}"),
306 )
307 })?;
308
309 if end_str.is_empty() {
310 return Ok(quote! {
312 ferray_core::dtype::SliceInfoElem::Slice {
313 start: #start_tokens,
314 end: ::core::option::Option::None,
315 step: #step_expr,
316 }
317 });
318 }
319
320 let end_tokens: proc_macro2::TokenStream = end_str.parse().map_err(|_| {
321 syn::Error::new(
322 proc_macro2::Span::call_site(),
323 format!("invalid end expression: {end_str}"),
324 )
325 })?;
326
327 return Ok(quote! {
328 ferray_core::dtype::SliceInfoElem::Slice {
329 start: #start_tokens,
330 end: ::core::option::Option::Some(#end_tokens),
331 step: #step_expr,
332 }
333 });
334 }
335
336 if step_part.is_some() {
338 return Err(syn::Error::new(
339 proc_macro2::Span::call_site(),
340 format!("step ';' is not valid for integer indices: {trimmed}"),
341 ));
342 }
343
344 let idx_tokens: proc_macro2::TokenStream = range_part.parse().map_err(|_| {
345 syn::Error::new(
346 proc_macro2::Span::call_site(),
347 format!("invalid index expression: {range_part}"),
348 )
349 })?;
350
351 Ok(quote! {
352 ferray_core::dtype::SliceInfoElem::Index(#idx_tokens)
353 })
354}
355
356#[proc_macro]
372pub fn promoted_type(input: TokenStream) -> TokenStream {
373 let input2: proc_macro2::TokenStream = input.into();
374 match impl_promoted_type(input2) {
375 Ok(ts) => ts.into(),
376 Err(e) => e.to_compile_error().into(),
377 }
378}
379
380fn impl_promoted_type(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
381 let input_str = input.to_string();
382 let parts: Vec<&str> = input_str.split(',').map(|s| s.trim()).collect();
383
384 if parts.len() != 2 {
385 return Err(syn::Error::new(
386 proc_macro2::Span::call_site(),
387 "promoted_type! expects exactly two type arguments: promoted_type!(T1, T2)",
388 ));
389 }
390
391 let t1 = normalize_type(parts[0]);
392 let t2 = normalize_type(parts[1]);
393
394 let result = promote_types_static(&t1, &t2).ok_or_else(|| {
395 syn::Error::new(
396 proc_macro2::Span::call_site(),
397 format!("cannot promote types: {t1} and {t2}"),
398 )
399 })?;
400
401 let result_tokens: proc_macro2::TokenStream = result.parse().map_err(|_| {
402 syn::Error::new(
403 proc_macro2::Span::call_site(),
404 format!("internal error: could not parse result type: {result}"),
405 )
406 })?;
407
408 Ok(result_tokens)
409}
410
411fn normalize_type(s: &str) -> String {
412 s.trim().replace(' ', "")
414}
415
416fn promote_types_static(a: &str, b: &str) -> Option<&'static str> {
420 let ra = type_rank(a)?;
434 let rb = type_rank(b)?;
435
436 match promote_ranks(ra, rb) {
440 "" => None,
441 other => Some(other),
442 }
443}
444
445#[derive(Clone, Copy, PartialEq, Eq)]
446enum TypeKind {
447 Bool,
448 Unsigned,
449 Signed,
450 Float,
451 Complex,
452}
453
454#[derive(Clone, Copy)]
455struct TypeRank {
456 kind: TypeKind,
457 bits: u32,
459}
460
461fn type_rank(s: &str) -> Option<TypeRank> {
462 let result = match s {
463 "bool" => TypeRank {
464 kind: TypeKind::Bool,
465 bits: 1,
466 },
467 "u8" => TypeRank {
468 kind: TypeKind::Unsigned,
469 bits: 8,
470 },
471 "u16" => TypeRank {
472 kind: TypeKind::Unsigned,
473 bits: 16,
474 },
475 "u32" => TypeRank {
476 kind: TypeKind::Unsigned,
477 bits: 32,
478 },
479 "u64" => TypeRank {
480 kind: TypeKind::Unsigned,
481 bits: 64,
482 },
483 "u128" => TypeRank {
484 kind: TypeKind::Unsigned,
485 bits: 128,
486 },
487 "i8" => TypeRank {
488 kind: TypeKind::Signed,
489 bits: 8,
490 },
491 "i16" => TypeRank {
492 kind: TypeKind::Signed,
493 bits: 16,
494 },
495 "i32" => TypeRank {
496 kind: TypeKind::Signed,
497 bits: 32,
498 },
499 "i64" => TypeRank {
500 kind: TypeKind::Signed,
501 bits: 64,
502 },
503 "i128" => TypeRank {
504 kind: TypeKind::Signed,
505 bits: 128,
506 },
507 "f32" => TypeRank {
508 kind: TypeKind::Float,
509 bits: 32,
510 },
511 "f64" => TypeRank {
512 kind: TypeKind::Float,
513 bits: 64,
514 },
515 "Complex<f32>" | "num_complex::Complex<f32>" => TypeRank {
516 kind: TypeKind::Complex,
517 bits: 32,
518 },
519 "Complex<f64>" | "num_complex::Complex<f64>" => TypeRank {
520 kind: TypeKind::Complex,
521 bits: 64,
522 },
523 "f16" | "half::f16" => TypeRank {
524 kind: TypeKind::Float,
525 bits: 16,
526 },
527 "bf16" | "half::bf16" => TypeRank {
528 kind: TypeKind::Float,
529 bits: 16,
530 },
531 _ => return None,
532 };
533 Some(result)
534}
535
536fn promote_ranks(a: TypeRank, b: TypeRank) -> &'static str {
537 use TypeKind::*;
538
539 if a.kind == b.kind && a.bits == b.bits {
541 return rank_to_type(a);
542 }
543
544 if a.kind == Bool {
546 return rank_to_type(b);
547 }
548 if b.kind == Bool {
549 return rank_to_type(a);
550 }
551
552 if a.kind == Complex || b.kind == Complex {
554 let float_bits_a = to_float_bits(a);
555 let float_bits_b = to_float_bits(b);
556 let bits = float_bits_a.max(float_bits_b);
557 return if bits <= 32 {
558 "num_complex::Complex<f32>"
559 } else {
560 "num_complex::Complex<f64>"
561 };
562 }
563
564 if a.kind == Float || b.kind == Float {
566 let float_bits_a = to_float_bits(a);
567 let float_bits_b = to_float_bits(b);
568 let bits = float_bits_a.max(float_bits_b);
569 return if bits <= 32 { "f32" } else { "f64" };
570 }
571
572 match (a.kind, b.kind) {
574 (Unsigned, Unsigned) => {
575 let bits = a.bits.max(b.bits);
576 uint_type(bits)
577 }
578 (Signed, Signed) => {
579 let bits = a.bits.max(b.bits);
580 int_type(bits)
581 }
582 (Unsigned, Signed) | (Signed, Unsigned) => {
583 let (u, s) = if a.kind == Unsigned { (a, b) } else { (b, a) };
584 if u.bits < s.bits {
587 int_type(s.bits)
589 } else {
590 let needed = u.bits.max(s.bits) * 2;
592 if needed <= 128 {
593 int_type(needed)
594 } else {
595 ""
602 }
603 }
604 }
605 _ => "f64", }
607}
608
609fn to_float_bits(r: TypeRank) -> u32 {
611 match r.kind {
612 TypeKind::Bool => 32,
613 TypeKind::Unsigned | TypeKind::Signed => {
614 if r.bits <= 16 { 32 } else { 64 }
617 }
618 TypeKind::Float => r.bits,
619 TypeKind::Complex => r.bits,
620 }
621}
622
623fn uint_type(bits: u32) -> &'static str {
624 match bits {
625 8 => "u8",
626 16 => "u16",
627 32 => "u32",
628 64 => "u64",
629 128 => "u128",
630 _ => "u64",
631 }
632}
633
634fn int_type(bits: u32) -> &'static str {
635 match bits {
636 8 => "i8",
637 16 => "i16",
638 32 => "i32",
639 64 => "i64",
640 128 => "i128",
641 _ => "i64",
642 }
643}
644
645fn rank_to_type(r: TypeRank) -> &'static str {
646 match r.kind {
647 TypeKind::Bool => "bool",
648 TypeKind::Unsigned => uint_type(r.bits),
649 TypeKind::Signed => int_type(r.bits),
650 TypeKind::Float => {
651 if r.bits <= 16 {
652 "half::f16"
653 } else if r.bits <= 32 {
654 "f32"
655 } else {
656 "f64"
657 }
658 }
659 TypeKind::Complex => {
660 if r.bits <= 32 {
661 "num_complex::Complex<f32>"
662 } else {
663 "num_complex::Complex<f64>"
664 }
665 }
666 }
667}