use crate::generation::{self, as_type_name, path_as_rust_path};
use codegen::Scope;
use openapi::v3_0::{ObjectOrReference, Spec};
use openapi::{Error, OpenApi};
use std::path::Path;
use std::{
fs,
io::{self, Write},
};
pub fn generate_from_spec(
path: impl AsRef<Path>,
spec_path: impl AsRef<Path>,
) -> openapi::Result<()> {
let spec = match openapi::from_path(spec_path)? {
OpenApi::V2(_) => unimplemented!("OpenAPI V2 schemas are not supported"),
OpenApi::V3_0(x) => x,
};
generate_spec(path, &spec)?;
Ok(())
}
pub fn generate_spec(path: impl AsRef<Path>, spec: &Spec) -> io::Result<()> {
let mut file = fs::File::create(path)?;
generate_spec_with(&mut file, spec)
}
pub fn generate_spec_with<W: Write>(w: &mut W, spec: &Spec) -> io::Result<()> {
generate_version(w, spec)?;
generation::generate_components(w, spec)?;
generate_paths(w, spec)?;
generate_server_mod(w, spec)?;
Ok(())
}
fn generate_version<W: Write>(w: &mut W, spec: &Spec) -> io::Result<()> {
writeln!(w, "pub const VERSION: &str = \"{}\";\n", spec.info.version)
}
pub fn generate_paths<W: Write>(w: &mut W, spec: &Spec) -> io::Result<()> {
let mut scope = Scope::new();
for (_, item) in &spec.paths {
for (_, operation) in &item.operations {
let operation_id = operation.operation_id.as_deref().unwrap(); let operation_id = heck::AsSnakeCase(operation_id).to_string();
let operation_mod = scope
.new_module(&operation_id)
.vis("pub")
.attr("allow(unused_assignments, unused_imports, unused_variables)");
operation_mod
.new_module("base")
.attr("allow(unused_imports, non_snake_case)")
.scope()
.raw("pub use super::super::*;");
operation_mod
.new_module("base__super")
.attr("allow(unused_imports, non_snake_case)")
.scope()
.raw("pub use super::super::super::*;");
let mut query_params = vec![];
let mut path_params = vec![];
if let Some(ref params) = operation.parameters {
for param in params {
match param {
ObjectOrReference::Object(p) => {
let v = match p.location.as_str() {
"query" => &mut query_params,
"path" => &mut path_params,
loc => {
panic!("Invalid or unsupported parameter location [{}]", loc)
}
};
if let Some(ref schema) = p.schema {
let ty = as_type_name(p.required.unwrap_or(false), schema);
v.push((&p.name, ty));
}
}
ObjectOrReference::Ref { .. } => todo!(),
}
}
}
let parameters = operation_mod.new_struct("Parameters").vis("pub");
if !query_params.is_empty() || !path_params.is_empty() {
parameters.derive("Debug, ::serde::Deserialize");
for (ident, ty) in &query_params {
parameters.field(ident, ty);
}
for (ident, ty) in &path_params {
parameters.field(ident, ty);
}
}
if !query_params.is_empty() || !path_params.is_empty() {
let params_new_fn = operation_mod
.new_impl("Parameters")
.new_fn("new")
.vis("pub")
.ret("Result<Self, ::serde::de::value::Error>");
if !query_params.is_empty() {
params_new_fn.arg("query", "Query");
}
if !path_params.is_empty() {
params_new_fn.arg("path", "Path");
}
params_new_fn.line("Ok(Self {");
for (ident, _) in &query_params {
params_new_fn.line(format!("{}: query.{},", ident, ident));
}
for (ident, _) in &path_params {
params_new_fn.line(format!("{}: path.{},", ident, ident));
}
params_new_fn.line("})");
}
if !query_params.is_empty() {
let query = operation_mod
.new_struct("Query")
.vis("pub")
.derive("::serde::Serialize, ::serde::Deserialize");
for (ident, ty) in query_params {
query.field(ident, ty);
}
}
if !path_params.is_empty() {
let query = operation_mod
.new_struct("Path")
.vis("pub")
.derive("::serde::Serialize, ::serde::Deserialize");
for (ident, ty) in path_params {
query.field(ident, ty);
}
}
for (status, resp) in &operation.responses {
let ident = format!("Status{}", status);
if let Some(ref path) = resp.ref_path {
operation_mod.scope().raw(format!(
"pub type {} = {};",
ident,
path_as_rust_path(path)
));
} else {
let content = resp.content.as_ref().unwrap(); for (_, ty) in content {
match ty.schema.as_ref().unwrap() {
ObjectOrReference::Object(o) => {
generation::gen_schema_as_type_in(operation_mod.scope(), &ident, o);
}
ObjectOrReference::Ref { ref_path } => {
operation_mod.scope().raw(format!(
"pub type {} = {}",
ident,
path_as_rust_path(ref_path)
));
}
}
}
}
}
let success = operation_mod.new_enum("Success").vis("pub").derive("Debug");
for (status, _) in &operation.responses {
let is_success = status.starts_with('2');
if is_success {
let status = format!("Status{}", status);
success.new_variant(&status).tuple(status);
}
}
let error = operation_mod
.new_enum("Error")
.vis("pub")
.derive("Debug")
.generic("T: std::fmt::Debug");
for (status, _) in &operation.responses {
let is_err = !status.starts_with('2'); if is_err {
let status = format!("Status{}", status);
error.new_variant(&status).tuple(status);
}
}
error.new_variant("Unknown").tuple("T");
let err_display_impl = operation_mod
.new_impl("Error<T>")
.generic("T: std::fmt::Debug + std::fmt::Display")
.impl_trait("std::fmt::Display")
.new_fn("fmt")
.arg_ref_self()
.arg("f", "&mut std::fmt::Formatter<'_>")
.ret("std::fmt::Result");
err_display_impl.line("match self {");
for (status, _) in &operation.responses {
let is_err = !status.starts_with('2'); if is_err {
err_display_impl.line(format!(
r#"Self::Status{}(status) => write!(f, "status {}: {{:?}}", status),"#,
status, status
));
}
}
err_display_impl.line(
r#"Self::Unknown(response) => write!(f, "Unspecified response: `{}`", response),"#,
);
err_display_impl.line("}");
operation_mod
.new_impl("Error<T>")
.generic("T: std::fmt::Debug + std::fmt::Display")
.impl_trait("std::error::Error");
}
}
writeln!(w, "{}\n", scope.to_string())
}
pub fn generate_server_mod<W: Write>(w: &mut W, spec: &Spec) -> io::Result<()> {
let mut scope = Scope::new();
let server_mod = scope
.new_module("server")
.attr("allow(unused_assignments, unused_imports, unused_variables)")
.import("actix_web::error", "InternalError")
.import("actix_web::http", "StatusCode")
.import("actix_web::web", "*")
.import(
"actix_web",
"{HttpRequest, HttpResponse, HttpResponseBuilder, Responder}",
)
.import("async_trait", "async_trait");
server_mod
.new_module("base")
.attr("allow(unused_imports, non_snake_case)")
.scope()
.raw("pub use super::super::*;");
let api_trait_ident = heck::AsPascalCase(&spec.info.title).to_string();
let api_trait = server_mod.new_trait(&api_trait_ident);
api_trait.attr("async_trait(?Send)");
api_trait.associated_type("AuthorizedData");
api_trait
.associated_type("Error")
.bound("std::error::Error");
for (path, item) in &spec.paths {
for (method, operation) in &item.operations {
let operation_id = operation.operation_id.as_deref().unwrap(); let operation_id = heck::AsSnakeCase(operation_id).to_string();
let mut doc = format!("Handler for the `{:?}` method for `{}`", method, path);
if let Some(ref desc) = operation.description {
doc.push_str("\n\n");
doc.push_str(desc);
}
let operation_fn = api_trait
.new_fn(&operation_id)
.doc(doc)
.set_async(true)
.arg_ref_self();
let has_security = operation
.security
.as_ref()
.map_or(false, |s| !s.scopes.is_empty());
if has_security {
operation_fn.arg("request", "Self::AuthorizedData");
} else {
operation_fn.arg("request", "HttpRequest");
}
operation_fn.arg("parameters", format!("base::{}::Parameters", operation_id));
operation_fn.ret(format!(
"Result<base::{}::Success, base::{}::Error<Self::Error>>",
operation_id, operation_id
));
}
}
if let Some(ref components) = spec.components {
for (scheme, _) in components.security_schemes.iter().flatten() {
api_trait
.new_fn(heck::AsSnakeCase(scheme).to_string())
.doc(format!("Handler for the `{}` security scheme", scheme))
.set_async(true)
.arg_ref_self()
.arg("request", "HttpRequest")
.ret("Result<Self::AuthorizedData, Self::Error>");
}
}
let err_to_string_body = r#"let mut errors_str = Vec::new();
let mut current_err = err.source();
while let Some(err) = current_err {
errors_str.push(err.to_string());
current_err = err.source();
}
format!(
"error: {}\n\ncaused by:\n\t{}",
err,
errors_str.as_slice().join("\n\t")
)"#;
let err_to_string = server_mod
.new_fn("err_to_string")
.arg("err", "&dyn std::error::Error")
.ret("String");
for line in err_to_string_body.lines() {
err_to_string.line(line);
}
for (_, item) in &spec.paths {
for (_, operation) in &item.operations {
let operation_id = operation.operation_id.as_deref().unwrap();
let operation_fn = server_mod
.new_fn(operation_id)
.set_async(true)
.generic(format!("Server: {}", api_trait_ident))
.ret("impl Responder");
let mut doc = String::new();
if let Some(ref desc) = operation.summary {
doc.push_str(desc);
}
if let Some(ref desc) = operation.description {
if !doc.is_empty() {
doc.push('\n');
}
doc.push_str(desc);
}
if !doc.is_empty() {
operation_fn.doc(doc);
}
operation_fn
.arg("request", "HttpRequest")
.arg("server", "Data<Server>");
let any_query_params = operation.parameters.as_ref().map_or(false, |p| {
p.iter().any(|x| match x {
ObjectOrReference::Object(x) => x.location == "query",
ObjectOrReference::Ref { .. } => todo!(),
})
});
if any_query_params {
operation_fn.arg("query", format!("Query<base::{}::Query>", operation_id));
}
operation_fn.line(format!("use base::{}::*;", operation_id));
if any_query_params {
operation_fn.line(
"let parameters_res = Parameters::new(query.into_inner())\
.map_err(|e| HttpResponse::BadRequest().body(err_to_string(&e)));",
);
operation_fn.line(
"let parameters = match parameters_res { Ok(x) => x, Err(e) => return e };\n",
);
} else {
operation_fn.line("let parameters = Parameters;");
}
let has_security = operation
.security
.as_ref()
.map_or(false, |s| !s.scopes.is_empty());
if has_security {
operation_fn.line(
"let request_res = server.bearer_auth(request).await\
.map_err(|e| HttpResponse::Unauthorized().body(err_to_string(&e)));",
);
operation_fn
.line("let request = match request_res { Ok(x) => x, Err(e) => return e };\n");
}
operation_fn.line(format!(
"match server.{}(request, parameters).await {{",
operation_id
));
for (status, _) in &operation.responses {
let is_success = status.starts_with("2");
let res_variant = is_success.then(|| "Ok").unwrap_or("Err");
let success_variant = is_success.then(|| "Success").unwrap_or("Error");
operation_fn.line(format!(
"{}({}::Status{}(response)) => \
HttpResponseBuilder::new(StatusCode::from_u16({}).unwrap()).json(&response),",
res_variant, success_variant, status, status
));
}
operation_fn.line(
"Err(Error::Unknown(err)) => HttpResponse::InternalServerError().body(err_to_string(&err)),",
);
operation_fn.line("}");
}
}
let config_fn = server_mod
.new_fn("config")
.generic(format!("Server: {} + 'static", api_trait_ident))
.arg("app", "&mut ServiceConfig");
config_fn.line("app");
for (path, item) in &spec.paths {
for (method, operation) in &item.operations {
let operation_id = operation.operation_id.as_deref().unwrap(); let resource = format!(
r#"resource("{}").route({}().to({}::<Server>))"#,
path,
method.as_str(),
operation_id
);
config_fn.line(format!(" .service({})", resource));
}
}
config_fn.line(
r#" .app_data(
actix_web::web::JsonConfig::default().error_handler(|err, _| {
let mut response = HttpResponseBuilder::new(StatusCode::BAD_REQUEST);
response.body(err_to_string(&err));
InternalError::from_response(err, response.into()).into()
}),
);"#,
);
write!(w, "{}", scope.to_string())
}