1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{Data, DeriveInput, Fields, FnArg, ItemFn, Pat, parse_macro_input};
9
10#[proc_macro_derive(Encode)]
30pub fn derive_encode(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 match impl_encode(&input) {
33 Ok(tokens) => tokens.into(),
34 Err(err) => err.to_compile_error().into(),
35 }
36}
37
38#[proc_macro_derive(Decode)]
58pub fn derive_decode(input: TokenStream) -> TokenStream {
59 let input = parse_macro_input!(input as DeriveInput);
60 match impl_decode(&input) {
61 Ok(tokens) => tokens.into(),
62 Err(err) => err.to_compile_error().into(),
63 }
64}
65
66fn named_fields(input: &DeriveInput) -> syn::Result<&syn::FieldsNamed> {
69 let name = &input.ident;
70 match &input.data {
71 Data::Struct(data) => match &data.fields {
72 Fields::Named(named) => Ok(named),
73 _ => Err(syn::Error::new_spanned(
74 name,
75 "Encode / Decode can only be derived for structs with named fields",
76 )),
77 },
78 Data::Enum(_) => Err(syn::Error::new_spanned(
79 name,
80 "Encode / Decode cannot be derived for enums",
81 )),
82 Data::Union(_) => Err(syn::Error::new_spanned(
83 name,
84 "Encode / Decode cannot be derived for unions",
85 )),
86 }
87}
88
89fn reject_generics(input: &DeriveInput) -> syn::Result<()> {
92 if !input.generics.params.is_empty() {
93 return Err(syn::Error::new_spanned(
94 &input.generics,
95 "Encode / Decode cannot be derived for generic structs",
96 ));
97 }
98 Ok(())
99}
100
101fn impl_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
104 reject_generics(input)?;
105 let name = &input.ident;
106 let fields = named_fields(input)?;
107
108 let encode_stmts: Vec<_> = fields
109 .named
110 .iter()
111 .map(|f| {
112 let ident = f.ident.as_ref().unwrap();
113 quote! {
114 ::conduit_core::Encode::encode(&self.#ident, buf);
115 }
116 })
117 .collect();
118
119 let size_terms: Vec<_> = fields
120 .named
121 .iter()
122 .map(|f| {
123 let ident = f.ident.as_ref().unwrap();
124 quote! {
125 ::conduit_core::Encode::encode_size(&self.#ident)
126 }
127 })
128 .collect();
129
130 let size_expr = if size_terms.is_empty() {
132 quote! { 0 }
133 } else {
134 let first = &size_terms[0];
135 let rest = &size_terms[1..];
136 quote! { #first #(+ #rest)* }
137 };
138
139 Ok(quote! {
140 impl ::conduit_core::Encode for #name {
141 fn encode(&self, buf: &mut Vec<u8>) {
142 #(#encode_stmts)*
143 }
144
145 fn encode_size(&self) -> usize {
146 #size_expr
147 }
148 }
149 })
150}
151
152fn impl_decode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
158 reject_generics(input)?;
159 let name = &input.ident;
160 let fields = named_fields(input)?;
161
162 let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
163
164 let decode_stmts: Vec<_> = fields
165 .named
166 .iter()
167 .map(|f| {
168 let ident = f.ident.as_ref().unwrap();
169 quote! {
170 let #ident = {
171 let (__cdec_v__, __cdec_n__) = ::conduit_core::Decode::decode(&__cdec_src__[__cdec_off__..])?;
172 __cdec_off__ += __cdec_n__;
173 __cdec_v__
174 };
175 }
176 })
177 .collect();
178
179 let field_names: Vec<_> = fields
180 .named
181 .iter()
182 .map(|f| f.ident.as_ref().unwrap())
183 .collect();
184
185 let min_size_expr = if field_types.is_empty() {
187 quote! { 0 }
188 } else {
189 let tys = &field_types;
190 quote! { 0 #(+ <#tys as ::conduit_core::Decode>::MIN_SIZE)* }
191 };
192
193 Ok(quote! {
194 impl ::conduit_core::Decode for #name {
195 const MIN_SIZE: usize = #min_size_expr;
196
197 fn decode(__cdec_src__: &[u8]) -> Option<(Self, usize)> {
198 if __cdec_src__.len() < Self::MIN_SIZE {
199 return None;
200 }
201 let mut __cdec_off__ = 0usize;
202 #(#decode_stmts)*
203 Some((Self { #(#field_names),* }, __cdec_off__))
204 }
205 }
206 })
207}
208
209#[proc_macro_attribute]
293pub fn command(attr: TokenStream, item: TokenStream) -> TokenStream {
294 if !attr.is_empty() {
295 return syn::Error::new(
296 proc_macro2::Span::call_site(),
297 "#[command] does not accept arguments",
298 )
299 .to_compile_error()
300 .into();
301 }
302 let func = parse_macro_input!(item as ItemFn);
303 match impl_conduit_command(func) {
304 Ok(tokens) => tokens.into(),
305 Err(err) => err.to_compile_error().into(),
306 }
307}
308
309fn is_state_type(ty: &syn::Type) -> bool {
315 if let syn::Type::Reference(type_ref) = ty {
316 return is_state_type(&type_ref.elem);
318 }
319 if let syn::Type::Path(type_path) = ty {
320 if let Some(seg) = type_path.path.segments.last() {
321 return seg.ident == "State";
322 }
323 }
324 false
325}
326
327fn is_app_handle_type(ty: &syn::Type) -> bool {
329 if let syn::Type::Reference(type_ref) = ty {
330 return is_app_handle_type(&type_ref.elem);
331 }
332 if let syn::Type::Path(type_path) = ty {
333 if let Some(seg) = type_path.path.segments.last() {
334 return seg.ident == "AppHandle";
335 }
336 }
337 false
338}
339
340fn extract_state_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
344 let ty = if let syn::Type::Reference(type_ref) = ty {
346 &*type_ref.elem
347 } else {
348 ty
349 };
350
351 if let syn::Type::Path(type_path) = ty {
352 if let Some(seg) = type_path.path.segments.last() {
353 if seg.ident == "State" {
354 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
355 for arg in &args.args {
357 if let syn::GenericArgument::Type(inner_ty) = arg {
358 return Some(inner_ty);
359 }
360 }
361 }
362 }
363 }
364 }
365 None
366}
367
368fn is_window_type(ty: &syn::Type) -> bool {
374 if let syn::Type::Reference(type_ref) = ty {
375 return is_window_type(&type_ref.elem);
376 }
377 if let syn::Type::Path(type_path) = ty {
378 if let Some(seg) = type_path.path.segments.last() {
379 return seg.ident == "Window" || seg.ident == "WebviewWindow";
380 }
381 }
382 false
383}
384
385fn is_webview_type(ty: &syn::Type) -> bool {
387 if let syn::Type::Reference(type_ref) = ty {
388 return is_webview_type(&type_ref.elem);
389 }
390 if let syn::Type::Path(type_path) = ty {
391 if let Some(seg) = type_path.path.segments.last() {
392 return seg.ident == "Webview";
393 }
394 }
395 false
396}
397
398fn is_option_type(ty: &syn::Type) -> bool {
400 if let syn::Type::Path(type_path) = ty {
401 if let Some(seg) = type_path.path.segments.last() {
402 return seg.ident == "Option";
403 }
404 }
405 false
406}
407
408fn is_result_return(output: &syn::ReturnType) -> bool {
410 match output {
411 syn::ReturnType::Default => false,
412 syn::ReturnType::Type(_, ty) => {
413 if let syn::Type::Path(type_path) = ty.as_ref() {
414 if let Some(seg) = type_path.path.segments.last() {
415 return seg.ident == "Result";
416 }
417 }
418 false
419 }
420 }
421}
422
423fn impl_conduit_command(func: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
431 let fn_name = &func.sig.ident;
432 let fn_vis = &func.vis;
433 let fn_sig = &func.sig;
434 let fn_block = &func.block;
435 let fn_attrs = &func.attrs;
436 let is_async = func.sig.asyncness.is_some();
437
438 if !func.sig.generics.params.is_empty() {
439 return Err(syn::Error::new_spanned(
440 &func.sig.generics,
441 "#[command] cannot be used on generic functions",
442 ));
443 }
444
445 if func.sig.generics.where_clause.is_some() {
446 return Err(syn::Error::new_spanned(
447 &func.sig.generics.where_clause,
448 "#[command] cannot be used on functions with where clauses",
449 ));
450 }
451
452 for arg in &func.sig.inputs {
453 if let FnArg::Typed(pat_type) = arg {
454 if matches!(&*pat_type.ty, syn::Type::ImplTrait(_)) {
455 return Err(syn::Error::new_spanned(
456 &pat_type.ty,
457 "#[command] cannot be used with `impl Trait` parameters",
458 ));
459 }
460 }
461 }
462
463 for arg in &func.sig.inputs {
465 if let FnArg::Typed(pat_type) = arg {
466 if !is_state_type(&pat_type.ty)
467 && !is_app_handle_type(&pat_type.ty)
468 && matches!(&*pat_type.ty, syn::Type::Reference(_))
469 {
470 return Err(syn::Error::new_spanned(
471 &pat_type.ty,
472 "#[command] parameters must be owned types (use String instead of &str)",
473 ));
474 }
475 }
476 }
477
478 let handler_struct_name = format_ident!("__conduit_handler_{}", fn_name);
479
480 let mut state_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
482 let mut app_handle_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
483 let mut window_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
484 let mut webview_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
485 let mut regular_params: Vec<(&syn::Ident, &syn::Type)> = Vec::new();
486 let mut all_param_names: Vec<&syn::Ident> = Vec::new();
488
489 for arg in &func.sig.inputs {
490 if let FnArg::Receiver(_) = arg {
491 return Err(syn::Error::new_spanned(
492 arg,
493 "#[command] cannot be used on methods with `self`",
494 ));
495 }
496 if let FnArg::Typed(pat_type) = arg {
497 if let Pat::Ident(pat_ident) = &*pat_type.pat {
498 if pat_ident.by_ref.is_some() {
499 return Err(syn::Error::new_spanned(
500 &pat_type.pat,
501 "#[command] does not support `ref` parameter bindings",
502 ));
503 }
504 let param_name = &pat_ident.ident;
505 let param_type = &*pat_type.ty;
506
507 all_param_names.push(param_name);
508
509 if is_state_type(param_type) {
510 state_params.push((param_name, param_type));
511 } else if is_app_handle_type(param_type) {
512 app_handle_params.push((param_name, param_type));
513 } else if is_window_type(param_type) {
514 window_params.push((param_name, param_type));
515 } else if is_webview_type(param_type) {
516 webview_params.push((param_name, param_type));
517 } else {
518 regular_params.push((param_name, param_type));
519 }
520 } else {
521 return Err(syn::Error::new_spanned(
522 &pat_type.pat,
523 "#[command] requires named parameters",
524 ));
525 }
526 }
527 }
528
529 let is_result = is_result_return(&func.sig.output);
531
532 let has_args = !regular_params.is_empty();
534 let struct_name = format_ident!("__conduit_args_{}", fn_name);
535
536 let regular_names: Vec<_> = regular_params.iter().map(|(n, _)| *n).collect();
537
538 let has_state = !state_params.is_empty();
539 let has_app_handle = !app_handle_params.is_empty();
540 let has_window = !window_params.is_empty();
541 let has_webview = !webview_params.is_empty();
542 let needs_context = has_state || has_app_handle || has_window || has_webview;
543
544 let state_extraction = if needs_context {
546 let state_stmts: Vec<proc_macro2::TokenStream> = state_params
547 .iter()
548 .map(|(name, ty)| {
549 let inner_ty = extract_state_inner_type(ty);
550 match inner_ty {
551 Some(inner) => {
552 quote! {
553 let #name: ::tauri::State<'_, #inner> = ::tauri::Manager::state(&*__app);
554 }
555 }
556 None => {
557 quote! {
559 let #name: #ty = ::tauri::Manager::state(&*__app);
560 }
561 }
562 }
563 })
564 .collect();
565
566 let app_handle_stmts: Vec<proc_macro2::TokenStream> = app_handle_params
567 .iter()
568 .map(|(name, _ty)| {
569 quote! {
570 let #name = __app.clone();
571 }
572 })
573 .collect();
574
575 let window_stmts: Vec<proc_macro2::TokenStream> = window_params
577 .iter()
578 .map(|(name, _ty)| {
579 quote! {
580 let #name = {
581 let __label = __handler_ctx.webview_label.as_ref()
582 .ok_or_else(|| ::conduit_core::Error::Handler(
583 "Window injection requires X-Conduit-Webview header".into()
584 ))?;
585 ::tauri::Manager::get_webview_window(&*__app, __label)
586 .ok_or_else(|| ::conduit_core::Error::Handler(
587 ::std::format!("webview window '{}' not found", __label)
588 ))?
589 };
590 }
591 })
592 .collect();
593
594 let webview_stmts: Vec<proc_macro2::TokenStream> = webview_params
596 .iter()
597 .map(|(name, _ty)| {
598 quote! {
599 let #name = {
600 let __label = __handler_ctx.webview_label.as_ref()
601 .ok_or_else(|| ::conduit_core::Error::Handler(
602 "Webview injection requires X-Conduit-Webview header".into()
603 ))?;
604 ::tauri::Manager::get_webview(&*__app, __label)
605 .ok_or_else(|| ::conduit_core::Error::Handler(
606 ::std::format!("webview '{}' not found", __label)
607 ))?
608 };
609 }
610 })
611 .collect();
612
613 let context_downcast = quote! {
614 let __handler_ctx = __ctx
615 .downcast_ref::<::conduit_core::HandlerContext>()
616 .ok_or_else(|| ::conduit_core::Error::Handler(
617 "internal: handler context must be HandlerContext".into()
618 ))?;
619 let __app = __handler_ctx.app_handle
620 .downcast_ref::<::tauri::AppHandle<::tauri::Wry>>()
621 .ok_or_else(|| ::conduit_core::Error::Handler(
622 "internal: handler context app_handle must be AppHandle<Wry>".into()
623 ))?;
624 };
625
626 quote! {
627 #context_downcast
628 #(#state_stmts)*
629 #(#app_handle_stmts)*
630 #(#window_stmts)*
631 #(#webview_stmts)*
632 }
633 } else {
634 quote! {}
635 };
636
637 let args_deser = if has_args {
639 quote! {
640 let #struct_name { #(#regular_names),* } =
641 ::conduit_core::sonic_rs::from_slice(&__payload)
642 .map_err(::conduit_core::Error::from)?;
643 }
644 } else {
645 quote! {
646 let _ = &__payload;
647 }
648 };
649
650 let fn_call = if is_async {
652 quote! { #fn_name(#(#all_param_names),*).await }
653 } else {
654 quote! { #fn_name(#(#all_param_names),*) }
655 };
656
657 let result_handling = if is_result {
659 quote! {
660 let __result = #fn_call;
661 match __result {
662 ::std::result::Result::Ok(__v) => {
663 ::conduit_core::sonic_rs::to_vec(&__v).map_err(::conduit_core::Error::from)
664 }
665 ::std::result::Result::Err(__e) => {
666 ::std::result::Result::Err(::conduit_core::Error::Handler(__e.to_string()))
667 }
668 }
669 }
670 } else {
671 quote! {
672 let __result = #fn_call;
673 ::conduit_core::sonic_rs::to_vec(&__result).map_err(::conduit_core::Error::from)
674 }
675 };
676
677 let struct_def = if has_args {
679 let field_defs: Vec<proc_macro2::TokenStream> = regular_params
681 .iter()
682 .map(|(name, ty)| {
683 if is_option_type(ty) {
684 quote! { #[serde(default)] #name: #ty }
685 } else {
686 quote! { #name: #ty }
687 }
688 })
689 .collect();
690 quote! {
691 #[doc(hidden)]
692 #[allow(non_camel_case_types)]
693 #[derive(::conduit_core::serde::Deserialize)]
694 #[serde(crate = "::conduit_core::serde", rename_all = "camelCase")]
695 struct #struct_name {
696 #(#field_defs),*
697 }
698 }
699 } else {
700 quote! {}
701 };
702
703 let handler_body = if is_async {
705 quote! {
706 ::conduit_core::HandlerResponse::Async(::std::boxed::Box::pin(async move {
707 #state_extraction
708 #args_deser
709 #result_handling
710 }))
711 }
712 } else {
713 quote! {
714 ::conduit_core::HandlerResponse::Sync((|| -> ::std::result::Result<::std::vec::Vec<u8>, ::conduit_core::Error> {
715 #state_extraction
716 #args_deser
717 #result_handling
718 })())
719 }
720 };
721
722 Ok(quote! {
723 #struct_def
724
725 #(#fn_attrs)*
727 #fn_vis #fn_sig #fn_block
728
729 #[doc(hidden)]
731 #[allow(non_camel_case_types)]
732 #fn_vis struct #handler_struct_name;
733
734 impl ::conduit_core::ConduitHandler for #handler_struct_name {
735 fn call(
736 &self,
737 __payload: ::std::vec::Vec<u8>,
738 __ctx: ::std::sync::Arc<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync>,
739 ) -> ::conduit_core::HandlerResponse {
740 #handler_body
741 }
742 }
743 })
744}
745
746#[proc_macro]
775pub fn handler(input: TokenStream) -> TokenStream {
776 let mut path = parse_macro_input!(input as syn::Path);
777 if let Some(last) = path.segments.last_mut() {
778 last.ident = format_ident!("__conduit_handler_{}", last.ident);
779 }
780 quote! { #path }.into()
781}