use super::{Error, get_hash_key, get_plugin_factory, get_str_conf};
use async_trait::async_trait;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use bytes::{Bytes, BytesMut};
use ctor::ctor;
use http::StatusCode;
use humantime::parse_duration;
use pingap_config::{PluginCategory, PluginConf};
use pingap_core::{
Ctx, ModifyResponseBody, Plugin, PluginStep, RequestPluginResult,
ResponseBodyPluginResult, ResponsePluginResult,
};
use pingap_core::{
HTTP_HEADER_CONTENT_JSON, HTTP_HEADER_TRANSFER_CHUNKED, HttpResponse,
};
use pingora::http::ResponseHeader;
use pingora::proxy::Session;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::sync::Arc;
use std::time::Duration;
use substring::Substring;
use tokio::time::sleep;
use tracing::debug;
const PLUGIN_ID: &str = "_jwt_";
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct JwtAuth {
plugin_step: PluginStep,
auth_path: String,
secret: String,
header: Option<String>,
query: Option<String>,
cookie: Option<String>,
algorithm: String,
delay: Option<Duration>,
unauthorized_resp: HttpResponse,
hash_value: String,
}
impl TryFrom<&PluginConf> for JwtAuth {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let hash_value = get_hash_key(value);
let header = get_str_conf(value, "header");
let query = get_str_conf(value, "query");
let cookie = get_str_conf(value, "cookie");
if header.is_empty() && query.is_empty() && cookie.is_empty() {
return Err(Error::Invalid {
category: PluginCategory::Jwt.to_string(),
message: "Jwt key or key type is not allowed empty".to_string(),
});
}
let header = if header.is_empty() {
None
} else {
Some(header)
};
let query = if query.is_empty() { None } else { Some(query) };
let cookie = if cookie.is_empty() {
None
} else {
Some(cookie)
};
let delay = get_str_conf(value, "delay");
let delay = if !delay.is_empty() {
let d = parse_duration(&delay).map_err(|e| Error::Invalid {
category: PluginCategory::KeyAuth.to_string(),
message: e.to_string(),
})?;
Some(d)
} else {
None
};
let params = Self {
hash_value,
plugin_step: PluginStep::Request,
secret: get_str_conf(value, "secret"),
auth_path: get_str_conf(value, "auth_path"),
algorithm: get_str_conf(value, "algorithm"),
delay,
header,
query,
cookie,
unauthorized_resp: HttpResponse {
status: StatusCode::UNAUTHORIZED,
body: Bytes::from_static(b"Invalid or expired jwt"),
..Default::default()
},
};
if params.secret.is_empty() {
return Err(Error::Invalid {
category: PluginCategory::Jwt.to_string(),
message: "Jwt secret is not allowed empty".to_string(),
});
}
Ok(params)
}
}
impl JwtAuth {
pub fn new(params: &PluginConf) -> Result<Self> {
debug!(params = params.to_string(), "new jwt auth plugin");
Self::try_from(params)
}
}
#[derive(Debug, Default, Deserialize, Clone, Serialize)]
struct JwtHeader {
alg: String,
typ: String,
}
#[async_trait]
impl Plugin for JwtAuth {
#[inline]
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
#[inline]
async fn handle_request(
&self,
step: PluginStep,
session: &mut Session,
_ctx: &mut Ctx,
) -> pingora::Result<RequestPluginResult> {
if step != self.plugin_step {
return Ok(RequestPluginResult::Skipped);
}
let req_header = session.req_header();
if req_header.uri.path() == self.auth_path {
return Ok(RequestPluginResult::Skipped);
}
let value = if let Some(key) = &self.header {
let value = pingap_core::get_req_header_value(req_header, key)
.unwrap_or_default();
let bearer = "Bearer ";
if value.starts_with(bearer) {
value.substring(bearer.len(), value.len())
} else {
value
}
} else if let Some(key) = &self.cookie {
pingap_core::get_cookie_value(req_header, key).unwrap_or_default()
} else if let Some(key) = &self.query {
pingap_core::get_query_value(req_header, key).unwrap_or_default()
} else {
""
};
if value.is_empty() {
let mut resp = self.unauthorized_resp.clone();
resp.body = Bytes::from_static(b"Jwt authorization is missing");
return Ok(RequestPluginResult::Respond(resp));
}
let arr: Vec<&str> = value.split('.').collect();
if arr.len() != 3 {
let mut resp = self.unauthorized_resp.clone();
resp.body =
Bytes::from_static(b"Jwt authorization format is invalid");
return Ok(RequestPluginResult::Respond(resp));
}
let jwt_header = serde_json::from_slice::<JwtHeader>(
&URL_SAFE_NO_PAD.decode(arr[0]).unwrap_or_default(),
)
.unwrap_or_default();
let content = format!("{}.{}", arr[0], arr[1]);
let secret = self.secret.as_bytes();
let valid = match jwt_header.alg.as_str() {
"HS512" => {
let hash = hmac_sha512::HMAC::mac(content.as_bytes(), secret);
URL_SAFE_NO_PAD.encode(hash) == arr[2]
},
_ => {
let hash = hmac_sha256::HMAC::mac(content.as_bytes(), secret);
URL_SAFE_NO_PAD.encode(hash) == arr[2]
},
};
if !valid {
if let Some(d) = self.delay {
sleep(d).await;
}
let mut resp = self.unauthorized_resp.clone();
resp.body = Bytes::from_static(b"Jwt authorization is invalid");
return Ok(RequestPluginResult::Respond(resp));
}
let value: serde_json::Value = serde_json::from_slice(
&URL_SAFE_NO_PAD.decode(arr[1]).unwrap_or_default(),
)
.unwrap_or_default();
if let Some(exp) = value.get("exp")
&& exp.as_u64().unwrap_or_default() < pingap_core::now_sec()
{
let mut resp = self.unauthorized_resp.clone();
resp.body = Bytes::from_static(b"Jwt authorization is expired");
return Ok(RequestPluginResult::Respond(resp));
}
Ok(RequestPluginResult::Continue)
}
#[inline]
async fn handle_response(
&self,
session: &mut Session,
ctx: &mut Ctx,
upstream_response: &mut ResponseHeader,
) -> pingora::Result<ResponsePluginResult> {
if session.req_header().uri.path() != self.auth_path {
return Ok(ResponsePluginResult::Unchanged);
}
upstream_response.remove_header(&http::header::CONTENT_LENGTH);
let json = HTTP_HEADER_CONTENT_JSON.clone();
let _ = upstream_response.insert_header(json.0, json.1);
let _ = upstream_response.insert_header(
http::header::TRANSFER_ENCODING,
HTTP_HEADER_TRANSFER_CHUNKED.1.clone(),
);
ctx.add_modify_body_handler(
PLUGIN_ID,
Box::new(Sign {
algorithm: self.algorithm.clone(),
secret: self.secret.clone(),
buffer: BytesMut::new(),
}),
);
Ok(ResponsePluginResult::Modified)
}
fn handle_response_body(
&self,
session: &mut Session,
ctx: &mut Ctx,
body: &mut Option<bytes::Bytes>,
end_of_stream: bool,
) -> pingora::Result<ResponseBodyPluginResult> {
if let Some(modifier) = ctx.get_modify_body_handler(PLUGIN_ID) {
modifier.handle(session, body, end_of_stream)?;
let result = if end_of_stream {
ResponseBodyPluginResult::FullyReplaced
} else {
ResponseBodyPluginResult::PartialReplaced
};
Ok(result)
} else {
Ok(ResponseBodyPluginResult::Unchanged)
}
}
}
struct Sign {
secret: String,
algorithm: String,
buffer: BytesMut,
}
impl ModifyResponseBody for Sign {
fn handle(
&mut self,
_session: &Session,
body: &mut Option<bytes::Bytes>,
end_of_stream: bool,
) -> pingora::Result<()> {
if let Some(data) = body {
self.buffer.extend(&data[..]);
data.clear();
}
if !end_of_stream {
return Ok(());
}
let is_hs512 = self.algorithm == "HS512";
let alg = if is_hs512 { "HS512" } else { "HS256" };
let header = URL_SAFE_NO_PAD
.encode(r#"{"alg": ""#.to_owned() + alg + r#"","typ": "JWT"}"#);
let payload = URL_SAFE_NO_PAD.encode(&self.buffer);
let content = format!("{header}.{payload}");
let secret = self.secret.as_bytes();
let sign = if is_hs512 {
let hash = hmac_sha512::HMAC::mac(content.as_bytes(), secret);
URL_SAFE_NO_PAD.encode(hash)
} else {
let hash = hmac_sha256::HMAC::mac(content.as_bytes(), secret);
URL_SAFE_NO_PAD.encode(hash)
};
let token = format!("{content}.{sign}");
*body = Some(Bytes::from(r#"{"token": "{}"}"#.replace("{}", &token)));
Ok(())
}
fn name(&self) -> String {
"jwt_sign".to_string()
}
}
#[ctor]
fn init() {
get_plugin_factory()
.register("jwt", |params| Ok(Arc::new(JwtAuth::new(params)?)));
}
#[cfg(test)]
mod tests {
use super::*;
use pingap_config::PluginConf;
use pingap_core::{Ctx, PluginStep};
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use tokio_test::io::Builder;
#[test]
fn test_jwt_auth_params() {
let params = JwtAuth::try_from(
&toml::from_str::<PluginConf>(
r###"
secret = "123123"
cookie = "jwt"
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("jwt", params.cookie.unwrap_or_default());
assert_eq!("123123", params.secret);
let result = JwtAuth::try_from(
&toml::from_str::<PluginConf>(
r###"
cookie = "jwt"
"###,
)
.unwrap(),
);
assert_eq!(
"Plugin jwt invalid, message: Jwt secret is not allowed empty",
result.err().unwrap().to_string()
);
let result = JwtAuth::try_from(
&toml::from_str::<PluginConf>(
r###"
secret = "123123"
"###,
)
.unwrap(),
);
assert_eq!(
"Plugin jwt invalid, message: Jwt key or key type is not allowed empty",
result.err().unwrap().to_string()
);
}
#[test]
fn test_new_jwt() {
let auth = JwtAuth::new(
&toml::from_str::<PluginConf>(
r###"
secret = "123123"
cookie = "jwt"
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("jwt", auth.cookie.unwrap());
let auth = JwtAuth::new(
&toml::from_str::<PluginConf>(
r###"
secret = "123123"
cookie = "jwt"
auth_path = "/login"
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("jwt", auth.cookie.unwrap());
assert_eq!("/login", auth.auth_path);
}
#[tokio::test]
async fn test_jwt_auth() {
let auth = JwtAuth::new(
&toml::from_str::<PluginConf>(
r###"
secret = "123123"
header = "Authorization"
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiIsImFkbWluIjp0cnVlLCJleHAiOjIzNDgwNTUyNjV9.j6sYJ2dCCSxskwPmvHM7WniGCbkT30z2BrjfsuQLFJc"].join("\r\n");
let input_header = format!("GET / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
let headers = ["Authorization: Bearer eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiIsImFkbWluIjp0cnVlLCJleHAiOjIzNDgwNTUyNjV9.HxFVxDd5ZiLsD1dWW1AywWMERhqk0Ck9IsdBHyD_1zap3w-waVOmFq0Yt1fWaYmh8HDtXLN6vlTd0HHYIYEGUw"].join("\r\n");
let input_header = format!("GET / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
let headers = [""].join("\r\n");
let input_header = format!("GET / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(401, resp.status.as_u16());
assert_eq!(
"Jwt authorization is missing",
std::string::String::from_utf8_lossy(resp.body.as_ref())
);
let headers = ["Authorization: Bearer a.b"].join("\r\n");
let input_header = format!("GET / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(401, resp.status.as_u16());
assert_eq!(
"Jwt authorization format is invalid",
std::string::String::from_utf8_lossy(resp.body.as_ref())
);
let headers = ["Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiIsImFkbWluIjp0cnVlLCJleHAiOjE3MTcwODQ4MDB9.zz7VHuqt9t6UGLNr5RZdfzvqMDEei"].join("\r\n");
let input_header = format!("GET / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(401, resp.status.as_u16());
assert_eq!(
"Jwt authorization is invalid",
std::string::String::from_utf8_lossy(resp.body.as_ref())
);
let headers = ["Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiIsImFkbWluIjp0cnVlLCJleHAiOjE3MTY5MDMyNjV9.PRS-PZafcGsV_rCL8QQfJdOJAvL5fOI_Z14N16JEcng"].join("\r\n");
let input_header = format!("GET / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = auth
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(401, resp.status.as_u16());
assert_eq!(
"Jwt authorization is expired",
std::string::String::from_utf8_lossy(resp.body.as_ref())
);
}
#[tokio::test]
async fn test_jwt_sign() {
}
}