progenitor_impl/
lib.rs

1// Copyright 2025 Oxide Computer Company
2
3//! Core implementation for the progenitor OpenAPI client generator.
4
5#![deny(missing_docs)]
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8
9use openapiv3::OpenAPI;
10use proc_macro2::TokenStream;
11use quote::quote;
12use serde::Deserialize;
13use thiserror::Error;
14use typify::{TypeSpace, TypeSpaceSettings};
15
16use crate::to_schema::ToSchema;
17
18pub use typify::CrateVers;
19pub use typify::TypeSpaceImpl as TypeImpl;
20pub use typify::TypeSpacePatch as TypePatch;
21pub use typify::UnknownPolicy;
22
23mod cli;
24mod httpmock;
25mod method;
26mod template;
27mod to_schema;
28mod util;
29
30#[allow(missing_docs)]
31#[derive(Error, Debug)]
32pub enum Error {
33    #[error("unexpected value type {0}: {1}")]
34    BadValue(String, serde_json::Value),
35    #[error("type error {0}")]
36    TypeError(#[from] typify::Error),
37    #[error("unexpected or unhandled format in the OpenAPI document {0}")]
38    UnexpectedFormat(String),
39    #[error("invalid operation path {0}")]
40    InvalidPath(String),
41    #[error("invalid dropshot extension use: {0}")]
42    InvalidExtension(String),
43    #[error("internal error {0}")]
44    InternalError(String),
45}
46
47#[allow(missing_docs)]
48pub type Result<T> = std::result::Result<T, Error>;
49
50/// OpenAPI generator.
51pub struct Generator {
52    type_space: TypeSpace,
53    settings: GenerationSettings,
54    uses_futures: bool,
55    uses_websockets: bool,
56}
57
58/// Settings for [Generator].
59#[derive(Default, Clone)]
60pub struct GenerationSettings {
61    interface: InterfaceStyle,
62    tag: TagStyle,
63    inner_type: Option<TokenStream>,
64    pre_hook: Option<TokenStream>,
65    pre_hook_async: Option<TokenStream>,
66    post_hook: Option<TokenStream>,
67    post_hook_async: Option<TokenStream>,
68    extra_derives: Vec<String>,
69    extra_cli_bounds: Vec<String>,
70
71    map_type: Option<String>,
72    unknown_crates: UnknownPolicy,
73    crates: BTreeMap<String, CrateSpec>,
74
75    patch: HashMap<String, TypePatch>,
76    replace: HashMap<String, (String, Vec<TypeImpl>)>,
77    convert: Vec<(schemars::schema::SchemaObject, String, Vec<TypeImpl>)>,
78    timeout: Option<u64>,
79}
80
81#[derive(Debug, Clone)]
82struct CrateSpec {
83    version: CrateVers,
84    rename: Option<String>,
85}
86
87/// Style of generated client.
88#[derive(Clone, Deserialize, PartialEq, Eq)]
89pub enum InterfaceStyle {
90    /// Use positional style.
91    Positional,
92    /// Use builder style.
93    Builder,
94}
95
96impl Default for InterfaceStyle {
97    fn default() -> Self {
98        Self::Positional
99    }
100}
101
102/// Style for using the OpenAPI tags when generating names in the client.
103#[derive(Clone, Deserialize)]
104pub enum TagStyle {
105    /// Merge tags to create names in the generated client.
106    Merged,
107    /// Use each tag name to create separate names in the generated client.
108    Separate,
109}
110
111impl Default for TagStyle {
112    fn default() -> Self {
113        Self::Merged
114    }
115}
116
117impl GenerationSettings {
118    /// Create new generator settings with default values.
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Set the [InterfaceStyle].
124    pub fn with_interface(&mut self, interface: InterfaceStyle) -> &mut Self {
125        self.interface = interface;
126        self
127    }
128
129    /// Set the [TagStyle].
130    pub fn with_tag(&mut self, tag: TagStyle) -> &mut Self {
131        self.tag = tag;
132        self
133    }
134
135    /// Client inner type available to pre and post hooks.
136    pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self {
137        self.inner_type = Some(inner_type);
138        self
139    }
140
141    /// Hook invoked before issuing the HTTP request.
142    pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self {
143        self.pre_hook = Some(pre_hook);
144        self
145    }
146
147    /// Hook invoked before issuing the HTTP request.
148    pub fn with_pre_hook_async(&mut self, pre_hook: TokenStream) -> &mut Self {
149        self.pre_hook_async = Some(pre_hook);
150        self
151    }
152
153    /// Hook invoked prior to receiving the HTTP response.
154    pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self {
155        self.post_hook = Some(post_hook);
156        self
157    }
158
159    /// Hook invoked prior to receiving the HTTP response.
160    pub fn with_post_hook_async(&mut self, post_hook: TokenStream) -> &mut Self {
161        self.post_hook_async = Some(post_hook);
162        self
163    }
164
165    /// Additional derive macros applied to generated types.
166    pub fn with_derive(&mut self, derive: impl ToString) -> &mut Self {
167        self.extra_derives.push(derive.to_string());
168        self
169    }
170
171    /// Additional trait bounds applied to `CliConfig` methods.
172    pub fn with_cli_bounds(&mut self, derive: impl ToString) -> &mut Self {
173        self.extra_cli_bounds.push(derive.to_string());
174        self
175    }
176
177    /// Modify a type with the given name.
178    /// See [typify::TypeSpaceSettings::with_patch].
179    pub fn with_patch<S: AsRef<str>>(&mut self, type_name: S, patch: &TypePatch) -> &mut Self {
180        self.patch
181            .insert(type_name.as_ref().to_string(), patch.clone());
182        self
183    }
184
185    /// Replace a referenced type with a named type.
186    /// See [typify::TypeSpaceSettings::with_replacement].
187    pub fn with_replacement<TS: ToString, RS: ToString, I: Iterator<Item = TypeImpl>>(
188        &mut self,
189        type_name: TS,
190        replace_name: RS,
191        impls: I,
192    ) -> &mut Self {
193        self.replace.insert(
194            type_name.to_string(),
195            (replace_name.to_string(), impls.collect()),
196        );
197        self
198    }
199
200    /// Replace a given schema with a named type.
201    /// See [typify::TypeSpaceSettings::with_conversion].
202    pub fn with_conversion<S: ToString, I: Iterator<Item = TypeImpl>>(
203        &mut self,
204        schema: schemars::schema::SchemaObject,
205        type_name: S,
206        impls: I,
207    ) -> &mut Self {
208        self.convert
209            .push((schema, type_name.to_string(), impls.collect()));
210        self
211    }
212
213    /// Policy regarding crates referenced by the schema extension
214    /// `x-rust-type` not explicitly specified via [Self::with_crate].
215    /// See [typify::TypeSpaceSettings::with_unknown_crates].
216    pub fn with_unknown_crates(&mut self, policy: UnknownPolicy) -> &mut Self {
217        self.unknown_crates = policy;
218        self
219    }
220
221    /// Explicitly named crates whose types may be used during generation
222    /// rather than generating new types based on their schemas (base on the
223    /// presence of the x-rust-type extension).
224    /// See [typify::TypeSpaceSettings::with_crate].
225    pub fn with_crate<S1: ToString>(
226        &mut self,
227        crate_name: S1,
228        version: CrateVers,
229        rename: Option<&String>,
230    ) -> &mut Self {
231        self.crates.insert(
232            crate_name.to_string(),
233            CrateSpec {
234                version,
235                rename: rename.cloned(),
236            },
237        );
238        self
239    }
240
241    /// Set the type used for key-value maps. Common examples:
242    /// - [`std::collections::HashMap`] - **Default**
243    /// - [`std::collections::BTreeMap`]
244    /// - [`indexmap::IndexMap`]
245    ///
246    /// The requiremnets for a map type can be found in the
247    /// [typify::TypeSpaceSettings::with_map_type] documentation.
248    pub fn with_map_type<MT: ToString>(&mut self, map_type: MT) -> &mut Self {
249        self.map_type = Some(map_type.to_string());
250        self
251    }
252
253    /// Set the underlying reqwest client's timeout
254    pub fn with_timeout(&mut self, timeout: u64) -> &mut Self {
255        self.timeout = Some(timeout);
256        self
257    }
258}
259
260impl Default for Generator {
261    fn default() -> Self {
262        Self {
263            type_space: TypeSpace::new(TypeSpaceSettings::default().with_type_mod("types")),
264            settings: Default::default(),
265            uses_futures: Default::default(),
266            uses_websockets: Default::default(),
267        }
268    }
269}
270
271impl Generator {
272    /// Create a new generator with default values.
273    pub fn new(settings: &GenerationSettings) -> Self {
274        let mut type_settings = TypeSpaceSettings::default();
275        type_settings
276            .with_type_mod("types")
277            .with_struct_builder(settings.interface == InterfaceStyle::Builder);
278        settings.extra_derives.iter().for_each(|derive| {
279            let _ = type_settings.with_derive(derive.clone());
280        });
281
282        // Control use of crates found in x-rust-type extension
283        type_settings.with_unknown_crates(settings.unknown_crates);
284        settings
285            .crates
286            .iter()
287            .for_each(|(crate_name, CrateSpec { version, rename })| {
288                type_settings.with_crate(crate_name, version.clone(), rename.as_ref());
289            });
290
291        // Adjust generation by type, name, or schema.
292        settings.patch.iter().for_each(|(type_name, patch)| {
293            type_settings.with_patch(type_name, patch);
294        });
295        settings
296            .replace
297            .iter()
298            .for_each(|(type_name, (replace_name, impls))| {
299                type_settings.with_replacement(type_name, replace_name, impls.iter().cloned());
300            });
301        settings
302            .convert
303            .iter()
304            .for_each(|(schema, type_name, impls)| {
305                type_settings.with_conversion(schema.clone(), type_name, impls.iter().cloned());
306            });
307
308        // Set the map type if specified.
309        if let Some(map_type) = &settings.map_type {
310            type_settings.with_map_type(map_type.clone());
311        }
312
313        Self {
314            type_space: TypeSpace::new(&type_settings),
315            settings: settings.clone(),
316            uses_futures: false,
317            uses_websockets: false,
318        }
319    }
320
321    /// Emit a [TokenStream] containing the generated client code.
322    pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result<TokenStream> {
323        validate_openapi(spec)?;
324
325        // Convert our components dictionary to schemars
326        let schemas = spec.components.iter().flat_map(|components| {
327            components
328                .schemas
329                .iter()
330                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
331        });
332
333        self.type_space.add_ref_types(schemas)?;
334
335        let raw_methods = spec
336            .paths
337            .iter()
338            .flat_map(|(path, ref_or_item)| {
339                // Exclude externally defined path items.
340                let item = ref_or_item.as_item().unwrap();
341                item.iter().map(move |(method, operation)| {
342                    (path.as_str(), method, operation, &item.parameters)
343                })
344            })
345            .map(|(path, method, operation, path_parameters)| {
346                self.process_operation(operation, &spec.components, path, method, path_parameters)
347            })
348            .collect::<Result<Vec<_>>>()?;
349
350        let operation_code = match (&self.settings.interface, &self.settings.tag) {
351            (InterfaceStyle::Positional, TagStyle::Merged) => self
352                .generate_tokens_positional_merged(
353                    &raw_methods,
354                    self.settings.inner_type.is_some(),
355                ),
356            (InterfaceStyle::Positional, TagStyle::Separate) => {
357                unimplemented!("positional arguments with separate tags are currently unsupported")
358            }
359            (InterfaceStyle::Builder, TagStyle::Merged) => self
360                .generate_tokens_builder_merged(&raw_methods, self.settings.inner_type.is_some()),
361            (InterfaceStyle::Builder, TagStyle::Separate) => {
362                let tag_info = spec
363                    .tags
364                    .iter()
365                    .map(|tag| (&tag.name, tag))
366                    .collect::<BTreeMap<_, _>>();
367                self.generate_tokens_builder_separate(
368                    &raw_methods,
369                    tag_info,
370                    self.settings.inner_type.is_some(),
371                )
372            }
373        }?;
374
375        let types = self.type_space.to_stream();
376
377        let (inner_type, inner_fn_value) = match self.settings.inner_type.as_ref() {
378            Some(inner_type) => (inner_type.clone(), quote! { &self.inner }),
379            None => (quote! { () }, quote! { &() }),
380        };
381
382        let inner_property = self.settings.inner_type.as_ref().map(|inner| {
383            quote! {
384                pub (crate) inner: #inner,
385            }
386        });
387        let inner_parameter = self.settings.inner_type.as_ref().map(|inner| {
388            quote! {
389                inner: #inner,
390            }
391        });
392        let inner_value = self.settings.inner_type.as_ref().map(|_| {
393            quote! {
394                inner
395            }
396        });
397        let client_timeout = self.settings.timeout.unwrap_or(15);
398
399        let client_docstring = {
400            let mut s = format!("Client for {}", spec.info.title);
401
402            if let Some(ss) = &spec.info.description {
403                s.push_str("\n\n");
404                s.push_str(ss);
405            }
406            if let Some(ss) = &spec.info.terms_of_service {
407                s.push_str("\n\n");
408                s.push_str(ss);
409            }
410
411            s.push_str(&format!("\n\nVersion: {}", &spec.info.version));
412
413            s
414        };
415
416        let version_str = &spec.info.version;
417
418        // The allow(unused_imports) on the `pub use` is necessary with Rust
419        // 1.76+, in case the generated file is not at the top level of the
420        // crate.
421
422        let file = quote! {
423            // Re-export types that are used by the public interface of Client.
424            #[allow(unused_imports)]
425            pub use progenitor_client::{
426                ByteStream,
427                ClientInfo,
428                Error,
429                ResponseValue,
430            };
431            #[allow(unused_imports)]
432            use progenitor_client::{
433                encode_path,
434                ClientHooks,
435                OperationInfo,
436                RequestBuilderExt,
437            };
438
439            /// Types used as operation parameters and responses.
440            #[allow(clippy::all)]
441            pub mod types {
442                #types
443            }
444
445            #[derive(Clone, Debug)]
446            #[doc = #client_docstring]
447            pub struct Client {
448                pub(crate) baseurl: String,
449                pub(crate) client: reqwest::Client,
450                #inner_property
451            }
452
453            impl Client {
454                /// Create a new client.
455                ///
456                /// `baseurl` is the base URL provided to the internal
457                /// `reqwest::Client`, and should include a scheme and hostname,
458                /// as well as port and a path stem if applicable.
459                pub fn new(
460                    baseurl: &str,
461                    #inner_parameter
462                ) -> Self {
463                    #[cfg(not(target_arch = "wasm32"))]
464                    let client = {
465                        let dur = ::std::time::Duration::from_secs(#client_timeout);
466
467                        reqwest::ClientBuilder::new()
468                            .connect_timeout(dur)
469                            .timeout(dur)
470                    };
471
472                    #[cfg(target_arch = "wasm32")]
473                    let client = reqwest::ClientBuilder::new();
474
475                    Self::new_with_client(baseurl, client.build().unwrap(), #inner_value)
476                }
477
478                /// Construct a new client with an existing `reqwest::Client`,
479                /// allowing more control over its configuration.
480                ///
481                /// `baseurl` is the base URL provided to the internal
482                /// `reqwest::Client`, and should include a scheme and hostname,
483                /// as well as port and a path stem if applicable.
484                pub fn new_with_client(
485                    baseurl: &str,
486                    client: reqwest::Client,
487                    #inner_parameter
488                ) -> Self {
489                    Self {
490                        baseurl: baseurl.to_string(),
491                        client,
492                        #inner_value
493                    }
494                }
495            }
496
497            impl ClientInfo<#inner_type> for Client {
498                fn api_version() -> &'static str {
499                    #version_str
500                }
501
502                fn baseurl(&self) -> &str {
503                    self.baseurl.as_str()
504                }
505
506                fn client(&self) -> &reqwest::Client {
507                    &self.client
508                }
509
510                fn inner(&self) -> &#inner_type {
511                    #inner_fn_value
512                }
513            }
514
515            impl ClientHooks<#inner_type> for &Client {}
516
517            #operation_code
518        };
519
520        Ok(file)
521    }
522
523    fn generate_tokens_positional_merged(
524        &mut self,
525        input_methods: &[method::OperationMethod],
526        has_inner: bool,
527    ) -> Result<TokenStream> {
528        let methods = input_methods
529            .iter()
530            .map(|method| self.positional_method(method, has_inner))
531            .collect::<Result<Vec<_>>>()?;
532
533        // The allow(unused_imports) on the `pub use` is necessary with Rust
534        // 1.76+, in case the generated file is not at the top level of the
535        // crate.
536
537        let out = quote! {
538            #[allow(clippy::all)]
539            impl Client {
540                #(#methods)*
541            }
542
543            /// Items consumers will typically use such as the Client.
544            pub mod prelude {
545                #[allow(unused_imports)]
546                pub use super::Client;
547            }
548        };
549        Ok(out)
550    }
551
552    fn generate_tokens_builder_merged(
553        &mut self,
554        input_methods: &[method::OperationMethod],
555        has_inner: bool,
556    ) -> Result<TokenStream> {
557        let builder_struct = input_methods
558            .iter()
559            .map(|method| self.builder_struct(method, TagStyle::Merged, has_inner))
560            .collect::<Result<Vec<_>>>()?;
561
562        let builder_methods = input_methods
563            .iter()
564            .map(|method| self.builder_impl(method))
565            .collect::<Vec<_>>();
566
567        let out = quote! {
568            impl Client {
569                #(#builder_methods)*
570            }
571
572            /// Types for composing operation parameters.
573            #[allow(clippy::all)]
574            pub mod builder {
575                use super::types;
576                #[allow(unused_imports)]
577                use super::{
578                    encode_path,
579                    ByteStream,
580                    ClientInfo,
581                    ClientHooks,
582                    Error,
583                    OperationInfo,
584                    RequestBuilderExt,
585                    ResponseValue,
586                };
587
588                #(#builder_struct)*
589            }
590
591            /// Items consumers will typically use such as the Client.
592            pub mod prelude {
593                pub use self::super::Client;
594            }
595        };
596
597        Ok(out)
598    }
599
600    fn generate_tokens_builder_separate(
601        &mut self,
602        input_methods: &[method::OperationMethod],
603        tag_info: BTreeMap<&String, &openapiv3::Tag>,
604        has_inner: bool,
605    ) -> Result<TokenStream> {
606        let builder_struct = input_methods
607            .iter()
608            .map(|method| self.builder_struct(method, TagStyle::Separate, has_inner))
609            .collect::<Result<Vec<_>>>()?;
610
611        let (traits_and_impls, trait_preludes) = self.builder_tags(input_methods, &tag_info);
612
613        // The allow(unused_imports) on the `pub use` is necessary with Rust
614        // 1.76+, in case the generated file is not at the top level of the
615        // crate.
616
617        let out = quote! {
618            #traits_and_impls
619
620            /// Types for composing operation parameters.
621            #[allow(clippy::all)]
622            pub mod builder {
623                use super::types;
624                #[allow(unused_imports)]
625                use super::{
626                    encode_path,
627                    ByteStream,
628                    ClientInfo,
629                    ClientHooks,
630                    Error,
631                    OperationInfo,
632                    RequestBuilderExt,
633                    ResponseValue,
634                };
635
636                #(#builder_struct)*
637            }
638
639            /// Items consumers will typically use such as the Client and
640            /// extension traits.
641            pub mod prelude {
642                #[allow(unused_imports)]
643                pub use super::Client;
644                #trait_preludes
645            }
646        };
647
648        Ok(out)
649    }
650
651    /// Get the [TypeSpace] for schemas present in the OpenAPI specification.
652    pub fn get_type_space(&self) -> &TypeSpace {
653        &self.type_space
654    }
655
656    /// Whether the generated client needs to use additional crates to support
657    /// futures.
658    pub fn uses_futures(&self) -> bool {
659        self.uses_futures
660    }
661
662    /// Whether the generated client needs to use additional crates to support
663    /// websockets.
664    pub fn uses_websockets(&self) -> bool {
665        self.uses_websockets
666    }
667}
668
669/// Add newlines after end-braces at <= two levels of indentation.
670pub fn space_out_items(content: String) -> Result<String> {
671    Ok(if cfg!(not(windows)) {
672        let regex = regex::Regex::new(r#"(\n\s*})(\n\s{0,8}[^} ])"#).unwrap();
673        regex.replace_all(&content, "$1\n$2").to_string()
674    } else {
675        let regex = regex::Regex::new(r#"(\n\s*})(\r\n\s{0,8}[^} ])"#).unwrap();
676        regex.replace_all(&content, "$1\r\n$2").to_string()
677    })
678}
679
680fn validate_openapi_spec_version(spec_version: &str) -> Result<()> {
681    // progenitor currenlty only support OAS 3.0.x
682    if spec_version.trim().starts_with("3.0.") {
683        Ok(())
684    } else {
685        Err(Error::UnexpectedFormat(format!(
686            "invalid version: {}",
687            spec_version
688        )))
689    }
690}
691
692/// Do some very basic checks of the OpenAPI documents.
693pub fn validate_openapi(spec: &OpenAPI) -> Result<()> {
694    validate_openapi_spec_version(spec.openapi.as_str())?;
695
696    let mut opids = HashSet::new();
697    spec.paths.paths.iter().try_for_each(|p| {
698        match p.1 {
699            openapiv3::ReferenceOr::Reference { reference: _ } => Err(Error::UnexpectedFormat(
700                format!("path {} uses reference, unsupported", p.0,),
701            )),
702            openapiv3::ReferenceOr::Item(item) => {
703                // Make sure every operation has an operation ID, and that each
704                // operation ID is only used once in the document.
705                item.iter().try_for_each(|(_, o)| {
706                    if let Some(oid) = o.operation_id.as_ref() {
707                        if !opids.insert(oid.to_string()) {
708                            return Err(Error::UnexpectedFormat(format!(
709                                "duplicate operation ID: {}",
710                                oid,
711                            )));
712                        }
713                    } else {
714                        return Err(Error::UnexpectedFormat(format!(
715                            "path {} is missing operation ID",
716                            p.0,
717                        )));
718                    }
719                    Ok(())
720                })
721            }
722        }
723    })?;
724
725    Ok(())
726}
727
728#[cfg(test)]
729mod tests {
730    use serde_json::json;
731
732    use crate::{validate_openapi_spec_version, Error};
733
734    #[test]
735    fn test_bad_value() {
736        assert_eq!(
737            Error::BadValue("nope".to_string(), json! { "nope"},).to_string(),
738            "unexpected value type nope: \"nope\"",
739        );
740    }
741
742    #[test]
743    fn test_type_error() {
744        assert_eq!(
745            Error::UnexpectedFormat("nope".to_string()).to_string(),
746            "unexpected or unhandled format in the OpenAPI document nope",
747        );
748    }
749
750    #[test]
751    fn test_invalid_path() {
752        assert_eq!(
753            Error::InvalidPath("nope".to_string()).to_string(),
754            "invalid operation path nope",
755        );
756    }
757
758    #[test]
759    fn test_internal_error() {
760        assert_eq!(
761            Error::InternalError("nope".to_string()).to_string(),
762            "internal error nope",
763        );
764    }
765
766    #[test]
767    fn test_validate_openapi_spec_version() {
768        assert!(validate_openapi_spec_version("3.0.0").is_ok());
769        assert!(validate_openapi_spec_version("3.0.1").is_ok());
770        assert!(validate_openapi_spec_version("3.0.4").is_ok());
771        assert!(validate_openapi_spec_version("3.0.5-draft").is_ok());
772        assert_eq!(
773            validate_openapi_spec_version("3.1.0")
774                .unwrap_err()
775                .to_string(),
776            "unexpected or unhandled format in the OpenAPI document invalid version: 3.1.0"
777        );
778    }
779}