1#![recursion_limit = "256"]
2#![allow(unused_imports)]
3
4use std::convert::TryFrom;
5use std::fmt;
6use std::fs;
7use std::path::Path;
8use std::str::FromStr;
9
10use actix_http::StatusCode;
11use derive_more::{Deref, Display};
12use either::Either;
13use heck::{ToPascalCase, ToSnakeCase};
14use indexmap::{IndexMap as Map, IndexSet as Set};
15use log::{debug, info};
16use openapiv3::{
17 AnySchema, ObjectType, OpenAPI, ReferenceOr, Schema, SchemaData, SchemaKind,
18 StatusCode as ApiStatusCode, Type as ApiType,
19};
20use proc_macro2::{Ident as QIdent, TokenStream};
21use quote::quote;
22use regex::Regex;
23use thiserror::Error;
24
25macro_rules! invalid {
26 ($($arg:tt)+) => (
27 return Err(Error::Validation(format!($($arg)+)))
28 );
29}
30
31mod route;
32mod walk;
33
34use route::Route;
35
36const SWAGGER_UI_TEMPLATE: &'static str = include_str!("../ui-template.html");
37
38fn ident(s: impl fmt::Display) -> QIdent {
39 QIdent::new(&s.to_string(), proc_macro2::Span::call_site())
40}
41
42type SchemaLookup = Map<String, ReferenceOr<Schema>>;
43
44#[derive(Debug, Error)]
45pub enum Error {
46 #[error("IO Error: {}", _0)]
47 Io(#[from] std::io::Error),
48 #[error("Yaml Error: {}", _0)]
49 Yaml(#[from] serde_yaml::Error),
50 #[error("Codegen failed: {}", _0)]
51 BadCodegen(String),
52 #[error("Bad reference: {}", _0)]
53 BadReference(String),
54 #[error("OpenAPI validation failed: {}", _0)]
55 Validation(String),
56}
57
58pub type Result<T> = std::result::Result<T, Error>;
59
60fn unwrap_ref<T>(item: &ReferenceOr<T>) -> Result<&T> {
63 match item {
64 ReferenceOr::Item(item) => Ok(item),
65 ReferenceOr::Reference { reference } => Err(Error::BadReference(reference.to_string())),
66 }
67}
68
69fn dereference<'a, T>(
71 refr: &'a ReferenceOr<T>,
72 lookup: &'a Map<String, ReferenceOr<T>>,
73) -> Result<&'a T> {
74 match refr {
75 ReferenceOr::Reference { reference } => lookup
76 .get(reference)
77 .ok_or_else(|| Error::BadReference(reference.to_string()))
78 .and_then(|refr| dereference(refr, lookup)),
79 ReferenceOr::Item(item) => Ok(item),
80 }
81}
82
83fn api_trait_name(api: &OpenAPI) -> TypeName {
84 TypeName::from_str(&format!("{}Api", api.info.title.to_pascal_case())).unwrap()
85}
86
87#[derive(Debug, Clone, Copy, derive_more::Display)]
88enum RawMethod {
89 Get,
90 Post,
91 Delete,
92 Put,
93 Patch,
94 Options,
95 Head,
96 Trace,
97}
98
99#[derive(Debug, Clone)]
102enum Method {
103 WithoutBody(MethodWithoutBody),
104 WithBody {
105 method: MethodWithBody,
106 body_type: Option<TypePath>,
108 },
109}
110
111impl fmt::Display for Method {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 match self {
114 Self::WithoutBody(method) => method.fmt(f),
115 Self::WithBody { method, .. } => method.fmt(f),
116 }
117 }
118}
119
120impl Method {
121 fn from_raw(method: RawMethod, body_type: Option<TypePath>) -> Result<Self> {
122 use Method as M;
123 use MethodWithBody::*;
124 use MethodWithoutBody::*;
125 use RawMethod as R;
126 match method {
127 R::Get | R::Head | R::Options | R::Trace => {
128 if body_type.is_some() {
129 invalid!("Method '{}' canoot have a body", method);
130 }
131 }
132 _ => {}
133 }
134 let meth = match method {
135 R::Get => M::WithoutBody(Get),
136 R::Head => M::WithoutBody(Head),
137 R::Trace => M::WithoutBody(Trace),
138 R::Options => M::WithoutBody(Options),
139 R::Post => M::WithBody {
140 method: Post,
141 body_type,
142 },
143 R::Patch => M::WithBody {
144 method: Patch,
145 body_type,
146 },
147 R::Put => M::WithBody {
148 method: Put,
149 body_type,
150 },
151 R::Delete => M::WithBody {
152 method: Delete,
153 body_type,
154 },
155 };
156 Ok(meth)
157 }
158
159 fn body_type(&self) -> Option<&TypePath> {
160 match self {
161 Method::WithoutBody(_)
162 | Method::WithBody {
163 body_type: None, ..
164 } => None,
165 Method::WithBody {
166 body_type: Some(ref body_ty),
167 ..
168 } => Some(body_ty),
169 }
170 }
171}
172
173#[derive(Debug, Clone, Copy)]
174enum MethodWithoutBody {
175 Get,
176 Head,
177 Options,
178 Trace,
179}
180
181impl fmt::Display for MethodWithoutBody {
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 use MethodWithoutBody::*;
184 match self {
185 Get => write!(f, "GET"),
186 Head => write!(f, "HEAD"),
187 Options => write!(f, "OPTIONS"),
188 Trace => write!(f, "TRACE"),
189 }
190 }
191}
192
193impl fmt::Display for MethodWithBody {
194 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
195 use MethodWithBody::*;
196 match self {
197 Post => write!(f, "POST"),
198 Delete => write!(f, "DELETE"),
199 Put => write!(f, "PUT"),
200 Patch => write!(f, "PATCH"),
201 }
202 }
203}
204
205#[derive(Debug, Clone, Copy)]
206enum MethodWithBody {
207 Post,
208 Delete,
209 Put,
210 Patch,
211}
212
213#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Display, Deref)]
217struct Ident(String);
218
219impl FromStr for Ident {
220 type Err = Error;
221 fn from_str(val: &str) -> Result<Self> {
222 let ident_re = Regex::new("^([[:alpha:]]|_)([[:alnum:]]|_)*$").unwrap();
225 if ident_re.is_match(val) {
226 Ok(Ident(val.to_string()))
227 } else {
228 invalid!("Bad identifier '{}' (not a valid Rust identifier)", val)
229 }
230 }
231}
232
233impl quote::ToTokens for Ident {
234 fn to_tokens(&self, tokens: &mut TokenStream) {
235 let id = ident(&self.0);
236 id.to_tokens(tokens)
237 }
238}
239
240#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
244pub(crate) struct ApiPath {
245 path: Vec<String>,
246}
247
248impl ApiPath {
249 fn push(mut self, s: impl Into<String>) -> Self {
250 self.path.push(s.into());
251 self
252 }
253}
254
255impl From<TypePath> for ApiPath {
256 fn from(path: TypePath) -> Self {
257 Self { path: path.0 }
258 }
259}
260
261impl std::fmt::Display for ApiPath {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
263 let joined = self.path.join(".");
264 write!(f, "{}", joined)
265 }
266}
267
268#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
273pub(crate) struct TypePath(Vec<String>);
274
275impl From<ApiPath> for TypePath {
276 fn from(path: ApiPath) -> Self {
277 Self(path.path)
278 }
279}
280
281impl TypePath {
282 pub(crate) fn from_reference(refr: &str) -> Result<Self> {
283 let rx = Regex::new("^#/components/schemas/([[:alnum:]]+)$").unwrap();
284 let cap = rx
285 .captures(refr)
286 .ok_or_else(|| Error::BadReference(refr.into()))?;
287 let name = cap.get(1).unwrap();
288 let path = vec![
289 "components".into(),
290 "schemas".into(),
291 name.as_str().to_string(),
292 ];
293 Ok(Self(path))
294 }
295
296 pub(crate) fn canonicalize(&self) -> TypeName {
299 let parts: Vec<&str> = self.0.iter().map(String::as_str).collect();
300 let parts = match &parts[..] {
301 ["components", "schemas", rest @ ..] => &rest,
303 ["paths", _path, _method, rest @ ..] => &rest,
307 rest => rest,
309 };
310 let joined = parts.join(" ");
311 TypeName::from_str(&joined.to_pascal_case()).unwrap()
312 }
313}
314
315#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Display, Deref)]
319struct TypeName(String);
320
321impl FromStr for TypeName {
322 type Err = Error;
323 fn from_str(val: &str) -> Result<Self> {
324 let camel = val.to_pascal_case();
325 if val == camel {
326 Ok(TypeName(camel))
327 } else {
328 invalid!("Bad type name '{}', must be ClassCase", val)
329 }
330 }
331}
332
333impl quote::ToTokens for TypeName {
334 fn to_tokens(&self, tokens: &mut TokenStream) {
335 let id = ident(&self.0);
336 id.to_tokens(tokens)
337 }
338}
339
340#[derive(Debug, Clone, PartialEq, Eq)]
341enum PathSegment {
342 Literal(String),
343 Parameter(String),
344}
345
346#[derive(Clone, Debug, PartialEq, Eq)]
347pub(crate) struct RoutePath {
348 segments: Vec<PathSegment>,
349}
350
351impl fmt::Display for RoutePath {
352 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
353 let mut path = String::new();
354 for segment in &self.segments {
355 match segment {
356 PathSegment::Literal(p) => {
357 path.push('/');
358 path.push_str(&p);
359 }
360 PathSegment::Parameter(p) => {
361 path.push_str(&format!("/{{{}}}", p));
362 }
363 }
364 }
365 write!(f, "{}", path)
366 }
367}
368
369impl RoutePath {
370 fn analyse(path: &str) -> Result<RoutePath> {
372 let literal_re = Regex::new("^[[:alpha:]]([[:alnum:]]|_)*$").unwrap();
374 let param_re = Regex::new(r#"^\{([[:alpha:]]([[:alnum:]]|_)*)\}$"#).unwrap();
375
376 if !path.starts_with('/') {
377 invalid!("Bad path '{}' (must start with '/')", path);
378 }
379
380 let mut segments = Vec::new();
381
382 let mut dupe_params = Set::new();
383 for segment in path.split('/').skip(1) {
384 if literal_re.is_match(segment) {
385 segments.push(PathSegment::Literal(segment.to_string()))
386 } else if let Some(seg) = param_re.captures(segment) {
387 let param = seg.get(1).unwrap().as_str().to_string();
388 if !dupe_params.insert(param.clone()) {
389 invalid!("Duplicate parameter in path '{}'", path);
390 }
391 segments.push(PathSegment::Parameter(param))
392 } else {
393 invalid!("Bad path '{}'", path);
394 }
395 }
396 Ok(RoutePath { segments })
397 }
398
399 fn path_args(&self) -> impl Iterator<Item = &str> {
400 self.segments.iter().filter_map(|s| {
401 if let PathSegment::Parameter(ref p) = s {
402 Some(p.as_ref())
403 } else {
404 None
405 }
406 })
407 }
408}
409
410#[derive(Debug, Copy, Clone, PartialEq, Eq)]
411pub(crate) enum Visibility {
412 Public,
413 Private,
414}
415
416impl quote::ToTokens for Visibility {
417 fn to_tokens(&self, tokens: &mut TokenStream) {
418 let tok = match self {
419 Visibility::Public => quote! {pub},
420 Visibility::Private => quote! {},
421 };
422 tok.to_tokens(tokens)
423 }
424}
425
426impl Default for Visibility {
427 fn default() -> Self {
428 Visibility::Public
429 }
430}
431
432#[derive(Debug, Clone, PartialEq, Eq, Default)]
433pub(crate) struct TypeMetadata {
434 title: Option<String>,
435 description: Option<String>,
436 nullable: bool,
437 visibility: Visibility,
438}
439
440impl TypeMetadata {
441 fn with_visibility(self, visibility: Visibility) -> Self {
442 Self { visibility, ..self }
443 }
444
445 fn with_description(self, description: String) -> Self {
446 Self {
447 description: Some(description),
448 ..self
449 }
450 }
451
452 fn description(&self) -> Option<TokenStream> {
453 self.description.as_ref().map(|s| {
454 quote! {
455 #[doc = #s]
456 }
457 })
458 }
459}
460
461impl From<openapiv3::SchemaData> for TypeMetadata {
462 fn from(from: openapiv3::SchemaData) -> Self {
463 Self {
464 title: from.title,
465 description: from.description,
466 nullable: from.nullable,
467 visibility: Visibility::Public,
468 }
469 }
470}
471
472#[derive(Debug, Clone, PartialEq, Eq, Default)]
473pub(crate) struct FieldMetadata {
474 description: Option<String>,
475 required: bool,
476}
477
478impl FieldMetadata {
479 fn with_required(self, required: bool) -> Self {
480 Self { required, ..self }
481 }
482}
483
484pub(crate) fn variant_from_status_code(code: &StatusCode) -> Ident {
485 code.canonical_reason()
486 .and_then(|reason| reason.to_pascal_case().parse().ok())
487 .unwrap_or_else(|| format!("Status{}", code.as_str()).parse().unwrap())
488}
489
490fn doc_comment(msg: impl AsRef<str>) -> TokenStream {
491 let msg = msg.as_ref();
492 quote! {
493 #[doc = #msg]
494 }
495}
496
497fn get_derive_tokens() -> TokenStream {
498 quote! {
499 # [derive(Debug, Clone, PartialEq, hsr::Serialize, hsr::Deserialize)]
500 }
501}
502
503fn generate_rust_interface(
504 routes: &Map<String, Vec<Route>>,
505 title: &str,
506 trait_name: &TypeName,
507) -> TokenStream {
508 let mut methods = TokenStream::new();
509 let descr = doc_comment(format!("Api generated from '{}' spec", title));
510 for (_, route_methods) in routes {
511 for route in route_methods {
512 methods.extend(route.generate_api_signature());
513 }
514 }
515 quote! {
516 #descr
517 #[hsr::async_trait::async_trait(?Send)]
518 pub trait #trait_name: 'static + Send + Sync {
519 #methods
520 }
521 }
522}
523
524fn generate_rust_dispatchers(
525 routes: &Map<String, Vec<Route>>,
526 trait_name: &TypeName,
527) -> TokenStream {
528 let mut dispatchers = TokenStream::new();
529 for (_api_path, route_methods) in routes {
530 for route in route_methods {
531 dispatchers.extend(route.generate_dispatcher(trait_name));
532 }
533 }
534 quote! {#dispatchers}
535}
536
537fn generate_rust_server(routemap: &Map<String, Vec<Route>>, trait_name: &TypeName) -> TokenStream {
538 let resources: Vec<_> = routemap
539 .iter()
540 .map(|(path, routes)| {
541 let (meth, opid): (Vec<_>, Vec<_>) = routes
542 .iter()
543 .map(|route| {
544 (
545 ident(route.method().to_string().to_snake_case()),
546 route.operation_id(),
547 )
548 })
549 .unzip();
550 quote! {
551 web::resource(#path)
552 #(.route(web::#meth().to(#opid::<A>)))*
553 }
554 })
555 .collect();
556
557 let server = quote! {
558 #[allow(dead_code)]
559 pub mod server {
560 use super::*;
561
562 fn configure_hsr<A: #trait_name>(cfg: &mut actix_web::web::ServiceConfig) {
563 cfg #(.service(#resources))*;
564 }
565
566 pub async fn serve<A: #trait_name>(api: A, cfg: hsr::Config) -> std::io::Result<()> {
569 let api = AxData::new(api);
578
579 let server = HttpServer::new(move || {
580 App::new()
581 .app_data(api.clone())
582 .wrap(Logger::default())
583 .configure(|cfg| hsr::configure_spec(cfg, JSON_SPEC, UI_TEMPLATE))
584 .configure(configure_hsr::<A>)
585 });
586
587 let server = if let Some(ssl) = cfg.ssl {
589 server.bind_openssl((cfg.host.host_str().unwrap(), cfg.host.port().unwrap()), ssl)
590 } else {
591 server.bind((cfg.host.host_str().unwrap(), cfg.host.port().unwrap()))
592 }?;
593
594 server.run().await
596 }
597 }
598 };
599 server
600}
601
602fn generate_rust_client(routes: &Map<String, Vec<Route>>) -> TokenStream {
603 let mut method_impls = TokenStream::new();
604 for (_, route_methods) in routes {
605 for route in route_methods {
606 method_impls.extend(route.generate_client_impl());
607 }
608 }
609
610 quote! {
611 #[allow(dead_code)]
612 #[allow(unused_imports)]
613 pub mod client {
614 use super::*;
615 use hsr::actix_http::Method;
616 use hsr::awc::Client as ActixClient;
617 use hsr::ClientError;
618 use hsr::futures::future::{err as fut_err, ok as fut_ok};
619 use hsr::serde_urlencoded;
620
621 pub struct Client {
622 domain: Url,
623 inner: ActixClient,
624 }
625
626 impl Client {
627
628 pub fn new(domain: Url) -> Self {
629 Client {
630 domain: domain,
631 inner: ActixClient::new()
632 }
633 }
634
635 #method_impls
636 }
637 }
638 }
639}
640
641pub fn generate_from_yaml_file(yaml: impl AsRef<Path>) -> Result<String> {
642 let f = fs::File::open(yaml)?;
644 generate_from_yaml_source(f)
645}
646
647pub fn generate_from_yaml_source(mut yaml: impl std::io::Read) -> Result<String> {
648 let mut openapi_source = String::new();
650 yaml.read_to_string(&mut openapi_source)?;
651 let api: OpenAPI = serde_yaml::from_str(&openapi_source)?;
652
653 let json_spec = serde_json::to_string(&api).expect("Bad api serialization");
662
663 let trait_name = api_trait_name(&api);
664
665 debug!("Gather types");
667 let (type_lookup, routes) = walk::walk_api(&api)?;
668
669 debug!("Generate API types");
671 let rust_api_types = walk::generate_rust_types(&type_lookup)?;
672
673 debug!("Generate response types");
675 let rust_response_types: Vec<_> = routes
676 .values()
677 .map(|routes| routes.iter().map(|route| route.generate_return_type()))
678 .flatten()
679 .collect();
680
681 debug!("Generate API trait");
682 let rust_trait = generate_rust_interface(&routes, &api.info.title, &trait_name);
683
684 debug!("Generate dispatchers");
685 let rust_dispatchers = generate_rust_dispatchers(&routes, &trait_name);
686
687 debug!("Generate server");
688 let rust_server = generate_rust_server(&routes, &trait_name);
689
690 debug!("Generate client");
691 let rust_client = generate_rust_client(&routes);
692
693 let code = quote! {
694 #[allow(dead_code)]
695
696 const JSON_SPEC: &'static str = #json_spec;
698 const UI_TEMPLATE: &'static str = #SWAGGER_UI_TEMPLATE;
699
700 mod __imports {
701 pub use hsr::HasStatusCode;
702 pub use hsr::actix_web::{
703 self, App, HttpServer, HttpRequest, HttpResponse, Responder, Either as AxEither,
704 Error as ActixError,
705 error::ErrorInternalServerError,
706 web::{self, Json as AxJson, Query as AxQuery, Path as AxPath, Data as AxData, ServiceConfig},
707 body::BoxBody,
708 HttpResponseBuilder,
709 middleware::Logger
710 };
711 pub use hsr::url::Url;
712 pub use hsr::actix_http::{StatusCode};
713 pub use hsr::futures::future::{Future, FutureExt, TryFutureExt, Ready, ok as fut_ok};
714 pub use hsr::serde_json::Value as JsonValue;
715
716 pub use hsr::{Serialize, Deserialize};
718 }
719 #[allow(dead_code)]
720 use __imports::*;
721
722 #rust_api_types
724 #(#rust_response_types)*
725 #rust_trait
727 #rust_dispatchers
729 #rust_server
731 #rust_client
733 };
734 let code = code.to_string();
735 #[cfg(feature = "pretty")]
736 {
737 debug!("Prettify");
738 prettify_code(code)
739 }
740 #[cfg(not(feature = "pretty"))]
741 {
742 Ok(code)
743 }
744}
745
746#[cfg(feature = "pretty")]
748pub fn prettify_code(input: String) -> Result<String> {
749 let formatted: String = rustfmt_wrapper::rustfmt(input).unwrap();
750 Ok(formatted)
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 #[test]
758 fn test_snake_casify() {
759 assert_eq!("/a/b/c".to_snake_case(), "a_b_c");
760 assert_eq!(
761 "/All/ThisIs/justFine".to_snake_case(),
762 "all_this_is_just_fine"
763 );
764 assert_eq!("/{someId}".to_snake_case(), "some_id");
765 assert_eq!(
766 "/123_abc{xyz\\!\"£$%^}/456 asdf".to_snake_case(),
767 "123_abc_xyz_456_asdf"
768 )
769 }
770
771 #[test]
772 fn test_valid_identifier() {
773 assert!(Ident::from_str("x").is_ok());
774 assert!(Ident::from_str("_").is_ok());
775 assert!(Ident::from_str("x1").is_ok());
776 assert!(Ident::from_str("x1_23_aB").is_ok());
777
778 assert!(Ident::from_str("").is_err());
779 assert!(Ident::from_str("1abc").is_err());
780 assert!(Ident::from_str("abc!").is_err());
781 }
782
783 #[test]
784 fn test_analyse_path() {
785 use PathSegment::*;
786
787 assert!(RoutePath::analyse("").is_err());
789 assert!(RoutePath::analyse("a").is_err());
790 assert!(RoutePath::analyse("/a/").is_err());
791 assert!(RoutePath::analyse("/a/b/c/").is_err());
792 assert!(RoutePath::analyse("/a{").is_err());
793 assert!(RoutePath::analyse("/a{}").is_err());
794 assert!(RoutePath::analyse("/{}a").is_err());
795 assert!(RoutePath::analyse("/{a}a").is_err());
796 assert!(RoutePath::analyse("/ a").is_err());
797 assert!(RoutePath::analyse("/1").is_err());
798 assert!(RoutePath::analyse("/a//b").is_err());
799
800 assert!(RoutePath::analyse("/a").is_ok());
801 assert!(RoutePath::analyse("/a/b/c").is_ok());
802 assert!(RoutePath::analyse("/a/a/a").is_ok());
803 assert!(RoutePath::analyse("/a1/b2/c3").is_ok());
804
805 assert!(RoutePath::analyse("/{a1}").is_ok());
806 assert!(RoutePath::analyse("/{a1}/b2/{c3}").is_ok());
807 assert!(RoutePath::analyse("/{a1B2c3}").is_ok());
808 assert!(RoutePath::analyse("/{a1_b2_c3}").is_ok());
809
810 assert!(RoutePath::analyse("/{a}/{b}/{a}").is_err());
812
813 assert_eq!(
814 RoutePath::analyse("/{a_1}/{b2C3}/a/b").unwrap(),
815 RoutePath {
816 segments: vec![
817 Parameter("a_1".into()),
818 Parameter("b2C3".into()),
819 Literal("a".into()),
820 Literal("b".into())
821 ]
822 }
823 );
824 }
825
826 }