use super::{
Error, get_bool_conf, get_hash_key, get_plugin_factory, get_str_conf,
get_str_slice_conf,
};
use async_trait::async_trait;
use bytes::Bytes;
use ctor::ctor;
use http::HeaderValue;
use http::StatusCode;
use humantime::parse_duration;
use pingap_config::{PluginCategory, PluginConf};
use pingap_core::{Ctx, HttpResponse, Plugin, PluginStep, RequestPluginResult};
use pingap_util::base64_decode;
use pingora::proxy::Session;
use std::borrow::Cow;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use tracing::debug;
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct BasicAuth {
plugin_step: PluginStep,
authorizations: Vec<Vec<u8>>,
hide_credentials: bool,
miss_authorization_resp: HttpResponse,
unauthorized_resp: HttpResponse,
delay: Option<Duration>,
hash_value: String,
}
impl TryFrom<&PluginConf> for BasicAuth {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let hash_value = get_hash_key(value);
let delay = get_str_conf(value, "delay");
let delay = if !delay.is_empty() {
let d = parse_duration(&delay).map_err(|e| Error::Invalid {
category: PluginCategory::KeyAuth.to_string(),
message: e.to_string(),
})?;
Some(d)
} else {
None
};
let mut authorizations = vec![];
for item in get_str_slice_conf(value, "authorizations").iter() {
let _ = base64_decode(item).map_err(|e| Error::Base64Decode {
category: PluginCategory::BasicAuth.to_string(),
source: e,
})?;
authorizations.push(format!("Basic {item}").as_bytes().to_vec());
}
if authorizations.is_empty() {
return Err(Error::Invalid {
category: PluginCategory::BasicAuth.to_string(),
message: "basic authorizations can't be empty".to_string(),
});
}
let miss_authorization_headers = if let Ok(value) =
HeaderValue::from_str(
r###"Basic realm="Access to the staging site""###,
) {
Some(vec![(http::header::WWW_AUTHENTICATE, value)])
} else {
None
};
let unauthorized_headers = if let Ok(value) = HeaderValue::from_str(
r###"Basic realm="Access to the staging site""###,
) {
Some(vec![(http::header::WWW_AUTHENTICATE, value)])
} else {
None
};
let params = Self {
hash_value,
plugin_step: PluginStep::Request,
delay,
hide_credentials: get_bool_conf(value, "hide_credentials"),
authorizations,
miss_authorization_resp: HttpResponse {
status: StatusCode::UNAUTHORIZED,
headers: miss_authorization_headers,
body: Bytes::from_static(b"Authorization is missing"),
..Default::default()
},
unauthorized_resp: HttpResponse {
status: StatusCode::UNAUTHORIZED,
headers: unauthorized_headers,
body: Bytes::from_static(b"Invalid user or password"),
..Default::default()
},
};
Ok(params)
}
}
impl BasicAuth {
pub fn new(params: &PluginConf) -> Result<Self> {
debug!(params = params.to_string(), "new basic auth plugin");
Self::try_from(params)
}
}
#[async_trait]
impl Plugin for BasicAuth {
#[inline]
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
#[inline]
async fn handle_request(
&self,
step: PluginStep,
session: &mut Session,
_ctx: &mut Ctx,
) -> pingora::Result<RequestPluginResult> {
if step != self.plugin_step {
return Ok(RequestPluginResult::Skipped);
}
let value = session.get_header_bytes(http::header::AUTHORIZATION);
if value.is_empty() {
return Ok(RequestPluginResult::Respond(
self.miss_authorization_resp.clone(),
));
}
if !self.authorizations.iter().any(|auth| auth == value) {
if let Some(d) = self.delay {
sleep(d).await;
}
return Ok(RequestPluginResult::Respond(
self.unauthorized_resp.clone(),
));
}
if self.hide_credentials {
session
.req_header_mut()
.remove_header(&http::header::AUTHORIZATION);
}
return Ok(RequestPluginResult::Continue);
}
}
#[ctor]
fn init() {
let factory = get_plugin_factory();
factory
.register("basic_auth", |params| Ok(Arc::new(BasicAuth::new(params)?)));
}
#[cfg(test)]
mod tests {
use super::{BasicAuth, Plugin};
use pingap_config::PluginConf;
use pingap_core::{Ctx, PluginStep, RequestPluginResult};
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use std::time::Duration;
use tokio_test::io::Builder;
#[test]
fn test_basic_auth_params() {
let params = BasicAuth::try_from(
&toml::from_str::<PluginConf>(
r###"
authorizations = [
"MTIz",
"NDU2",
]
delay = "10s"
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("request", params.plugin_step.to_string());
assert_eq!(
"Basic MTIz,Basic NDU2",
params
.authorizations
.iter()
.map(|item| std::string::String::from_utf8_lossy(item))
.collect::<Vec<_>>()
.join(","),
);
assert_eq!(Duration::from_secs(10), params.delay.unwrap());
assert_eq!("AC7E9E03", params.config_key());
let result = BasicAuth::try_from(
&toml::from_str::<PluginConf>(
r###"
authorizations = [
"1"
]
"###,
)
.unwrap(),
);
assert_eq!(
"Plugin basic_auth, base64 decode error Invalid input length: 1",
result.err().unwrap().to_string()
);
}
#[tokio::test]
async fn test_basic_auth() {
let auth = BasicAuth::new(
&toml::from_str::<PluginConf>(
r###"
authorizations = [
"YWRtaW46MTIzMTIz"
]
hide_credentials = true
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["Authorization: Basic YWRtaW46MTIzMTIz"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
assert_eq!(
false,
session.req_header().headers.contains_key("Authorization")
);
let headers = ["Authorization: Basic YWRtaW46MTIzMTIa"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(resp.status, http::StatusCode::UNAUTHORIZED);
}
}