progenitor_middleware_impl/
cli.rs1use std::collections::BTreeMap;
4
5use heck::ToKebabCase;
6use openapiv3::OpenAPI;
7use proc_macro2::TokenStream;
8use quote::{format_ident, quote};
9use typify::{Type, TypeEnumVariant, TypeSpaceImpl, TypeStructPropInfo};
10
11use crate::{
12 method::{OperationParameterKind, OperationParameterType, OperationResponseStatus},
13 to_schema::ToSchema,
14 util::{sanitize, Case},
15 validate_openapi, Generator, Result,
16};
17
18struct CliOperation {
19 cli_fn: TokenStream,
20 execute_fn: TokenStream,
21 execute_trait: TokenStream,
22}
23
24impl Generator {
25 pub fn cli(&mut self, spec: &OpenAPI, crate_name: &str) -> Result<TokenStream> {
27 validate_openapi(spec)?;
28
29 let schemas = spec.components.iter().flat_map(|components| {
31 components
32 .schemas
33 .iter()
34 .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
35 });
36
37 self.type_space.add_ref_types(schemas)?;
38
39 let raw_methods = spec
40 .paths
41 .iter()
42 .flat_map(|(path, ref_or_item)| {
43 let item = ref_or_item.as_item().unwrap();
45 item.iter().map(move |(method, operation)| {
46 (path.as_str(), method, operation, &item.parameters)
47 })
48 })
49 .map(|(path, method, operation, path_parameters)| {
50 self.process_operation(operation, &spec.components, path, method, path_parameters)
51 })
52 .collect::<Result<Vec<_>>>()?;
53
54 let methods = raw_methods
55 .iter()
56 .map(|method| self.cli_method(method))
57 .collect::<Vec<_>>();
58
59 let cli_ops = methods.iter().map(|op| &op.cli_fn);
60 let execute_ops = methods.iter().map(|op| &op.execute_fn);
61 let trait_ops = methods.iter().map(|op| &op.execute_trait);
62
63 let cli_fns = raw_methods
64 .iter()
65 .map(|method| format_ident!("cli_{}", sanitize(&method.operation_id, Case::Snake)))
66 .collect::<Vec<_>>();
67 let execute_fns = raw_methods
68 .iter()
69 .map(|method| format_ident!("execute_{}", sanitize(&method.operation_id, Case::Snake)))
70 .collect::<Vec<_>>();
71
72 let cli_variants = raw_methods
73 .iter()
74 .map(|method| format_ident!("{}", sanitize(&method.operation_id, Case::Pascal)))
75 .collect::<Vec<_>>();
76
77 let crate_path = syn::TypePath {
78 qself: None,
79 path: syn::parse_str(crate_name).unwrap(),
80 };
81
82 let code = quote! {
83 use #crate_path::*;
84
85 pub struct Cli<T: CliConfig> {
86 client: Client,
87 config: T,
88 }
89 impl<T: CliConfig> Cli<T> {
90 pub fn new(
91 client: Client,
92 config: T,
93 ) -> Self {
94 Self { client, config }
95 }
96
97 pub fn get_command(cmd: CliCommand) -> ::clap::Command {
98 match cmd {
99 #(
100 CliCommand::#cli_variants => Self::#cli_fns(),
101 )*
102 }
103 }
104
105 #(#cli_ops)*
106
107 pub async fn execute(
108 &self,
109 cmd: CliCommand,
110 matches: &::clap::ArgMatches,
111 ) -> anyhow::Result<()> {
112 match cmd {
113 #(
114 CliCommand::#cli_variants => {
115 self.#execute_fns(matches).await
117 }
118 )*
119 }
120 }
121
122 #(#execute_ops)*
123 }
124
125 pub trait CliConfig {
126 fn success_item<T>(&self, value: &ResponseValue<T>)
127 where
128 T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
129 fn success_no_item(&self, value: &ResponseValue<()>);
130 fn error<T>(&self, value: &Error<T>)
131 where
132 T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
133
134 fn list_start<T>(&self)
135 where
136 T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
137 fn list_item<T>(&self, value: &T)
138 where
139 T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
140 fn list_end_success<T>(&self)
141 where
142 T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
143 fn list_end_error<T>(&self, value: &Error<T>)
144 where
145 T: schemars::JsonSchema + serde::Serialize + std::fmt::Debug;
146
147 #(#trait_ops)*
148 }
149
150 #[derive(Copy, Clone, Debug)]
151 pub enum CliCommand {
152 #(#cli_variants,)*
153 }
154
155 impl CliCommand {
156 pub fn iter() -> impl Iterator<Item = CliCommand> {
157 vec![
158 #(
159 CliCommand::#cli_variants,
160 )*
161 ].into_iter()
162 }
163 }
164
165 };
166
167 Ok(code)
168 }
169
170 fn cli_method(&mut self, method: &crate::method::OperationMethod) -> CliOperation {
171 let CliArg {
172 parser: parser_args,
173 consumer: consumer_args,
174 } = self.cli_method_args(method);
175
176 let about = method.summary.as_ref().map(|summary| {
177 quote! {
178 .about(#summary)
179 }
180 });
181
182 let long_about = method.description.as_ref().map(|description| {
183 quote! {
184 .long_about(#description)
185 }
186 });
187
188 let fn_name = format_ident!("cli_{}", &method.operation_id);
189
190 let cli_fn = quote! {
191 pub fn #fn_name() -> ::clap::Command
192 {
193 ::clap::Command::new("")
194 #parser_args
195 #about
196 #long_about
197 }
198 };
199
200 let fn_name = format_ident!("execute_{}", &method.operation_id);
201 let op_name = format_ident!("{}", &method.operation_id);
202
203 let (_, success_response_type) =
204 self.extract_responses(method, OperationResponseStatus::is_success_or_default);
205 let (_, error_response_type) =
206 self.extract_responses(method, OperationResponseStatus::is_error_or_default);
207
208 let success_kind = match success_response_type {
211 crate::method::ErrorResponseType::Single(kind) => kind,
212 crate::method::ErrorResponseType::Multiple { .. } => {
213 panic!("CLI generation does not support operations with multiple success types");
214 }
215 };
216 let error_kind = match error_response_type {
217 crate::method::ErrorResponseType::Single(kind) => kind,
218 crate::method::ErrorResponseType::Multiple { .. } => {
219 panic!("CLI generation does not support operations with multiple error types");
220 }
221 };
222
223 let execute_and_output = match method.dropshot_paginated {
224 None => {
226 let success_output = match success_kind {
227 crate::method::OperationResponseKind::Type(_) => {
228 quote! {
229 {
230 self.config.success_item(&r);
231 Ok(())
232 }
233 }
234 }
235 crate::method::OperationResponseKind::None => {
236 quote! {
237 {
238 self.config.success_no_item(&r);
239 Ok(())
240 }
241 }
242 }
243 crate::method::OperationResponseKind::Raw
244 | crate::method::OperationResponseKind::Upgrade => {
245 quote! {
246 {
247 todo!()
248 }
249 }
250 }
251 };
252
253 let error_output = match error_kind {
254 crate::method::OperationResponseKind::Type(_)
255 | crate::method::OperationResponseKind::None => {
256 quote! {
257 {
258 self.config.error(&r);
259 Err(anyhow::Error::new(r))
260 }
261 }
262 }
263 crate::method::OperationResponseKind::Raw
264 | crate::method::OperationResponseKind::Upgrade => {
265 quote! {
266 {
267 todo!()
268 }
269 }
270 }
271 };
272
273 quote! {
274 let result = request.send().await;
275
276 match result {
277 Ok(r) => #success_output
278 Err(r) => #error_output
279 }
280 }
281 }
282
283 Some(_) => {
285 let success_type = match success_kind {
286 crate::method::OperationResponseKind::Type(type_id) => {
287 self.type_space.get_type(&type_id).unwrap().ident()
288 }
289 crate::method::OperationResponseKind::None => quote! { () },
290 crate::method::OperationResponseKind::Raw => todo!(),
291 crate::method::OperationResponseKind::Upgrade => todo!(),
292 };
293 let error_output = match error_kind {
294 crate::method::OperationResponseKind::Type(_)
295 | crate::method::OperationResponseKind::None => {
296 quote! {
297 {
298 self.config.list_end_error(&r);
299 return Err(anyhow::Error::new(r))
300 }
301 }
302 }
303 crate::method::OperationResponseKind::Raw
304 | crate::method::OperationResponseKind::Upgrade => {
305 quote! {
306 {
307 todo!()
308 }
309 }
310 }
311 };
312 quote! {
313 self.config.list_start::<#success_type>();
314
315 let mut stream = futures::StreamExt::take(
320 request.stream(),
321 matches
322 .get_one::<std::num::NonZeroU32>("limit")
323 .map_or(usize::MAX, |x| x.get() as usize));
324
325 loop {
326 match futures::TryStreamExt::try_next(&mut stream).await {
327 Err(r) => #error_output
328 Ok(None) => {
329 self.config.list_end_success::<#success_type>();
330 return Ok(());
331 }
332 Ok(Some(value)) => {
333 self.config.list_item(&value);
334 }
335 }
336 }
337 }
338 }
339 };
340
341 let execute_fn = quote! {
342 pub async fn #fn_name(&self, matches: &::clap::ArgMatches)
343 -> anyhow::Result<()>
344 {
345 let mut request = self.client.#op_name();
346 #consumer_args
347
348 self.config.#fn_name(matches, &mut request)?;
350
351 #execute_and_output
352 }
353 };
354
355 let struct_name = sanitize(&method.operation_id, Case::Pascal);
357 let struct_ident = format_ident!("{}", struct_name);
358
359 let execute_trait = quote! {
360 fn #fn_name(
361 &self,
362 matches: &::clap::ArgMatches,
363 request: &mut builder :: #struct_ident,
364 ) -> anyhow::Result<()> {
365 Ok(())
366 }
367 };
368
369 CliOperation {
370 cli_fn,
371 execute_fn,
372 execute_trait,
373 }
374 }
375
376 fn cli_method_args(&self, method: &crate::method::OperationMethod) -> CliArg {
377 let mut args = CliOperationArgs::default();
378
379 let first_page_required_set = method
380 .dropshot_paginated
381 .as_ref()
382 .map(|d| &d.first_page_params);
383
384 for param in &method.params {
385 let innately_required = match ¶m.kind {
386 OperationParameterKind::Body(_) => continue,
388
389 OperationParameterKind::Path => true,
390 OperationParameterKind::Query(required) => *required,
391 OperationParameterKind::Header(required) => *required,
392 };
393
394 if method.dropshot_paginated.is_some() && param.name.as_str() == "page_token" {
396 continue;
397 }
398
399 let first_page_required = first_page_required_set
400 .map_or(false, |required| required.contains(¶m.api_name));
401
402 let volitionality = if innately_required || first_page_required {
403 Volitionality::Required
404 } else {
405 Volitionality::Optional
406 };
407
408 let OperationParameterType::Type(arg_type_id) = ¶m.typ else {
409 unreachable!("query and path parameters must be typed")
410 };
411 let arg_type = self.type_space.get_type(arg_type_id).unwrap();
412
413 let arg_name = param.name.to_kebab_case();
414
415 assert!(!args.has_arg(&arg_name));
417
418 let parser = clap_arg(&arg_name, volitionality, ¶m.description, &arg_type);
419
420 let arg_fn_name = sanitize(¶m.name, Case::Snake);
421 let arg_fn = format_ident!("{}", arg_fn_name);
422 let OperationParameterType::Type(arg_type_id) = ¶m.typ else {
423 panic!()
424 };
425 let arg_type = self.type_space.get_type(arg_type_id).unwrap();
426 let arg_type_name = arg_type.ident();
427
428 let consumer = quote! {
429 if let Some(value) =
430 matches.get_one::<#arg_type_name>(#arg_name)
431 {
432 request = request.#arg_fn(value.clone());
435 }
436 };
437
438 args.add_arg(arg_name, CliArg { parser, consumer })
439 }
440
441 let maybe_body_type_id = method
442 .params
443 .iter()
444 .find(|param| matches!(¶m.kind, OperationParameterKind::Body(_)))
445 .and_then(|param| match ¶m.typ {
446 OperationParameterType::RawBody => None,
450
451 OperationParameterType::Type(body_type_id) => Some(body_type_id),
452 });
453
454 if let Some(body_type_id) = maybe_body_type_id {
455 args.body_present();
456 let body_type = self.type_space.get_type(body_type_id).unwrap();
457 let details = body_type.details();
458
459 match details {
460 typify::TypeDetails::Struct(struct_info) => {
461 for prop_info in struct_info.properties_info() {
462 self.cli_method_body_arg(&mut args, prop_info)
463 }
464 }
465
466 _ => {
467 args.body_required()
470 }
471 }
472 }
473
474 let parser_args = args.args.values().map(|CliArg { parser, .. }| parser);
475
476 let body_json_args = (match args.body {
478 CliBodyArg::None => None,
479 CliBodyArg::Required => Some(true),
480 CliBodyArg::Optional => Some(false),
481 })
482 .map(|required| {
483 let help = "Path to a file that contains the full json body.";
484
485 quote! {
486 .arg(
487 ::clap::Arg::new("json-body")
488 .long("json-body")
489 .value_name("JSON-FILE")
490 .required(#required)
493 .value_parser(::clap::value_parser!(std::path::PathBuf))
494 .help(#help)
495 )
496 .arg(
497 ::clap::Arg::new("json-body-template")
498 .long("json-body-template")
499 .action(::clap::ArgAction::SetTrue)
500 .help("XXX")
501 )
502 }
503 });
504
505 let parser = quote! {
506 #(
507 .arg(#parser_args)
508 )*
509 #body_json_args
510 };
511
512 let consumer_args = args.args.values().map(|CliArg { consumer, .. }| consumer);
513
514 let body_json_consumer = maybe_body_type_id.map(|body_type_id| {
515 let body_type = self.type_space.get_type(body_type_id).unwrap();
516 let body_type_ident = body_type.ident();
517 quote! {
518 if let Some(value) =
519 matches.get_one::<std::path::PathBuf>("json-body")
520 {
521 let body_txt = std::fs::read_to_string(value).unwrap();
522 let body_value =
523 serde_json::from_str::<#body_type_ident>(
524 &body_txt,
525 )
526 .unwrap();
527 request = request.body(body_value);
528 }
529 }
530 });
531
532 let consumer = quote! {
533 #(
534 #consumer_args
535 )*
536 #body_json_consumer
537 };
538
539 CliArg { parser, consumer }
540 }
541
542 fn cli_method_body_arg(&self, args: &mut CliOperationArgs, prop_info: TypeStructPropInfo<'_>) {
543 let TypeStructPropInfo {
544 name,
545 description,
546 required,
547 type_id,
548 } = prop_info;
549
550 let prop_type = self.type_space.get_type(&type_id).unwrap();
551
552 let maybe_inner_type =
561 if let typify::TypeDetails::Option(inner_type_id) = prop_type.details() {
562 let inner_type = self.type_space.get_type(&inner_type_id).unwrap();
563 Some(inner_type)
564 } else {
565 None
566 };
567
568 let prop_type = if let Some(inner_type) = maybe_inner_type {
569 inner_type
570 } else {
571 prop_type
572 };
573
574 let scalar = prop_type.has_impl(TypeSpaceImpl::FromStr);
575
576 let prop_name = name.to_kebab_case();
577 if scalar && !args.has_arg(&prop_name) {
578 let volitionality = if required {
579 Volitionality::RequiredIfNoBody
580 } else {
581 Volitionality::Optional
582 };
583 let parser = clap_arg(
584 &prop_name,
585 volitionality,
586 &description.map(str::to_string),
587 &prop_type,
588 );
589
590 let prop_fn = format_ident!("{}", sanitize(name, Case::Snake));
591 let prop_type_ident = prop_type.ident();
592 let consumer = quote! {
593 if let Some(value) =
594 matches.get_one::<#prop_type_ident>(
595 #prop_name,
596 )
597 {
598 request = request.body_map(|body| {
601 body.#prop_fn(value.clone())
602 })
603 }
604 };
605 args.add_arg(prop_name, CliArg { parser, consumer })
606 } else if required {
607 args.body_required()
608 }
609
610 }
621}
622
623enum Volitionality {
624 Optional,
625 Required,
626 RequiredIfNoBody,
627}
628
629fn clap_arg(
630 arg_name: &str,
631 volitionality: Volitionality,
632 description: &Option<String>,
633 arg_type: &Type,
634) -> TokenStream {
635 let help = description.as_ref().map(|description| {
636 quote! {
637 .help(#description)
638 }
639 });
640 let arg_type_name = arg_type.ident();
641
642 let maybe_enum_parser = if let typify::TypeDetails::Enum(e) = arg_type.details() {
648 let maybe_var_names = e
649 .variants()
650 .map(|(var_name, var_details)| {
651 if let TypeEnumVariant::Simple = var_details {
652 Some(format_ident!("{}", var_name))
653 } else {
654 None
655 }
656 })
657 .collect::<Option<Vec<_>>>();
658
659 maybe_var_names.map(|var_names| {
660 quote! {
661 ::clap::builder::TypedValueParser::map(
662 ::clap::builder::PossibleValuesParser::new([
663 #( #arg_type_name :: #var_names.to_string(), )*
664 ]),
665 |s| #arg_type_name :: try_from(s).unwrap()
666 )
667 }
668 })
669 } else {
670 None
671 };
672
673 let value_parser = if let Some(enum_parser) = maybe_enum_parser {
674 enum_parser
675 } else {
676 quote! {
680 ::clap::value_parser!(#arg_type_name)
681 }
682 };
683
684 let required = match volitionality {
685 Volitionality::Optional => quote! { .required(false) },
686 Volitionality::Required => quote! { .required(true) },
687 Volitionality::RequiredIfNoBody => {
688 quote! { .required_unless_present("json-body") }
689 }
690 };
691
692 quote! {
693 ::clap::Arg::new(#arg_name)
694 .long(#arg_name)
695 .value_parser(#value_parser)
696 #required
697 #help
698 }
699}
700
701#[derive(Debug)]
702struct CliArg {
703 parser: TokenStream,
705
706 consumer: TokenStream,
708}
709
710#[derive(Debug, Default, PartialEq, Eq)]
711enum CliBodyArg {
712 #[default]
713 None,
714 Required,
715 Optional,
716}
717
718#[derive(Default, Debug)]
719struct CliOperationArgs {
720 args: BTreeMap<String, CliArg>,
721 body: CliBodyArg,
722}
723
724impl CliOperationArgs {
725 fn has_arg(&self, name: &String) -> bool {
726 self.args.contains_key(name)
727 }
728 fn add_arg(&mut self, name: String, arg: CliArg) {
729 self.args.insert(name, arg);
730 }
731
732 fn body_present(&mut self) {
733 assert_eq!(self.body, CliBodyArg::None);
734 self.body = CliBodyArg::Optional;
735 }
736
737 fn body_required(&mut self) {
738 assert!(self.body == CliBodyArg::Optional || self.body == CliBodyArg::Required);
739 self.body = CliBodyArg::Required;
740 }
741}