progenitor_middleware_impl/
lib.rs

1// Copyright 2025 Oxide Computer Company
2
3//! Core implementation for the progenitor OpenAPI client generator.
4
5#![deny(missing_docs)]
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8
9use openapiv3::OpenAPI;
10use proc_macro2::TokenStream;
11use quote::quote;
12use serde::Deserialize;
13use thiserror::Error;
14use typify::{TypeSpace, TypeSpaceSettings};
15
16use crate::to_schema::ToSchema;
17
18pub use typify::CrateVers;
19pub use typify::TypeSpaceImpl as TypeImpl;
20pub use typify::TypeSpacePatch as TypePatch;
21pub use typify::UnknownPolicy;
22
23mod cli;
24mod httpmock;
25mod method;
26mod template;
27mod to_schema;
28mod util;
29
30#[allow(missing_docs)]
31#[derive(Error, Debug)]
32pub enum Error {
33    #[error("unexpected value type {0}: {1}")]
34    BadValue(String, serde_json::Value),
35    #[error("type error {0}")]
36    TypeError(#[from] typify::Error),
37    #[error("unexpected or unhandled format in the OpenAPI document {0}")]
38    UnexpectedFormat(String),
39    #[error("invalid operation path {0}")]
40    InvalidPath(String),
41    #[error("invalid dropshot extension use: {0}")]
42    InvalidExtension(String),
43    #[error("internal error {0}")]
44    InternalError(String),
45}
46
47#[allow(missing_docs)]
48pub type Result<T> = std::result::Result<T, Error>;
49
50/// OpenAPI generator.
51pub struct Generator {
52    type_space: TypeSpace,
53    settings: GenerationSettings,
54    uses_futures: bool,
55    uses_websockets: bool,
56}
57
58/// Settings for [Generator].
59#[derive(Default, Clone)]
60pub struct GenerationSettings {
61    interface: InterfaceStyle,
62    tag: TagStyle,
63    inner_type: Option<TokenStream>,
64    pre_hook: Option<TokenStream>,
65    pre_hook_async: Option<TokenStream>,
66    post_hook: Option<TokenStream>,
67    post_hook_async: Option<TokenStream>,
68    extra_derives: Vec<String>,
69
70    map_type: Option<String>,
71    unknown_crates: UnknownPolicy,
72    crates: BTreeMap<String, CrateSpec>,
73
74    patch: HashMap<String, TypePatch>,
75    replace: HashMap<String, (String, Vec<TypeImpl>)>,
76    convert: Vec<(schemars::schema::SchemaObject, String, Vec<TypeImpl>)>,
77}
78
79#[derive(Debug, Clone)]
80struct CrateSpec {
81    version: CrateVers,
82    rename: Option<String>,
83}
84
85/// Style of generated client.
86#[derive(Clone, Deserialize, PartialEq, Eq)]
87pub enum InterfaceStyle {
88    /// Use positional style.
89    Positional,
90    /// Use builder style.
91    Builder,
92}
93
94impl Default for InterfaceStyle {
95    fn default() -> Self {
96        Self::Positional
97    }
98}
99
100/// Style for using the OpenAPI tags when generating names in the client.
101#[derive(Clone, Deserialize)]
102pub enum TagStyle {
103    /// Merge tags to create names in the generated client.
104    Merged,
105    /// Use each tag name to create separate names in the generated client.
106    Separate,
107}
108
109impl Default for TagStyle {
110    fn default() -> Self {
111        Self::Merged
112    }
113}
114
115impl GenerationSettings {
116    /// Create new generator settings with default values.
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    /// Set the [InterfaceStyle].
122    pub fn with_interface(&mut self, interface: InterfaceStyle) -> &mut Self {
123        self.interface = interface;
124        self
125    }
126
127    /// Set the [TagStyle].
128    pub fn with_tag(&mut self, tag: TagStyle) -> &mut Self {
129        self.tag = tag;
130        self
131    }
132
133    /// Client inner type available to pre and post hooks.
134    pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self {
135        self.inner_type = Some(inner_type);
136        self
137    }
138
139    /// Hook invoked before issuing the HTTP request.
140    pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self {
141        self.pre_hook = Some(pre_hook);
142        self
143    }
144
145    /// Hook invoked before issuing the HTTP request.
146    pub fn with_pre_hook_async(&mut self, pre_hook: TokenStream) -> &mut Self {
147        self.pre_hook_async = Some(pre_hook);
148        self
149    }
150
151    /// Hook invoked prior to receiving the HTTP response.
152    pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self {
153        self.post_hook = Some(post_hook);
154        self
155    }
156
157    /// Hook invoked prior to receiving the HTTP response.
158    pub fn with_post_hook_async(&mut self, post_hook: TokenStream) -> &mut Self {
159        self.post_hook_async = Some(post_hook);
160        self
161    }
162
163    /// Additional derive macros applied to generated types.
164    pub fn with_derive(&mut self, derive: impl ToString) -> &mut Self {
165        self.extra_derives.push(derive.to_string());
166        self
167    }
168
169    /// Modify a type with the given name.
170    /// See [typify::TypeSpaceSettings::with_patch].
171    pub fn with_patch<S: AsRef<str>>(&mut self, type_name: S, patch: &TypePatch) -> &mut Self {
172        self.patch
173            .insert(type_name.as_ref().to_string(), patch.clone());
174        self
175    }
176
177    /// Replace a referenced type with a named type.
178    /// See [typify::TypeSpaceSettings::with_replacement].
179    pub fn with_replacement<TS: ToString, RS: ToString, I: Iterator<Item = TypeImpl>>(
180        &mut self,
181        type_name: TS,
182        replace_name: RS,
183        impls: I,
184    ) -> &mut Self {
185        self.replace.insert(
186            type_name.to_string(),
187            (replace_name.to_string(), impls.collect()),
188        );
189        self
190    }
191
192    /// Replace a given schema with a named type.
193    /// See [typify::TypeSpaceSettings::with_conversion].
194    pub fn with_conversion<S: ToString, I: Iterator<Item = TypeImpl>>(
195        &mut self,
196        schema: schemars::schema::SchemaObject,
197        type_name: S,
198        impls: I,
199    ) -> &mut Self {
200        self.convert
201            .push((schema, type_name.to_string(), impls.collect()));
202        self
203    }
204
205    /// Policy regarding crates referenced by the schema extension
206    /// `x-rust-type` not explicitly specified via [Self::with_crate].
207    /// See [typify::TypeSpaceSettings::with_unknown_crates].
208    pub fn with_unknown_crates(&mut self, policy: UnknownPolicy) -> &mut Self {
209        self.unknown_crates = policy;
210        self
211    }
212
213    /// Explicitly named crates whose types may be used during generation
214    /// rather than generating new types based on their schemas (base on the
215    /// presence of the x-rust-type extension).
216    /// See [typify::TypeSpaceSettings::with_crate].
217    pub fn with_crate<S1: ToString>(
218        &mut self,
219        crate_name: S1,
220        version: CrateVers,
221        rename: Option<&String>,
222    ) -> &mut Self {
223        self.crates.insert(
224            crate_name.to_string(),
225            CrateSpec {
226                version,
227                rename: rename.cloned(),
228            },
229        );
230        self
231    }
232
233    /// Set the type used for key-value maps. Common examples:
234    /// - [`std::collections::HashMap`] - **Default**
235    /// - [`std::collections::BTreeMap`]
236    /// - [`indexmap::IndexMap`]
237    ///
238    /// The requiremnets for a map type can be found in the
239    /// [typify::TypeSpaceSettings::with_map_type] documentation.
240    pub fn with_map_type<MT: ToString>(&mut self, map_type: MT) -> &mut Self {
241        self.map_type = Some(map_type.to_string());
242        self
243    }
244}
245
246impl Default for Generator {
247    fn default() -> Self {
248        Self {
249            type_space: TypeSpace::new(TypeSpaceSettings::default().with_type_mod("types")),
250            settings: Default::default(),
251            uses_futures: Default::default(),
252            uses_websockets: Default::default(),
253        }
254    }
255}
256
257impl Generator {
258    /// Create a new generator with default values.
259    pub fn new(settings: &GenerationSettings) -> Self {
260        let mut type_settings = TypeSpaceSettings::default();
261        type_settings
262            .with_type_mod("types")
263            .with_struct_builder(settings.interface == InterfaceStyle::Builder);
264        settings.extra_derives.iter().for_each(|derive| {
265            let _ = type_settings.with_derive(derive.clone());
266        });
267
268        // Control use of crates found in x-rust-type extension
269        type_settings.with_unknown_crates(settings.unknown_crates);
270        settings
271            .crates
272            .iter()
273            .for_each(|(crate_name, CrateSpec { version, rename })| {
274                type_settings.with_crate(crate_name, version.clone(), rename.as_ref());
275            });
276
277        // Adjust generation by type, name, or schema.
278        settings.patch.iter().for_each(|(type_name, patch)| {
279            type_settings.with_patch(type_name, patch);
280        });
281        settings
282            .replace
283            .iter()
284            .for_each(|(type_name, (replace_name, impls))| {
285                type_settings.with_replacement(type_name, replace_name, impls.iter().cloned());
286            });
287        settings
288            .convert
289            .iter()
290            .for_each(|(schema, type_name, impls)| {
291                type_settings.with_conversion(schema.clone(), type_name, impls.iter().cloned());
292            });
293
294        // Set the map type if specified.
295        if let Some(map_type) = &settings.map_type {
296            type_settings.with_map_type(map_type.clone());
297        }
298
299        Self {
300            type_space: TypeSpace::new(&type_settings),
301            settings: settings.clone(),
302            uses_futures: false,
303            uses_websockets: false,
304        }
305    }
306
307    /// Emit a [TokenStream] containing the generated client code.
308    pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result<TokenStream> {
309        validate_openapi(spec)?;
310
311        // Convert our components dictionary to schemars
312        let schemas = spec.components.iter().flat_map(|components| {
313            components
314                .schemas
315                .iter()
316                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
317        });
318
319        self.type_space.add_ref_types(schemas)?;
320
321        let raw_methods = spec
322            .paths
323            .iter()
324            .flat_map(|(path, ref_or_item)| {
325                // Exclude externally defined path items.
326                let item = ref_or_item.as_item().unwrap();
327                item.iter().map(move |(method, operation)| {
328                    (path.as_str(), method, operation, &item.parameters)
329                })
330            })
331            .map(|(path, method, operation, path_parameters)| {
332                self.process_operation(operation, &spec.components, path, method, path_parameters)
333            })
334            .collect::<Result<Vec<_>>>()?;
335
336        let operation_code = match (&self.settings.interface, &self.settings.tag) {
337            (InterfaceStyle::Positional, TagStyle::Merged) => self
338                .generate_tokens_positional_merged(
339                    &raw_methods,
340                    self.settings.inner_type.is_some(),
341                ),
342            (InterfaceStyle::Positional, TagStyle::Separate) => {
343                unimplemented!("positional arguments with separate tags are currently unsupported")
344            }
345            (InterfaceStyle::Builder, TagStyle::Merged) => self
346                .generate_tokens_builder_merged(&raw_methods, self.settings.inner_type.is_some()),
347            (InterfaceStyle::Builder, TagStyle::Separate) => {
348                let tag_info = spec
349                    .tags
350                    .iter()
351                    .map(|tag| (&tag.name, tag))
352                    .collect::<BTreeMap<_, _>>();
353                self.generate_tokens_builder_separate(
354                    &raw_methods,
355                    tag_info,
356                    self.settings.inner_type.is_some(),
357                )
358            }
359        }?;
360
361        let types = self.type_space.to_stream();
362
363        let (inner_type, inner_fn_value) = match self.settings.inner_type.as_ref() {
364            Some(inner_type) => (inner_type.clone(), quote! { &self.inner }),
365            None => (quote! { () }, quote! { &() }),
366        };
367
368        let inner_property = self.settings.inner_type.as_ref().map(|inner| {
369            quote! {
370                pub (crate) inner: #inner,
371            }
372        });
373        let inner_parameter = self.settings.inner_type.as_ref().map(|inner| {
374            quote! {
375                inner: #inner,
376            }
377        });
378        let inner_value = self.settings.inner_type.as_ref().map(|_| {
379            quote! {
380                inner
381            }
382        });
383
384        let client_docstring = {
385            let mut s = format!("Client for {}", spec.info.title);
386
387            if let Some(ss) = &spec.info.description {
388                s.push_str("\n\n");
389                s.push_str(ss);
390            }
391            if let Some(ss) = &spec.info.terms_of_service {
392                s.push_str("\n\n");
393                s.push_str(ss);
394            }
395
396            s.push_str(&format!("\n\nVersion: {}", &spec.info.version));
397
398            s
399        };
400
401        let version_str = &spec.info.version;
402
403        // The allow(unused_imports) on the `pub use` is necessary with Rust
404        // 1.76+, in case the generated file is not at the top level of the
405        // crate.
406
407        let file = quote! {
408            // Re-export types that are used by the public interface of Client.
409            #[allow(unused_imports)]
410            pub use progenitor_middleware_client::{
411                ByteStream,
412                ClientInfo,
413                Error,
414                ResponseValue,
415            };
416            #[allow(unused_imports)]
417            use progenitor_middleware_client::{
418                encode_path,
419                ClientHooks,
420                OperationInfo,
421                RequestBuilderExt,
422            };
423
424            /// Types used as operation parameters and responses.
425            #[allow(clippy::all)]
426            pub mod types {
427                #types
428            }
429
430            #[derive(Clone, Debug)]
431            #[doc = #client_docstring]
432            pub struct Client {
433                pub(crate) baseurl: String,
434                pub(crate) client: reqwest_middleware::ClientWithMiddleware,
435                #inner_property
436            }
437
438            impl Client {
439                /// Create a new client.
440                ///
441                /// `baseurl` is the base URL provided to the internal
442                /// `reqwest::Client`, and should include a scheme and hostname,
443                /// as well as port and a path stem if applicable.
444                pub fn new(
445                    baseurl: &str,
446                    #inner_parameter
447                ) -> Self {
448                    #[cfg(not(target_arch = "wasm32"))]
449                    let client = {
450                        let dur = std::time::Duration::from_secs(15);
451
452                        let reqwest_client = reqwest::ClientBuilder::new()
453                            .connect_timeout(dur)
454                            .timeout(dur)
455                            .build()
456                            .unwrap();
457
458                        reqwest_middleware::ClientBuilder::new(reqwest_client)
459                            .build()
460                    };
461                    #[cfg(target_arch = "wasm32")]
462                    let client = {
463                        let reqwest_client = reqwest::ClientBuilder::new()
464                            .build()
465                            .unwrap();
466
467                        reqwest_middleware::ClientBuilder::new(reqwest_client)
468                            .build()
469                    };
470
471                    Self::new_with_client(baseurl, client, #inner_value)
472                }
473
474                /// Construct a new client with an existing `reqwest_middleware::ClientWithMiddleware`,
475                /// allowing more control over its configuration.
476                ///
477                /// `baseurl` is the base URL provided to the internal
478                /// `reqwest_middleware::ClientWithMiddleware`, and should include a scheme and hostname,
479                /// as well as port and a path stem if applicable.
480                pub fn new_with_client(
481                    baseurl: &str,
482                    client: reqwest_middleware::ClientWithMiddleware,
483                    #inner_parameter
484                ) -> Self {
485                    Self {
486                        baseurl: baseurl.to_string(),
487                        client,
488                        #inner_value
489                    }
490                }
491            }
492
493            impl ClientInfo<#inner_type> for Client {
494                fn api_version() -> &'static str {
495                    #version_str
496                }
497
498                fn baseurl(&self) -> &str {
499                    self.baseurl.as_str()
500                }
501
502                fn client(&self) -> &reqwest_middleware::ClientWithMiddleware {
503                    &self.client
504                }
505
506                fn inner(&self) -> &#inner_type {
507                    #inner_fn_value
508                }
509            }
510
511            impl ClientHooks<#inner_type> for &Client {}
512
513            #operation_code
514        };
515
516        Ok(file)
517    }
518
519    fn generate_tokens_positional_merged(
520        &mut self,
521        input_methods: &[method::OperationMethod],
522        has_inner: bool,
523    ) -> Result<TokenStream> {
524        let methods = input_methods
525            .iter()
526            .map(|method| self.positional_method(method, has_inner))
527            .collect::<Result<Vec<_>>>()?;
528
529        // The allow(unused_imports) on the `pub use` is necessary with Rust
530        // 1.76+, in case the generated file is not at the top level of the
531        // crate.
532
533        let out = quote! {
534            #[allow(clippy::all)]
535            impl Client {
536                #(#methods)*
537            }
538
539            /// Items consumers will typically use such as the Client.
540            pub mod prelude {
541                #[allow(unused_imports)]
542                pub use super::Client;
543            }
544        };
545        Ok(out)
546    }
547
548    fn generate_tokens_builder_merged(
549        &mut self,
550        input_methods: &[method::OperationMethod],
551        has_inner: bool,
552    ) -> Result<TokenStream> {
553        let builder_struct = input_methods
554            .iter()
555            .map(|method| self.builder_struct(method, TagStyle::Merged, has_inner))
556            .collect::<Result<Vec<_>>>()?;
557
558        let builder_methods = input_methods
559            .iter()
560            .map(|method| self.builder_impl(method))
561            .collect::<Vec<_>>();
562
563        let out = quote! {
564            impl Client {
565                #(#builder_methods)*
566            }
567
568            /// Types for composing operation parameters.
569            #[allow(clippy::all)]
570            pub mod builder {
571                use super::types;
572                #[allow(unused_imports)]
573                use super::{
574                    encode_path,
575                    ByteStream,
576                    ClientInfo,
577                    ClientHooks,
578                    Error,
579                    OperationInfo,
580                    RequestBuilderExt,
581                    ResponseValue,
582                };
583
584                #(#builder_struct)*
585            }
586
587            /// Items consumers will typically use such as the Client.
588            pub mod prelude {
589                pub use self::super::Client;
590            }
591        };
592
593        Ok(out)
594    }
595
596    fn generate_tokens_builder_separate(
597        &mut self,
598        input_methods: &[method::OperationMethod],
599        tag_info: BTreeMap<&String, &openapiv3::Tag>,
600        has_inner: bool,
601    ) -> Result<TokenStream> {
602        let builder_struct = input_methods
603            .iter()
604            .map(|method| self.builder_struct(method, TagStyle::Separate, has_inner))
605            .collect::<Result<Vec<_>>>()?;
606
607        let (traits_and_impls, trait_preludes) = self.builder_tags(input_methods, &tag_info);
608
609        // The allow(unused_imports) on the `pub use` is necessary with Rust
610        // 1.76+, in case the generated file is not at the top level of the
611        // crate.
612
613        let out = quote! {
614            #traits_and_impls
615
616            /// Types for composing operation parameters.
617            #[allow(clippy::all)]
618            pub mod builder {
619                use super::types;
620                #[allow(unused_imports)]
621                use super::{
622                    encode_path,
623                    ByteStream,
624                    ClientInfo,
625                    ClientHooks,
626                    Error,
627                    OperationInfo,
628                    RequestBuilderExt,
629                    ResponseValue,
630                };
631
632                #(#builder_struct)*
633            }
634
635            /// Items consumers will typically use such as the Client and
636            /// extension traits.
637            pub mod prelude {
638                #[allow(unused_imports)]
639                pub use super::Client;
640                #trait_preludes
641            }
642        };
643
644        Ok(out)
645    }
646
647    /// Get the [TypeSpace] for schemas present in the OpenAPI specification.
648    pub fn get_type_space(&self) -> &TypeSpace {
649        &self.type_space
650    }
651
652    /// Whether the generated client needs to use additional crates to support
653    /// futures.
654    pub fn uses_futures(&self) -> bool {
655        self.uses_futures
656    }
657
658    /// Whether the generated client needs to use additional crates to support
659    /// websockets.
660    pub fn uses_websockets(&self) -> bool {
661        self.uses_websockets
662    }
663}
664
665/// Add newlines after end-braces at <= two levels of indentation.
666pub fn space_out_items(content: String) -> Result<String> {
667    Ok(if cfg!(not(windows)) {
668        let regex = regex::Regex::new(r#"(\n\s*})(\n\s{0,8}[^} ])"#).unwrap();
669        regex.replace_all(&content, "$1\n$2").to_string()
670    } else {
671        let regex = regex::Regex::new(r#"(\n\s*})(\r\n\s{0,8}[^} ])"#).unwrap();
672        regex.replace_all(&content, "$1\r\n$2").to_string()
673    })
674}
675
676fn validate_openapi_spec_version(spec_version: &str) -> Result<()> {
677    // progenitor currenlty only support OAS 3.0.x
678    if spec_version.trim().starts_with("3.0.") {
679        Ok(())
680    } else {
681        Err(Error::UnexpectedFormat(format!(
682            "invalid version: {}",
683            spec_version
684        )))
685    }
686}
687
688/// Do some very basic checks of the OpenAPI documents.
689pub fn validate_openapi(spec: &OpenAPI) -> Result<()> {
690    validate_openapi_spec_version(spec.openapi.as_str())?;
691
692    let mut opids = HashSet::new();
693    spec.paths.paths.iter().try_for_each(|p| {
694        match p.1 {
695            openapiv3::ReferenceOr::Reference { reference: _ } => Err(Error::UnexpectedFormat(
696                format!("path {} uses reference, unsupported", p.0,),
697            )),
698            openapiv3::ReferenceOr::Item(item) => {
699                // Make sure every operation has an operation ID, and that each
700                // operation ID is only used once in the document.
701                item.iter().try_for_each(|(_, o)| {
702                    if let Some(oid) = o.operation_id.as_ref() {
703                        if !opids.insert(oid.to_string()) {
704                            return Err(Error::UnexpectedFormat(format!(
705                                "duplicate operation ID: {}",
706                                oid,
707                            )));
708                        }
709                    } else {
710                        return Err(Error::UnexpectedFormat(format!(
711                            "path {} is missing operation ID",
712                            p.0,
713                        )));
714                    }
715                    Ok(())
716                })
717            }
718        }
719    })?;
720
721    Ok(())
722}
723
724#[cfg(test)]
725mod tests {
726    use serde_json::json;
727
728    use crate::{validate_openapi_spec_version, Error};
729
730    #[test]
731    fn test_bad_value() {
732        assert_eq!(
733            Error::BadValue("nope".to_string(), json! { "nope"},).to_string(),
734            "unexpected value type nope: \"nope\"",
735        );
736    }
737
738    #[test]
739    fn test_type_error() {
740        assert_eq!(
741            Error::UnexpectedFormat("nope".to_string()).to_string(),
742            "unexpected or unhandled format in the OpenAPI document nope",
743        );
744    }
745
746    #[test]
747    fn test_invalid_path() {
748        assert_eq!(
749            Error::InvalidPath("nope".to_string()).to_string(),
750            "invalid operation path nope",
751        );
752    }
753
754    #[test]
755    fn test_internal_error() {
756        assert_eq!(
757            Error::InternalError("nope".to_string()).to_string(),
758            "internal error nope",
759        );
760    }
761
762    #[test]
763    fn test_validate_openapi_spec_version() {
764        assert!(validate_openapi_spec_version("3.0.0").is_ok());
765        assert!(validate_openapi_spec_version("3.0.1").is_ok());
766        assert!(validate_openapi_spec_version("3.0.4").is_ok());
767        assert!(validate_openapi_spec_version("3.0.5-draft").is_ok());
768        assert_eq!(
769            validate_openapi_spec_version("3.1.0")
770                .unwrap_err()
771                .to_string(),
772            "unexpected or unhandled format in the OpenAPI document invalid version: 3.1.0"
773        );
774    }
775}