progenitor_macro/
lib.rs

1// Copyright 2022 Oxide Computer Company
2
3//! Macros for the progenitor OpenAPI client generator.
4
5#![deny(missing_docs)]
6
7use std::{
8    collections::HashMap,
9    fs::File,
10    path::{Path, PathBuf},
11};
12
13use openapiv3::OpenAPI;
14use proc_macro::TokenStream;
15use progenitor_impl::{
16    CrateVers, GenerationSettings, Generator, InterfaceStyle, TagStyle, TypePatch, UnknownPolicy,
17};
18use quote::{quote, ToTokens};
19use schemars::schema::SchemaObject;
20use serde::Deserialize;
21use serde_tokenstream::{OrderedMap, ParseWrapper};
22use syn::LitStr;
23use token_utils::TypeAndImpls;
24
25mod token_utils;
26
27/// Generates a client from the given OpenAPI document
28///
29/// `generate_api!` can be invoked in two ways. The simple form, takes a path
30/// to the OpenAPI document:
31/// ```ignore
32/// generate_api!("path/to/spec.json");
33/// ```
34///
35/// The more complex form accepts the following key-value pairs in any order:
36/// ```ignore
37/// generate_api!(
38///     spec = "path/to/spec.json",
39///     [ interface = ( Positional | Builder ), ]
40///     [ tags = ( Merged | Separate ), ]
41///     [ pre_hook = closure::or::path::to::function, ]
42///     [ post_hook = closure::or::path::to::function, ]
43///     [ pre_hook_async = closure::or::path::to::function, ]
44///     [ post_hook_async = closure::or::path::to::function, ]
45///
46///     [ derives = [ path::to::DeriveMacro ], ]
47///
48///     [ unknown_crates = (Generate | Allow | Deny ), ]
49///     [ crates = { "<crate-name>" = ("<version>" | "*" | "!" ) } ]
50///
51///     [ patch = { TypeName = { [rename = NewTypeName], [derives = []] }, } ]
52///     [ replace = { TypeName = full_path::to::other::TypeName, }]
53///     [ convert = { { <schema> } = full_path::to::TypeName, }]
54///     [ timeout = u64 ]
55/// );
56/// ```
57///
58/// The `spec` key is required; it is the OpenAPI document (JSON or YAML) from
59/// which the client is derived.
60///
61/// The optional `interface` lets you specify either a `Positional` argument or
62/// `Builder` argument style; `Positional` is the default.
63///
64/// The optional `tags` may be `Merged` in which case all operations are
65/// methods on the `Client` struct or `Separate` in which case each tag is
66/// represented by an "extension trait" that `Client` implements. The default
67/// is `Merged`.
68///
69/// The optional `inner_type` is for ancillary data, stored with the generated
70/// client that can be used by the pre- and post-hooks.
71///
72/// The optional `pre_hook` is either a closure (that must be within
73/// parentheses: `(fn |[inner,] request| { .. })`) or a path to a function. The
74/// closure or function must take one or two parameters: the inner type (if one
75/// is specified) and a `&reqwest::Request`. This allows clients to examine
76/// requests before they're sent to the server, for example to log them. The
77/// optional `pre_hook_async` is the `async` variant of the same.
78///
79/// The optional `post_hook` is either a closure (that must be within
80/// parentheses: `(fn |[inner,] result| { .. })`) or a path to a function. The
81/// closure or function must take one or two parameters: the inner type (if one
82/// is specified) and a `&Result<reqwest::Response, reqwest::Error>`. This
83/// allows clients to examine responses, for example to log them. The optional
84/// `post_hook_async` is the `async` variant of the same.
85///
86/// Additional options control type generation:
87/// - `derives`: optional array of derive macro paths; the derive macros to be
88///   applied to all generated types
89///
90/// - `struct_builder`: optional boolean; (if true) generates a `::builder()`
91///   method for each generated struct that can be used to specify each
92///   property and construct the struct
93///
94/// - `unknown_crates`: optional policy regarding the handling of schemas that
95///   contain the `x-rust-type` extension whose crates are not explicitly named
96///   in the `crates` section. The options are `generate` to ignore the
97///   extension and generate a *de novo* type, `allow` to use the named type
98///   (which may require the addition of a new dependency to compile, and which
99///   ignores version compatibility checks), or `deny` to produce a
100///   compile-time error (requiring the user to specify the crate's disposition
101///   in the `crates` section).
102///
103/// - `crates`: optional map from crate name to the version of the crate in
104///   use. Types encountered with the Rust type extension (`x-rust-type`) will
105///   use types from the specified crates rather than generating them (within
106///   the constraints of type compatibility).
107///
108/// - `patch`: optional map from type to an object with the optional members
109///   `rename` and `derives`. This may be used to renamed generated types or
110///   to apply additional (non-default) derive macros to them.
111///
112/// - `replace`: optional map from definition name to a replacement type. This
113///   may be used to skip generation of the named type and use a existing Rust
114///   type.
115///
116/// - `convert`: optional map from a JSON schema type defined in `$defs` to a
117///   replacement type. This may be used to skip generation of the schema and
118///   use an existing Rust type.
119///
120/// - `timeout`: the default connection timeout for the underlying reqwest
121///   client (15s if not specified)
122#[proc_macro]
123pub fn generate_api(item: TokenStream) -> TokenStream {
124    match do_generate_api(item) {
125        Err(err) => err.to_compile_error().into(),
126        Ok(out) => out,
127    }
128}
129
130#[derive(Deserialize)]
131struct MacroSettings {
132    spec: ParseWrapper<LitStr>,
133    #[serde(default)]
134    interface: InterfaceStyle,
135    #[serde(default)]
136    tags: TagStyle,
137
138    inner_type: Option<ParseWrapper<syn::Type>>,
139    pre_hook: Option<ParseWrapper<ClosureOrPath>>,
140    pre_hook_async: Option<ParseWrapper<ClosureOrPath>>,
141    post_hook: Option<ParseWrapper<ClosureOrPath>>,
142    post_hook_async: Option<ParseWrapper<ClosureOrPath>>,
143
144    map_type: Option<ParseWrapper<syn::Type>>,
145
146    #[serde(default)]
147    derives: Vec<ParseWrapper<syn::Path>>,
148
149    #[serde(default)]
150    unknown_crates: UnknownPolicy,
151    #[serde(default)]
152    crates: HashMap<CrateName, MacroCrateSpec>,
153
154    #[serde(default)]
155    patch: HashMap<ParseWrapper<syn::Ident>, MacroPatch>,
156    #[serde(default)]
157    replace: HashMap<ParseWrapper<syn::Ident>, ParseWrapper<TypeAndImpls>>,
158    #[serde(default)]
159    convert: OrderedMap<SchemaObject, ParseWrapper<TypeAndImpls>>,
160    timeout: Option<u64>,
161}
162
163#[derive(Deserialize)]
164struct MacroPatch {
165    #[serde(default)]
166    rename: Option<String>,
167    #[serde(default)]
168    derives: Vec<ParseWrapper<syn::Path>>,
169}
170
171impl From<MacroPatch> for TypePatch {
172    fn from(a: MacroPatch) -> Self {
173        let mut s = Self::default();
174        a.rename.iter().for_each(|rename| {
175            s.with_rename(rename);
176        });
177        a.derives.iter().for_each(|derive| {
178            s.with_derive(derive.to_token_stream().to_string());
179        });
180        s
181    }
182}
183
184#[derive(Debug)]
185struct ClosureOrPath(proc_macro2::TokenStream);
186
187impl syn::parse::Parse for ClosureOrPath {
188    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
189        let lookahead = input.lookahead1();
190
191        if lookahead.peek(syn::token::Paren) {
192            let group: proc_macro2::Group = input.parse()?;
193            return syn::parse2::<Self>(group.stream());
194        }
195
196        if let Ok(closure) = input.parse::<syn::ExprClosure>() {
197            return Ok(Self(closure.to_token_stream()));
198        }
199
200        input
201            .parse::<syn::Path>()
202            .map(|path| Self(path.to_token_stream()))
203    }
204}
205
206struct MacroCrateSpec {
207    original: Option<String>,
208    version: CrateVers,
209}
210
211impl<'de> Deserialize<'de> for MacroCrateSpec {
212    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
213    where
214        D: serde::Deserializer<'de>,
215    {
216        let ss = String::deserialize(deserializer)?;
217
218        let (original, vers_str) = if let Some(ii) = ss.find('@') {
219            let original_str = &ss[..ii];
220            let rest = &ss[ii + 1..];
221            if !is_crate(original_str) {
222                return Err(<D::Error as serde::de::Error>::invalid_value(
223                    serde::de::Unexpected::Str(&ss),
224                    &"valid crate name",
225                ));
226            }
227
228            (Some(original_str.to_string()), rest)
229        } else {
230            (None, ss.as_ref())
231        };
232
233        let Some(version) = CrateVers::parse(vers_str) else {
234            return Err(<D::Error as serde::de::Error>::invalid_value(
235                serde::de::Unexpected::Str(&ss),
236                &"valid version",
237            ));
238        };
239
240        Ok(Self { original, version })
241    }
242}
243
244#[derive(Hash, PartialEq, Eq)]
245struct CrateName(String);
246impl<'de> Deserialize<'de> for CrateName {
247    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
248    where
249        D: serde::Deserializer<'de>,
250    {
251        let ss = String::deserialize(deserializer)?;
252
253        if is_crate(&ss) {
254            Ok(Self(ss))
255        } else {
256            Err(<D::Error as serde::de::Error>::invalid_value(
257                serde::de::Unexpected::Str(&ss),
258                &"valid crate name",
259            ))
260        }
261    }
262}
263
264fn is_crate(s: &str) -> bool {
265    !s.contains(|cc: char| !cc.is_alphanumeric() && cc != '_' && cc != '-')
266}
267
268fn open_file(path: PathBuf, span: proc_macro2::Span) -> Result<File, syn::Error> {
269    File::open(path.clone()).map_err(|e| {
270        let path_str = path.to_string_lossy();
271        syn::Error::new(span, format!("couldn't read file {}: {}", path_str, e))
272    })
273}
274
275fn do_generate_api(item: TokenStream) -> Result<TokenStream, syn::Error> {
276    let (spec, settings) = if let Ok(spec) = syn::parse::<LitStr>(item.clone()) {
277        (spec, GenerationSettings::default())
278    } else {
279        let MacroSettings {
280            spec,
281            interface,
282            tags,
283            inner_type,
284            pre_hook,
285            pre_hook_async,
286            post_hook,
287            post_hook_async,
288            map_type,
289            unknown_crates,
290            crates,
291            derives,
292            patch,
293            replace,
294            convert,
295            timeout,
296        } = serde_tokenstream::from_tokenstream(&item.into())?;
297
298        let mut settings = GenerationSettings::default();
299        settings.with_interface(interface);
300        settings.with_tag(tags);
301        inner_type.map(|inner_type| settings.with_inner_type(inner_type.to_token_stream()));
302        pre_hook.map(|pre_hook| settings.with_pre_hook(pre_hook.into_inner().0));
303        pre_hook_async
304            .map(|pre_hook_async| settings.with_pre_hook_async(pre_hook_async.into_inner().0));
305        post_hook.map(|post_hook| settings.with_post_hook(post_hook.into_inner().0));
306        post_hook_async
307            .map(|post_hook_async| settings.with_post_hook_async(post_hook_async.into_inner().0));
308        map_type.map(|map_type| settings.with_map_type(map_type.to_token_stream()));
309
310        settings.with_unknown_crates(unknown_crates);
311        crates.into_iter().for_each(
312            |(CrateName(crate_name), MacroCrateSpec { original, version })| {
313                if let Some(original_crate) = original {
314                    settings.with_crate(original_crate, version, Some(&crate_name));
315                } else {
316                    settings.with_crate(crate_name, version, None);
317                }
318            },
319        );
320
321        derives.into_iter().for_each(|derive| {
322            settings.with_derive(derive.to_token_stream());
323        });
324        patch.into_iter().for_each(|(type_name, patch)| {
325            settings.with_patch(type_name.to_token_stream().to_string(), &patch.into());
326        });
327        replace.into_iter().for_each(|(type_name, type_and_impls)| {
328            let type_name = type_name.to_token_stream();
329            let (replace_name, impls) = type_and_impls.into_inner().into_name_and_impls();
330            settings.with_replacement(type_name, replace_name, impls);
331        });
332        convert.into_iter().for_each(|(schema, type_and_impls)| {
333            let (type_name, impls) = type_and_impls.into_inner().into_name_and_impls();
334            settings.with_conversion(schema, type_name, impls);
335        });
336        if let Some(timeout) = timeout {
337            settings.with_timeout(timeout);
338        }
339        (spec.into_inner(), settings)
340    };
341
342    let dir = std::env::var("CARGO_MANIFEST_DIR").map_or_else(
343        |_| std::env::current_dir().unwrap(),
344        |s| Path::new(&s).to_path_buf(),
345    );
346
347    let path = dir.join(spec.value());
348    let path_str = path.to_string_lossy();
349
350    let mut f = open_file(path.clone(), spec.span())?;
351    let oapi: OpenAPI = match serde_json::from_reader(f) {
352        Ok(json_value) => json_value,
353        _ => {
354            f = open_file(path.clone(), spec.span())?;
355            serde_yaml::from_reader(f).map_err(|e| {
356                syn::Error::new(spec.span(), format!("failed to parse {}: {}", path_str, e))
357            })?
358        }
359    };
360
361    let mut builder = Generator::new(&settings);
362
363    let code = builder.generate_tokens(&oapi).map_err(|e| {
364        syn::Error::new(
365            spec.span(),
366            format!("generation error for {}: {}", spec.value(), e),
367        )
368    })?;
369
370    let output = quote! {
371        // The progenitor_client is tautologically visible from macro
372        // consumers.
373        use progenitor::progenitor_client;
374
375        #code
376
377        // Force a rebuild when the given file is modified.
378        const _: &str = include_str!(#path_str);
379    };
380
381    Ok(output.into())
382}