waynest_gen/
server.rs

1use heck::{ToSnekCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use tracing::debug;
5
6use crate::{
7    common::write_dispatchers,
8    parser::{ArgType, Interface, Pair},
9    utils::{description_to_docs, find_enum, make_ident, write_enums},
10};
11
12pub fn generate_server_code(current: &[Pair], pairs: &[Pair]) -> TokenStream {
13    let mut modules = Vec::new();
14
15    for pair in current {
16        let protocol = &pair.protocol;
17        debug!("Generating server code for \"{}\"", &protocol.name);
18
19        let mut inner_modules = Vec::new();
20
21        for interface in &protocol.interfaces {
22            let docs = description_to_docs(interface.description.as_ref());
23            let module_name = make_ident(&interface.name);
24            let trait_name = make_ident(interface.name.to_upper_camel_case());
25
26            let trait_docs = format!(
27                "Trait to implement the {} interface. See the module level documentation for more info",
28                interface.name
29            );
30
31            let name = &interface.name;
32            let version = &interface.version;
33
34            let dispatchers = write_dispatchers(interface, interface.requests.clone().into_iter());
35            let requests = write_requests(pairs, pair, interface);
36            let events = write_events(pairs, pair, interface);
37            let enums = write_enums(interface);
38
39            let handler_args = if dispatchers.is_empty() {
40                quote! {
41                    _client: &mut crate::server::Client,
42                    _sender_id: crate::wire::ObjectId,
43                }
44            } else {
45                quote! {
46                    client: &mut crate::server::Client,
47                    sender_id: crate::wire::ObjectId,
48                }
49            };
50
51            inner_modules.push(quote! {
52                #(#docs)*
53                #[allow(clippy::too_many_arguments)]
54                pub mod #module_name {
55                    #[allow(unused)]
56                    use std::os::fd::AsRawFd;
57                    #[allow(unused)]
58                    use futures_util::SinkExt;
59
60                    #(#enums)*
61
62                    #[doc = #trait_docs]
63                    pub trait #trait_name: crate::server::Dispatcher {
64                        const INTERFACE: &'static str = #name;
65                        const VERSION: u32 = #version;
66
67                        fn handle_request(
68                            &self,
69                            #handler_args
70                            message: &mut crate::wire::Message,
71                        ) -> impl Future<Output = crate::server::Result<()>> + Send {
72                            async move {
73                                #[allow(clippy::match_single_binding)]
74                                match message.opcode() {
75                                    #(#dispatchers),*
76                                    opcode => Err(crate::server::error::Error::UnknownOpcode(opcode)),
77                                }
78                            }
79                        }
80
81                        #(#requests)*
82                        #(#events)*
83                    }
84                }
85            })
86        }
87
88        let docs = description_to_docs(protocol.description.as_ref());
89        let module_name = make_ident(&protocol.name);
90
91        modules.push(quote! {
92            #(#docs)*
93            #[allow(clippy::module_inception)]
94            pub mod #module_name {
95                #(#inner_modules)*
96            }
97        })
98    }
99
100    quote! {
101        #(#modules)*
102    }
103}
104
105fn write_requests(pairs: &[Pair], pair: &Pair, interface: &Interface) -> Vec<TokenStream> {
106    let mut requests = Vec::new();
107
108    for request in &interface.requests {
109        let docs = description_to_docs(request.description.as_ref());
110        let name = make_ident(request.name.to_snek_case());
111        let mut args = vec![
112            quote! {&self },
113            quote! {client: &mut crate::server::Client},
114            quote! {sender_id: crate::wire::ObjectId},
115        ];
116
117        for arg in &request.args {
118            let mut ty = arg.to_rust_type_token(arg.find_protocol(pairs).as_ref().unwrap_or(pair));
119
120            if arg.allow_null {
121                ty = quote! {Option<#ty>};
122            }
123
124            let name = make_ident(arg.name.to_snek_case());
125
126            args.push(quote! {#name: #ty})
127        }
128
129        requests.push(quote! {
130            #(#docs)*
131            fn #name(#(#args),*) -> impl Future<Output = crate::server::Result<()>> + Send;
132        });
133    }
134
135    requests
136}
137
138fn write_events(pairs: &[Pair], pair: &Pair, interface: &Interface) -> Vec<TokenStream> {
139    let mut events = Vec::new();
140
141    for (opcode, event) in interface.events.iter().enumerate() {
142        let opcode = opcode as u16;
143
144        let docs = description_to_docs(event.description.as_ref());
145        let name = make_ident(event.name.to_snek_case());
146
147        let mut args = vec![
148            quote! {&self },
149            quote! {client: &mut crate::server::Client},
150            quote! {sender_id: crate::wire::ObjectId},
151        ];
152
153        let mut tracing_fmt = Vec::new();
154        let mut tracing_args = Vec::new();
155
156        for arg in &event.args {
157            let mut ty = arg.to_rust_type_token(arg.find_protocol(pairs).as_ref().unwrap_or(pair));
158
159            let mut map_display = quote! {};
160
161            if arg.allow_null {
162                ty = quote! {Option<#ty>};
163                map_display = quote! {.as_ref().map_or("null".to_string(), |v| v.to_string())}
164            }
165
166            let name = make_ident(arg.name.to_snek_case());
167
168            args.push(quote! {#name: #ty});
169
170            match arg.ty {
171                ArgType::Array => {
172                    tracing_fmt.push("array[{}]");
173                    tracing_args.push(quote! { #name .len() });
174                }
175                ArgType::String => {
176                    tracing_fmt.push("\"{}\"");
177                    tracing_args.push(quote! { #name #map_display });
178                }
179                ArgType::Fd => {
180                    tracing_fmt.push("{}");
181                    tracing_args.push(quote! { #name .as_raw_fd() #map_display });
182                }
183                _ => {
184                    tracing_fmt.push("{}");
185                    tracing_args.push(quote! { #name #map_display });
186                }
187            }
188        }
189
190        let tracing_fmt = tracing_fmt.join(", ");
191
192        let tracing_inner = format!(
193            "-> {interface}#{{}}.{event}({tracing_fmt})",
194            interface = interface.name,
195            event = event.name.to_snek_case()
196        );
197
198        let mut build_args = Vec::new();
199
200        for arg in &event.args {
201            let build_ty = arg.to_caller();
202            let build_ty = format_ident!("put_{build_ty}");
203
204            let mut build_convert = quote! {};
205
206            if let Some((enum_interface, name)) = arg.to_enum_name() {
207                let e = if let Some(enum_interface) = enum_interface {
208                    pairs.iter().find_map(|pair| {
209                        pair.protocol
210                            .interfaces
211                            .iter()
212                            .find(|e| e.name == enum_interface)
213                            .and_then(|interface| interface.enums.iter().find(|e| e.name == name))
214                    })
215                } else {
216                    find_enum(&pair.protocol, &name)
217                };
218
219                if let Some(e) = e {
220                    if e.bitfield {
221                        build_convert = quote! { .bits() };
222                    } else {
223                        build_convert = quote! {  as u32 };
224                    }
225                }
226            }
227
228            let build_name = make_ident(arg.name.to_snek_case());
229            let mut build_name = quote! { #build_name };
230
231            if arg.is_return_option() && !arg.allow_null {
232                build_name = quote! { Some(#build_name) }
233            }
234
235            build_args.push(quote! { .#build_ty(#build_name #build_convert) })
236        }
237
238        events.push(quote! {
239            #(#docs)*
240            fn #name(#(#args),*) -> impl Future<Output = crate::server::Result<()>> + Send {
241                async move {
242                    tracing::debug!(#tracing_inner, sender_id, #(#tracing_args),*);
243
244                    let (payload,fds) = crate::wire::PayloadBuilder::new()
245                        #(#build_args)*
246                        .build();
247
248                    client
249                        .send_message(crate::wire::Message::new(sender_id, #opcode, payload, fds))
250                        .await
251                        .map_err(crate::server::error::Error::IoError)
252                }
253            }
254        });
255    }
256
257    events
258}