progenitor_impl/
httpmock.rs

1// Copyright 2025 Oxide Computer Company
2
3//! Generation of mocking extensions for `httpmock`
4
5use openapiv3::OpenAPI;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8
9use crate::{
10    method::{
11        BodyContentType, HttpMethod, OperationParameter, OperationParameterKind,
12        OperationParameterType, OperationResponse, OperationResponseStatus,
13    },
14    to_schema::ToSchema,
15    util::{sanitize, Case},
16    validate_openapi, Generator, Result,
17};
18
19struct MockOp {
20    when: TokenStream,
21    when_impl: TokenStream,
22    then: TokenStream,
23    then_impl: TokenStream,
24}
25
26impl Generator {
27    /// Generate a strongly-typed mocking extension to the `httpmock` crate.
28    ///
29    /// The `crate_path` parameter should be a valid Rust path corresponding to
30    /// the SDK. This can include `::` and instances of `-` in the crate name
31    /// should be converted to `_`.
32    pub fn httpmock(&mut self, spec: &OpenAPI, crate_path: &str) -> Result<TokenStream> {
33        validate_openapi(spec)?;
34
35        // Convert our components dictionary to schemars
36        let schemas = spec.components.iter().flat_map(|components| {
37            components
38                .schemas
39                .iter()
40                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
41        });
42
43        self.type_space.add_ref_types(schemas)?;
44
45        let raw_methods = spec
46            .paths
47            .iter()
48            .flat_map(|(path, ref_or_item)| {
49                // Exclude externally defined path items.
50                let item = ref_or_item.as_item().unwrap();
51                item.iter().map(move |(method, operation)| {
52                    (path.as_str(), method, operation, &item.parameters)
53                })
54            })
55            .map(|(path, method, operation, path_parameters)| {
56                self.process_operation(operation, &spec.components, path, method, path_parameters)
57            })
58            .collect::<Result<Vec<_>>>()?;
59
60        let methods = raw_methods
61            .iter()
62            .map(|method| self.httpmock_method(method))
63            .collect::<Vec<_>>();
64
65        let op = raw_methods
66            .iter()
67            .map(|method| format_ident!("{}", &method.operation_id))
68            .collect::<Vec<_>>();
69        let when = methods.iter().map(|op| &op.when).collect::<Vec<_>>();
70        let when_impl = methods.iter().map(|op| &op.when_impl).collect::<Vec<_>>();
71        let then = methods.iter().map(|op| &op.then).collect::<Vec<_>>();
72        let then_impl = methods.iter().map(|op| &op.then_impl).collect::<Vec<_>>();
73
74        let crate_path = syn::TypePath {
75            qself: None,
76            path: syn::parse_str(crate_path)
77                .unwrap_or_else(|_| panic!("{} is not a valid identifier", crate_path)),
78        };
79
80        let code = quote! {
81            pub mod operations {
82
83                //! [`When`](::httpmock::When) and [`Then`](::httpmock::Then)
84                //! wrappers for each operation. Each can be converted to
85                //! its inner type with a call to `into_inner()`. This can
86                //! be used to explicitly deviate from permitted values.
87
88                use #crate_path::*;
89
90                #(
91                    pub struct #when(::httpmock::When);
92                    #when_impl
93
94                    pub struct #then(::httpmock::Then);
95                    #then_impl
96                )*
97            }
98
99            /// An extension trait for [`MockServer`](::httpmock::MockServer) that
100            /// adds a method for each operation. These are the equivalent of
101            /// type-checked [`mock()`](::httpmock::MockServer::mock) calls.
102            pub trait MockServerExt {
103                #(
104                    fn #op<F>(&self, config_fn: F) -> ::httpmock::Mock<'_>
105                    where
106                        F: FnOnce(operations::#when, operations::#then);
107                )*
108            }
109
110            impl MockServerExt for ::httpmock::MockServer {
111                #(
112                    fn #op<F>(&self, config_fn: F) -> ::httpmock::Mock<'_>
113                    where
114                        F: FnOnce(operations::#when, operations::#then)
115                    {
116                        self.mock(|when, then| {
117                            config_fn(
118                                operations::#when::new(when),
119                                operations::#then::new(then),
120                            )
121                        })
122                    }
123                )*
124            }
125        };
126        Ok(code)
127    }
128
129    fn httpmock_method(&mut self, method: &crate::method::OperationMethod) -> MockOp {
130        let when_name = sanitize(&format!("{}-when", method.operation_id), Case::Pascal);
131        let when = format_ident!("{}", when_name).to_token_stream();
132        let then_name = sanitize(&format!("{}-then", method.operation_id), Case::Pascal);
133        let then = format_ident!("{}", then_name).to_token_stream();
134
135        let http_method = match &method.method {
136            HttpMethod::Get => quote! { ::httpmock::Method::GET },
137            HttpMethod::Put => quote! { ::httpmock::Method::PUT },
138            HttpMethod::Post => quote! { ::httpmock::Method::POST },
139            HttpMethod::Delete => quote! { ::httpmock::Method::DELETE },
140            HttpMethod::Options => quote! { ::httpmock::Method::OPTIONS },
141            HttpMethod::Head => quote! { ::httpmock::Method::HEAD },
142            HttpMethod::Patch => quote! { ::httpmock::Method::PATCH },
143            HttpMethod::Trace => quote! { ::httpmock::Method::TRACE },
144        };
145
146        let path_re = method.path.as_wildcard();
147
148        // Generate methods corresponding to each parameter so that callers
149        // can specify a prescribed value for that parameter.
150        let when_methods = method.params.iter().map(
151            |OperationParameter {
152                 name,
153                 typ,
154                 kind,
155                 api_name,
156                 description: _,
157             }| {
158                let arg_type_name = match typ {
159                    OperationParameterType::Type(arg_type_id) => self
160                        .type_space
161                        .get_type(arg_type_id)
162                        .unwrap()
163                        .parameter_ident(),
164                    OperationParameterType::RawBody => match kind {
165                        OperationParameterKind::Body(BodyContentType::OctetStream) => quote! {
166                            ::serde_json::Value
167                        },
168                        OperationParameterKind::Body(BodyContentType::Text(_)) => quote! {
169                            String
170                        },
171                        _ => unreachable!(),
172                    },
173                };
174
175                let name_ident = format_ident!("{}", name);
176                let (required, handler) = match kind {
177                    OperationParameterKind::Path => {
178                        let re_fmt = method.path.as_wildcard_param(api_name);
179                        (
180                            true,
181                            quote! {
182                                let re = regex::Regex::new(
183                                    &format!(#re_fmt, value.to_string())
184                                ).unwrap();
185                                Self(self.0.path_matches(re))
186                            },
187                        )
188                    }
189                    OperationParameterKind::Query(true) => (
190                        true,
191                        quote! {
192                            Self(self.0.query_param(#api_name, value.to_string()))
193                        },
194                    ),
195                    OperationParameterKind::Header(true) => (
196                        true,
197                        quote! {
198                            Self(self.0.header(#api_name, value.to_string()))
199                        },
200                    ),
201
202                    OperationParameterKind::Query(false) => (
203                        false,
204                        quote! {
205                            if let Some(value) = value.into() {
206                                Self(self.0.query_param(
207                                    #api_name,
208                                    value.to_string(),
209                                ))
210                            } else {
211                                Self(self.0.matches(|req| {
212                                    req.query_params
213                                        .as_ref()
214                                        .and_then(|qs| {
215                                            qs.iter().find(
216                                                |(key, _)| key == #api_name)
217                                        })
218                                        .is_none()
219                                }))
220                            }
221                        },
222                    ),
223                    OperationParameterKind::Header(false) => (
224                        false,
225                        quote! {
226                            if let Some(value) = value.into() {
227                                Self(self.0.header(
228                                    #api_name,
229                                    value.to_string()
230                                ))
231                            } else {
232                                Self(self.0.matches(|req| {
233                                    req.headers
234                                        .as_ref()
235                                        .and_then(|hs| {
236                                            hs.iter().find(
237                                                |(key, _)| key == #api_name
238                                            )
239                                        })
240                                        .is_none()
241                                }))
242                            }
243                        },
244                    ),
245                    OperationParameterKind::Body(body_content_type) => match typ {
246                        OperationParameterType::Type(_) => (
247                            true,
248                            quote! {
249                                Self(self.0.json_body_obj(value))
250
251                            },
252                        ),
253                        OperationParameterType::RawBody => match body_content_type {
254                            BodyContentType::OctetStream => (
255                                true,
256                                quote! {
257                                    Self(self.0.json_body(value))
258                                },
259                            ),
260                            BodyContentType::Text(_) => (
261                                true,
262                                quote! {
263                                    Self(self.0.body(value))
264                                },
265                            ),
266                            _ => unreachable!(),
267                        },
268                    },
269                };
270
271                if required {
272                    // The value is required so we just check for a simple
273                    // match.
274                    quote! {
275                        pub fn #name_ident(self, value: #arg_type_name) -> Self {
276                            #handler
277                        }
278                    }
279                } else {
280                    // For optional values we permit an input that's an
281                    // `Into<Option<T>`. This allows callers to specify a value
282                    // or specify that the parameter must be absent with None.
283
284                    // If the type is a ref, augment it with a lifetime that
285                    // we'll also use in the function
286                    let (lifetime, arg_type_name) = if let syn::Type::Reference(mut rr) =
287                        syn::parse2::<syn::Type>(arg_type_name.clone()).unwrap()
288                    {
289                        rr.lifetime =
290                            Some(syn::Lifetime::new("'a", proc_macro2::Span::call_site()));
291                        (Some(quote! { 'a, }), rr.to_token_stream())
292                    } else {
293                        (None, arg_type_name)
294                    };
295
296                    quote! {
297                        pub fn #name_ident<#lifetime T>(
298                            self,
299                            value: T,
300                        ) -> Self
301                        where
302                            T: Into<Option<#arg_type_name>>,
303                        {
304                            #handler
305                        }
306                    }
307                }
308            },
309        );
310
311        let when_impl = quote! {
312            impl #when {
313                pub fn new(inner: ::httpmock::When) -> Self {
314                    Self(inner
315                        .method(#http_method)
316                        .path_matches(regex::Regex::new(#path_re).unwrap()))
317                }
318
319                pub fn into_inner(self) -> ::httpmock::When {
320                    self.0
321                }
322
323                #(#when_methods)*
324            }
325        };
326
327        // Methods for each discrete response. For specific status codes we use
328        // the name of that code; for classes of codes we use the class name
329        // and require a status code that must be within the prescribed range.
330        let then_methods = method.responses.iter().map(
331            |OperationResponse {
332                 status_code, typ, ..
333             }| {
334                let (value_param, value_use) = match typ {
335                    crate::method::OperationResponseKind::Type(arg_type_id) => {
336                        let arg_type = self.type_space.get_type(arg_type_id).unwrap();
337                        let arg_type_ident = arg_type.parameter_ident();
338                        (
339                            quote! {
340                                value: #arg_type_ident,
341                            },
342                            quote! {
343                                .header("content-type", "application/json")
344                                .json_body_obj(value)
345                            },
346                        )
347                    }
348                    crate::method::OperationResponseKind::None => Default::default(),
349                    crate::method::OperationResponseKind::Raw => (
350                        quote! {
351                            value: ::serde_json::Value,
352                        },
353                        quote! {
354                            .header("content-type", "application/json")
355                            .json_body(value)
356                        },
357                    ),
358                    crate::method::OperationResponseKind::Upgrade => Default::default(),
359                };
360
361                match status_code {
362                    OperationResponseStatus::Code(status_code) => {
363                        let canonical_reason = http::StatusCode::from_u16(*status_code)
364                            .unwrap()
365                            .canonical_reason()
366                            .unwrap();
367                        let fn_name = format_ident!("{}", &sanitize(canonical_reason, Case::Snake));
368
369                        quote! {
370                            pub fn #fn_name(self, #value_param) -> Self {
371                                Self(self.0
372                                    .status(#status_code)
373                                    #value_use
374                                )
375                            }
376                        }
377                    }
378                    OperationResponseStatus::Range(status_type) => {
379                        let status_string = match status_type {
380                            1 => "informational",
381                            2 => "success",
382                            3 => "redirect",
383                            4 => "client_error",
384                            5 => "server_error",
385                            _ => unreachable!(),
386                        };
387                        let fn_name = format_ident!("{}", status_string);
388                        quote! {
389                            pub fn #fn_name(self, status: u16, #value_param) -> Self {
390                                assert_eq!(status / 100u16, #status_type);
391                                Self(self.0
392                                    .status(status)
393                                    #value_use
394                                )
395                            }
396                        }
397                    }
398                    OperationResponseStatus::Default => quote! {
399                        pub fn default_response(self, status: u16, #value_param) -> Self {
400                            Self(self.0
401                                .status(status)
402                                #value_use
403                            )
404                        }
405                    },
406                }
407            },
408        );
409
410        let then_impl = quote! {
411            impl #then {
412                pub fn new(inner: ::httpmock::Then) -> Self {
413                    Self(inner)
414                }
415
416                pub fn into_inner(self) -> ::httpmock::Then {
417                    self.0
418                }
419
420                #(#then_methods)*
421            }
422        };
423
424        MockOp {
425            when,
426            when_impl,
427            then,
428            then_impl,
429        }
430    }
431}