1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{Data, DeriveInput, Fields, FnArg, ItemFn, Pat, parse_macro_input};
9
10#[proc_macro_derive(Encode)]
30pub fn derive_encode(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 match impl_encode(&input) {
33 Ok(tokens) => tokens.into(),
34 Err(err) => err.to_compile_error().into(),
35 }
36}
37
38#[proc_macro_derive(Decode)]
58pub fn derive_decode(input: TokenStream) -> TokenStream {
59 let input = parse_macro_input!(input as DeriveInput);
60 match impl_decode(&input) {
61 Ok(tokens) => tokens.into(),
62 Err(err) => err.to_compile_error().into(),
63 }
64}
65
66fn named_fields(input: &DeriveInput) -> syn::Result<&syn::FieldsNamed> {
69 let name = &input.ident;
70 match &input.data {
71 Data::Struct(data) => match &data.fields {
72 Fields::Named(named) => Ok(named),
73 _ => Err(syn::Error::new_spanned(
74 name,
75 "Encode / Decode can only be derived for structs with named fields",
76 )),
77 },
78 Data::Enum(_) => Err(syn::Error::new_spanned(
79 name,
80 "Encode / Decode cannot be derived for enums",
81 )),
82 Data::Union(_) => Err(syn::Error::new_spanned(
83 name,
84 "Encode / Decode cannot be derived for unions",
85 )),
86 }
87}
88
89fn reject_generics(input: &DeriveInput) -> syn::Result<()> {
92 if !input.generics.params.is_empty() {
93 return Err(syn::Error::new_spanned(
94 &input.generics,
95 "Encode / Decode cannot be derived for generic structs",
96 ));
97 }
98 Ok(())
99}
100
101fn impl_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
104 reject_generics(input)?;
105 let name = &input.ident;
106 let fields = named_fields(input)?;
107
108 let encode_stmts: Vec<_> = fields
109 .named
110 .iter()
111 .map(|f| {
112 let ident = f.ident.as_ref().unwrap();
113 quote! {
114 ::conduit_core::Encode::encode(&self.#ident, buf);
115 }
116 })
117 .collect();
118
119 let size_terms: Vec<_> = fields
120 .named
121 .iter()
122 .map(|f| {
123 let ident = f.ident.as_ref().unwrap();
124 quote! {
125 ::conduit_core::Encode::encode_size(&self.#ident)
126 }
127 })
128 .collect();
129
130 let size_expr = if size_terms.is_empty() {
132 quote! { 0 }
133 } else {
134 let first = &size_terms[0];
135 let rest = &size_terms[1..];
136 quote! { #first #(+ #rest)* }
137 };
138
139 Ok(quote! {
140 impl ::conduit_core::Encode for #name {
141 fn encode(&self, buf: &mut Vec<u8>) {
142 #(#encode_stmts)*
143 }
144
145 fn encode_size(&self) -> usize {
146 #size_expr
147 }
148 }
149 })
150}
151
152fn impl_decode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
155 reject_generics(input)?;
156 let name = &input.ident;
157 let fields = named_fields(input)?;
158
159 let decode_stmts: Vec<_> = fields
160 .named
161 .iter()
162 .map(|f| {
163 let ident = f.ident.as_ref().unwrap();
164 quote! {
165 let #ident = {
166 let (__cdec_v__, __cdec_n__) = ::conduit_core::Decode::decode(&__cdec_src__[__cdec_off__..])?;
167 __cdec_off__ += __cdec_n__;
168 __cdec_v__
169 };
170 }
171 })
172 .collect();
173
174 let field_names: Vec<_> = fields
175 .named
176 .iter()
177 .map(|f| f.ident.as_ref().unwrap())
178 .collect();
179
180 Ok(quote! {
181 impl ::conduit_core::Decode for #name {
182 fn decode(__cdec_src__: &[u8]) -> Option<(Self, usize)> {
183 let mut __cdec_off__ = 0usize;
184 #(#decode_stmts)*
185 Some((Self { #(#field_names),* }, __cdec_off__))
186 }
187 }
188 })
189}
190
191#[proc_macro_attribute]
275pub fn command(attr: TokenStream, item: TokenStream) -> TokenStream {
276 if !attr.is_empty() {
277 return syn::Error::new(
278 proc_macro2::Span::call_site(),
279 "#[command] does not accept arguments",
280 )
281 .to_compile_error()
282 .into();
283 }
284 let func = parse_macro_input!(item as ItemFn);
285 match impl_conduit_command(func) {
286 Ok(tokens) => tokens.into(),
287 Err(err) => err.to_compile_error().into(),
288 }
289}
290
291fn is_state_type(ty: &syn::Type) -> bool {
297 if let syn::Type::Reference(type_ref) = ty {
298 return is_state_type(&type_ref.elem);
300 }
301 if let syn::Type::Path(type_path) = ty {
302 if let Some(seg) = type_path.path.segments.last() {
303 return seg.ident == "State";
304 }
305 }
306 false
307}
308
309fn is_app_handle_type(ty: &syn::Type) -> bool {
311 if let syn::Type::Reference(type_ref) = ty {
312 return is_app_handle_type(&type_ref.elem);
313 }
314 if let syn::Type::Path(type_path) = ty {
315 if let Some(seg) = type_path.path.segments.last() {
316 return seg.ident == "AppHandle";
317 }
318 }
319 false
320}
321
322fn extract_state_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
326 let ty = if let syn::Type::Reference(type_ref) = ty {
328 &*type_ref.elem
329 } else {
330 ty
331 };
332
333 if let syn::Type::Path(type_path) = ty {
334 if let Some(seg) = type_path.path.segments.last() {
335 if seg.ident == "State" {
336 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
337 for arg in &args.args {
339 if let syn::GenericArgument::Type(inner_ty) = arg {
340 return Some(inner_ty);
341 }
342 }
343 }
344 }
345 }
346 }
347 None
348}
349
350fn is_window_type(ty: &syn::Type) -> bool {
356 if let syn::Type::Reference(type_ref) = ty {
357 return is_window_type(&type_ref.elem);
358 }
359 if let syn::Type::Path(type_path) = ty {
360 if let Some(seg) = type_path.path.segments.last() {
361 return seg.ident == "Window" || seg.ident == "WebviewWindow";
362 }
363 }
364 false
365}
366
367fn is_webview_type(ty: &syn::Type) -> bool {
369 if let syn::Type::Reference(type_ref) = ty {
370 return is_webview_type(&type_ref.elem);
371 }
372 if let syn::Type::Path(type_path) = ty {
373 if let Some(seg) = type_path.path.segments.last() {
374 return seg.ident == "Webview";
375 }
376 }
377 false
378}
379
380fn is_option_type(ty: &syn::Type) -> bool {
382 if let syn::Type::Path(type_path) = ty {
383 if let Some(seg) = type_path.path.segments.last() {
384 return seg.ident == "Option";
385 }
386 }
387 false
388}
389
390fn is_result_return(output: &syn::ReturnType) -> bool {
392 match output {
393 syn::ReturnType::Default => false,
394 syn::ReturnType::Type(_, ty) => {
395 if let syn::Type::Path(type_path) = ty.as_ref() {
396 if let Some(seg) = type_path.path.segments.last() {
397 return seg.ident == "Result";
398 }
399 }
400 false
401 }
402 }
403}
404
405fn impl_conduit_command(func: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
413 let fn_name = &func.sig.ident;
414 let fn_vis = &func.vis;
415 let fn_sig = &func.sig;
416 let fn_block = &func.block;
417 let fn_attrs = &func.attrs;
418 let is_async = func.sig.asyncness.is_some();
419
420 if !func.sig.generics.params.is_empty() {
421 return Err(syn::Error::new_spanned(
422 &func.sig.generics,
423 "#[command] cannot be used on generic functions",
424 ));
425 }
426
427 if func.sig.generics.where_clause.is_some() {
428 return Err(syn::Error::new_spanned(
429 &func.sig.generics.where_clause,
430 "#[command] cannot be used on functions with where clauses",
431 ));
432 }
433
434 for arg in &func.sig.inputs {
435 if let FnArg::Typed(pat_type) = arg {
436 if matches!(&*pat_type.ty, syn::Type::ImplTrait(_)) {
437 return Err(syn::Error::new_spanned(
438 &pat_type.ty,
439 "#[command] cannot be used with `impl Trait` parameters",
440 ));
441 }
442 }
443 }
444
445 for arg in &func.sig.inputs {
447 if let FnArg::Typed(pat_type) = arg {
448 if !is_state_type(&pat_type.ty)
449 && !is_app_handle_type(&pat_type.ty)
450 && matches!(&*pat_type.ty, syn::Type::Reference(_))
451 {
452 return Err(syn::Error::new_spanned(
453 &pat_type.ty,
454 "#[command] parameters must be owned types (use String instead of &str)",
455 ));
456 }
457 }
458 }
459
460 let handler_struct_name = format_ident!("__conduit_handler_{}", fn_name);
461
462 let mut state_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
464 let mut app_handle_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
465 let mut window_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
466 let mut webview_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
467 let mut regular_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
468 let mut all_param_names: Vec<&syn::Ident> = Vec::new();
470
471 for arg in &func.sig.inputs {
472 if let FnArg::Receiver(_) = arg {
473 return Err(syn::Error::new_spanned(
474 arg,
475 "#[command] cannot be used on methods with `self`",
476 ));
477 }
478 if let FnArg::Typed(pat_type) = arg {
479 if let Pat::Ident(pat_ident) = &*pat_type.pat {
480 if pat_ident.by_ref.is_some() {
481 return Err(syn::Error::new_spanned(
482 &pat_type.pat,
483 "#[command] does not support `ref` parameter bindings",
484 ));
485 }
486 let param_name = &pat_ident.ident;
487 let param_type = &*pat_type.ty;
488
489 all_param_names.push(param_name);
490
491 if is_state_type(param_type) {
492 state_params.push((param_name, param_type));
493 } else if is_app_handle_type(param_type) {
494 app_handle_params.push((param_name, param_type));
495 } else if is_window_type(param_type) {
496 window_params.push((param_name, param_type));
497 } else if is_webview_type(param_type) {
498 webview_params.push((param_name, param_type));
499 } else {
500 regular_params.push((param_name, param_type));
501 }
502 } else {
503 return Err(syn::Error::new_spanned(
504 &pat_type.pat,
505 "#[command] requires named parameters",
506 ));
507 }
508 }
509 }
510
511 let is_result = is_result_return(&func.sig.output);
513
514 let has_args = !regular_params.is_empty();
516 let struct_name = format_ident!("__conduit_args_{}", fn_name);
517
518 let regular_names: Vec<_> = regular_params.iter().map(|(n, _)| *n).collect();
519
520 let has_state = !state_params.is_empty();
521 let has_app_handle = !app_handle_params.is_empty();
522 let has_window = !window_params.is_empty();
523 let has_webview = !webview_params.is_empty();
524 let needs_context = has_state || has_app_handle || has_window || has_webview;
525
526 let state_extraction = if needs_context {
528 let state_stmts: Vec<proc_macro2::TokenStream> = state_params
529 .iter()
530 .map(|(name, ty)| {
531 let inner_ty = extract_state_inner_type(ty);
532 match inner_ty {
533 Some(inner) => {
534 quote! {
535 let #name: ::tauri::State<'_, #inner> = ::tauri::Manager::state(&*__app);
536 }
537 }
538 None => {
539 quote! {
541 let #name: #ty = ::tauri::Manager::state(&*__app);
542 }
543 }
544 }
545 })
546 .collect();
547
548 let app_handle_stmts: Vec<proc_macro2::TokenStream> = app_handle_params
549 .iter()
550 .map(|(name, _ty)| {
551 quote! {
552 let #name = __app.clone();
553 }
554 })
555 .collect();
556
557 let window_stmts: Vec<proc_macro2::TokenStream> = window_params
559 .iter()
560 .map(|(name, _ty)| {
561 quote! {
562 let #name = {
563 let __label = __handler_ctx.webview_label.as_ref()
564 .ok_or_else(|| ::conduit_core::Error::Handler(
565 "Window injection requires X-Conduit-Webview header".into()
566 ))?;
567 ::tauri::Manager::get_webview_window(&*__app, __label)
568 .ok_or_else(|| ::conduit_core::Error::Handler(
569 ::std::format!("webview window '{}' not found", __label)
570 ))?
571 };
572 }
573 })
574 .collect();
575
576 let webview_stmts: Vec<proc_macro2::TokenStream> = webview_params
578 .iter()
579 .map(|(name, _ty)| {
580 quote! {
581 let #name = {
582 let __label = __handler_ctx.webview_label.as_ref()
583 .ok_or_else(|| ::conduit_core::Error::Handler(
584 "Webview injection requires X-Conduit-Webview header".into()
585 ))?;
586 ::tauri::Manager::get_webview(&*__app, __label)
587 .ok_or_else(|| ::conduit_core::Error::Handler(
588 ::std::format!("webview '{}' not found", __label)
589 ))?
590 };
591 }
592 })
593 .collect();
594
595 let context_downcast = quote! {
596 let __handler_ctx = __ctx
597 .downcast_ref::<::conduit_core::HandlerContext>()
598 .ok_or_else(|| ::conduit_core::Error::Handler(
599 "internal: handler context must be HandlerContext".into()
600 ))?;
601 let __app = __handler_ctx.app_handle
602 .downcast_ref::<::tauri::AppHandle<::tauri::Wry>>()
603 .ok_or_else(|| ::conduit_core::Error::Handler(
604 "internal: handler context app_handle must be AppHandle<Wry>".into()
605 ))?;
606 };
607
608 quote! {
609 #context_downcast
610 #(#state_stmts)*
611 #(#app_handle_stmts)*
612 #(#window_stmts)*
613 #(#webview_stmts)*
614 }
615 } else {
616 quote! {}
617 };
618
619 let args_deser = if has_args {
621 quote! {
622 let #struct_name { #(#regular_names),* } =
623 ::conduit_core::sonic_rs::from_slice(&__payload)
624 .map_err(::conduit_core::Error::from)?;
625 }
626 } else {
627 quote! {
628 let _ = &__payload;
629 }
630 };
631
632 let fn_call = if is_async {
634 quote! { #fn_name(#(#all_param_names),*).await }
635 } else {
636 quote! { #fn_name(#(#all_param_names),*) }
637 };
638
639 let result_handling = if is_result {
641 quote! {
642 let __result = #fn_call;
643 match __result {
644 ::std::result::Result::Ok(__v) => {
645 ::conduit_core::sonic_rs::to_vec(&__v).map_err(::conduit_core::Error::from)
646 }
647 ::std::result::Result::Err(__e) => {
648 ::std::result::Result::Err(::conduit_core::Error::Handler(__e.to_string()))
649 }
650 }
651 }
652 } else {
653 quote! {
654 let __result = #fn_call;
655 ::conduit_core::sonic_rs::to_vec(&__result).map_err(::conduit_core::Error::from)
656 }
657 };
658
659 let struct_def = if has_args {
661 let field_defs: Vec<proc_macro2::TokenStream> = regular_params
663 .iter()
664 .map(|(name, ty)| {
665 if is_option_type(ty) {
666 quote! { #[serde(default)] #name: #ty }
667 } else {
668 quote! { #name: #ty }
669 }
670 })
671 .collect();
672 quote! {
673 #[doc(hidden)]
674 #[allow(non_camel_case_types)]
675 #[derive(::conduit_core::serde::Deserialize)]
676 #[serde(crate = "::conduit_core::serde", rename_all = "camelCase")]
677 struct #struct_name {
678 #(#field_defs),*
679 }
680 }
681 } else {
682 quote! {}
683 };
684
685 let handler_body = if is_async {
687 quote! {
688 ::conduit_core::HandlerResponse::Async(::std::boxed::Box::pin(async move {
689 #state_extraction
690 #args_deser
691 #result_handling
692 }))
693 }
694 } else {
695 quote! {
696 ::conduit_core::HandlerResponse::Sync((|| -> ::std::result::Result<::std::vec::Vec<u8>, ::conduit_core::Error> {
697 #state_extraction
698 #args_deser
699 #result_handling
700 })())
701 }
702 };
703
704 Ok(quote! {
705 #struct_def
706
707 #(#fn_attrs)*
709 #fn_vis #fn_sig #fn_block
710
711 #[doc(hidden)]
713 #[allow(non_camel_case_types)]
714 #fn_vis struct #handler_struct_name;
715
716 impl ::conduit_core::ConduitHandler for #handler_struct_name {
717 fn call(
718 &self,
719 __payload: ::std::vec::Vec<u8>,
720 __ctx: ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>,
721 ) -> ::conduit_core::HandlerResponse {
722 #handler_body
723 }
724 }
725 })
726}
727
728#[proc_macro]
757pub fn handler(input: TokenStream) -> TokenStream {
758 let mut path = parse_macro_input!(input as syn::Path);
759 if let Some(last) = path.segments.last_mut() {
760 last.ident = format_ident!("__conduit_handler_{}", last.ident);
761 }
762 quote! { #path }.into()
763}