progenitor_impl/
lib.rs

1// Copyright 2025 Oxide Computer Company
2
3//! Core implementation for the progenitor OpenAPI client generator.
4
5#![deny(missing_docs)]
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8
9use openapiv3::OpenAPI;
10use proc_macro2::TokenStream;
11use quote::quote;
12use serde::Deserialize;
13use thiserror::Error;
14use typify::{TypeSpace, TypeSpaceSettings};
15
16use crate::to_schema::ToSchema;
17
18pub use typify::CrateVers;
19pub use typify::TypeSpaceImpl as TypeImpl;
20pub use typify::TypeSpacePatch as TypePatch;
21pub use typify::UnknownPolicy;
22
23mod cli;
24mod httpmock;
25mod method;
26mod template;
27mod to_schema;
28mod util;
29
30#[allow(missing_docs)]
31#[derive(Error, Debug)]
32pub enum Error {
33    #[error("unexpected value type {0}: {1}")]
34    BadValue(String, serde_json::Value),
35    #[error("type error {0}")]
36    TypeError(#[from] typify::Error),
37    #[error("unexpected or unhandled format in the OpenAPI document {0}")]
38    UnexpectedFormat(String),
39    #[error("invalid operation path {0}")]
40    InvalidPath(String),
41    #[error("invalid dropshot extension use: {0}")]
42    InvalidExtension(String),
43    #[error("internal error {0}")]
44    InternalError(String),
45}
46
47#[allow(missing_docs)]
48pub type Result<T> = std::result::Result<T, Error>;
49
50/// OpenAPI generator.
51pub struct Generator {
52    type_space: TypeSpace,
53    settings: GenerationSettings,
54    uses_futures: bool,
55    uses_websockets: bool,
56}
57
58/// Settings for [Generator].
59#[derive(Default, Clone)]
60pub struct GenerationSettings {
61    interface: InterfaceStyle,
62    tag: TagStyle,
63    inner_type: Option<TokenStream>,
64    pre_hook: Option<TokenStream>,
65    pre_hook_async: Option<TokenStream>,
66    post_hook: Option<TokenStream>,
67    post_hook_async: Option<TokenStream>,
68    extra_derives: Vec<String>,
69
70    map_type: Option<String>,
71    unknown_crates: UnknownPolicy,
72    crates: BTreeMap<String, CrateSpec>,
73
74    patch: HashMap<String, TypePatch>,
75    replace: HashMap<String, (String, Vec<TypeImpl>)>,
76    convert: Vec<(schemars::schema::SchemaObject, String, Vec<TypeImpl>)>,
77}
78
79#[derive(Debug, Clone)]
80struct CrateSpec {
81    version: CrateVers,
82    rename: Option<String>,
83}
84
85/// Style of generated client.
86#[derive(Clone, Deserialize, PartialEq, Eq)]
87pub enum InterfaceStyle {
88    /// Use positional style.
89    Positional,
90    /// Use builder style.
91    Builder,
92}
93
94impl Default for InterfaceStyle {
95    fn default() -> Self {
96        Self::Positional
97    }
98}
99
100/// Style for using the OpenAPI tags when generating names in the client.
101#[derive(Clone, Deserialize)]
102pub enum TagStyle {
103    /// Merge tags to create names in the generated client.
104    Merged,
105    /// Use each tag name to create separate names in the generated client.
106    Separate,
107}
108
109impl Default for TagStyle {
110    fn default() -> Self {
111        Self::Merged
112    }
113}
114
115impl GenerationSettings {
116    /// Create new generator settings with default values.
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    /// Set the [InterfaceStyle].
122    pub fn with_interface(&mut self, interface: InterfaceStyle) -> &mut Self {
123        self.interface = interface;
124        self
125    }
126
127    /// Set the [TagStyle].
128    pub fn with_tag(&mut self, tag: TagStyle) -> &mut Self {
129        self.tag = tag;
130        self
131    }
132
133    /// Client inner type available to pre and post hooks.
134    pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self {
135        self.inner_type = Some(inner_type);
136        self
137    }
138
139    /// Hook invoked before issuing the HTTP request.
140    pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self {
141        self.pre_hook = Some(pre_hook);
142        self
143    }
144
145    /// Hook invoked before issuing the HTTP request.
146    pub fn with_pre_hook_async(&mut self, pre_hook: TokenStream) -> &mut Self {
147        self.pre_hook_async = Some(pre_hook);
148        self
149    }
150
151    /// Hook invoked prior to receiving the HTTP response.
152    pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self {
153        self.post_hook = Some(post_hook);
154        self
155    }
156
157    /// Hook invoked prior to receiving the HTTP response.
158    pub fn with_post_hook_async(&mut self, post_hook: TokenStream) -> &mut Self {
159        self.post_hook_async = Some(post_hook);
160        self
161    }
162
163    /// Additional derive macros applied to generated types.
164    pub fn with_derive(&mut self, derive: impl ToString) -> &mut Self {
165        self.extra_derives.push(derive.to_string());
166        self
167    }
168
169    /// Modify a type with the given name.
170    /// See [typify::TypeSpaceSettings::with_patch].
171    pub fn with_patch<S: AsRef<str>>(&mut self, type_name: S, patch: &TypePatch) -> &mut Self {
172        self.patch
173            .insert(type_name.as_ref().to_string(), patch.clone());
174        self
175    }
176
177    /// Replace a referenced type with a named type.
178    /// See [typify::TypeSpaceSettings::with_replacement].
179    pub fn with_replacement<TS: ToString, RS: ToString, I: Iterator<Item = TypeImpl>>(
180        &mut self,
181        type_name: TS,
182        replace_name: RS,
183        impls: I,
184    ) -> &mut Self {
185        self.replace.insert(
186            type_name.to_string(),
187            (replace_name.to_string(), impls.collect()),
188        );
189        self
190    }
191
192    /// Replace a given schema with a named type.
193    /// See [typify::TypeSpaceSettings::with_conversion].
194    pub fn with_conversion<S: ToString, I: Iterator<Item = TypeImpl>>(
195        &mut self,
196        schema: schemars::schema::SchemaObject,
197        type_name: S,
198        impls: I,
199    ) -> &mut Self {
200        self.convert
201            .push((schema, type_name.to_string(), impls.collect()));
202        self
203    }
204
205    /// Policy regarding crates referenced by the schema extension
206    /// `x-rust-type` not explicitly specified via [Self::with_crate].
207    /// See [typify::TypeSpaceSettings::with_unknown_crates].
208    pub fn with_unknown_crates(&mut self, policy: UnknownPolicy) -> &mut Self {
209        self.unknown_crates = policy;
210        self
211    }
212
213    /// Explicitly named crates whose types may be used during generation
214    /// rather than generating new types based on their schemas (base on the
215    /// presence of the x-rust-type extension).
216    /// See [typify::TypeSpaceSettings::with_crate].
217    pub fn with_crate<S1: ToString>(
218        &mut self,
219        crate_name: S1,
220        version: CrateVers,
221        rename: Option<&String>,
222    ) -> &mut Self {
223        self.crates.insert(
224            crate_name.to_string(),
225            CrateSpec {
226                version,
227                rename: rename.cloned(),
228            },
229        );
230        self
231    }
232
233    /// Set the type used for key-value maps. Common examples:
234    /// - [`std::collections::HashMap`] - **Default**
235    /// - [`std::collections::BTreeMap`]
236    /// - [`indexmap::IndexMap`]
237    ///
238    /// The requiremnets for a map type can be found in the
239    /// [typify::TypeSpaceSettings::with_map_type] documentation.
240    pub fn with_map_type<MT: ToString>(&mut self, map_type: MT) -> &mut Self {
241        self.map_type = Some(map_type.to_string());
242        self
243    }
244}
245
246impl Default for Generator {
247    fn default() -> Self {
248        Self {
249            type_space: TypeSpace::new(TypeSpaceSettings::default().with_type_mod("types")),
250            settings: Default::default(),
251            uses_futures: Default::default(),
252            uses_websockets: Default::default(),
253        }
254    }
255}
256
257impl Generator {
258    /// Create a new generator with default values.
259    pub fn new(settings: &GenerationSettings) -> Self {
260        let mut type_settings = TypeSpaceSettings::default();
261        type_settings
262            .with_type_mod("types")
263            .with_struct_builder(settings.interface == InterfaceStyle::Builder);
264        settings.extra_derives.iter().for_each(|derive| {
265            let _ = type_settings.with_derive(derive.clone());
266        });
267
268        // Control use of crates found in x-rust-type extension
269        type_settings.with_unknown_crates(settings.unknown_crates);
270        settings
271            .crates
272            .iter()
273            .for_each(|(crate_name, CrateSpec { version, rename })| {
274                type_settings.with_crate(crate_name, version.clone(), rename.as_ref());
275            });
276
277        // Adjust generation by type, name, or schema.
278        settings.patch.iter().for_each(|(type_name, patch)| {
279            type_settings.with_patch(type_name, patch);
280        });
281        settings
282            .replace
283            .iter()
284            .for_each(|(type_name, (replace_name, impls))| {
285                type_settings.with_replacement(type_name, replace_name, impls.iter().cloned());
286            });
287        settings
288            .convert
289            .iter()
290            .for_each(|(schema, type_name, impls)| {
291                type_settings.with_conversion(schema.clone(), type_name, impls.iter().cloned());
292            });
293
294        // Set the map type if specified.
295        if let Some(map_type) = &settings.map_type {
296            type_settings.with_map_type(map_type.clone());
297        }
298
299        Self {
300            type_space: TypeSpace::new(&type_settings),
301            settings: settings.clone(),
302            uses_futures: false,
303            uses_websockets: false,
304        }
305    }
306
307    /// Emit a [TokenStream] containing the generated client code.
308    pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result<TokenStream> {
309        validate_openapi(spec)?;
310
311        // Convert our components dictionary to schemars
312        let schemas = spec.components.iter().flat_map(|components| {
313            components
314                .schemas
315                .iter()
316                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
317        });
318
319        self.type_space.add_ref_types(schemas)?;
320
321        let raw_methods = spec
322            .paths
323            .iter()
324            .flat_map(|(path, ref_or_item)| {
325                // Exclude externally defined path items.
326                let item = ref_or_item.as_item().unwrap();
327                item.iter().map(move |(method, operation)| {
328                    (path.as_str(), method, operation, &item.parameters)
329                })
330            })
331            .map(|(path, method, operation, path_parameters)| {
332                self.process_operation(operation, &spec.components, path, method, path_parameters)
333            })
334            .collect::<Result<Vec<_>>>()?;
335
336        let operation_code = match (&self.settings.interface, &self.settings.tag) {
337            (InterfaceStyle::Positional, TagStyle::Merged) => self
338                .generate_tokens_positional_merged(
339                    &raw_methods,
340                    self.settings.inner_type.is_some(),
341                ),
342            (InterfaceStyle::Positional, TagStyle::Separate) => {
343                unimplemented!("positional arguments with separate tags are currently unsupported")
344            }
345            (InterfaceStyle::Builder, TagStyle::Merged) => self
346                .generate_tokens_builder_merged(&raw_methods, self.settings.inner_type.is_some()),
347            (InterfaceStyle::Builder, TagStyle::Separate) => {
348                let tag_info = spec
349                    .tags
350                    .iter()
351                    .map(|tag| (&tag.name, tag))
352                    .collect::<BTreeMap<_, _>>();
353                self.generate_tokens_builder_separate(
354                    &raw_methods,
355                    tag_info,
356                    self.settings.inner_type.is_some(),
357                )
358            }
359        }?;
360
361        let types = self.type_space.to_stream();
362
363        // Generate an implementation of a `Self::as_inner` method, if an inner
364        // type is defined.
365        let maybe_inner = self.settings.inner_type.as_ref().map(|inner| {
366            quote! {
367                /// Return a reference to the inner type stored in `self`.
368                pub fn inner(&self) -> &#inner {
369                    &self.inner
370                }
371            }
372        });
373
374        let inner_property = self.settings.inner_type.as_ref().map(|inner| {
375            quote! {
376                pub (crate) inner: #inner,
377            }
378        });
379        let inner_parameter = self.settings.inner_type.as_ref().map(|inner| {
380            quote! {
381                inner: #inner,
382            }
383        });
384        let inner_value = self.settings.inner_type.as_ref().map(|_| {
385            quote! {
386                inner
387            }
388        });
389
390        let client_docstring = {
391            let mut s = format!("Client for {}", spec.info.title);
392
393            if let Some(ss) = &spec.info.description {
394                s.push_str("\n\n");
395                s.push_str(ss);
396            }
397            if let Some(ss) = &spec.info.terms_of_service {
398                s.push_str("\n\n");
399                s.push_str(ss);
400            }
401
402            s.push_str(&format!("\n\nVersion: {}", &spec.info.version));
403
404            s
405        };
406
407        let version_str = &spec.info.version;
408
409        // The allow(unused_imports) on the `pub use` is necessary with Rust
410        // 1.76+, in case the generated file is not at the top level of the
411        // crate.
412
413        let file = quote! {
414            // Re-export ResponseValue and Error since those are used by the
415            // public interface of Client.
416            #[allow(unused_imports)]
417            pub use progenitor_client::{ByteStream, Error, ResponseValue};
418            #[allow(unused_imports)]
419            use progenitor_client::{encode_path, RequestBuilderExt};
420
421            /// Types used as operation parameters and responses.
422            #[allow(clippy::all)]
423            pub mod types {
424                #types
425            }
426
427            #[derive(Clone, Debug)]
428            #[doc = #client_docstring]
429            pub struct Client {
430                pub(crate) baseurl: String,
431                pub(crate) client: reqwest::Client,
432                #inner_property
433            }
434
435            impl Client {
436                /// Create a new client.
437                ///
438                /// `baseurl` is the base URL provided to the internal
439                /// `reqwest::Client`, and should include a scheme and hostname,
440                /// as well as port and a path stem if applicable.
441                pub fn new(
442                    baseurl: &str,
443                    #inner_parameter
444                ) -> Self {
445                    #[cfg(not(target_arch = "wasm32"))]
446                    let client = {
447                        let dur = std::time::Duration::from_secs(15);
448
449                        reqwest::ClientBuilder::new()
450                            .connect_timeout(dur)
451                            .timeout(dur)
452                    };
453                    #[cfg(target_arch = "wasm32")]
454                    let client = reqwest::ClientBuilder::new();
455
456                    Self::new_with_client(baseurl, client.build().unwrap(), #inner_value)
457                }
458
459                /// Construct a new client with an existing `reqwest::Client`,
460                /// allowing more control over its configuration.
461                ///
462                /// `baseurl` is the base URL provided to the internal
463                /// `reqwest::Client`, and should include a scheme and hostname,
464                /// as well as port and a path stem if applicable.
465                pub fn new_with_client(
466                    baseurl: &str,
467                    client: reqwest::Client,
468                    #inner_parameter
469                ) -> Self {
470                    Self {
471                        baseurl: baseurl.to_string(),
472                        client,
473                        #inner_value
474                    }
475                }
476
477                /// Get the base URL to which requests are made.
478                pub fn baseurl(&self) -> &String {
479                    &self.baseurl
480                }
481
482                /// Get the internal `reqwest::Client` used to make requests.
483                pub fn client(&self) -> &reqwest::Client {
484                    &self.client
485                }
486
487                /// Get the version of this API.
488                ///
489                /// This string is pulled directly from the source OpenAPI
490                /// document and may be in any format the API selects.
491                pub fn api_version(&self) -> &'static str {
492                    #version_str
493                }
494
495                #maybe_inner
496            }
497
498            #operation_code
499        };
500
501        Ok(file)
502    }
503
504    fn generate_tokens_positional_merged(
505        &mut self,
506        input_methods: &[method::OperationMethod],
507        has_inner: bool,
508    ) -> Result<TokenStream> {
509        let methods = input_methods
510            .iter()
511            .map(|method| self.positional_method(method, has_inner))
512            .collect::<Result<Vec<_>>>()?;
513
514        // The allow(unused_imports) on the `pub use` is necessary with Rust
515        // 1.76+, in case the generated file is not at the top level of the
516        // crate.
517
518        let out = quote! {
519            #[allow(clippy::all)]
520            #[allow(elided_named_lifetimes)]
521            impl Client {
522                #(#methods)*
523            }
524
525            /// Items consumers will typically use such as the Client.
526            pub mod prelude {
527                #[allow(unused_imports)]
528                pub use super::Client;
529            }
530        };
531        Ok(out)
532    }
533
534    fn generate_tokens_builder_merged(
535        &mut self,
536        input_methods: &[method::OperationMethod],
537        has_inner: bool,
538    ) -> Result<TokenStream> {
539        let builder_struct = input_methods
540            .iter()
541            .map(|method| self.builder_struct(method, TagStyle::Merged, has_inner))
542            .collect::<Result<Vec<_>>>()?;
543
544        let builder_methods = input_methods
545            .iter()
546            .map(|method| self.builder_impl(method))
547            .collect::<Vec<_>>();
548
549        let out = quote! {
550            impl Client {
551                #(#builder_methods)*
552            }
553
554            /// Types for composing operation parameters.
555            #[allow(clippy::all)]
556            pub mod builder {
557                use super::types;
558                #[allow(unused_imports)]
559                use super::{
560                    encode_path,
561                    ByteStream,
562                    Error,
563                    RequestBuilderExt,
564                    ResponseValue,
565                };
566
567                #(#builder_struct)*
568            }
569
570            /// Items consumers will typically use such as the Client.
571            pub mod prelude {
572                pub use self::super::Client;
573            }
574        };
575
576        Ok(out)
577    }
578
579    fn generate_tokens_builder_separate(
580        &mut self,
581        input_methods: &[method::OperationMethod],
582        tag_info: BTreeMap<&String, &openapiv3::Tag>,
583        has_inner: bool,
584    ) -> Result<TokenStream> {
585        let builder_struct = input_methods
586            .iter()
587            .map(|method| self.builder_struct(method, TagStyle::Separate, has_inner))
588            .collect::<Result<Vec<_>>>()?;
589
590        let (traits_and_impls, trait_preludes) = self.builder_tags(input_methods, &tag_info);
591
592        // The allow(unused_imports) on the `pub use` is necessary with Rust
593        // 1.76+, in case the generated file is not at the top level of the
594        // crate.
595
596        let out = quote! {
597            #traits_and_impls
598
599            /// Types for composing operation parameters.
600            #[allow(clippy::all)]
601            pub mod builder {
602                use super::types;
603                #[allow(unused_imports)]
604                use super::{
605                    encode_path,
606                    ByteStream,
607                    Error,
608                    RequestBuilderExt,
609                    ResponseValue,
610                };
611
612                #(#builder_struct)*
613            }
614
615            /// Items consumers will typically use such as the Client and
616            /// extension traits.
617            pub mod prelude {
618                #[allow(unused_imports)]
619                pub use super::Client;
620                #trait_preludes
621            }
622        };
623
624        Ok(out)
625    }
626
627    /// Get the [TypeSpace] for schemas present in the OpenAPI specification.
628    pub fn get_type_space(&self) -> &TypeSpace {
629        &self.type_space
630    }
631
632    /// Whether the generated client needs to use additional crates to support
633    /// futures.
634    pub fn uses_futures(&self) -> bool {
635        self.uses_futures
636    }
637
638    /// Whether the generated client needs to use additional crates to support
639    /// websockets.
640    pub fn uses_websockets(&self) -> bool {
641        self.uses_websockets
642    }
643}
644
645/// Add newlines after end-braces at <= two levels of indentation.
646pub fn space_out_items(content: String) -> Result<String> {
647    Ok(if cfg!(not(windows)) {
648        let regex = regex::Regex::new(r#"(\n\s*})(\n\s{0,8}[^} ])"#).unwrap();
649        regex.replace_all(&content, "$1\n$2").to_string()
650    } else {
651        let regex = regex::Regex::new(r#"(\n\s*})(\r\n\s{0,8}[^} ])"#).unwrap();
652        regex.replace_all(&content, "$1\r\n$2").to_string()
653    })
654}
655
656fn validate_openapi_spec_version(spec_version: &str) -> Result<()> {
657    // progenitor currenlty only support OAS 3.0.x
658    if spec_version.trim().starts_with("3.0.") {
659        Ok(())
660    } else {
661        Err(Error::UnexpectedFormat(format!(
662            "invalid version: {}",
663            spec_version
664        )))
665    }
666}
667
668/// Do some very basic checks of the OpenAPI documents.
669pub fn validate_openapi(spec: &OpenAPI) -> Result<()> {
670    validate_openapi_spec_version(spec.openapi.as_str())?;
671
672    let mut opids = HashSet::new();
673    spec.paths.paths.iter().try_for_each(|p| {
674        match p.1 {
675            openapiv3::ReferenceOr::Reference { reference: _ } => Err(Error::UnexpectedFormat(
676                format!("path {} uses reference, unsupported", p.0,),
677            )),
678            openapiv3::ReferenceOr::Item(item) => {
679                // Make sure every operation has an operation ID, and that each
680                // operation ID is only used once in the document.
681                item.iter().try_for_each(|(_, o)| {
682                    if let Some(oid) = o.operation_id.as_ref() {
683                        if !opids.insert(oid.to_string()) {
684                            return Err(Error::UnexpectedFormat(format!(
685                                "duplicate operation ID: {}",
686                                oid,
687                            )));
688                        }
689                    } else {
690                        return Err(Error::UnexpectedFormat(format!(
691                            "path {} is missing operation ID",
692                            p.0,
693                        )));
694                    }
695                    Ok(())
696                })
697            }
698        }
699    })?;
700
701    Ok(())
702}
703
704#[cfg(test)]
705mod tests {
706    use serde_json::json;
707
708    use crate::{validate_openapi_spec_version, Error};
709
710    #[test]
711    fn test_bad_value() {
712        assert_eq!(
713            Error::BadValue("nope".to_string(), json! { "nope"},).to_string(),
714            "unexpected value type nope: \"nope\"",
715        );
716    }
717
718    #[test]
719    fn test_type_error() {
720        assert_eq!(
721            Error::UnexpectedFormat("nope".to_string()).to_string(),
722            "unexpected or unhandled format in the OpenAPI document nope",
723        );
724    }
725
726    #[test]
727    fn test_invalid_path() {
728        assert_eq!(
729            Error::InvalidPath("nope".to_string()).to_string(),
730            "invalid operation path nope",
731        );
732    }
733
734    #[test]
735    fn test_internal_error() {
736        assert_eq!(
737            Error::InternalError("nope".to_string()).to_string(),
738            "internal error nope",
739        );
740    }
741
742    #[test]
743    fn test_validate_openapi_spec_version() {
744        assert!(validate_openapi_spec_version("3.0.0").is_ok());
745        assert!(validate_openapi_spec_version("3.0.1").is_ok());
746        assert!(validate_openapi_spec_version("3.0.4").is_ok());
747        assert!(validate_openapi_spec_version("3.0.5-draft").is_ok());
748        assert_eq!(
749            validate_openapi_spec_version("3.1.0")
750                .unwrap_err()
751                .to_string(),
752            "unexpected or unhandled format in the OpenAPI document invalid version: 3.1.0"
753        );
754    }
755}