fire_stream_api_codegen/
api.rs1use crate::ApiArgs;
2use crate::util::{
3 validate_signature, fire_api_crate, validate_inputs, ref_type
4};
5
6use proc_macro2::{TokenStream};
7use syn::{Result, ItemFn};
8use quote::{quote, format_ident};
9
10
11pub(crate) fn expand(
12 args: ApiArgs,
13 item: ItemFn
14) -> Result<TokenStream> {
15 let fire = fire_api_crate()?;
16 let req_ty = args.ty;
17
18 validate_signature(&item.sig)?;
19
20 let input_types = validate_inputs(item.sig.inputs.iter())?;
22
23 let struct_name = &item.sig.ident;
24 let struct_gen = generate_struct(&item);
25
26 let ty_as_req = quote!(<#req_ty as #fire::request::Request>);
28
29 let type_action = quote!(
30 type Action = #ty_as_req::Action;
31 );
32
33 let action_fn = quote!(
34 fn action() -> Self::Action
35 where Self: Sized {
36 #ty_as_req::ACTION
37 }
38 );
39
40 let valid_data_fn = {
41 let mut asserts = vec![];
42
43 for ty in &input_types {
44 let error_msg = format!("could not find {}", quote!(#ty));
45
46 let valid_fn = match ref_type(&ty) {
47 Some(reff) => {
48 let elem = &reff.elem;
49 quote!(
50 #fire::util::valid_data_as_ref
51 ::<#elem, #req_ty>
52 )
53 },
54 None => quote!(
55 #fire::util::valid_data_as_owned
56 ::<#ty, #req_ty>
57 )
58 };
59
60 asserts.push(quote!(
61 assert!(#valid_fn(data), #error_msg);
62 ));
63 }
64
65 quote!(
66 fn validate_data(&self, data: &#fire::server::Data) {
67 #(#asserts)*
68 }
69 )
70 };
71
72 let handler_fn = {
73 let asyncness = &item.sig.asyncness;
74 let inputs = &item.sig.inputs;
75 let output = &item.sig.output;
76 let block = &item.block;
77
78 quote!(
79 #asyncness fn handler<B: #fire::message::PacketBytes + Send + 'static>(
80 #inputs
81 ) #output
82 #block
83 )
84 };
85
86 let handle_fn = {
87 let is_async = item.sig.asyncness.is_some();
88 let await_kw = if is_async {
89 quote!(.await)
90 } else {
91 quote!()
92 };
93
94 let mut handler_args_vars = vec![];
95 let mut handler_args = vec![];
96
97 for (idx, ty) in input_types.iter().enumerate() {
98 let get_fn = match ref_type(&ty) {
99 Some(reff) => {
100 let elem = &reff.elem;
101 quote!(
102 #fire::util::get_data_as_ref
103 ::<#elem, #req_ty>
104 )
105 },
106 None => quote!(
107 #fire::util::get_data_as_owned
108 ::<#ty, #req_ty>
109 )
110 };
111
112 let var_name = format_ident!("handler_arg_{idx}");
113
114 handler_args_vars.push(quote!(
115 let #var_name = #get_fn(data, session, &req);
116 ));
117 handler_args.push(quote!(#var_name));
118 }
119
120 let action_ty = quote!(#ty_as_req::Action);
121 let msg_ty = quote!(#fire::message::Message<#action_ty, B>);
122 let from_msg = quote!(#fire::message::FromMessage<#action_ty, B>);
123 let into_msg = quote!(#fire::message::IntoMessage<#action_ty, B>);
124 let api_err = quote!(#fire::error::ApiError);
125
126 quote!(
127 fn handle<'a>(
128 &'a self,
129 msg: #msg_ty,
130 data: &'a #fire::server::Data,
131 session: &'a #fire::server::Session
132 ) -> #fire::util::PinnedFuture<'a,
133 std::result::Result<#msg_ty, #fire::error::Error>
134 > {
135 #handler_fn
136
137 async fn handle_with_api_err<B>(
143 msg: #msg_ty,
144 data: &#fire::server::Data,
145 session: &#fire::server::Session
146 ) -> std::result::Result<#msg_ty, #ty_as_req::Error>
147 where B: #fire::message::PacketBytes + Send + 'static {
148 let req = <#req_ty as #from_msg>::from_message(msg)
149 .map_err(<#ty_as_req::Error as #api_err>::from_message_error)?;
150
151 let req = #fire::util::DataManager::new(req);
152
153 #(#handler_args_vars)*
154
155 let resp: #ty_as_req::Response = handler::<B>(
156 #(#handler_args),*
157 )#await_kw?;
158
159 <#ty_as_req::Response as #into_msg>::into_message(resp)
160 .map_err(<#ty_as_req::Error as #api_err>::from_message_error)
161 }
162
163 #fire::util::PinnedFuture::new(async move {
164 let r = handle_with_api_err(msg, data, session).await;
165
166 match r {
167 Ok(m) => Ok(m),
168 Err(e) => {
169 if data.cfg().log_errors {
170 #fire::tracing::error!(
171 "handler returned an error {:?}", e
172 );
173 }
174
175 <#ty_as_req::Error as #into_msg>::into_message(e)
176 .map(|mut msg| {
177 msg.set_success(false);
178 msg
179 })
180 .map_err(#fire::error::Error::from)
181 }
182 }
183 })
184 }
185 )
186 };
187
188 Ok(quote!(
189 #struct_gen
190
191 impl<B> #fire::request::RequestHandler<B> for #struct_name
192 where B: #fire::message::PacketBytes + Send + 'static {
193 #type_action
194 #action_fn
195 #valid_data_fn
196 #handle_fn
197 }
198 ))
199}
200
201pub(crate) fn generate_struct(item: &ItemFn) -> TokenStream {
202 let struct_name = &item.sig.ident;
203 let attrs = &item.attrs;
204 let vis = &item.vis;
205
206 quote!(
207 #(#attrs)*
208 #[allow(non_camel_case_types)]
209 #vis struct #struct_name;
210 )
211}