Skip to main content

progenitor_impl/
cli.rs

1// Copyright 2024 Oxide Computer Company
2
3use std::collections::BTreeMap;
4
5use heck::ToKebabCase;
6use openapiv3::OpenAPI;
7use proc_macro2::TokenStream;
8use quote::{ToTokens, format_ident, quote};
9use typify::{Type, TypeEnumVariant, TypeSpaceImpl, TypeStructPropInfo};
10
11use crate::{
12    Generator, Result,
13    method::{OperationParameterKind, OperationParameterType, OperationResponseStatus},
14    to_schema::ToSchema,
15    util::{Case, sanitize},
16    validate_openapi,
17};
18
19struct CliOperation {
20    cli_fn: TokenStream,
21    execute_fn: TokenStream,
22    execute_trait: TokenStream,
23}
24
25impl Generator {
26    /// Generate a `clap`-based CLI.
27    pub fn cli(&mut self, spec: &OpenAPI, crate_name: &str) -> Result<TokenStream> {
28        validate_openapi(spec)?;
29
30        // Convert our components dictionary to schemars
31        let schemas = spec.components.iter().flat_map(|components| {
32            components
33                .schemas
34                .iter()
35                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
36        });
37
38        self.type_space.add_ref_types(schemas)?;
39
40        let raw_methods = spec
41            .paths
42            .iter()
43            .flat_map(|(path, ref_or_item)| {
44                // Exclude externally defined path items.
45                let item = ref_or_item.as_item().unwrap();
46                item.iter().map(move |(method, operation)| {
47                    (path.as_str(), method, operation, &item.parameters)
48                })
49            })
50            .map(|(path, method, operation, path_parameters)| {
51                self.process_operation(operation, &spec.components, path, method, path_parameters)
52            })
53            .collect::<Result<Vec<_>>>()?;
54
55        let methods = raw_methods
56            .iter()
57            .map(|method| self.cli_method(method))
58            .collect::<Vec<_>>();
59
60        let cli_ops = methods.iter().map(|op| &op.cli_fn);
61        let execute_ops = methods.iter().map(|op| &op.execute_fn);
62        let trait_ops = methods.iter().map(|op| &op.execute_trait);
63
64        let cli_fns = raw_methods
65            .iter()
66            .map(|method| format_ident!("cli_{}", sanitize(&method.operation_id, Case::Snake)))
67            .collect::<Vec<_>>();
68        let execute_fns = raw_methods
69            .iter()
70            .map(|method| format_ident!("execute_{}", sanitize(&method.operation_id, Case::Snake)))
71            .collect::<Vec<_>>();
72
73        let cli_variants = raw_methods
74            .iter()
75            .map(|method| format_ident!("{}", sanitize(&method.operation_id, Case::Pascal)))
76            .collect::<Vec<_>>();
77
78        let crate_path = syn::TypePath {
79            qself: None,
80            path: syn::parse_str(crate_name).unwrap(),
81        };
82
83        let cli_bounds: Vec<_> = self
84            .settings
85            .extra_cli_bounds
86            .iter()
87            .map(|b| syn::parse_str::<syn::Path>(b).unwrap().into_token_stream())
88            .collect();
89
90        let code = quote! {
91            use #crate_path::*;
92            use anyhow::Context as _;
93
94            pub struct Cli<T: CliConfig> {
95                client: Client,
96                config: T,
97            }
98            impl<T: CliConfig> Cli<T> {
99                pub fn new(
100                    client: Client,
101                    config: T,
102                ) -> Self {
103                    Self { client, config }
104                }
105
106                pub fn get_command(cmd: CliCommand) -> ::clap::Command {
107                    match cmd {
108                        #(
109                            CliCommand::#cli_variants => Self::#cli_fns(),
110                        )*
111                    }
112                }
113
114                #(#cli_ops)*
115
116                pub async fn execute(
117                    &self,
118                    cmd: CliCommand,
119                    matches: &::clap::ArgMatches,
120                ) -> anyhow::Result<()> {
121                    match cmd {
122                        #(
123                            CliCommand::#cli_variants => {
124                                // TODO ... do something with output
125                                self.#execute_fns(matches).await
126                            }
127                        )*
128                    }
129                }
130
131                #(#execute_ops)*
132            }
133
134            pub trait CliConfig {
135                fn success_item<T>(&self, value: &ResponseValue<T>)
136                where
137                    T: #(#cli_bounds+)* schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
138                fn success_no_item(&self, value: &ResponseValue<()>);
139                fn error<T>(&self, value: &Error<T>)
140                where
141                    T: #(#cli_bounds+)* schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
142
143                fn list_start<T>(&self)
144                where
145                    T: #(#cli_bounds+)* schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
146                fn list_item<T>(&self, value: &T)
147                where
148                    T: #(#cli_bounds+)* schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
149                fn list_end_success<T>(&self)
150                where
151                    T: #(#cli_bounds+)* schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
152                fn list_end_error<T>(&self, value: &Error<T>)
153                where
154                    T: #(#cli_bounds+)* schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
155
156                #(#trait_ops)*
157            }
158
159            #[derive(Copy, Clone, Debug)]
160            pub enum CliCommand {
161                #(#cli_variants,)*
162            }
163
164            impl CliCommand {
165                pub fn iter() -> impl Iterator<Item = CliCommand> {
166                    vec![
167                        #(
168                            CliCommand::#cli_variants,
169                        )*
170                    ].into_iter()
171                }
172            }
173
174        };
175
176        Ok(code)
177    }
178
179    fn cli_method(&mut self, method: &crate::method::OperationMethod) -> CliOperation {
180        let CliArg {
181            parser: parser_args,
182            consumer: consumer_args,
183        } = self.cli_method_args(method);
184
185        let about = method.summary.as_ref().map(|summary| {
186            quote! {
187                .about(#summary)
188            }
189        });
190
191        let long_about = method.description.as_ref().map(|description| {
192            quote! {
193                .long_about(#description)
194            }
195        });
196
197        let fn_name = format_ident!("cli_{}", &method.operation_id);
198
199        let cli_fn = quote! {
200            pub fn #fn_name() -> ::clap::Command
201            {
202                ::clap::Command::new("")
203                #parser_args
204                #about
205                #long_about
206            }
207        };
208
209        let fn_name = format_ident!("execute_{}", &method.operation_id);
210        let op_name = format_ident!("{}", &method.operation_id);
211
212        let (_, success_kind) =
213            self.extract_responses(method, OperationResponseStatus::is_success_or_default);
214        let (_, error_kind) =
215            self.extract_responses(method, OperationResponseStatus::is_error_or_default);
216
217        let execute_and_output = match method.dropshot_paginated {
218            // Normal, one-shot API calls.
219            None => {
220                let success_output = match success_kind {
221                    crate::method::OperationResponseKind::Type(_) => {
222                        quote! {
223                            {
224                                self.config.success_item(&r);
225                                Ok(())
226                            }
227                        }
228                    }
229                    crate::method::OperationResponseKind::None => {
230                        quote! {
231                            {
232                                self.config.success_no_item(&r);
233                                Ok(())
234                            }
235                        }
236                    }
237                    crate::method::OperationResponseKind::Raw
238                    | crate::method::OperationResponseKind::Upgrade => {
239                        quote! {
240                            {
241                                todo!()
242                            }
243                        }
244                    }
245                };
246
247                let error_output = match error_kind {
248                    crate::method::OperationResponseKind::Type(_)
249                    | crate::method::OperationResponseKind::None => {
250                        quote! {
251                            {
252                                self.config.error(&r);
253                                Err(anyhow::Error::new(r))
254                            }
255                        }
256                    }
257                    crate::method::OperationResponseKind::Raw
258                    | crate::method::OperationResponseKind::Upgrade => {
259                        quote! {
260                            {
261                                todo!()
262                            }
263                        }
264                    }
265                };
266
267                quote! {
268                    let result = request.send().await;
269
270                    match result {
271                        Ok(r) => #success_output
272                        Err(r) => #error_output
273                    }
274                }
275            }
276
277            // Paginated APIs for which we iterate over each item.
278            Some(_) => {
279                let success_type = match success_kind {
280                    crate::method::OperationResponseKind::Type(type_id) => {
281                        self.type_space.get_type(&type_id).unwrap().ident()
282                    }
283                    crate::method::OperationResponseKind::None => quote! { () },
284                    crate::method::OperationResponseKind::Raw => todo!(),
285                    crate::method::OperationResponseKind::Upgrade => todo!(),
286                };
287                let error_output = match error_kind {
288                    crate::method::OperationResponseKind::Type(_)
289                    | crate::method::OperationResponseKind::None => {
290                        quote! {
291                            {
292                                self.config.list_end_error(&r);
293                                return Err(anyhow::Error::new(r))
294                            }
295                        }
296                    }
297                    crate::method::OperationResponseKind::Raw
298                    | crate::method::OperationResponseKind::Upgrade => {
299                        quote! {
300                            {
301                                todo!()
302                            }
303                        }
304                    }
305                };
306                quote! {
307                    self.config.list_start::<#success_type>();
308
309                    // We're using "limit" as both the maximum page size and
310                    // as the full limit. It's not ideal in that we could
311                    // reduce the limit with each iteration and we might get a
312                    // bunch of results we don't display... but it's fine.
313                    let mut stream = futures::StreamExt::take(
314                        request.stream(),
315                        matches
316                            .get_one::<std::num::NonZeroU32>("limit")
317                            .map_or(usize::MAX, |x| x.get() as usize));
318
319                    loop {
320                        match futures::TryStreamExt::try_next(&mut stream).await {
321                            Err(r) => #error_output
322                            Ok(None) => {
323                                self.config.list_end_success::<#success_type>();
324                                return Ok(());
325                            }
326                            Ok(Some(value)) => {
327                                self.config.list_item(&value);
328                            }
329                        }
330                    }
331                }
332            }
333        };
334
335        let execute_fn = quote! {
336            pub async fn #fn_name(&self, matches: &::clap::ArgMatches)
337                -> anyhow::Result<()>
338            {
339                let mut request = self.client.#op_name();
340                #consumer_args
341
342                // Call the override function.
343                self.config.#fn_name(matches, &mut request)?;
344
345                #execute_and_output
346            }
347        };
348
349        // TODO this is copy-pasted--unwisely?
350        let struct_name = sanitize(&method.operation_id, Case::Pascal);
351        let struct_ident = format_ident!("{}", struct_name);
352
353        let execute_trait = quote! {
354            fn #fn_name(
355                &self,
356                matches: &::clap::ArgMatches,
357                request: &mut builder :: #struct_ident,
358            ) -> anyhow::Result<()> {
359                Ok(())
360            }
361        };
362
363        CliOperation {
364            cli_fn,
365            execute_fn,
366            execute_trait,
367        }
368    }
369
370    fn cli_method_args(&self, method: &crate::method::OperationMethod) -> CliArg {
371        let mut args = CliOperationArgs::default();
372
373        let first_page_required_set = method
374            .dropshot_paginated
375            .as_ref()
376            .map(|d| &d.first_page_params);
377
378        for param in &method.params {
379            let innately_required = match &param.kind {
380                // We're not interetested in the body parameter yet.
381                OperationParameterKind::Body(_) => continue,
382
383                OperationParameterKind::Path => true,
384                OperationParameterKind::Query(required) => *required,
385                OperationParameterKind::Header(required) => *required,
386            };
387
388            // For paginated endpoints, we don't generate 'page_token' args.
389            if method.dropshot_paginated.is_some() && param.name.as_str() == "page_token" {
390                continue;
391            }
392
393            let first_page_required = first_page_required_set
394                .map_or(false, |required| required.contains(&param.api_name));
395
396            let volitionality = if innately_required || first_page_required {
397                Volitionality::Required
398            } else {
399                Volitionality::Optional
400            };
401
402            let OperationParameterType::Type(arg_type_id) = &param.typ else {
403                unreachable!("query and path parameters must be typed")
404            };
405            let arg_type = self.type_space.get_type(arg_type_id).unwrap();
406
407            let arg_name = param.name.to_kebab_case();
408
409            // There should be no conflicting path or query parameters.
410            assert!(!args.has_arg(&arg_name));
411
412            let parser = clap_arg(&arg_name, volitionality, &param.description, &arg_type);
413
414            let arg_fn_name = sanitize(&param.name, Case::Snake);
415            let arg_fn = format_ident!("{}", arg_fn_name);
416            let OperationParameterType::Type(arg_type_id) = &param.typ else {
417                panic!()
418            };
419            let arg_type = self.type_space.get_type(arg_type_id).unwrap();
420            let arg_type_name = arg_type.ident();
421
422            let consumer = quote! {
423                if let Some(value) =
424                    matches.get_one::<#arg_type_name>(#arg_name)
425                {
426                    // clone here in case the arg type doesn't impl
427                    // From<&T>
428                    request = request.#arg_fn(value.clone());
429                }
430            };
431
432            args.add_arg(arg_name, CliArg { parser, consumer })
433        }
434
435        let maybe_body_type_id = method
436            .params
437            .iter()
438            .find(|param| matches!(&param.kind, OperationParameterKind::Body(_)))
439            .and_then(|param| match &param.typ {
440                // TODO not sure how to deal with raw bodies, but we definitely
441                // need **some** input so we shouldn't just ignore it... as we
442                // are currently...
443                OperationParameterType::RawBody => None,
444
445                OperationParameterType::Type(body_type_id) => Some(body_type_id),
446            });
447
448        if let Some(body_type_id) = maybe_body_type_id {
449            args.body_present();
450            let body_type = self.type_space.get_type(body_type_id).unwrap();
451            let details = body_type.details();
452
453            match details {
454                typify::TypeDetails::Struct(struct_info) => {
455                    for prop_info in struct_info.properties_info() {
456                        self.cli_method_body_arg(&mut args, prop_info)
457                    }
458                }
459
460                _ => {
461                    // If the body is not a struct, we don't know what's
462                    // required or how to generate it
463                    args.body_required()
464                }
465            }
466        }
467
468        let parser_args = args.args.values().map(|CliArg { parser, .. }| parser);
469
470        // TODO do this as args we add in.
471        let body_json_args = (match args.body {
472            CliBodyArg::None => None,
473            CliBodyArg::Required => Some(true),
474            CliBodyArg::Optional => Some(false),
475        })
476        .map(|required| {
477            let help = "Path to a file that contains the full json body.";
478
479            quote! {
480                .arg(
481                    ::clap::Arg::new("json-body")
482                        .long("json-body")
483                        .value_name("JSON-FILE")
484                        // Required if we can't turn the body into individual
485                        // parameters.
486                        .required(#required)
487                        .value_parser(::clap::value_parser!(std::path::PathBuf))
488                        .help(#help)
489                )
490                .arg(
491                    ::clap::Arg::new("json-body-template")
492                        .long("json-body-template")
493                        .action(::clap::ArgAction::SetTrue)
494                        .help("XXX")
495                )
496            }
497        });
498
499        let parser = quote! {
500            #(
501                .arg(#parser_args)
502            )*
503            #body_json_args
504        };
505
506        let consumer_args = args.args.values().map(|CliArg { consumer, .. }| consumer);
507
508        let body_json_consumer = maybe_body_type_id.map(|body_type_id| {
509            let body_type = self.type_space.get_type(body_type_id).unwrap();
510            let body_type_ident = body_type.ident();
511            quote! {
512                if let Some(value) =
513                    matches.get_one::<std::path::PathBuf>("json-body")
514                {
515                    let body_txt = std::fs::read_to_string(value).with_context(|| format!("failed to read {}", value.display()))?;
516                    let body_value =
517                        serde_json::from_str::<#body_type_ident>(
518                            &body_txt,
519                        )
520                        .with_context(|| format!("failed to parse {}", value.display()))?;
521                    request = request.body(body_value);
522                }
523            }
524        });
525
526        let consumer = quote! {
527            #(
528                #consumer_args
529            )*
530            #body_json_consumer
531        };
532
533        CliArg { parser, consumer }
534    }
535
536    fn cli_method_body_arg(&self, args: &mut CliOperationArgs, prop_info: TypeStructPropInfo<'_>) {
537        let TypeStructPropInfo {
538            name,
539            description,
540            required,
541            type_id,
542        } = prop_info;
543
544        let prop_type = self.type_space.get_type(&type_id).unwrap();
545
546        // TODO this is maybe a kludge--not completely sure of the right way to
547        // handle option types. On one hand, we could want types from this
548        // interface to never show us Option<T> types--we could let the
549        // `required` field give us that information. On the other hand, there
550        // might be Option types that are required ... at least in the JSON
551        // sense, meaning that we need to include `"foo": null` rather than
552        // omitting the field. Back to the first hand: is that last point just
553        // a serde issue rather than an interface one?
554        let maybe_inner_type =
555            if let typify::TypeDetails::Option(inner_type_id) = prop_type.details() {
556                let inner_type = self.type_space.get_type(&inner_type_id).unwrap();
557                Some(inner_type)
558            } else {
559                None
560            };
561
562        let prop_type = if let Some(inner_type) = maybe_inner_type {
563            inner_type
564        } else {
565            prop_type
566        };
567
568        let scalar = prop_type.has_impl(TypeSpaceImpl::FromStr);
569
570        let prop_name = name.to_kebab_case();
571        if scalar && !args.has_arg(&prop_name) {
572            let volitionality = if required {
573                Volitionality::RequiredIfNoBody
574            } else {
575                Volitionality::Optional
576            };
577            let parser = clap_arg(
578                &prop_name,
579                volitionality,
580                &description.map(str::to_string),
581                &prop_type,
582            );
583
584            let prop_fn = format_ident!("{}", sanitize(name, Case::Snake));
585            let prop_type_ident = prop_type.ident();
586            let consumer = quote! {
587                if let Some(value) =
588                    matches.get_one::<#prop_type_ident>(
589                        #prop_name,
590                    )
591                {
592                    // clone here in case the arg type
593                    // doesn't impl TryFrom<&T>
594                    request = request.body_map(|body| {
595                        body.#prop_fn(value.clone())
596                    })
597                }
598            };
599            args.add_arg(prop_name, CliArg { parser, consumer })
600        } else if required {
601            args.body_required()
602        }
603
604        // Cases
605        // 1. If the type can be represented as a string, great
606        //
607        // 2. If it's a substruct then we can try to glue the names together
608        // and hope?
609        //
610        // 3. enums
611        // 3.1 simple enums (should be covered by 1 above)
612        //   e.g. enum { A, B }
613        //   args for --a and --b that are in a group
614    }
615}
616
617enum Volitionality {
618    Optional,
619    Required,
620    RequiredIfNoBody,
621}
622
623fn clap_arg(
624    arg_name: &str,
625    volitionality: Volitionality,
626    description: &Option<String>,
627    arg_type: &Type,
628) -> TokenStream {
629    let help = description.as_ref().map(|description| {
630        quote! {
631            .help(#description)
632        }
633    });
634    let arg_type_name = arg_type.ident();
635
636    // For enums that have **only** simple variants, we do some slightly
637    // fancier argument handling to expose the possible values. In particular,
638    // we use clap's `PossibleValuesParser` with each variant converted to a
639    // string. Then we use TypedValueParser::map to translate that into the
640    // actual type of the enum.
641    let maybe_enum_parser = if let typify::TypeDetails::Enum(e) = arg_type.details() {
642        let maybe_var_names = e
643            .variants()
644            .map(|(var_name, var_details)| {
645                if let TypeEnumVariant::Simple = var_details {
646                    Some(format_ident!("{}", var_name))
647                } else {
648                    None
649                }
650            })
651            .collect::<Option<Vec<_>>>();
652
653        maybe_var_names.map(|var_names| {
654            quote! {
655                ::clap::builder::TypedValueParser::map(
656                    ::clap::builder::PossibleValuesParser::new([
657                        #( #arg_type_name :: #var_names.to_string(), )*
658                    ]),
659                    |s| #arg_type_name :: try_from(s).unwrap()
660                )
661            }
662        })
663    } else {
664        None
665    };
666
667    let value_parser = if let Some(enum_parser) = maybe_enum_parser {
668        enum_parser
669    } else {
670        // Let clap pick a value parser for us. This has the benefit of
671        // allowing for override implementations. A generated client may
672        // implement ValueParserFactory for a type to create a custom parser.
673        quote! {
674            ::clap::value_parser!(#arg_type_name)
675        }
676    };
677
678    let required = match volitionality {
679        Volitionality::Optional => quote! { .required(false) },
680        Volitionality::Required => quote! { .required(true) },
681        Volitionality::RequiredIfNoBody => {
682            quote! { .required_unless_present("json-body") }
683        }
684    };
685
686    quote! {
687        ::clap::Arg::new(#arg_name)
688            .long(#arg_name)
689            .value_parser(#value_parser)
690            #required
691            #help
692    }
693}
694
695#[derive(Debug)]
696struct CliArg {
697    /// Code to parse the argument
698    parser: TokenStream,
699
700    /// Code to consume the argument
701    consumer: TokenStream,
702}
703
704#[derive(Debug, Default, PartialEq, Eq)]
705enum CliBodyArg {
706    #[default]
707    None,
708    Required,
709    Optional,
710}
711
712#[derive(Default, Debug)]
713struct CliOperationArgs {
714    args: BTreeMap<String, CliArg>,
715    body: CliBodyArg,
716}
717
718impl CliOperationArgs {
719    fn has_arg(&self, name: &String) -> bool {
720        self.args.contains_key(name)
721    }
722    fn add_arg(&mut self, name: String, arg: CliArg) {
723        self.args.insert(name, arg);
724    }
725
726    fn body_present(&mut self) {
727        assert_eq!(self.body, CliBodyArg::None);
728        self.body = CliBodyArg::Optional;
729    }
730
731    fn body_required(&mut self) {
732        assert!(self.body == CliBodyArg::Optional || self.body == CliBodyArg::Required);
733        self.body = CliBodyArg::Required;
734    }
735}