Skip to main content

progenitor_impl/
httpmock.rs

1// Copyright 2025 Oxide Computer Company
2
3//! Generation of mocking extensions for `httpmock`
4
5use openapiv3::OpenAPI;
6use proc_macro2::TokenStream;
7use quote::{ToTokens, format_ident, quote};
8
9use crate::{
10    Generator, Result,
11    method::{
12        BodyContentType, HttpMethod, OperationParameter, OperationParameterKind,
13        OperationParameterType, OperationResponse, OperationResponseStatus,
14    },
15    to_schema::ToSchema,
16    util::{Case, sanitize},
17    validate_openapi,
18};
19
20struct MockOp {
21    when: TokenStream,
22    when_impl: TokenStream,
23    then: TokenStream,
24    then_impl: TokenStream,
25}
26
27impl Generator {
28    /// Generate a strongly-typed mocking extension to the `httpmock` crate.
29    ///
30    /// The `crate_path` parameter should be a valid Rust path corresponding to
31    /// the SDK. This can include `::` and instances of `-` in the crate name
32    /// should be converted to `_`.
33    pub fn httpmock(&mut self, spec: &OpenAPI, crate_path: &str) -> Result<TokenStream> {
34        validate_openapi(spec)?;
35
36        // Convert our components dictionary to schemars
37        let schemas = spec.components.iter().flat_map(|components| {
38            components
39                .schemas
40                .iter()
41                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
42        });
43
44        self.type_space.add_ref_types(schemas)?;
45
46        let raw_methods = spec
47            .paths
48            .iter()
49            .flat_map(|(path, ref_or_item)| {
50                // Exclude externally defined path items.
51                let item = ref_or_item.as_item().unwrap();
52                item.iter().map(move |(method, operation)| {
53                    (path.as_str(), method, operation, &item.parameters)
54                })
55            })
56            .map(|(path, method, operation, path_parameters)| {
57                self.process_operation(operation, &spec.components, path, method, path_parameters)
58            })
59            .collect::<Result<Vec<_>>>()?;
60
61        let methods = raw_methods
62            .iter()
63            .map(|method| self.httpmock_method(method))
64            .collect::<Vec<_>>();
65
66        let op = raw_methods
67            .iter()
68            .map(|method| format_ident!("{}", &method.operation_id))
69            .collect::<Vec<_>>();
70        let when = methods.iter().map(|op| &op.when).collect::<Vec<_>>();
71        let when_impl = methods.iter().map(|op| &op.when_impl).collect::<Vec<_>>();
72        let then = methods.iter().map(|op| &op.then).collect::<Vec<_>>();
73        let then_impl = methods.iter().map(|op| &op.then_impl).collect::<Vec<_>>();
74
75        let crate_path = syn::TypePath {
76            qself: None,
77            path: syn::parse_str(crate_path)
78                .unwrap_or_else(|_| panic!("{} is not a valid identifier", crate_path)),
79        };
80
81        let code = quote! {
82            pub mod operations {
83
84                //! [`When`](::httpmock::When) and [`Then`](::httpmock::Then)
85                //! wrappers for each operation. Each can be converted to
86                //! its inner type with a call to `into_inner()`. This can
87                //! be used to explicitly deviate from permitted values.
88
89                use #crate_path::*;
90
91                #(
92                    pub struct #when(::httpmock::When);
93                    #when_impl
94
95                    pub struct #then(::httpmock::Then);
96                    #then_impl
97                )*
98            }
99
100            /// An extension trait for [`MockServer`](::httpmock::MockServer) that
101            /// adds a method for each operation. These are the equivalent of
102            /// type-checked [`mock()`](::httpmock::MockServer::mock) calls.
103            pub trait MockServerExt {
104                #(
105                    fn #op<F>(&self, config_fn: F) -> ::httpmock::Mock<'_>
106                    where
107                        F: FnOnce(operations::#when, operations::#then);
108                )*
109            }
110
111            impl MockServerExt for ::httpmock::MockServer {
112                #(
113                    fn #op<F>(&self, config_fn: F) -> ::httpmock::Mock<'_>
114                    where
115                        F: FnOnce(operations::#when, operations::#then)
116                    {
117                        self.mock(|when, then| {
118                            config_fn(
119                                operations::#when::new(when),
120                                operations::#then::new(then),
121                            )
122                        })
123                    }
124                )*
125            }
126        };
127        Ok(code)
128    }
129
130    fn httpmock_method(&mut self, method: &crate::method::OperationMethod) -> MockOp {
131        let when_name = sanitize(&format!("{}-when", method.operation_id), Case::Pascal);
132        let when = format_ident!("{}", when_name).to_token_stream();
133        let then_name = sanitize(&format!("{}-then", method.operation_id), Case::Pascal);
134        let then = format_ident!("{}", then_name).to_token_stream();
135
136        let http_method = match &method.method {
137            HttpMethod::Get => quote! { ::httpmock::Method::GET },
138            HttpMethod::Put => quote! { ::httpmock::Method::PUT },
139            HttpMethod::Post => quote! { ::httpmock::Method::POST },
140            HttpMethod::Delete => quote! { ::httpmock::Method::DELETE },
141            HttpMethod::Options => quote! { ::httpmock::Method::OPTIONS },
142            HttpMethod::Head => quote! { ::httpmock::Method::HEAD },
143            HttpMethod::Patch => quote! { ::httpmock::Method::PATCH },
144            HttpMethod::Trace => quote! { ::httpmock::Method::TRACE },
145        };
146
147        let path_re = method.path.as_wildcard();
148
149        // Generate methods corresponding to each parameter so that callers
150        // can specify a prescribed value for that parameter.
151        let when_methods = method.params.iter().map(
152            |OperationParameter {
153                 name,
154                 typ,
155                 kind,
156                 api_name,
157                 description: _,
158             }| {
159                let arg_type_name = match typ {
160                    OperationParameterType::Type(arg_type_id) => self
161                        .type_space
162                        .get_type(arg_type_id)
163                        .unwrap()
164                        .parameter_ident(),
165                    OperationParameterType::RawBody => match kind {
166                        OperationParameterKind::Body(BodyContentType::OctetStream) => quote! {
167                            ::serde_json::Value
168                        },
169                        OperationParameterKind::Body(BodyContentType::Text(_)) => quote! {
170                            String
171                        },
172                        _ => unreachable!(),
173                    },
174                };
175
176                let name_ident = format_ident!("{}", name);
177                let (required, handler) = match kind {
178                    OperationParameterKind::Path => {
179                        let re_fmt = method.path.as_wildcard_param(api_name);
180                        (
181                            true,
182                            quote! {
183                                let re = regex::Regex::new(
184                                    &format!(#re_fmt, value.to_string())
185                                ).unwrap();
186                                Self(self.0.path_matches(re))
187                            },
188                        )
189                    }
190                    OperationParameterKind::Query(true) => (
191                        true,
192                        quote! {
193                            Self(self.0.query_param(#api_name, value.to_string()))
194                        },
195                    ),
196                    OperationParameterKind::Header(true) => (
197                        true,
198                        quote! {
199                            Self(self.0.header(#api_name, value.to_string()))
200                        },
201                    ),
202
203                    OperationParameterKind::Query(false) => (
204                        false,
205                        quote! {
206                            if let Some(value) = value.into() {
207                                Self(self.0.query_param(
208                                    #api_name,
209                                    value.to_string(),
210                                ))
211                            } else {
212                                Self(self.0.query_param_missing(#api_name))
213                            }
214                        },
215                    ),
216                    OperationParameterKind::Header(false) => (
217                        false,
218                        quote! {
219                            if let Some(value) = value.into() {
220                                Self(self.0.header(
221                                    #api_name,
222                                    value.to_string()
223                                ))
224                            } else {
225                                Self(self.0.header_missing(#api_name))
226                            }
227                        },
228                    ),
229                    OperationParameterKind::Body(body_content_type) => match typ {
230                        OperationParameterType::Type(_) => (
231                            true,
232                            quote! {
233                                Self(self.0.json_body_obj(value))
234                            },
235                        ),
236                        OperationParameterType::RawBody => match body_content_type {
237                            BodyContentType::OctetStream => (
238                                true,
239                                quote! {
240                                    Self(self.0.json_body(value))
241                                },
242                            ),
243                            BodyContentType::Text(_) => (
244                                true,
245                                quote! {
246                                    Self(self.0.body(value))
247                                },
248                            ),
249                            _ => unreachable!(),
250                        },
251                    },
252                };
253
254                if required {
255                    // The value is required so we just check for a simple
256                    // match.
257                    quote! {
258                        pub fn #name_ident(self, value: #arg_type_name) -> Self {
259                            #handler
260                        }
261                    }
262                } else {
263                    // For optional values we permit an input that's an
264                    // `Into<Option<T>`. This allows callers to specify a value
265                    // or specify that the parameter must be absent with None.
266
267                    // If the type is a ref, augment it with a lifetime that
268                    // we'll also use in the function
269                    let (lifetime, arg_type_name) = if let syn::Type::Reference(mut rr) =
270                        syn::parse2::<syn::Type>(arg_type_name.clone()).unwrap()
271                    {
272                        rr.lifetime =
273                            Some(syn::Lifetime::new("'a", proc_macro2::Span::call_site()));
274                        (Some(quote! { 'a, }), rr.to_token_stream())
275                    } else {
276                        (None, arg_type_name)
277                    };
278
279                    quote! {
280                        pub fn #name_ident<#lifetime T>(
281                            self,
282                            value: T,
283                        ) -> Self
284                        where
285                            T: Into<Option<#arg_type_name>>,
286                        {
287                            #handler
288                        }
289                    }
290                }
291            },
292        );
293
294        let when_impl = quote! {
295            impl #when {
296                pub fn new(inner: ::httpmock::When) -> Self {
297                    Self(inner
298                        .method(#http_method)
299                        .path_matches(regex::Regex::new(#path_re).unwrap()))
300                }
301
302                pub fn into_inner(self) -> ::httpmock::When {
303                    self.0
304                }
305
306                #(#when_methods)*
307            }
308        };
309
310        // Methods for each discrete response. For specific status codes we use
311        // the name of that code; for classes of codes we use the class name
312        // and require a status code that must be within the prescribed range.
313        let then_methods = method.responses.iter().map(
314            |OperationResponse {
315                 status_code, typ, ..
316             }| {
317                let (value_param, value_use) = match typ {
318                    crate::method::OperationResponseKind::Type(arg_type_id) => {
319                        let arg_type = self.type_space.get_type(arg_type_id).unwrap();
320                        let arg_type_ident = arg_type.parameter_ident();
321                        (
322                            quote! {
323                                value: #arg_type_ident,
324                            },
325                            quote! {
326                                .header("content-type", "application/json")
327                                .json_body_obj(value)
328                            },
329                        )
330                    }
331                    crate::method::OperationResponseKind::None => Default::default(),
332                    crate::method::OperationResponseKind::Raw => (
333                        quote! {
334                            value: ::serde_json::Value,
335                        },
336                        quote! {
337                            .header("content-type", "application/json")
338                            .json_body(value)
339                        },
340                    ),
341                    crate::method::OperationResponseKind::Upgrade => Default::default(),
342                };
343
344                match status_code {
345                    OperationResponseStatus::Code(status_code) => {
346                        let canonical_reason = http::StatusCode::from_u16(*status_code)
347                            .unwrap()
348                            .canonical_reason()
349                            .unwrap();
350                        let fn_name = format_ident!("{}", &sanitize(canonical_reason, Case::Snake));
351
352                        quote! {
353                            pub fn #fn_name(self, #value_param) -> Self {
354                                Self(self.0
355                                    .status(#status_code)
356                                    #value_use
357                                )
358                            }
359                        }
360                    }
361                    OperationResponseStatus::Range(status_type) => {
362                        let status_string = match status_type {
363                            1 => "informational",
364                            2 => "success",
365                            3 => "redirect",
366                            4 => "client_error",
367                            5 => "server_error",
368                            _ => unreachable!(),
369                        };
370                        let fn_name = format_ident!("{}", status_string);
371                        quote! {
372                            pub fn #fn_name(self, status: u16, #value_param) -> Self {
373                                assert_eq!(status / 100u16, #status_type);
374                                Self(self.0
375                                    .status(status)
376                                    #value_use
377                                )
378                            }
379                        }
380                    }
381                    OperationResponseStatus::Default => quote! {
382                        pub fn default_response(self, status: u16, #value_param) -> Self {
383                            Self(self.0
384                                .status(status)
385                                #value_use
386                            )
387                        }
388                    },
389                }
390            },
391        );
392
393        let then_impl = quote! {
394            impl #then {
395                pub fn new(inner: ::httpmock::Then) -> Self {
396                    Self(inner)
397                }
398
399                pub fn into_inner(self) -> ::httpmock::Then {
400                    self.0
401                }
402
403                #(#then_methods)*
404            }
405        };
406
407        MockOp {
408            when,
409            when_impl,
410            then,
411            then_impl,
412        }
413    }
414}