retrofit_codegen/
api.rs

1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::{parse_macro_input, FnArg, ItemFn, Pat, PatType, ReturnType, Type};
6
7pub fn api(header: TokenStream, function: TokenStream) -> TokenStream {
8    let has_state = !header.to_string().replace(' ', "").is_empty();
9    let input_fn = parse_macro_input!(function as ItemFn);
10
11    let mut args = input_fn.sig.inputs.clone();
12    if has_state {
13        // Remove last element because it is a state
14        args.pop();
15    }
16
17    let arg_idents: Vec<_> = args
18        .iter()
19        .map(|fn_arg| match fn_arg {
20            FnArg::Typed(PatType { pat, .. }) => match &**pat {
21                Pat::Ident(ident) => ident.ident.clone(),
22                _ => panic!("argument pattern is not a simple ident"),
23            },
24            FnArg::Receiver(_) => panic!("argument is a receiver"),
25        })
26        .collect();
27    let (fn_input_args, struct_types): (Vec<_>, Vec<_>) = args
28        .iter()
29        .zip(arg_idents.iter())
30        .map(|(i, ident)| {
31            let s = i.into_token_stream().to_string();
32            let ty = proc_macro2::TokenStream::from_str(if let Some(i) = s.find('&') {
33                &s[i + 1..]
34            } else {
35                &s[s.find(':').unwrap() + 1..]
36            })
37            .unwrap();
38            (
39                if s.contains('&') {
40                    quote! {& #ident}
41                } else {
42                    quote! {#ident}
43                },
44                quote! {#ident: #ty},
45            )
46        })
47        .unzip();
48
49    let input_fn_ident_string = input_fn.sig.ident.to_string();
50    let data_struct_ident = format_ident!("{}Data", input_fn_ident_string);
51
52    let destructure = if args.is_empty() {quote!{_}} else {quote!{#data_struct_ident{ #(#arg_idents),* }}};
53    let route_args = if has_state {
54        let state = parse_macro_input!(header as Type);
55        quote! {
56            axum::extract::State(state) : axum::extract::State<#state>,
57            axum::Json(#destructure) : axum::Json<#data_struct_ident>,
58        }
59    } else {
60        quote! {axum::Json(#destructure) : axum::Json<#data_struct_ident>}
61    };
62    let pass_through_state = if has_state {
63        if args.is_empty() {
64            quote!{&state}
65        } else {
66            quote! {, &state}
67        }
68    } else {
69        quote! {}
70    };
71
72    let input_fn_ident = input_fn.sig.ident.clone();
73    let return_type = match input_fn.sig.output {
74        ReturnType::Default => quote! {-> anyhow::Result<()>},
75        ReturnType::Type(_, ref ty) => quote! {-> anyhow::Result<#ty>},
76    };
77    let route_ident = format_ident!("{}_route", input_fn_ident_string);
78    let request_ident = format_ident!("{}_request", input_fn_ident_string);
79    TokenStream::from(quote! {
80        // Original function
81        #[cfg(feature = "server")]
82        #[allow(clippy::ptr_arg)]
83        #input_fn
84
85        // Data Struct
86        #[derive(serde::Serialize, serde::Deserialize, Clone)]
87        #[allow(non_camel_case_types)]
88        pub struct #data_struct_ident {
89            #(#struct_types),*
90        }
91
92        // Route function
93        #[cfg(feature = "server")]
94        async fn #route_ident ( #route_args ) -> String {
95            serde_json::to_string(
96                & #input_fn_ident ( #(#fn_input_args),* #pass_through_state)
97            ).unwrap()
98        }
99
100        // Request function
101        #[cfg(feature = "client")]
102        #[allow(clippy::ptr_arg)]
103        pub async fn #request_ident ( #args ) #return_type {
104            // Send request to endpoint
105            #[cfg(not(target_family = "wasm"))]
106            return Ok(reqwest::Client::new()
107                .post(&format!("http://localhost:8000/{}", #input_fn_ident_string))
108                .header("Content-Type", "application/json")
109                .body(serde_json::to_string(
110                    &#data_struct_ident {
111                        #(#arg_idents: #arg_idents.to_owned()),*
112                    }
113                ).unwrap())
114                .send().await?
115                .json().await?);
116
117            #[cfg(target_family = "wasm")]
118            return Ok(reqwasm::http::Request::post(&format!("/{}", #input_fn_ident_string))
119                .header("Content-Type", "application/json")
120                .body(serde_json::to_string(
121                    &#data_struct_ident {
122                        #(#arg_idents: #arg_idents.to_owned()),*
123                    }
124                ).unwrap())
125                .send().await?
126                .json().await?);
127        }
128    })
129}