wiremock_grpc_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5    braced,
6    parse::{Parse, ParseStream},
7    punctuated::Punctuated,
8    Ident, Result, Token,
9};
10
11/// Generates a complete mock gRPC server with RPC method builders.
12///
13/// This macro creates:
14/// - A mock server struct (`{ServiceName}MockServer` or custom name with `as`)
15/// - An extension trait for `WhenBuilder` with `path_{method_name}` methods
16///
17/// # Syntax
18///
19/// ```
20/// # macro_rules! generate_svc {
21/// #     ($($tt:tt)*) => {};
22/// # }
23/// generate_svc! {
24///     package hello;
25///     service Greeter {
26///         SayHello,
27///         WeatherInfo,
28///     }
29/// }
30/// ```
31///
32/// Or with a custom server name:
33///
34/// ```
35/// # macro_rules! generate_svc {
36/// #     ($($tt:tt)*) => {};
37/// # }
38/// generate_svc! {
39///     package hello;
40///     service Greeter as MyMockServer {
41///         SayHello,
42///         WeatherInfo,
43///     }
44/// }
45/// ```
46///
47/// # Generated Code
48///
49/// The macro generates:
50/// - `{ServiceName}MockServer` (or custom name) - the mock server struct
51/// - `{ServiceName}TypeSafeExt` trait with `path_{method_name}` methods
52///
53/// # Example
54///
55/// ```no_run
56/// # macro_rules! generate_svc {
57/// #     ($($tt:tt)*) => {};
58/// # }
59/// # struct MockBuilder;
60/// # impl MockBuilder {
61/// #     fn when() -> Self { MockBuilder }
62/// #     fn path_say_hello(self) -> Self { self }
63/// #     fn then(self) -> Self { self }
64/// #     fn return_body<F>(self, _f: F) -> Self { self }
65/// # }
66/// # struct HelloReply { message: String }
67/// # struct GreeterMockServer;
68/// # impl GreeterMockServer {
69/// #     async fn start_default() -> Self { GreeterMockServer }
70/// #     fn setup(&mut self, _builder: MockBuilder) {}
71/// # }
72/// generate_svc! {
73///     package hello;
74///     service Greeter {
75///         SayHello,
76///         WeatherInfo,
77///     }
78/// }
79///
80/// async fn example() {
81///     let mut server = GreeterMockServer::start_default().await;
82///
83///     server.setup(
84///         MockBuilder::when()
85///             .path_say_hello()
86///             .then()
87///             .return_body(|| HelloReply { message: "Hi".into() })
88///     );
89///
90///     // ... test client code
91/// }
92/// ```
93#[proc_macro]
94pub fn generate_svc(input: TokenStream) -> TokenStream {
95    let service_def = syn::parse_macro_input!(input as ServiceDefinition);
96    service_def.generate().into()
97}
98
99struct ServiceDefinition {
100    package: String,
101    service_name: Ident,
102    server_name: Ident,
103    methods: Punctuated<Ident, Token![,]>,
104}
105
106impl Parse for ServiceDefinition {
107    fn parse(input: ParseStream) -> Result<Self> {
108        // package keyword (not used)
109        let _package_kw: Ident = input.parse()?;
110        if _package_kw != "package" {
111            return Err(syn::Error::new(
112                _package_kw.span(),
113                "expected `package` keyword",
114            ));
115        }
116
117        // parse the package name (x.y.z)
118        let first: Ident = input.parse()?;
119        let mut package = first.to_string();
120        while input.peek(Token![.]) {
121            let _dot: Token![.] = input.parse()?;
122            let next: Ident = input.parse()?;
123            package.push('.');
124            package.push_str(&next.to_string());
125        }
126
127        let _semi: Token![;] = input.parse()?;
128
129        // next line: service <name> [as <custom name>]
130        let _service_kw: Ident = input.parse()?;
131        if _service_kw != "service" {
132            return Err(syn::Error::new(
133                _service_kw.span(),
134                "expected `service` keyword",
135            ));
136        }
137        let service_name: Ident = input.parse()?;
138
139        let server_name = if input.peek(Token![as]) {
140            let _as: Token![as] = input.parse()?;
141            input.parse()?
142        } else {
143            format_ident!("{}MockServer", service_name)
144        };
145
146        let content;
147        braced!(content in input);
148
149        let methods = content.parse_terminated(Ident::parse, Token![,])?;
150
151        Ok(ServiceDefinition {
152            package,
153            service_name,
154            server_name,
155            methods,
156        })
157    }
158}
159
160impl ServiceDefinition {
161    fn generate(&self) -> TokenStream2 {
162        let ext_trait = self.generate_ext_trait();
163        let mock_server = self.generate_mock_server();
164
165        quote! {
166            #ext_trait
167            #mock_server
168        }
169    }
170
171    fn generate_ext_trait(&self) -> TokenStream2 {
172        let trait_name = format_ident!("{}TypeSafeExt", self.service_name);
173        let package = &self.package;
174        let service_name = &self.service_name;
175
176        let method_signatures: Vec<_> = self
177            .methods
178            .iter()
179            .map(|method| {
180                let fn_name = format_ident!("path_{}", to_snake_case(&method.to_string()));
181                quote! {
182                    fn #fn_name(&self) -> Self;
183                }
184            })
185            .collect();
186
187        let method_impls: Vec<_> = self
188            .methods
189            .iter()
190            .map(|method| {
191                let fn_name = format_ident!("path_{}", to_snake_case(&method.to_string()));
192                let path = format!("/{}.{}/{}", package, service_name, method);
193                quote! {
194                    fn #fn_name(&self) -> Self {
195                        #[expect(deprecated)]
196                        self.path(#path)
197                    }
198                }
199            })
200            .collect();
201
202        quote! {
203            pub trait #trait_name {
204                #(#method_signatures)*
205            }
206
207            impl #trait_name for wiremock_grpc::WhenBuilder {
208                #(#method_impls)*
209            }
210        }
211    }
212
213    fn generate_mock_server(&self) -> TokenStream2 {
214        let server_name = &self.server_name;
215        let package = &self.package;
216        let service_name = &self.service_name;
217        let prefix = format!("{}.{}", package, service_name);
218
219        quote! {
220            #[derive(Clone)]
221            pub struct #server_name(wiremock_grpc::GrpcServer);
222
223            impl ::std::ops::Deref for #server_name {
224                type Target = wiremock_grpc::GrpcServer;
225
226                fn deref(&self) -> &Self::Target {
227                    &self.0
228                }
229            }
230
231            impl ::std::ops::DerefMut for #server_name {
232                fn deref_mut(&mut self) -> &mut Self::Target {
233                    &mut self.0
234                }
235            }
236
237            impl<B> wiremock_grpc::tonic::codegen::Service<wiremock_grpc::tonic::codegen::http::Request<B>> for #server_name
238            where
239                B: wiremock_grpc::http_body::Body + Send + 'static,
240                B::Error: Into<wiremock_grpc::tonic::codegen::StdError> + Send + 'static,
241            {
242                type Response = wiremock_grpc::tonic::codegen::http::Response<wiremock_grpc::tonic::body::Body>;
243                type Error = ::std::convert::Infallible;
244                type Future = wiremock_grpc::tonic::codegen::BoxFuture<Self::Response, Self::Error>;
245
246                fn poll_ready(
247                    &mut self,
248                    _cx: &mut ::std::task::Context<'_>,
249                ) -> ::std::task::Poll<Result<(), Self::Error>> {
250                    ::std::task::Poll::Ready(Ok(()))
251                }
252
253                fn call(&mut self, req: wiremock_grpc::tonic::codegen::http::Request<B>) -> Self::Future {
254                    self.0.handle_request(req)
255                }
256            }
257
258            impl wiremock_grpc::tonic::server::NamedService for #server_name {
259                const NAME: &'static str = #prefix;
260            }
261
262            impl #server_name {
263                pub async fn start_default() -> Self {
264                    let port = wiremock_grpc::GrpcServer::find_unused_port()
265                        .await
266                        .expect("Unable to find an open port");
267
268                    Self(wiremock_grpc::GrpcServer::new(port)).start_internal().await
269                }
270
271                pub async fn start(port: u16) -> Self {
272                    Self(wiremock_grpc::GrpcServer::new(port)).start_internal().await
273                }
274
275                pub async fn start_with_addr(addr: ::std::net::SocketAddr) -> Self {
276                    Self(wiremock_grpc::GrpcServer::with_addr(addr)).start_internal().await
277                }
278
279                async fn start_internal(&mut self) -> Self {
280                    let address = self.address().clone();
281                    let thread = ::tokio::spawn(
282                        wiremock_grpc::tonic::transport::Server::builder()
283                            .add_service(self.clone())
284                            .serve(address),
285                    );
286                    self._start(thread).await;
287                    self.to_owned()
288                }
289            }
290        }
291    }
292}
293
294fn to_snake_case(s: &str) -> String {
295    let mut result = String::new();
296    let chars: Vec<char> = s.chars().collect();
297
298    for (i, &ch) in chars.iter().enumerate() {
299        if ch.is_uppercase() {
300            if i > 0 {
301                let prev = chars[i - 1];
302                let next = chars.get(i + 1).copied();
303
304                // Insert underscore at:
305                // - Transition from lowercase/digit/underscore to uppercase, or
306                // - Boundary between an acronym and a following word, e.g. "HTTPServer".
307                if !prev.is_uppercase() || next.map(|n| n.is_lowercase()).unwrap_or(false) {
308                    result.push('_');
309                }
310            }
311
312            for lower in ch.to_lowercase() {
313                result.push(lower);
314            }
315        } else {
316            result.push(ch);
317        }
318    }
319    result
320}
321
322#[test]
323fn test_to_snake_case() {
324    assert_eq!(to_snake_case("HTTPServer"), "http_server");
325    assert_eq!(to_snake_case("GetWeather"), "get_weather");
326    assert_eq!(to_snake_case("nothing"), "nothing");
327    assert_eq!(to_snake_case(""), "");
328}