1use proc_macro::TokenStream;
102use proc_macro2::TokenStream as TokenStream2;
103use quote::{format_ident, quote};
104use std::collections::HashSet;
105use syn::parse::{Parse, ParseStream};
106use syn::{
107 parse_macro_input, FnArg, Ident, ImplItem, ImplItemFn, Item, ItemFn, ItemImpl, ItemStruct, Pat,
108 ReturnType, Token, Type, Visibility,
109};
110
111struct RingExtension {
112 prefix: Option<String>,
113 items: Vec<Item>,
114}
115
116impl Parse for RingExtension {
117 fn parse(input: ParseStream) -> syn::Result<Self> {
118 let mut prefix = None;
119 let mut items = Vec::new();
120
121 while !input.is_empty() {
122 if input.peek(Ident) {
123 let ident: Ident = input.parse()?;
124 if ident == "prefix" {
125 let _: Token![:] = input.parse()?;
126 let lit: syn::LitStr = input.parse()?;
127 let _: Token![;] = input.parse()?;
128 prefix = Some(lit.value());
129 continue;
130 } else {
131 return Err(syn::Error::new(ident.span(), "expected 'prefix' or item"));
132 }
133 }
134 items.push(input.parse()?);
135 }
136
137 Ok(RingExtension { prefix, items })
138 }
139}
140
141#[proc_macro]
143pub fn ring_extension(input: TokenStream) -> TokenStream {
144 let module = parse_macro_input!(input as RingExtension);
145
146 let prefix = module.prefix.unwrap_or_default();
147 let prefix_underscore = if prefix.is_empty() {
148 String::new()
149 } else {
150 format!("{}_", prefix)
151 };
152
153 let mut structs_with_custom_new: HashSet<String> = HashSet::new();
154 let mut impl_methods: HashSet<(String, String)> = HashSet::new();
155
156 for item in &module.items {
157 if let Item::Impl(i) = item {
158 if let Type::Path(p) = &*i.self_ty {
159 let struct_name = p.path.segments.last().unwrap().ident.to_string();
160 for impl_item in &i.items {
161 if let ImplItem::Fn(method) = impl_item {
162 let method_name = method.sig.ident.to_string();
163 if method_name == "new" {
164 structs_with_custom_new.insert(struct_name.clone());
165 }
166 impl_methods.insert((struct_name.clone(), method_name));
167 }
168 }
169 }
170 }
171 }
172
173 let mut original_items = Vec::new();
174 let mut generated_code = Vec::new();
175 let mut registrations: Vec<(String, syn::Ident)> = Vec::new();
176
177 for item in module.items {
178 match item {
179 Item::Struct(s) => {
180 let has_custom_new = structs_with_custom_new.contains(&s.ident.to_string());
181 let (orig, generated, regs) =
182 process_struct(&s, &prefix_underscore, has_custom_new, &impl_methods);
183 original_items.push(orig);
184 generated_code.push(generated);
185 registrations.extend(regs);
186 }
187 Item::Impl(i) => {
188 let (orig, generated, regs) = process_impl(&i, &prefix_underscore);
189 original_items.push(orig);
190 generated_code.push(generated);
191 registrations.extend(regs);
192 }
193 Item::Fn(f) => {
194 let (orig, generated, regs) = process_function(&f, &prefix_underscore);
195 original_items.push(orig);
196 generated_code.push(generated);
197 registrations.extend(regs);
198 }
199 other => {
200 original_items.push(quote! { #other });
201 }
202 }
203 }
204
205 let libinit_entries: Vec<_> = registrations
206 .iter()
207 .map(|(name, fn_ident)| {
208 let name_bytes = format!("{}\0", name);
209 quote! { #name_bytes.as_bytes() => #fn_ident }
210 })
211 .collect();
212
213 let expanded = quote! {
214 #(#original_items)*
215 #(#generated_code)*
216
217 ring_libinit! {
218 #(#libinit_entries),*
219 }
220 };
221
222 expanded.into()
223}
224
225fn process_struct(
226 s: &ItemStruct,
227 prefix: &str,
228 has_custom_new: bool,
229 impl_methods: &HashSet<(String, String)>,
230) -> (TokenStream2, TokenStream2, Vec<(String, syn::Ident)>) {
231 let struct_name = &s.ident;
232 let struct_name_lower = struct_name.to_string().to_lowercase();
233 let type_const = format_ident!("{}_TYPE", struct_name.to_string().to_uppercase());
234 let type_const_str = format!("{}\0", struct_name);
235
236 let mut regs = Vec::new();
237
238 let delete_fn_name = format_ident!("ring_{}{}_delete", prefix, struct_name_lower);
239 let delete_ring_name = format!("{}{}_delete", prefix, struct_name_lower);
240 regs.push((delete_ring_name, delete_fn_name.clone()));
241
242 let new_code = if !has_custom_new {
243 let new_fn_name = format_ident!("ring_{}{}_new", prefix, struct_name_lower);
244 let new_ring_name = format!("{}{}_new", prefix, struct_name_lower);
245 regs.push((new_ring_name, new_fn_name.clone()));
246
247 quote! {
248 ring_func!(#new_fn_name, |p| {
249 ring_check_paracount!(p, 0);
250 let obj = Box::new(#struct_name::default());
251 ring_ret_cpointer!(p, Box::into_raw(obj), #type_const);
252 });
253 }
254 } else {
255 quote! {}
256 };
257
258 let mut accessors = Vec::new();
259 let struct_name_str = struct_name.to_string();
260
261 if let syn::Fields::Named(fields) = &s.fields {
262 for field in &fields.named {
263 if !matches!(field.vis, Visibility::Public(_)) {
264 continue;
265 }
266
267 let field_name = field.ident.as_ref().unwrap();
268 let field_name_str = field_name.to_string();
269 let field_type = &field.ty;
270
271 let getter_method = format!("get_{}", field_name_str);
272 let setter_method = format!("set_{}", field_name_str);
273
274 if !impl_methods.contains(&(struct_name_str.clone(), getter_method.clone()))
275 && !impl_methods.contains(&(struct_name_str.clone(), field_name_str.clone()))
276 {
277 let getter_fn =
278 format_ident!("ring_{}{}_get_{}", prefix, struct_name_lower, field_name);
279 let getter_name = format!("{}{}_get_{}", prefix, struct_name_lower, field_name);
280 regs.push((getter_name, getter_fn.clone()));
281
282 let getter_code = generate_field_getter(
283 &getter_fn,
284 struct_name,
285 &type_const,
286 field_name,
287 field_type,
288 );
289 accessors.push(getter_code);
290 }
291
292 if !impl_methods.contains(&(struct_name_str.clone(), setter_method)) {
293 let setter_fn =
294 format_ident!("ring_{}{}_set_{}", prefix, struct_name_lower, field_name);
295 let setter_name = format!("{}{}_set_{}", prefix, struct_name_lower, field_name);
296 regs.push((setter_name, setter_fn.clone()));
297
298 let setter_code = generate_field_setter(
299 &setter_fn,
300 struct_name,
301 &type_const,
302 field_name,
303 field_type,
304 );
305 accessors.push(setter_code);
306 }
307 }
308 }
309
310 let original = quote! { #s };
311
312 let generated = quote! {
313 const #type_const: &[u8] = #type_const_str.as_bytes();
314
315 #new_code
316
317 ring_func!(#delete_fn_name, |p| {
318 ring_check_paracount!(p, 1);
319 ring_check_cpointer!(p, 1);
320 let ptr = ring_get_cpointer!(p, 1, #type_const);
321 if !ptr.is_null() {
322 unsafe { let _ = Box::from_raw(ptr as *mut #struct_name); }
323 }
324 });
325
326 #(#accessors)*
327 };
328
329 (original, generated, regs)
330}
331
332fn process_impl(
333 i: &ItemImpl,
334 prefix: &str,
335) -> (TokenStream2, TokenStream2, Vec<(String, syn::Ident)>) {
336 let struct_name = match &*i.self_ty {
337 Type::Path(p) => p.path.segments.last().unwrap().ident.clone(),
338 _ => return (quote! { #i }, quote! {}, vec![]),
339 };
340
341 let struct_name_lower = struct_name.to_string().to_lowercase();
342 let type_const = format_ident!("{}_TYPE", struct_name.to_string().to_uppercase());
343
344 let mut regs = Vec::new();
345 let mut method_wrappers = Vec::new();
346
347 for item in &i.items {
348 if let ImplItem::Fn(method) = item {
349 if !matches!(method.vis, Visibility::Public(_)) {
350 continue;
351 }
352
353 let method_name = &method.sig.ident;
354 let method_name_str = method_name.to_string();
355
356 if method_name_str == "new" {
357 let (code, name, fn_ident) = generate_custom_new(
358 &struct_name,
359 &struct_name_lower,
360 &type_const,
361 method,
362 prefix,
363 );
364 method_wrappers.push(code);
365 regs.push((name, fn_ident));
366 continue;
367 }
368
369 let has_self = method
370 .sig
371 .inputs
372 .iter()
373 .any(|arg| matches!(arg, FnArg::Receiver(_)));
374
375 if has_self {
376 let (code, name, fn_ident) = generate_method(
377 &struct_name,
378 &struct_name_lower,
379 &type_const,
380 method,
381 prefix,
382 );
383 method_wrappers.push(code);
384 regs.push((name, fn_ident));
385 }
386 }
387 }
388
389 let original = quote! { #i };
390 let generated = quote! { #(#method_wrappers)* };
391
392 (original, generated, regs)
393}
394
395fn process_function(
396 f: &ItemFn,
397 prefix: &str,
398) -> (TokenStream2, TokenStream2, Vec<(String, syn::Ident)>) {
399 let fn_name = &f.sig.ident;
400 let ring_fn_name = format_ident!("ring_{}{}", prefix, fn_name);
401 let ring_name = format!("{}{}", prefix, fn_name);
402
403 let params: Vec<_> = f
404 .sig
405 .inputs
406 .iter()
407 .filter_map(|arg| {
408 if let FnArg::Typed(pat) = arg {
409 let name = if let Pat::Ident(ident) = &*pat.pat {
410 ident.ident.clone()
411 } else {
412 return None;
413 };
414 Some((name, (*pat.ty).clone()))
415 } else {
416 None
417 }
418 })
419 .collect();
420
421 let param_count = params.len();
422 let mut checks = Vec::new();
423 let mut gets = Vec::new();
424 let mut args = Vec::new();
425
426 for (i, (name, ty)) in params.iter().enumerate() {
427 let idx = (i + 1) as i32;
428 let type_str = quote!(#ty).to_string();
429
430 if is_number_type(&type_str) {
431 checks.push(quote! { ring_check_number!(p, #idx); });
432 let cast = get_number_cast(&type_str);
433 gets.push(quote! { let #name = ring_get_number!(p, #idx) as #cast; });
434 } else if is_string_type(&type_str) {
435 checks.push(quote! { ring_check_string!(p, #idx); });
436 gets.push(quote! { let #name = ring_get_string!(p, #idx); });
437 } else if type_str == "bool" {
438 checks.push(quote! { ring_check_number!(p, #idx); });
439 gets.push(quote! { let #name = ring_get_number!(p, #idx) != 0.0; });
440 } else {
441 checks.push(quote! { ring_check_number!(p, #idx); });
442 gets.push(quote! { let #name = ring_get_number!(p, #idx) as _; });
443 }
444 args.push(name.clone());
445 }
446
447 let param_count_i32 = param_count as i32;
448 let return_code = generate_return_code(&f.sig.output, quote! { #fn_name(#(#args),*) });
449
450 let original = quote! { #f };
451 let generated = quote! {
452 ring_func!(#ring_fn_name, |p| {
453 ring_check_paracount!(p, #param_count_i32);
454 #(#checks)*
455 #(#gets)*
456 #return_code
457 });
458 };
459
460 (original, generated, vec![(ring_name, ring_fn_name)])
461}
462
463fn generate_field_getter(
464 fn_name: &syn::Ident,
465 struct_name: &syn::Ident,
466 type_const: &syn::Ident,
467 field_name: &syn::Ident,
468 field_type: &Type,
469) -> TokenStream2 {
470 let type_str = quote!(#field_type).to_string();
471 let return_expr = if is_number_type(&type_str) {
472 quote! { ring_ret_number!(p, obj.#field_name as f64); }
473 } else if is_string_type(&type_str) {
474 quote! { ring_ret_string!(p, &obj.#field_name); }
475 } else if type_str == "bool" {
476 quote! { ring_ret_number!(p, if obj.#field_name { 1.0 } else { 0.0 }); }
477 } else {
478 quote! { ring_ret_number!(p, obj.#field_name as f64); }
479 };
480
481 quote! {
482 ring_func!(#fn_name, |p| {
483 ring_check_paracount!(p, 1);
484 ring_check_cpointer!(p, 1);
485 if let Some(obj) = ring_get_pointer!(p, 1, #struct_name, #type_const) {
486 #return_expr
487 } else {
488 ring_error!(p, concat!("Invalid ", stringify!(#struct_name), " pointer"));
489 }
490 });
491 }
492}
493
494fn generate_field_setter(
495 fn_name: &syn::Ident,
496 struct_name: &syn::Ident,
497 type_const: &syn::Ident,
498 field_name: &syn::Ident,
499 field_type: &Type,
500) -> TokenStream2 {
501 let type_str = quote!(#field_type).to_string();
502
503 let (check, set_expr) = if is_number_type(&type_str) {
504 let cast = get_number_cast(&type_str);
505 (
506 quote! { ring_check_number!(p, 2); },
507 quote! { obj.#field_name = ring_get_number!(p, 2) as #cast; },
508 )
509 } else if is_string_type(&type_str) {
510 (
511 quote! { ring_check_string!(p, 2); },
512 quote! { obj.#field_name = ring_get_string!(p, 2).to_string(); },
513 )
514 } else if type_str == "bool" {
515 (
516 quote! { ring_check_number!(p, 2); },
517 quote! { obj.#field_name = ring_get_number!(p, 2) != 0.0; },
518 )
519 } else {
520 (
521 quote! { ring_check_number!(p, 2); },
522 quote! { obj.#field_name = ring_get_number!(p, 2) as _; },
523 )
524 };
525
526 quote! {
527 ring_func!(#fn_name, |p| {
528 ring_check_paracount!(p, 2);
529 ring_check_cpointer!(p, 1);
530 #check
531 if let Some(obj) = ring_get_pointer!(p, 1, #struct_name, #type_const) {
532 #set_expr
533 } else {
534 ring_error!(p, concat!("Invalid ", stringify!(#struct_name), " pointer"));
535 }
536 });
537 }
538}
539
540fn generate_custom_new(
541 struct_name: &syn::Ident,
542 struct_name_lower: &str,
543 type_const: &syn::Ident,
544 method: &ImplItemFn,
545 prefix: &str,
546) -> (TokenStream2, String, syn::Ident) {
547 let fn_name = format_ident!("ring_{}{}_new", prefix, struct_name_lower);
548 let ring_name = format!("{}{}_new", prefix, struct_name_lower);
549
550 let params: Vec<_> = method
551 .sig
552 .inputs
553 .iter()
554 .filter_map(|arg| {
555 if let FnArg::Typed(pat) = arg {
556 let name = if let Pat::Ident(ident) = &*pat.pat {
557 ident.ident.clone()
558 } else {
559 return None;
560 };
561 Some((name, (*pat.ty).clone()))
562 } else {
563 None
564 }
565 })
566 .collect();
567
568 let param_count = params.len() as i32;
569 let mut checks = Vec::new();
570 let mut gets = Vec::new();
571 let mut args = Vec::new();
572
573 for (i, (name, ty)) in params.iter().enumerate() {
574 let idx = (i + 1) as i32;
575 let type_str = quote!(#ty).to_string();
576
577 if is_number_type(&type_str) {
578 checks.push(quote! { ring_check_number!(p, #idx); });
579 let cast = get_number_cast(&type_str);
580 gets.push(quote! { let #name = ring_get_number!(p, #idx) as #cast; });
581 } else if is_string_type(&type_str) {
582 checks.push(quote! { ring_check_string!(p, #idx); });
583 gets.push(quote! { let #name = ring_get_string!(p, #idx); });
584 } else if type_str == "bool" {
585 checks.push(quote! { ring_check_number!(p, #idx); });
586 gets.push(quote! { let #name = ring_get_number!(p, #idx) != 0.0; });
587 } else {
588 checks.push(quote! { ring_check_number!(p, #idx); });
589 gets.push(quote! { let #name = ring_get_number!(p, #idx) as _; });
590 }
591 args.push(name.clone());
592 }
593
594 let code = quote! {
595 ring_func!(#fn_name, |p| {
596 ring_check_paracount!(p, #param_count);
597 #(#checks)*
598 #(#gets)*
599 let obj = Box::new(#struct_name::new(#(#args),*));
600 ring_ret_cpointer!(p, Box::into_raw(obj), #type_const);
601 });
602 };
603
604 (code, ring_name, fn_name)
605}
606
607fn generate_method(
608 struct_name: &syn::Ident,
609 struct_name_lower: &str,
610 type_const: &syn::Ident,
611 method: &ImplItemFn,
612 prefix: &str,
613) -> (TokenStream2, String, syn::Ident) {
614 let method_name = &method.sig.ident;
615 let fn_name = format_ident!("ring_{}{}_{}", prefix, struct_name_lower, method_name);
616 let ring_name = format!("{}{}_{}", prefix, struct_name_lower, method_name);
617
618 let params: Vec<_> = method
619 .sig
620 .inputs
621 .iter()
622 .filter_map(|arg| {
623 if let FnArg::Typed(pat) = arg {
624 let name = if let Pat::Ident(ident) = &*pat.pat {
625 ident.ident.clone()
626 } else {
627 return None;
628 };
629 Some((name, (*pat.ty).clone()))
630 } else {
631 None
632 }
633 })
634 .collect();
635
636 let param_count = (params.len() + 1) as i32;
637 let mut checks = Vec::new();
638 let mut gets = Vec::new();
639 let mut args = Vec::new();
640
641 for (i, (name, ty)) in params.iter().enumerate() {
642 let idx = (i + 2) as i32;
643 let type_str = quote!(#ty).to_string();
644
645 if is_number_type(&type_str) {
646 checks.push(quote! { ring_check_number!(p, #idx); });
647 let cast = get_number_cast(&type_str);
648 gets.push(quote! { let #name = ring_get_number!(p, #idx) as #cast; });
649 } else if is_string_type(&type_str) {
650 checks.push(quote! { ring_check_string!(p, #idx); });
651 gets.push(quote! { let #name = ring_get_string!(p, #idx); });
652 } else if type_str == "bool" {
653 checks.push(quote! { ring_check_number!(p, #idx); });
654 gets.push(quote! { let #name = ring_get_number!(p, #idx) != 0.0; });
655 } else {
656 checks.push(quote! { ring_check_number!(p, #idx); });
657 gets.push(quote! { let #name = ring_get_number!(p, #idx) as _; });
658 }
659 args.push(name.clone());
660 }
661
662 let return_code =
663 generate_return_code(&method.sig.output, quote! { obj.#method_name(#(#args),*) });
664
665 let code = quote! {
666 ring_func!(#fn_name, |p| {
667 ring_check_paracount!(p, #param_count);
668 ring_check_cpointer!(p, 1);
669 #(#checks)*
670 if let Some(obj) = ring_get_pointer!(p, 1, #struct_name, #type_const) {
671 #(#gets)*
672 #return_code
673 } else {
674 ring_error!(p, concat!("Invalid ", stringify!(#struct_name), " pointer"));
675 }
676 });
677 };
678
679 (code, ring_name, fn_name)
680}
681
682fn generate_return_code(output: &ReturnType, call: TokenStream2) -> TokenStream2 {
683 match output {
684 ReturnType::Default => quote! { #call; },
685 ReturnType::Type(_, ty) => {
686 let type_str = quote!(#ty).to_string();
687 if is_number_type(&type_str) {
688 quote! {
689 let __result = #call;
690 ring_ret_number!(p, __result as f64);
691 }
692 } else if is_string_type(&type_str) {
693 quote! {
694 let __result = #call;
695 ring_ret_string!(p, &__result);
696 }
697 } else if type_str == "bool" {
698 quote! {
699 let __result = #call;
700 ring_ret_number!(p, if __result { 1.0 } else { 0.0 });
701 }
702 } else {
703 quote! {
704 let __result = #call;
705 ring_ret_number!(p, __result as f64);
706 }
707 }
708 }
709 }
710}
711
712fn is_number_type(ty: &str) -> bool {
713 let ty = ty.trim();
714 matches!(
715 ty,
716 "i8" | "i16"
717 | "i32"
718 | "i64"
719 | "i128"
720 | "isize"
721 | "u8"
722 | "u16"
723 | "u32"
724 | "u64"
725 | "u128"
726 | "usize"
727 | "f32"
728 | "f64"
729 )
730}
731
732fn is_string_type(ty: &str) -> bool {
733 let ty = ty.trim();
734 ty == "String" || ty == "& str" || ty.contains("str")
735}
736
737fn get_number_cast(ty: &str) -> TokenStream2 {
738 let ty = ty.trim();
739 match ty {
740 "i8" => quote!(i8),
741 "i16" => quote!(i16),
742 "i32" => quote!(i32),
743 "i64" => quote!(i64),
744 "i128" => quote!(i128),
745 "isize" => quote!(isize),
746 "u8" => quote!(u8),
747 "u16" => quote!(u16),
748 "u32" => quote!(u32),
749 "u64" => quote!(u64),
750 "u128" => quote!(u128),
751 "usize" => quote!(usize),
752 "f32" => quote!(f32),
753 "f64" => quote!(f64),
754 _ => quote!(f64),
755 }
756}