Skip to main content

progenitor_middleware_impl/
cli.rs

1// Copyright 2024 Oxide Computer Company
2
3use std::collections::BTreeMap;
4
5use heck::ToKebabCase;
6use openapiv3::OpenAPI;
7use proc_macro2::TokenStream;
8use quote::{format_ident, quote};
9use typify::{Type, TypeEnumVariant, TypeSpaceImpl, TypeStructPropInfo};
10
11use crate::{
12    method::{OperationParameterKind, OperationParameterType, OperationResponseStatus},
13    to_schema::ToSchema,
14    util::{sanitize, Case},
15    validate_openapi, Generator, Result,
16};
17
18struct CliOperation {
19    cli_fn: TokenStream,
20    execute_fn: TokenStream,
21    execute_trait: TokenStream,
22}
23
24impl Generator {
25    /// Generate a `clap`-based CLI.
26    pub fn cli(&mut self, spec: &OpenAPI, crate_name: &str) -> Result<TokenStream> {
27        validate_openapi(spec)?;
28
29        // Convert our components dictionary to schemars
30        let schemas = spec.components.iter().flat_map(|components| {
31            components
32                .schemas
33                .iter()
34                .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
35        });
36
37        self.type_space.add_ref_types(schemas)?;
38
39        let raw_methods = spec
40            .paths
41            .iter()
42            .flat_map(|(path, ref_or_item)| {
43                // Exclude externally defined path items.
44                let item = ref_or_item.as_item().unwrap();
45                item.iter().map(move |(method, operation)| {
46                    (path.as_str(), method, operation, &item.parameters)
47                })
48            })
49            .map(|(path, method, operation, path_parameters)| {
50                self.process_operation(operation, &spec.components, path, method, path_parameters)
51            })
52            .collect::<Result<Vec<_>>>()?;
53
54        let methods = raw_methods
55            .iter()
56            .map(|method| self.cli_method(method))
57            .collect::<Vec<_>>();
58
59        let cli_ops = methods.iter().map(|op| &op.cli_fn);
60        let execute_ops = methods.iter().map(|op| &op.execute_fn);
61        let trait_ops = methods.iter().map(|op| &op.execute_trait);
62
63        let cli_fns = raw_methods
64            .iter()
65            .map(|method| format_ident!("cli_{}", sanitize(&method.operation_id, Case::Snake)))
66            .collect::<Vec<_>>();
67        let execute_fns = raw_methods
68            .iter()
69            .map(|method| format_ident!("execute_{}", sanitize(&method.operation_id, Case::Snake)))
70            .collect::<Vec<_>>();
71
72        let cli_variants = raw_methods
73            .iter()
74            .map(|method| format_ident!("{}", sanitize(&method.operation_id, Case::Pascal)))
75            .collect::<Vec<_>>();
76
77        let crate_path = syn::TypePath {
78            qself: None,
79            path: syn::parse_str(crate_name).unwrap(),
80        };
81
82        let code = quote! {
83            use #crate_path::*;
84
85            pub struct Cli<T: CliConfig> {
86                client: Client,
87                config: T,
88            }
89            impl<T: CliConfig> Cli<T> {
90                pub fn new(
91                    client: Client,
92                    config: T,
93                ) -> Self {
94                    Self { client, config }
95                }
96
97                pub fn get_command(cmd: CliCommand) -> ::clap::Command {
98                    match cmd {
99                        #(
100                            CliCommand::#cli_variants => Self::#cli_fns(),
101                        )*
102                    }
103                }
104
105                #(#cli_ops)*
106
107                pub async fn execute(
108                    &self,
109                    cmd: CliCommand,
110                    matches: &::clap::ArgMatches,
111                ) -> anyhow::Result<()> {
112                    match cmd {
113                        #(
114                            CliCommand::#cli_variants => {
115                                // TODO ... do something with output
116                                self.#execute_fns(matches).await
117                            }
118                        )*
119                    }
120                }
121
122                #(#execute_ops)*
123            }
124
125            pub trait CliConfig {
126                fn success_item<T>(&self, value: &ResponseValue<T>)
127                where
128                    T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
129                fn success_no_item(&self, value: &ResponseValue<()>);
130                fn error<T>(&self, value: &Error<T>)
131                where
132                    T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
133
134                fn list_start<T>(&self)
135                where
136                    T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
137                fn list_item<T>(&self, value: &T)
138                where
139                    T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
140                fn list_end_success<T>(&self)
141                where
142                    T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
143                fn list_end_error<T>(&self, value: &Error<T>)
144                where
145                    T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
146
147                #(#trait_ops)*
148            }
149
150            #[derive(Copy, Clone, Debug)]
151            pub enum CliCommand {
152                #(#cli_variants,)*
153            }
154
155            impl CliCommand {
156                pub fn iter() -> impl Iterator<Item = CliCommand> {
157                    vec![
158                        #(
159                            CliCommand::#cli_variants,
160                        )*
161                    ].into_iter()
162                }
163            }
164
165        };
166
167        Ok(code)
168    }
169
170    fn cli_method(&mut self, method: &crate::method::OperationMethod) -> CliOperation {
171        let CliArg {
172            parser: parser_args,
173            consumer: consumer_args,
174        } = self.cli_method_args(method);
175
176        let about = method.summary.as_ref().map(|summary| {
177            quote! {
178                .about(#summary)
179            }
180        });
181
182        let long_about = method.description.as_ref().map(|description| {
183            quote! {
184                .long_about(#description)
185            }
186        });
187
188        let fn_name = format_ident!("cli_{}", &method.operation_id);
189
190        let cli_fn = quote! {
191            pub fn #fn_name() -> ::clap::Command
192            {
193                ::clap::Command::new("")
194                #parser_args
195                #about
196                #long_about
197            }
198        };
199
200        let fn_name = format_ident!("execute_{}", &method.operation_id);
201        let op_name = format_ident!("{}", &method.operation_id);
202
203        let (_, success_response_type) =
204            self.extract_responses(method, OperationResponseStatus::is_success_or_default);
205        let (_, error_response_type) =
206            self.extract_responses(method, OperationResponseStatus::is_error_or_default);
207
208        // Extract the underlying OperationResponseKind from ErrorResponseType
209        // For CLI, we only support Single error types (not Multiple)
210        let success_kind = match success_response_type {
211            crate::method::ErrorResponseType::Single(kind) => kind,
212            crate::method::ErrorResponseType::Multiple { .. } => {
213                panic!("CLI generation does not support operations with multiple success types");
214            }
215        };
216        let error_kind = match error_response_type {
217            crate::method::ErrorResponseType::Single(kind) => kind,
218            crate::method::ErrorResponseType::Multiple { .. } => {
219                panic!("CLI generation does not support operations with multiple error types");
220            }
221        };
222
223        let execute_and_output = match method.dropshot_paginated {
224            // Normal, one-shot API calls.
225            None => {
226                let success_output = match success_kind {
227                    crate::method::OperationResponseKind::Type(_) => {
228                        quote! {
229                            {
230                                self.config.success_item(&r);
231                                Ok(())
232                            }
233                        }
234                    }
235                    crate::method::OperationResponseKind::None => {
236                        quote! {
237                            {
238                                self.config.success_no_item(&r);
239                                Ok(())
240                            }
241                        }
242                    }
243                    crate::method::OperationResponseKind::Raw
244                    | crate::method::OperationResponseKind::Upgrade => {
245                        quote! {
246                            {
247                                todo!()
248                            }
249                        }
250                    }
251                };
252
253                let error_output = match error_kind {
254                    crate::method::OperationResponseKind::Type(_)
255                    | crate::method::OperationResponseKind::None => {
256                        quote! {
257                            {
258                                self.config.error(&r);
259                                Err(anyhow::Error::new(r))
260                            }
261                        }
262                    }
263                    crate::method::OperationResponseKind::Raw
264                    | crate::method::OperationResponseKind::Upgrade => {
265                        quote! {
266                            {
267                                todo!()
268                            }
269                        }
270                    }
271                };
272
273                quote! {
274                    let result = request.send().await;
275
276                    match result {
277                        Ok(r) => #success_output
278                        Err(r) => #error_output
279                    }
280                }
281            }
282
283            // Paginated APIs for which we iterate over each item.
284            Some(_) => {
285                let success_type = match success_kind {
286                    crate::method::OperationResponseKind::Type(type_id) => {
287                        self.type_space.get_type(&type_id).unwrap().ident()
288                    }
289                    crate::method::OperationResponseKind::None => quote! { () },
290                    crate::method::OperationResponseKind::Raw => todo!(),
291                    crate::method::OperationResponseKind::Upgrade => todo!(),
292                };
293                let error_output = match error_kind {
294                    crate::method::OperationResponseKind::Type(_)
295                    | crate::method::OperationResponseKind::None => {
296                        quote! {
297                            {
298                                self.config.list_end_error(&r);
299                                return Err(anyhow::Error::new(r))
300                            }
301                        }
302                    }
303                    crate::method::OperationResponseKind::Raw
304                    | crate::method::OperationResponseKind::Upgrade => {
305                        quote! {
306                            {
307                                todo!()
308                            }
309                        }
310                    }
311                };
312                quote! {
313                    self.config.list_start::<#success_type>();
314
315                    // We're using "limit" as both the maximum page size and
316                    // as the full limit. It's not ideal in that we could
317                    // reduce the limit with each iteration and we might get a
318                    // bunch of results we don't display... but it's fine.
319                    let mut stream = futures::StreamExt::take(
320                        request.stream(),
321                        matches
322                            .get_one::<std::num::NonZeroU32>("limit")
323                            .map_or(usize::MAX, |x| x.get() as usize));
324
325                    loop {
326                        match futures::TryStreamExt::try_next(&mut stream).await {
327                            Err(r) => #error_output
328                            Ok(None) => {
329                                self.config.list_end_success::<#success_type>();
330                                return Ok(());
331                            }
332                            Ok(Some(value)) => {
333                                self.config.list_item(&value);
334                            }
335                        }
336                    }
337                }
338            }
339        };
340
341        let execute_fn = quote! {
342            pub async fn #fn_name(&self, matches: &::clap::ArgMatches)
343                -> anyhow::Result<()>
344            {
345                let mut request = self.client.#op_name();
346                #consumer_args
347
348                // Call the override function.
349                self.config.#fn_name(matches, &mut request)?;
350
351                #execute_and_output
352            }
353        };
354
355        // TODO this is copy-pasted--unwisely?
356        let struct_name = sanitize(&method.operation_id, Case::Pascal);
357        let struct_ident = format_ident!("{}", struct_name);
358
359        let execute_trait = quote! {
360            fn #fn_name(
361                &self,
362                matches: &::clap::ArgMatches,
363                request: &mut builder :: #struct_ident,
364            ) -> anyhow::Result<()> {
365                Ok(())
366            }
367        };
368
369        CliOperation {
370            cli_fn,
371            execute_fn,
372            execute_trait,
373        }
374    }
375
376    fn cli_method_args(&self, method: &crate::method::OperationMethod) -> CliArg {
377        let mut args = CliOperationArgs::default();
378
379        let first_page_required_set = method
380            .dropshot_paginated
381            .as_ref()
382            .map(|d| &d.first_page_params);
383
384        for param in &method.params {
385            let innately_required = match &param.kind {
386                // We're not interetested in the body parameter yet.
387                OperationParameterKind::Body(_) => continue,
388
389                OperationParameterKind::Path => true,
390                OperationParameterKind::Query(required) => *required,
391                OperationParameterKind::Header(required) => *required,
392            };
393
394            // For paginated endpoints, we don't generate 'page_token' args.
395            if method.dropshot_paginated.is_some() && param.name.as_str() == "page_token" {
396                continue;
397            }
398
399            let first_page_required = first_page_required_set
400                .map_or(false, |required| required.contains(&param.api_name));
401
402            let volitionality = if innately_required || first_page_required {
403                Volitionality::Required
404            } else {
405                Volitionality::Optional
406            };
407
408            let OperationParameterType::Type(arg_type_id) = &param.typ else {
409                unreachable!("query and path parameters must be typed")
410            };
411            let arg_type = self.type_space.get_type(arg_type_id).unwrap();
412
413            let arg_name = param.name.to_kebab_case();
414
415            // There should be no conflicting path or query parameters.
416            assert!(!args.has_arg(&arg_name));
417
418            let parser = clap_arg(&arg_name, volitionality, &param.description, &arg_type);
419
420            let arg_fn_name = sanitize(&param.name, Case::Snake);
421            let arg_fn = format_ident!("{}", arg_fn_name);
422            let OperationParameterType::Type(arg_type_id) = &param.typ else {
423                panic!()
424            };
425            let arg_type = self.type_space.get_type(arg_type_id).unwrap();
426            let arg_type_name = arg_type.ident();
427
428            let consumer = quote! {
429                if let Some(value) =
430                    matches.get_one::<#arg_type_name>(#arg_name)
431                {
432                    // clone here in case the arg type doesn't impl
433                    // From<&T>
434                    request = request.#arg_fn(value.clone());
435                }
436            };
437
438            args.add_arg(arg_name, CliArg { parser, consumer })
439        }
440
441        let maybe_body_type_id = method
442            .params
443            .iter()
444            .find(|param| matches!(&param.kind, OperationParameterKind::Body(_)))
445            .and_then(|param| match &param.typ {
446                // TODO not sure how to deal with raw bodies, but we definitely
447                // need **some** input so we shouldn't just ignore it... as we
448                // are currently...
449                OperationParameterType::RawBody => None,
450
451                OperationParameterType::Type(body_type_id) => Some(body_type_id),
452            });
453
454        if let Some(body_type_id) = maybe_body_type_id {
455            args.body_present();
456            let body_type = self.type_space.get_type(body_type_id).unwrap();
457            let details = body_type.details();
458
459            match details {
460                typify::TypeDetails::Struct(struct_info) => {
461                    for prop_info in struct_info.properties_info() {
462                        self.cli_method_body_arg(&mut args, prop_info)
463                    }
464                }
465
466                _ => {
467                    // If the body is not a struct, we don't know what's
468                    // required or how to generate it
469                    args.body_required()
470                }
471            }
472        }
473
474        let parser_args = args.args.values().map(|CliArg { parser, .. }| parser);
475
476        // TODO do this as args we add in.
477        let body_json_args = (match args.body {
478            CliBodyArg::None => None,
479            CliBodyArg::Required => Some(true),
480            CliBodyArg::Optional => Some(false),
481        })
482        .map(|required| {
483            let help = "Path to a file that contains the full json body.";
484
485            quote! {
486                .arg(
487                    ::clap::Arg::new("json-body")
488                        .long("json-body")
489                        .value_name("JSON-FILE")
490                        // Required if we can't turn the body into individual
491                        // parameters.
492                        .required(#required)
493                        .value_parser(::clap::value_parser!(std::path::PathBuf))
494                        .help(#help)
495                )
496                .arg(
497                    ::clap::Arg::new("json-body-template")
498                        .long("json-body-template")
499                        .action(::clap::ArgAction::SetTrue)
500                        .help("XXX")
501                )
502            }
503        });
504
505        let parser = quote! {
506            #(
507                .arg(#parser_args)
508            )*
509            #body_json_args
510        };
511
512        let consumer_args = args.args.values().map(|CliArg { consumer, .. }| consumer);
513
514        let body_json_consumer = maybe_body_type_id.map(|body_type_id| {
515            let body_type = self.type_space.get_type(body_type_id).unwrap();
516            let body_type_ident = body_type.ident();
517            quote! {
518                if let Some(value) =
519                    matches.get_one::<std::path::PathBuf>("json-body")
520                {
521                    let body_txt = std::fs::read_to_string(value).unwrap();
522                    let body_value =
523                        serde_json::from_str::<#body_type_ident>(
524                            &body_txt,
525                        )
526                        .unwrap();
527                    request = request.body(body_value);
528                }
529            }
530        });
531
532        let consumer = quote! {
533            #(
534                #consumer_args
535            )*
536            #body_json_consumer
537        };
538
539        CliArg { parser, consumer }
540    }
541
542    fn cli_method_body_arg(&self, args: &mut CliOperationArgs, prop_info: TypeStructPropInfo<'_>) {
543        let TypeStructPropInfo {
544            name,
545            description,
546            required,
547            type_id,
548        } = prop_info;
549
550        let prop_type = self.type_space.get_type(&type_id).unwrap();
551
552        // TODO this is maybe a kludge--not completely sure of the right way to
553        // handle option types. On one hand, we could want types from this
554        // interface to never show us Option<T> types--we could let the
555        // `required` field give us that information. On the other hand, there
556        // might be Option types that are required ... at least in the JSON
557        // sense, meaning that we need to include `"foo": null` rather than
558        // omitting the field. Back to the first hand: is that last point just
559        // a serde issue rather than an interface one?
560        let maybe_inner_type =
561            if let typify::TypeDetails::Option(inner_type_id) = prop_type.details() {
562                let inner_type = self.type_space.get_type(&inner_type_id).unwrap();
563                Some(inner_type)
564            } else {
565                None
566            };
567
568        let prop_type = if let Some(inner_type) = maybe_inner_type {
569            inner_type
570        } else {
571            prop_type
572        };
573
574        let scalar = prop_type.has_impl(TypeSpaceImpl::FromStr);
575
576        let prop_name = name.to_kebab_case();
577        if scalar && !args.has_arg(&prop_name) {
578            let volitionality = if required {
579                Volitionality::RequiredIfNoBody
580            } else {
581                Volitionality::Optional
582            };
583            let parser = clap_arg(
584                &prop_name,
585                volitionality,
586                &description.map(str::to_string),
587                &prop_type,
588            );
589
590            let prop_fn = format_ident!("{}", sanitize(name, Case::Snake));
591            let prop_type_ident = prop_type.ident();
592            let consumer = quote! {
593                if let Some(value) =
594                    matches.get_one::<#prop_type_ident>(
595                        #prop_name,
596                    )
597                {
598                    // clone here in case the arg type
599                    // doesn't impl TryFrom<&T>
600                    request = request.body_map(|body| {
601                        body.#prop_fn(value.clone())
602                    })
603                }
604            };
605            args.add_arg(prop_name, CliArg { parser, consumer })
606        } else if required {
607            args.body_required()
608        }
609
610        // Cases
611        // 1. If the type can be represented as a string, great
612        //
613        // 2. If it's a substruct then we can try to glue the names together
614        // and hope?
615        //
616        // 3. enums
617        // 3.1 simple enums (should be covered by 1 above)
618        //   e.g. enum { A, B }
619        //   args for --a and --b that are in a group
620    }
621}
622
623enum Volitionality {
624    Optional,
625    Required,
626    RequiredIfNoBody,
627}
628
629fn clap_arg(
630    arg_name: &str,
631    volitionality: Volitionality,
632    description: &Option<String>,
633    arg_type: &Type,
634) -> TokenStream {
635    let help = description.as_ref().map(|description| {
636        quote! {
637            .help(#description)
638        }
639    });
640    let arg_type_name = arg_type.ident();
641
642    // For enums that have **only** simple variants, we do some slightly
643    // fancier argument handling to expose the possible values. In particular,
644    // we use clap's `PossibleValuesParser` with each variant converted to a
645    // string. Then we use TypedValueParser::map to translate that into the
646    // actual type of the enum.
647    let maybe_enum_parser = if let typify::TypeDetails::Enum(e) = arg_type.details() {
648        let maybe_var_names = e
649            .variants()
650            .map(|(var_name, var_details)| {
651                if let TypeEnumVariant::Simple = var_details {
652                    Some(format_ident!("{}", var_name))
653                } else {
654                    None
655                }
656            })
657            .collect::<Option<Vec<_>>>();
658
659        maybe_var_names.map(|var_names| {
660            quote! {
661                ::clap::builder::TypedValueParser::map(
662                    ::clap::builder::PossibleValuesParser::new([
663                        #( #arg_type_name :: #var_names.to_string(), )*
664                    ]),
665                    |s| #arg_type_name :: try_from(s).unwrap()
666                )
667            }
668        })
669    } else {
670        None
671    };
672
673    let value_parser = if let Some(enum_parser) = maybe_enum_parser {
674        enum_parser
675    } else {
676        // Let clap pick a value parser for us. This has the benefit of
677        // allowing for override implementations. A generated client may
678        // implement ValueParserFactory for a type to create a custom parser.
679        quote! {
680            ::clap::value_parser!(#arg_type_name)
681        }
682    };
683
684    let required = match volitionality {
685        Volitionality::Optional => quote! { .required(false) },
686        Volitionality::Required => quote! { .required(true) },
687        Volitionality::RequiredIfNoBody => {
688            quote! { .required_unless_present("json-body") }
689        }
690    };
691
692    quote! {
693        ::clap::Arg::new(#arg_name)
694            .long(#arg_name)
695            .value_parser(#value_parser)
696            #required
697            #help
698    }
699}
700
701#[derive(Debug)]
702struct CliArg {
703    /// Code to parse the argument
704    parser: TokenStream,
705
706    /// Code to consume the argument
707    consumer: TokenStream,
708}
709
710#[derive(Debug, Default, PartialEq, Eq)]
711enum CliBodyArg {
712    #[default]
713    None,
714    Required,
715    Optional,
716}
717
718#[derive(Default, Debug)]
719struct CliOperationArgs {
720    args: BTreeMap<String, CliArg>,
721    body: CliBodyArg,
722}
723
724impl CliOperationArgs {
725    fn has_arg(&self, name: &String) -> bool {
726        self.args.contains_key(name)
727    }
728    fn add_arg(&mut self, name: String, arg: CliArg) {
729        self.args.insert(name, arg);
730    }
731
732    fn body_present(&mut self) {
733        assert_eq!(self.body, CliBodyArg::None);
734        self.body = CliBodyArg::Optional;
735    }
736
737    fn body_required(&mut self) {
738        assert!(self.body == CliBodyArg::Optional || self.body == CliBodyArg::Required);
739        self.body = CliBodyArg::Required;
740    }
741}