use crate::fang::SendSyncOnThreaded;
use crate::prelude::*;
#[cfg(feature = "openapi")]
use crate::openapi;
#[derive(Clone, Debug)]
pub struct BasicAuth<S>
where
S: AsRef<str> + Clone + SendSyncOnThreaded + 'static,
{
pub username: S,
pub password: S,
}
impl<S> BasicAuth<S>
where
S: AsRef<str> + Clone + SendSyncOnThreaded + 'static,
{
#[inline]
fn matches(&self, username: &str, password: &str) -> bool {
self.username.as_ref() == username && self.password.as_ref() == password
}
}
const _: () = {
fn unauthorized() -> Response {
Response::Unauthorized().with_headers(|h| h.www_authenticate("Basic realm=\"Secure Area\""))
}
#[inline]
fn basic_credential_of(req: &Request) -> Result<String, Response> {
(|| {
crate::util::base64_decode_utf8(req.headers.authorization()?.strip_prefix("Basic ")?)
.ok()
})()
.ok_or_else(unauthorized)
}
impl<S> FangAction for BasicAuth<S>
where
S: AsRef<str> + Clone + SendSyncOnThreaded + 'static,
{
#[inline]
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
let credential = basic_credential_of(req)?;
let (username, password) = credential.split_once(':').ok_or_else(unauthorized)?;
self.matches(username, password)
.then_some(())
.ok_or_else(unauthorized)?;
Ok(())
}
#[cfg(feature = "openapi")]
fn openapi_map_operation(&self, operation: openapi::Operation) -> openapi::Operation {
use openapi::security::SecurityScheme;
operation.security(SecurityScheme::basic("basicAuth"), &[])
}
}
impl<S, const N: usize> FangAction for [BasicAuth<S>; N]
where
S: AsRef<str> + Clone + SendSyncOnThreaded + 'static,
{
#[inline]
async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> {
let credential = basic_credential_of(req)?;
let (username, password) = credential.split_once(':').ok_or_else(unauthorized)?;
self.iter()
.any(|candidate| candidate.matches(username, password))
.then_some(())
.ok_or_else(unauthorized)?;
Ok(())
}
#[cfg(feature = "openapi")]
fn openapi_map_operation(&self, operation: openapi::Operation) -> openapi::Operation {
use openapi::security::SecurityScheme;
operation.security(SecurityScheme::basic("basicAuth"), &[])
}
}
};
#[cfg(test)]
mod test {
#[test]
fn test_basicauth_fang_bound() {
use crate::fang::{BoxedFPC, Fang};
fn assert_fang<T: Fang<BoxedFPC>>() {}
assert_fang::<super::BasicAuth<&'static str>>();
assert_fang::<super::BasicAuth<String>>();
}
#[cfg(feature = "__rt_native__")]
#[test]
fn test_basicauth() {
use super::*;
use crate::testing::*;
let t = Ohkami::new((
"/hello".GET(|| async { "Hello!" }),
"/private".By(Ohkami::new((
BasicAuth {
username: "ohkami",
password: "password",
},
"/".GET(|| async { "Hello, private!" }),
))),
))
.test();
crate::__rt__::testing::block_on(async {
{
let req = TestRequest::GET("/hello");
let res = t.oneshot(req).await;
assert_eq!(res.status().code(), 200);
assert_eq!(res.text(), Some("Hello!"));
}
{
let req = TestRequest::GET("/private");
let res = t.oneshot(req).await;
assert_eq!(res.status().code(), 401);
}
{
let req = TestRequest::GET("/private").header(
"Authorization",
format!("Basic {}", crate::util::base64_encode("ohkami:password")),
);
let res = t.oneshot(req).await;
assert_eq!(res.status().code(), 200);
assert_eq!(res.text(), Some("Hello, private!"));
}
{
let req = TestRequest::GET("/private").header(
"Authorization",
format!("Basic {}", crate::util::base64_encode("ohkami:wrong")),
);
let res = t.oneshot(req).await;
assert_eq!(res.status().code(), 401);
}
});
}
}