use crate::printer::*;
use crate::rust::lib_gen::{Module, ModuleDef, ModuleName};
use crate::rust::printer::*;
use crate::{Error, Result};
use convert_case::{Case, Casing};
use openapiv3::{OpenAPI, ReferenceOr, SecurityScheme};
fn client() -> TreePrinter<RustContext> {
rust_name("reqwest", "Client")
}
fn url() -> TreePrinter<RustContext> {
rust_name("reqwest", "Url")
}
struct Security {
typedef: TreePrinter<RustContext>,
filed: TreePrinter<RustContext>,
methods: TreePrinter<RustContext>,
}
fn security_gen(name: &str, ref_or_sec: &ReferenceOr<SecurityScheme>) -> Result<Security> {
match ref_or_sec {
ReferenceOr::Reference { .. } => Err(Error::unimplemented(format!(
"$ref in security_schemes[{name}]"
))),
ReferenceOr::Item(sec) => {
match sec {
SecurityScheme::HTTP { scheme, .. } => {
if scheme == "bearer" {
#[rustfmt::skip]
let typedef = unit() +
line("#[derive(Debug, Clone)]") +
line("pub enum Security {") +
indented(
line("Empty,") +
line("Bearer(String),")
) +
line("}") +
NewLine +
line("impl Security {") +
indented(
line("pub fn bearer<S: Into<String>>(s: S) -> Security { Security::Bearer(s.into()) }")
) +
line("}");
let field_name = format!(
"security_{}",
name.from_case(Case::UpperCamel).to_case(Case::Snake)
);
let filed = line(unit() + "pub " + &field_name + ": Security,");
#[rustfmt::skip]
let methods = unit() +
line("pub fn bearer_token(&self) -> Option<&str> {") +
indented(
line(unit() + "match &self." + &field_name + "{") +
indented(
line("Security::Empty => None,") +
line(r#"Security::Bearer(token) => Some(token),"#)
) +
line("}")
) +
line("}");
Ok(Security {
typedef,
filed,
methods,
})
} else {
Err(Error::unimplemented(format!("Unsupported http security_schemes[{name}], only bearer is implemented.")))
}
}
_ => Err(Error::unimplemented(format!(
"Unsupported security_schemes[{name}], only http is implemented."
))),
}
}
}
}
pub fn context_gen(open_api: &OpenAPI) -> Result<Module> {
let security = match &open_api.components {
None => None,
Some(components) => {
let ref_sec = components
.security_schemes
.iter()
.find_map(|(name, ref_or)| {
if ref_or.as_item().is_none() {
Some(name)
} else {
None
}
});
if let Some(name) = ref_sec {
return Err(Error::unimplemented(format!(
"$ref in security_schemes[{name}]"
)));
}
let sec_res: Result<Vec<Security>> = components
.security_schemes
.iter()
.filter(|(_, sec)| {
!matches!(
sec,
ReferenceOr::Item(SecurityScheme::APIKey {
location: openapiv3::APIKeyLocation::Cookie,
..
})
)
})
.map(|(name, sec)| security_gen(name, sec))
.collect();
sec_res?.pop()
}
};
let no_security = security.is_none();
let security = security.unwrap_or(Security {
typedef: unit(),
filed: unit(),
methods: line("pub fn bearer_token(&self) -> Option<&str> { None }"),
});
#[rustfmt::skip]
let code = unit() +
security.typedef +
NewLine +
line("#[derive(Debug, Clone)]") +
line(unit() + "pub struct Context {") +
indented(
line(unit() + "pub client: " + client() + ",") +
line(unit() + "pub base_url: " + url() + ",") +
security.filed
) +
line("}") +
NewLine +
line("impl Context {") +
indented(
security.methods
) +
line("}");
let mut exports = vec!["Context".to_string()];
if !no_security {
exports.push("Security".to_string());
}
Ok(Module {
def: ModuleDef {
name: ModuleName::new("context"),
exports,
},
code: RustContext::new().print_to_string(code),
})
}