1#![deny(missing_docs)]
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8
9use openapiv3::OpenAPI;
10use proc_macro2::TokenStream;
11use quote::quote;
12use serde::Deserialize;
13use thiserror::Error;
14use typify::{TypeSpace, TypeSpaceSettings};
15
16use crate::to_schema::ToSchema;
17
18pub use typify::CrateVers;
19pub use typify::TypeSpaceImpl as TypeImpl;
20pub use typify::TypeSpacePatch as TypePatch;
21pub use typify::UnknownPolicy;
22
23mod cli;
24mod httpmock;
25mod method;
26mod template;
27mod to_schema;
28mod util;
29
30#[allow(missing_docs)]
31#[derive(Error, Debug)]
32pub enum Error {
33 #[error("unexpected value type {0}: {1}")]
34 BadValue(String, serde_json::Value),
35 #[error("type error {0}")]
36 TypeError(#[from] typify::Error),
37 #[error("unexpected or unhandled format in the OpenAPI document {0}")]
38 UnexpectedFormat(String),
39 #[error("invalid operation path {0}")]
40 InvalidPath(String),
41 #[error("invalid dropshot extension use: {0}")]
42 InvalidExtension(String),
43 #[error("internal error {0}")]
44 InternalError(String),
45}
46
47#[allow(missing_docs)]
48pub type Result<T> = std::result::Result<T, Error>;
49
50pub struct Generator {
52 type_space: TypeSpace,
53 settings: GenerationSettings,
54 uses_futures: bool,
55 uses_websockets: bool,
56}
57
58#[derive(Default, Clone)]
60pub struct GenerationSettings {
61 interface: InterfaceStyle,
62 tag: TagStyle,
63 inner_type: Option<TokenStream>,
64 pre_hook: Option<TokenStream>,
65 pre_hook_async: Option<TokenStream>,
66 post_hook: Option<TokenStream>,
67 post_hook_async: Option<TokenStream>,
68 extra_derives: Vec<String>,
69 extra_cli_bounds: Vec<String>,
70
71 map_type: Option<String>,
72 unknown_crates: UnknownPolicy,
73 crates: BTreeMap<String, CrateSpec>,
74
75 patch: HashMap<String, TypePatch>,
76 replace: HashMap<String, (String, Vec<TypeImpl>)>,
77 convert: Vec<(schemars::schema::SchemaObject, String, Vec<TypeImpl>)>,
78 timeout: Option<u64>,
79}
80
81#[derive(Debug, Clone)]
82struct CrateSpec {
83 version: CrateVers,
84 rename: Option<String>,
85}
86
87#[derive(Clone, Deserialize, PartialEq, Eq)]
89pub enum InterfaceStyle {
90 Positional,
92 Builder,
94}
95
96impl Default for InterfaceStyle {
97 fn default() -> Self {
98 Self::Positional
99 }
100}
101
102#[derive(Clone, Deserialize)]
104pub enum TagStyle {
105 Merged,
107 Separate,
109}
110
111impl Default for TagStyle {
112 fn default() -> Self {
113 Self::Merged
114 }
115}
116
117impl GenerationSettings {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn with_interface(&mut self, interface: InterfaceStyle) -> &mut Self {
125 self.interface = interface;
126 self
127 }
128
129 pub fn with_tag(&mut self, tag: TagStyle) -> &mut Self {
131 self.tag = tag;
132 self
133 }
134
135 pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self {
137 self.inner_type = Some(inner_type);
138 self
139 }
140
141 pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self {
143 self.pre_hook = Some(pre_hook);
144 self
145 }
146
147 pub fn with_pre_hook_async(&mut self, pre_hook: TokenStream) -> &mut Self {
149 self.pre_hook_async = Some(pre_hook);
150 self
151 }
152
153 pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self {
155 self.post_hook = Some(post_hook);
156 self
157 }
158
159 pub fn with_post_hook_async(&mut self, post_hook: TokenStream) -> &mut Self {
161 self.post_hook_async = Some(post_hook);
162 self
163 }
164
165 pub fn with_derive(&mut self, derive: impl ToString) -> &mut Self {
167 self.extra_derives.push(derive.to_string());
168 self
169 }
170
171 pub fn with_cli_bounds(&mut self, derive: impl ToString) -> &mut Self {
173 self.extra_cli_bounds.push(derive.to_string());
174 self
175 }
176
177 pub fn with_patch<S: AsRef<str>>(&mut self, type_name: S, patch: &TypePatch) -> &mut Self {
180 self.patch
181 .insert(type_name.as_ref().to_string(), patch.clone());
182 self
183 }
184
185 pub fn with_replacement<TS: ToString, RS: ToString, I: Iterator<Item = TypeImpl>>(
188 &mut self,
189 type_name: TS,
190 replace_name: RS,
191 impls: I,
192 ) -> &mut Self {
193 self.replace.insert(
194 type_name.to_string(),
195 (replace_name.to_string(), impls.collect()),
196 );
197 self
198 }
199
200 pub fn with_conversion<S: ToString, I: Iterator<Item = TypeImpl>>(
203 &mut self,
204 schema: schemars::schema::SchemaObject,
205 type_name: S,
206 impls: I,
207 ) -> &mut Self {
208 self.convert
209 .push((schema, type_name.to_string(), impls.collect()));
210 self
211 }
212
213 pub fn with_unknown_crates(&mut self, policy: UnknownPolicy) -> &mut Self {
217 self.unknown_crates = policy;
218 self
219 }
220
221 pub fn with_crate<S1: ToString>(
226 &mut self,
227 crate_name: S1,
228 version: CrateVers,
229 rename: Option<&String>,
230 ) -> &mut Self {
231 self.crates.insert(
232 crate_name.to_string(),
233 CrateSpec {
234 version,
235 rename: rename.cloned(),
236 },
237 );
238 self
239 }
240
241 pub fn with_map_type<MT: ToString>(&mut self, map_type: MT) -> &mut Self {
249 self.map_type = Some(map_type.to_string());
250 self
251 }
252
253 pub fn with_timeout(&mut self, timeout: u64) -> &mut Self {
255 self.timeout = Some(timeout);
256 self
257 }
258}
259
260impl Default for Generator {
261 fn default() -> Self {
262 Self {
263 type_space: TypeSpace::new(TypeSpaceSettings::default().with_type_mod("types")),
264 settings: Default::default(),
265 uses_futures: Default::default(),
266 uses_websockets: Default::default(),
267 }
268 }
269}
270
271impl Generator {
272 pub fn new(settings: &GenerationSettings) -> Self {
274 let mut type_settings = TypeSpaceSettings::default();
275 type_settings
276 .with_type_mod("types")
277 .with_struct_builder(settings.interface == InterfaceStyle::Builder);
278 settings.extra_derives.iter().for_each(|derive| {
279 let _ = type_settings.with_derive(derive.clone());
280 });
281
282 type_settings.with_unknown_crates(settings.unknown_crates);
284 settings
285 .crates
286 .iter()
287 .for_each(|(crate_name, CrateSpec { version, rename })| {
288 type_settings.with_crate(crate_name, version.clone(), rename.as_ref());
289 });
290
291 settings.patch.iter().for_each(|(type_name, patch)| {
293 type_settings.with_patch(type_name, patch);
294 });
295 settings
296 .replace
297 .iter()
298 .for_each(|(type_name, (replace_name, impls))| {
299 type_settings.with_replacement(type_name, replace_name, impls.iter().cloned());
300 });
301 settings
302 .convert
303 .iter()
304 .for_each(|(schema, type_name, impls)| {
305 type_settings.with_conversion(schema.clone(), type_name, impls.iter().cloned());
306 });
307
308 if let Some(map_type) = &settings.map_type {
310 type_settings.with_map_type(map_type.clone());
311 }
312
313 Self {
314 type_space: TypeSpace::new(&type_settings),
315 settings: settings.clone(),
316 uses_futures: false,
317 uses_websockets: false,
318 }
319 }
320
321 pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result<TokenStream> {
323 validate_openapi(spec)?;
324
325 let schemas = spec.components.iter().flat_map(|components| {
327 components
328 .schemas
329 .iter()
330 .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
331 });
332
333 self.type_space.add_ref_types(schemas)?;
334
335 let raw_methods = spec
336 .paths
337 .iter()
338 .flat_map(|(path, ref_or_item)| {
339 let item = ref_or_item.as_item().unwrap();
341 item.iter().map(move |(method, operation)| {
342 (path.as_str(), method, operation, &item.parameters)
343 })
344 })
345 .map(|(path, method, operation, path_parameters)| {
346 self.process_operation(operation, &spec.components, path, method, path_parameters)
347 })
348 .collect::<Result<Vec<_>>>()?;
349
350 let operation_code = match (&self.settings.interface, &self.settings.tag) {
351 (InterfaceStyle::Positional, TagStyle::Merged) => self
352 .generate_tokens_positional_merged(
353 &raw_methods,
354 self.settings.inner_type.is_some(),
355 ),
356 (InterfaceStyle::Positional, TagStyle::Separate) => {
357 unimplemented!("positional arguments with separate tags are currently unsupported")
358 }
359 (InterfaceStyle::Builder, TagStyle::Merged) => self
360 .generate_tokens_builder_merged(&raw_methods, self.settings.inner_type.is_some()),
361 (InterfaceStyle::Builder, TagStyle::Separate) => {
362 let tag_info = spec
363 .tags
364 .iter()
365 .map(|tag| (&tag.name, tag))
366 .collect::<BTreeMap<_, _>>();
367 self.generate_tokens_builder_separate(
368 &raw_methods,
369 tag_info,
370 self.settings.inner_type.is_some(),
371 )
372 }
373 }?;
374
375 let types = self.type_space.to_stream();
376
377 let (inner_type, inner_fn_value) = match self.settings.inner_type.as_ref() {
378 Some(inner_type) => (inner_type.clone(), quote! { &self.inner }),
379 None => (quote! { () }, quote! { &() }),
380 };
381
382 let inner_property = self.settings.inner_type.as_ref().map(|inner| {
383 quote! {
384 pub (crate) inner: #inner,
385 }
386 });
387 let inner_parameter = self.settings.inner_type.as_ref().map(|inner| {
388 quote! {
389 inner: #inner,
390 }
391 });
392 let inner_value = self.settings.inner_type.as_ref().map(|_| {
393 quote! {
394 inner
395 }
396 });
397 let client_timeout = self.settings.timeout.unwrap_or(15);
398
399 let client_docstring = {
400 let mut s = format!("Client for {}", spec.info.title);
401
402 if let Some(ss) = &spec.info.description {
403 s.push_str("\n\n");
404 s.push_str(ss);
405 }
406 if let Some(ss) = &spec.info.terms_of_service {
407 s.push_str("\n\n");
408 s.push_str(ss);
409 }
410
411 s.push_str(&format!("\n\nVersion: {}", &spec.info.version));
412
413 s
414 };
415
416 let version_str = &spec.info.version;
417
418 let file = quote! {
423 #[allow(unused_imports)]
425 pub use progenitor_client::{
426 ByteStream,
427 ClientInfo,
428 Error,
429 ResponseValue,
430 };
431 #[allow(unused_imports)]
432 use progenitor_client::{
433 encode_path,
434 ClientHooks,
435 OperationInfo,
436 RequestBuilderExt,
437 };
438
439 #[allow(clippy::all)]
441 pub mod types {
442 #types
443 }
444
445 #[derive(Clone, Debug)]
446 #[doc = #client_docstring]
447 pub struct Client {
448 pub(crate) baseurl: String,
449 pub(crate) client: reqwest::Client,
450 #inner_property
451 }
452
453 impl Client {
454 pub fn new(
460 baseurl: &str,
461 #inner_parameter
462 ) -> Self {
463 #[cfg(not(target_arch = "wasm32"))]
464 let client = {
465 let dur = ::std::time::Duration::from_secs(#client_timeout);
466
467 reqwest::ClientBuilder::new()
468 .connect_timeout(dur)
469 .timeout(dur)
470 };
471
472 #[cfg(target_arch = "wasm32")]
473 let client = reqwest::ClientBuilder::new();
474
475 Self::new_with_client(baseurl, client.build().unwrap(), #inner_value)
476 }
477
478 pub fn new_with_client(
485 baseurl: &str,
486 client: reqwest::Client,
487 #inner_parameter
488 ) -> Self {
489 Self {
490 baseurl: baseurl.to_string(),
491 client,
492 #inner_value
493 }
494 }
495 }
496
497 impl ClientInfo<#inner_type> for Client {
498 fn api_version() -> &'static str {
499 #version_str
500 }
501
502 fn baseurl(&self) -> &str {
503 self.baseurl.as_str()
504 }
505
506 fn client(&self) -> &reqwest::Client {
507 &self.client
508 }
509
510 fn inner(&self) -> &#inner_type {
511 #inner_fn_value
512 }
513 }
514
515 impl ClientHooks<#inner_type> for &Client {}
516
517 #operation_code
518 };
519
520 Ok(file)
521 }
522
523 fn generate_tokens_positional_merged(
524 &mut self,
525 input_methods: &[method::OperationMethod],
526 has_inner: bool,
527 ) -> Result<TokenStream> {
528 let methods = input_methods
529 .iter()
530 .map(|method| self.positional_method(method, has_inner))
531 .collect::<Result<Vec<_>>>()?;
532
533 let out = quote! {
538 #[allow(clippy::all)]
539 impl Client {
540 #(#methods)*
541 }
542
543 pub mod prelude {
545 #[allow(unused_imports)]
546 pub use super::Client;
547 }
548 };
549 Ok(out)
550 }
551
552 fn generate_tokens_builder_merged(
553 &mut self,
554 input_methods: &[method::OperationMethod],
555 has_inner: bool,
556 ) -> Result<TokenStream> {
557 let builder_struct = input_methods
558 .iter()
559 .map(|method| self.builder_struct(method, TagStyle::Merged, has_inner))
560 .collect::<Result<Vec<_>>>()?;
561
562 let builder_methods = input_methods
563 .iter()
564 .map(|method| self.builder_impl(method))
565 .collect::<Vec<_>>();
566
567 let out = quote! {
568 impl Client {
569 #(#builder_methods)*
570 }
571
572 #[allow(clippy::all)]
574 pub mod builder {
575 use super::types;
576 #[allow(unused_imports)]
577 use super::{
578 encode_path,
579 ByteStream,
580 ClientInfo,
581 ClientHooks,
582 Error,
583 OperationInfo,
584 RequestBuilderExt,
585 ResponseValue,
586 };
587
588 #(#builder_struct)*
589 }
590
591 pub mod prelude {
593 pub use self::super::Client;
594 }
595 };
596
597 Ok(out)
598 }
599
600 fn generate_tokens_builder_separate(
601 &mut self,
602 input_methods: &[method::OperationMethod],
603 tag_info: BTreeMap<&String, &openapiv3::Tag>,
604 has_inner: bool,
605 ) -> Result<TokenStream> {
606 let builder_struct = input_methods
607 .iter()
608 .map(|method| self.builder_struct(method, TagStyle::Separate, has_inner))
609 .collect::<Result<Vec<_>>>()?;
610
611 let (traits_and_impls, trait_preludes) = self.builder_tags(input_methods, &tag_info);
612
613 let out = quote! {
618 #traits_and_impls
619
620 #[allow(clippy::all)]
622 pub mod builder {
623 use super::types;
624 #[allow(unused_imports)]
625 use super::{
626 encode_path,
627 ByteStream,
628 ClientInfo,
629 ClientHooks,
630 Error,
631 OperationInfo,
632 RequestBuilderExt,
633 ResponseValue,
634 };
635
636 #(#builder_struct)*
637 }
638
639 pub mod prelude {
642 #[allow(unused_imports)]
643 pub use super::Client;
644 #trait_preludes
645 }
646 };
647
648 Ok(out)
649 }
650
651 pub fn get_type_space(&self) -> &TypeSpace {
653 &self.type_space
654 }
655
656 pub fn uses_futures(&self) -> bool {
659 self.uses_futures
660 }
661
662 pub fn uses_websockets(&self) -> bool {
665 self.uses_websockets
666 }
667}
668
669pub fn space_out_items(content: String) -> Result<String> {
671 Ok(if cfg!(not(windows)) {
672 let regex = regex::Regex::new(r#"(\n\s*})(\n\s{0,8}[^} ])"#).unwrap();
673 regex.replace_all(&content, "$1\n$2").to_string()
674 } else {
675 let regex = regex::Regex::new(r#"(\n\s*})(\r\n\s{0,8}[^} ])"#).unwrap();
676 regex.replace_all(&content, "$1\r\n$2").to_string()
677 })
678}
679
680fn validate_openapi_spec_version(spec_version: &str) -> Result<()> {
681 if spec_version.trim().starts_with("3.0.") {
683 Ok(())
684 } else {
685 Err(Error::UnexpectedFormat(format!(
686 "invalid version: {}",
687 spec_version
688 )))
689 }
690}
691
692pub fn validate_openapi(spec: &OpenAPI) -> Result<()> {
694 validate_openapi_spec_version(spec.openapi.as_str())?;
695
696 let mut opids = HashSet::new();
697 spec.paths.paths.iter().try_for_each(|p| {
698 match p.1 {
699 openapiv3::ReferenceOr::Reference { reference: _ } => Err(Error::UnexpectedFormat(
700 format!("path {} uses reference, unsupported", p.0,),
701 )),
702 openapiv3::ReferenceOr::Item(item) => {
703 item.iter().try_for_each(|(_, o)| {
706 if let Some(oid) = o.operation_id.as_ref() {
707 if !opids.insert(oid.to_string()) {
708 return Err(Error::UnexpectedFormat(format!(
709 "duplicate operation ID: {}",
710 oid,
711 )));
712 }
713 } else {
714 return Err(Error::UnexpectedFormat(format!(
715 "path {} is missing operation ID",
716 p.0,
717 )));
718 }
719 Ok(())
720 })
721 }
722 }
723 })?;
724
725 Ok(())
726}
727
728#[cfg(test)]
729mod tests {
730 use serde_json::json;
731
732 use crate::{validate_openapi_spec_version, Error};
733
734 #[test]
735 fn test_bad_value() {
736 assert_eq!(
737 Error::BadValue("nope".to_string(), json! { "nope"},).to_string(),
738 "unexpected value type nope: \"nope\"",
739 );
740 }
741
742 #[test]
743 fn test_type_error() {
744 assert_eq!(
745 Error::UnexpectedFormat("nope".to_string()).to_string(),
746 "unexpected or unhandled format in the OpenAPI document nope",
747 );
748 }
749
750 #[test]
751 fn test_invalid_path() {
752 assert_eq!(
753 Error::InvalidPath("nope".to_string()).to_string(),
754 "invalid operation path nope",
755 );
756 }
757
758 #[test]
759 fn test_internal_error() {
760 assert_eq!(
761 Error::InternalError("nope".to_string()).to_string(),
762 "internal error nope",
763 );
764 }
765
766 #[test]
767 fn test_validate_openapi_spec_version() {
768 assert!(validate_openapi_spec_version("3.0.0").is_ok());
769 assert!(validate_openapi_spec_version("3.0.1").is_ok());
770 assert!(validate_openapi_spec_version("3.0.4").is_ok());
771 assert!(validate_openapi_spec_version("3.0.5-draft").is_ok());
772 assert_eq!(
773 validate_openapi_spec_version("3.1.0")
774 .unwrap_err()
775 .to_string(),
776 "unexpected or unhandled format in the OpenAPI document invalid version: 3.1.0"
777 );
778 }
779}