1#![deny(missing_docs)]
6
7use std::collections::{BTreeMap, HashMap, HashSet};
8
9use openapiv3::OpenAPI;
10use proc_macro2::TokenStream;
11use quote::quote;
12use serde::Deserialize;
13use thiserror::Error;
14use typify::{TypeSpace, TypeSpaceSettings};
15
16use crate::to_schema::ToSchema;
17
18pub use typify::CrateVers;
19pub use typify::TypeSpaceImpl as TypeImpl;
20pub use typify::TypeSpacePatch as TypePatch;
21pub use typify::UnknownPolicy;
22
23mod cli;
24mod httpmock;
25mod method;
26mod template;
27mod to_schema;
28mod util;
29
30#[allow(missing_docs)]
31#[derive(Error, Debug)]
32pub enum Error {
33 #[error("unexpected value type {0}: {1}")]
34 BadValue(String, serde_json::Value),
35 #[error("type error {0}")]
36 TypeError(#[from] typify::Error),
37 #[error("unexpected or unhandled format in the OpenAPI document {0}")]
38 UnexpectedFormat(String),
39 #[error("invalid operation path {0}")]
40 InvalidPath(String),
41 #[error("invalid dropshot extension use: {0}")]
42 InvalidExtension(String),
43 #[error("internal error {0}")]
44 InternalError(String),
45}
46
47#[allow(missing_docs)]
48pub type Result<T> = std::result::Result<T, Error>;
49
50pub struct Generator {
52 type_space: TypeSpace,
53 settings: GenerationSettings,
54 uses_futures: bool,
55 uses_websockets: bool,
56}
57
58#[derive(Default, Clone)]
60pub struct GenerationSettings {
61 interface: InterfaceStyle,
62 tag: TagStyle,
63 inner_type: Option<TokenStream>,
64 pre_hook: Option<TokenStream>,
65 pre_hook_async: Option<TokenStream>,
66 post_hook: Option<TokenStream>,
67 post_hook_async: Option<TokenStream>,
68 extra_derives: Vec<String>,
69
70 map_type: Option<String>,
71 unknown_crates: UnknownPolicy,
72 crates: BTreeMap<String, CrateSpec>,
73
74 patch: HashMap<String, TypePatch>,
75 replace: HashMap<String, (String, Vec<TypeImpl>)>,
76 convert: Vec<(schemars::schema::SchemaObject, String, Vec<TypeImpl>)>,
77}
78
79#[derive(Debug, Clone)]
80struct CrateSpec {
81 version: CrateVers,
82 rename: Option<String>,
83}
84
85#[derive(Clone, Deserialize, PartialEq, Eq)]
87pub enum InterfaceStyle {
88 Positional,
90 Builder,
92}
93
94impl Default for InterfaceStyle {
95 fn default() -> Self {
96 Self::Positional
97 }
98}
99
100#[derive(Clone, Deserialize)]
102pub enum TagStyle {
103 Merged,
105 Separate,
107}
108
109impl Default for TagStyle {
110 fn default() -> Self {
111 Self::Merged
112 }
113}
114
115impl GenerationSettings {
116 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn with_interface(&mut self, interface: InterfaceStyle) -> &mut Self {
123 self.interface = interface;
124 self
125 }
126
127 pub fn with_tag(&mut self, tag: TagStyle) -> &mut Self {
129 self.tag = tag;
130 self
131 }
132
133 pub fn with_inner_type(&mut self, inner_type: TokenStream) -> &mut Self {
135 self.inner_type = Some(inner_type);
136 self
137 }
138
139 pub fn with_pre_hook(&mut self, pre_hook: TokenStream) -> &mut Self {
141 self.pre_hook = Some(pre_hook);
142 self
143 }
144
145 pub fn with_pre_hook_async(&mut self, pre_hook: TokenStream) -> &mut Self {
147 self.pre_hook_async = Some(pre_hook);
148 self
149 }
150
151 pub fn with_post_hook(&mut self, post_hook: TokenStream) -> &mut Self {
153 self.post_hook = Some(post_hook);
154 self
155 }
156
157 pub fn with_post_hook_async(&mut self, post_hook: TokenStream) -> &mut Self {
159 self.post_hook_async = Some(post_hook);
160 self
161 }
162
163 pub fn with_derive(&mut self, derive: impl ToString) -> &mut Self {
165 self.extra_derives.push(derive.to_string());
166 self
167 }
168
169 pub fn with_patch<S: AsRef<str>>(&mut self, type_name: S, patch: &TypePatch) -> &mut Self {
172 self.patch
173 .insert(type_name.as_ref().to_string(), patch.clone());
174 self
175 }
176
177 pub fn with_replacement<TS: ToString, RS: ToString, I: Iterator<Item = TypeImpl>>(
180 &mut self,
181 type_name: TS,
182 replace_name: RS,
183 impls: I,
184 ) -> &mut Self {
185 self.replace.insert(
186 type_name.to_string(),
187 (replace_name.to_string(), impls.collect()),
188 );
189 self
190 }
191
192 pub fn with_conversion<S: ToString, I: Iterator<Item = TypeImpl>>(
195 &mut self,
196 schema: schemars::schema::SchemaObject,
197 type_name: S,
198 impls: I,
199 ) -> &mut Self {
200 self.convert
201 .push((schema, type_name.to_string(), impls.collect()));
202 self
203 }
204
205 pub fn with_unknown_crates(&mut self, policy: UnknownPolicy) -> &mut Self {
209 self.unknown_crates = policy;
210 self
211 }
212
213 pub fn with_crate<S1: ToString>(
218 &mut self,
219 crate_name: S1,
220 version: CrateVers,
221 rename: Option<&String>,
222 ) -> &mut Self {
223 self.crates.insert(
224 crate_name.to_string(),
225 CrateSpec {
226 version,
227 rename: rename.cloned(),
228 },
229 );
230 self
231 }
232
233 pub fn with_map_type<MT: ToString>(&mut self, map_type: MT) -> &mut Self {
241 self.map_type = Some(map_type.to_string());
242 self
243 }
244}
245
246impl Default for Generator {
247 fn default() -> Self {
248 Self {
249 type_space: TypeSpace::new(TypeSpaceSettings::default().with_type_mod("types")),
250 settings: Default::default(),
251 uses_futures: Default::default(),
252 uses_websockets: Default::default(),
253 }
254 }
255}
256
257impl Generator {
258 pub fn new(settings: &GenerationSettings) -> Self {
260 let mut type_settings = TypeSpaceSettings::default();
261 type_settings
262 .with_type_mod("types")
263 .with_struct_builder(settings.interface == InterfaceStyle::Builder);
264 settings.extra_derives.iter().for_each(|derive| {
265 let _ = type_settings.with_derive(derive.clone());
266 });
267
268 type_settings.with_unknown_crates(settings.unknown_crates);
270 settings
271 .crates
272 .iter()
273 .for_each(|(crate_name, CrateSpec { version, rename })| {
274 type_settings.with_crate(crate_name, version.clone(), rename.as_ref());
275 });
276
277 settings.patch.iter().for_each(|(type_name, patch)| {
279 type_settings.with_patch(type_name, patch);
280 });
281 settings
282 .replace
283 .iter()
284 .for_each(|(type_name, (replace_name, impls))| {
285 type_settings.with_replacement(type_name, replace_name, impls.iter().cloned());
286 });
287 settings
288 .convert
289 .iter()
290 .for_each(|(schema, type_name, impls)| {
291 type_settings.with_conversion(schema.clone(), type_name, impls.iter().cloned());
292 });
293
294 if let Some(map_type) = &settings.map_type {
296 type_settings.with_map_type(map_type.clone());
297 }
298
299 Self {
300 type_space: TypeSpace::new(&type_settings),
301 settings: settings.clone(),
302 uses_futures: false,
303 uses_websockets: false,
304 }
305 }
306
307 pub fn generate_tokens(&mut self, spec: &OpenAPI) -> Result<TokenStream> {
309 validate_openapi(spec)?;
310
311 let schemas = spec.components.iter().flat_map(|components| {
313 components
314 .schemas
315 .iter()
316 .map(|(name, ref_or_schema)| (name.clone(), ref_or_schema.to_schema()))
317 });
318
319 self.type_space.add_ref_types(schemas)?;
320
321 let raw_methods = spec
322 .paths
323 .iter()
324 .flat_map(|(path, ref_or_item)| {
325 let item = ref_or_item.as_item().unwrap();
327 item.iter().map(move |(method, operation)| {
328 (path.as_str(), method, operation, &item.parameters)
329 })
330 })
331 .map(|(path, method, operation, path_parameters)| {
332 self.process_operation(operation, &spec.components, path, method, path_parameters)
333 })
334 .collect::<Result<Vec<_>>>()?;
335
336 let operation_code = match (&self.settings.interface, &self.settings.tag) {
337 (InterfaceStyle::Positional, TagStyle::Merged) => self
338 .generate_tokens_positional_merged(
339 &raw_methods,
340 self.settings.inner_type.is_some(),
341 ),
342 (InterfaceStyle::Positional, TagStyle::Separate) => {
343 unimplemented!("positional arguments with separate tags are currently unsupported")
344 }
345 (InterfaceStyle::Builder, TagStyle::Merged) => self
346 .generate_tokens_builder_merged(&raw_methods, self.settings.inner_type.is_some()),
347 (InterfaceStyle::Builder, TagStyle::Separate) => {
348 let tag_info = spec
349 .tags
350 .iter()
351 .map(|tag| (&tag.name, tag))
352 .collect::<BTreeMap<_, _>>();
353 self.generate_tokens_builder_separate(
354 &raw_methods,
355 tag_info,
356 self.settings.inner_type.is_some(),
357 )
358 }
359 }?;
360
361 let types = self.type_space.to_stream();
362
363 let (inner_type, inner_fn_value) = match self.settings.inner_type.as_ref() {
364 Some(inner_type) => (inner_type.clone(), quote! { &self.inner }),
365 None => (quote! { () }, quote! { &() }),
366 };
367
368 let inner_property = self.settings.inner_type.as_ref().map(|inner| {
369 quote! {
370 pub (crate) inner: #inner,
371 }
372 });
373 let inner_parameter = self.settings.inner_type.as_ref().map(|inner| {
374 quote! {
375 inner: #inner,
376 }
377 });
378 let inner_value = self.settings.inner_type.as_ref().map(|_| {
379 quote! {
380 inner
381 }
382 });
383
384 let client_docstring = {
385 let mut s = format!("Client for {}", spec.info.title);
386
387 if let Some(ss) = &spec.info.description {
388 s.push_str("\n\n");
389 s.push_str(ss);
390 }
391 if let Some(ss) = &spec.info.terms_of_service {
392 s.push_str("\n\n");
393 s.push_str(ss);
394 }
395
396 s.push_str(&format!("\n\nVersion: {}", &spec.info.version));
397
398 s
399 };
400
401 let version_str = &spec.info.version;
402
403 let file = quote! {
408 #[allow(unused_imports)]
410 pub use progenitor_middleware_client::{
411 ByteStream,
412 ClientInfo,
413 Error,
414 ResponseValue,
415 };
416 #[allow(unused_imports)]
417 use progenitor_middleware_client::{
418 encode_path,
419 ClientHooks,
420 OperationInfo,
421 RequestBuilderExt,
422 };
423
424 #[allow(clippy::all)]
426 pub mod types {
427 #types
428 }
429
430 #[derive(Clone, Debug)]
431 #[doc = #client_docstring]
432 pub struct Client {
433 pub(crate) baseurl: String,
434 pub(crate) client: reqwest_middleware::ClientWithMiddleware,
435 #inner_property
436 }
437
438 impl Client {
439 pub fn new(
445 baseurl: &str,
446 #inner_parameter
447 ) -> Self {
448 #[cfg(not(target_arch = "wasm32"))]
449 let client = {
450 let dur = std::time::Duration::from_secs(15);
451
452 let reqwest_client = reqwest::ClientBuilder::new()
453 .connect_timeout(dur)
454 .timeout(dur)
455 .build()
456 .unwrap();
457
458 reqwest_middleware::ClientBuilder::new(reqwest_client)
459 .build()
460 };
461 #[cfg(target_arch = "wasm32")]
462 let client = {
463 let reqwest_client = reqwest::ClientBuilder::new()
464 .build()
465 .unwrap();
466
467 reqwest_middleware::ClientBuilder::new(reqwest_client)
468 .build()
469 };
470
471 Self::new_with_client(baseurl, client, #inner_value)
472 }
473
474 pub fn new_with_client(
481 baseurl: &str,
482 client: reqwest_middleware::ClientWithMiddleware,
483 #inner_parameter
484 ) -> Self {
485 Self {
486 baseurl: baseurl.to_string(),
487 client,
488 #inner_value
489 }
490 }
491 }
492
493 impl ClientInfo<#inner_type> for Client {
494 fn api_version() -> &'static str {
495 #version_str
496 }
497
498 fn baseurl(&self) -> &str {
499 self.baseurl.as_str()
500 }
501
502 fn client(&self) -> &reqwest_middleware::ClientWithMiddleware {
503 &self.client
504 }
505
506 fn inner(&self) -> &#inner_type {
507 #inner_fn_value
508 }
509 }
510
511 impl ClientHooks<#inner_type> for &Client {}
512
513 #operation_code
514 };
515
516 Ok(file)
517 }
518
519 fn generate_tokens_positional_merged(
520 &mut self,
521 input_methods: &[method::OperationMethod],
522 has_inner: bool,
523 ) -> Result<TokenStream> {
524 let methods = input_methods
525 .iter()
526 .map(|method| self.positional_method(method, has_inner))
527 .collect::<Result<Vec<_>>>()?;
528
529 let out = quote! {
534 #[allow(clippy::all)]
535 impl Client {
536 #(#methods)*
537 }
538
539 pub mod prelude {
541 #[allow(unused_imports)]
542 pub use super::Client;
543 }
544 };
545 Ok(out)
546 }
547
548 fn generate_tokens_builder_merged(
549 &mut self,
550 input_methods: &[method::OperationMethod],
551 has_inner: bool,
552 ) -> Result<TokenStream> {
553 let builder_struct = input_methods
554 .iter()
555 .map(|method| self.builder_struct(method, TagStyle::Merged, has_inner))
556 .collect::<Result<Vec<_>>>()?;
557
558 let builder_methods = input_methods
559 .iter()
560 .map(|method| self.builder_impl(method))
561 .collect::<Vec<_>>();
562
563 let out = quote! {
564 impl Client {
565 #(#builder_methods)*
566 }
567
568 #[allow(clippy::all)]
570 pub mod builder {
571 use super::types;
572 #[allow(unused_imports)]
573 use super::{
574 encode_path,
575 ByteStream,
576 ClientInfo,
577 ClientHooks,
578 Error,
579 OperationInfo,
580 RequestBuilderExt,
581 ResponseValue,
582 };
583
584 #(#builder_struct)*
585 }
586
587 pub mod prelude {
589 pub use self::super::Client;
590 }
591 };
592
593 Ok(out)
594 }
595
596 fn generate_tokens_builder_separate(
597 &mut self,
598 input_methods: &[method::OperationMethod],
599 tag_info: BTreeMap<&String, &openapiv3::Tag>,
600 has_inner: bool,
601 ) -> Result<TokenStream> {
602 let builder_struct = input_methods
603 .iter()
604 .map(|method| self.builder_struct(method, TagStyle::Separate, has_inner))
605 .collect::<Result<Vec<_>>>()?;
606
607 let (traits_and_impls, trait_preludes) = self.builder_tags(input_methods, &tag_info);
608
609 let out = quote! {
614 #traits_and_impls
615
616 #[allow(clippy::all)]
618 pub mod builder {
619 use super::types;
620 #[allow(unused_imports)]
621 use super::{
622 encode_path,
623 ByteStream,
624 ClientInfo,
625 ClientHooks,
626 Error,
627 OperationInfo,
628 RequestBuilderExt,
629 ResponseValue,
630 };
631
632 #(#builder_struct)*
633 }
634
635 pub mod prelude {
638 #[allow(unused_imports)]
639 pub use super::Client;
640 #trait_preludes
641 }
642 };
643
644 Ok(out)
645 }
646
647 pub fn get_type_space(&self) -> &TypeSpace {
649 &self.type_space
650 }
651
652 pub fn uses_futures(&self) -> bool {
655 self.uses_futures
656 }
657
658 pub fn uses_websockets(&self) -> bool {
661 self.uses_websockets
662 }
663}
664
665pub fn space_out_items(content: String) -> Result<String> {
667 Ok(if cfg!(not(windows)) {
668 let regex = regex::Regex::new(r#"(\n\s*})(\n\s{0,8}[^} ])"#).unwrap();
669 regex.replace_all(&content, "$1\n$2").to_string()
670 } else {
671 let regex = regex::Regex::new(r#"(\n\s*})(\r\n\s{0,8}[^} ])"#).unwrap();
672 regex.replace_all(&content, "$1\r\n$2").to_string()
673 })
674}
675
676fn validate_openapi_spec_version(spec_version: &str) -> Result<()> {
677 if spec_version.trim().starts_with("3.0.") {
679 Ok(())
680 } else {
681 Err(Error::UnexpectedFormat(format!(
682 "invalid version: {}",
683 spec_version
684 )))
685 }
686}
687
688pub fn validate_openapi(spec: &OpenAPI) -> Result<()> {
690 validate_openapi_spec_version(spec.openapi.as_str())?;
691
692 let mut opids = HashSet::new();
693 spec.paths.paths.iter().try_for_each(|p| {
694 match p.1 {
695 openapiv3::ReferenceOr::Reference { reference: _ } => Err(Error::UnexpectedFormat(
696 format!("path {} uses reference, unsupported", p.0,),
697 )),
698 openapiv3::ReferenceOr::Item(item) => {
699 item.iter().try_for_each(|(_, o)| {
702 if let Some(oid) = o.operation_id.as_ref() {
703 if !opids.insert(oid.to_string()) {
704 return Err(Error::UnexpectedFormat(format!(
705 "duplicate operation ID: {}",
706 oid,
707 )));
708 }
709 } else {
710 return Err(Error::UnexpectedFormat(format!(
711 "path {} is missing operation ID",
712 p.0,
713 )));
714 }
715 Ok(())
716 })
717 }
718 }
719 })?;
720
721 Ok(())
722}
723
724#[cfg(test)]
725mod tests {
726 use serde_json::json;
727
728 use crate::{validate_openapi_spec_version, Error};
729
730 #[test]
731 fn test_bad_value() {
732 assert_eq!(
733 Error::BadValue("nope".to_string(), json! { "nope"},).to_string(),
734 "unexpected value type nope: \"nope\"",
735 );
736 }
737
738 #[test]
739 fn test_type_error() {
740 assert_eq!(
741 Error::UnexpectedFormat("nope".to_string()).to_string(),
742 "unexpected or unhandled format in the OpenAPI document nope",
743 );
744 }
745
746 #[test]
747 fn test_invalid_path() {
748 assert_eq!(
749 Error::InvalidPath("nope".to_string()).to_string(),
750 "invalid operation path nope",
751 );
752 }
753
754 #[test]
755 fn test_internal_error() {
756 assert_eq!(
757 Error::InternalError("nope".to_string()).to_string(),
758 "internal error nope",
759 );
760 }
761
762 #[test]
763 fn test_validate_openapi_spec_version() {
764 assert!(validate_openapi_spec_version("3.0.0").is_ok());
765 assert!(validate_openapi_spec_version("3.0.1").is_ok());
766 assert!(validate_openapi_spec_version("3.0.4").is_ok());
767 assert!(validate_openapi_spec_version("3.0.5-draft").is_ok());
768 assert_eq!(
769 validate_openapi_spec_version("3.1.0")
770 .unwrap_err()
771 .to_string(),
772 "unexpected or unhandled format in the OpenAPI document invalid version: 3.1.0"
773 );
774 }
775}