use super::{Error, get_hash_key, get_plugin_factory, get_str_conf};
use async_trait::async_trait;
use bytes::Bytes;
use cookie::Cookie;
use ctor::ctor;
use http::{HeaderValue, Method, StatusCode, header};
use humantime::parse_duration;
use nanoid::nanoid;
use pingap_config::{PluginCategory, PluginConf};
use pingap_core::{
Ctx, HTTP_HEADER_NO_STORE, HttpResponse, Plugin, PluginStep,
RequestPluginResult,
};
use pingap_core::{get_cookie_value, new_internal_error, now_sec};
use pingap_util::base64_encode;
use pingora::proxy::Session;
use sha2::{Digest, Sha256};
use std::borrow::Cow;
use std::sync::Arc;
use tracing::debug;
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct Csrf {
plugin_step: PluginStep,
token_path: String,
key: String,
name: String,
ttl: u64,
unauthorized_resp: HttpResponse,
hash_value: String,
}
impl TryFrom<&PluginConf> for Csrf {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let hash_value = get_hash_key(value);
let mut csrf = Self {
hash_value,
plugin_step: PluginStep::Request,
name: get_str_conf(value, "name"),
token_path: get_str_conf(value, "token_path"),
key: get_str_conf(value, "key"),
ttl: 0,
unauthorized_resp: HttpResponse {
status: StatusCode::UNAUTHORIZED,
body: Bytes::from("Csrf token is empty or invalid"),
..Default::default()
},
};
if csrf.name.is_empty() {
csrf.name = "x-csrf-token".to_string();
}
let ttl = get_str_conf(value, "ttl");
if !ttl.is_empty() {
let ttl = parse_duration(&ttl).map_err(|e| Error::Invalid {
category: PluginCategory::Csrf.to_string(),
message: e.to_string(),
})?;
csrf.ttl = ttl.as_secs();
}
if csrf.token_path.is_empty() {
return Err(Error::Invalid {
category: PluginCategory::Csrf.to_string(),
message: "Token path is not allowed empty".to_string(),
});
}
if csrf.key.is_empty() {
return Err(Error::Invalid {
category: PluginCategory::Csrf.to_string(),
message: "Key is not allowed empty".to_string(),
});
}
Ok(csrf)
}
}
impl Csrf {
pub fn new(params: &PluginConf) -> Result<Self> {
debug!(params = params.to_string(), "new csrf plugin");
Csrf::try_from(params)
}
}
#[inline]
fn generate_token(key: &str) -> String {
let id = nanoid!(12);
let prefix = format!("{id}.{:x}", now_sec());
let mut hasher = Sha256::new();
hasher.update(prefix.as_bytes());
hasher.update(key.as_bytes());
let hash256 = hasher.finalize();
format!("{prefix}.{}", base64_encode(hash256))
}
#[inline]
fn validate_token(key: &str, ttl: u64, value: &str) -> bool {
let arr: Vec<&str> = value.split('.').collect();
if arr.len() != 3 {
return false;
}
if ttl > 0 {
let now = now_sec();
if now - u64::from_str_radix(arr[1], 16).unwrap_or_default() > ttl {
return false;
}
}
let mut hasher = Sha256::new();
hasher.update(format!("{}.{}", arr[0], arr[1]).as_bytes());
hasher.update(key.as_bytes());
let hash256 = hasher.finalize();
arr[2] == base64_encode(hash256)
}
#[async_trait]
impl Plugin for Csrf {
#[inline]
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
async fn handle_request(
&self,
step: PluginStep,
session: &mut Session,
_ctx: &mut Ctx,
) -> pingora::Result<RequestPluginResult> {
if step != self.plugin_step {
return Ok(RequestPluginResult::Skipped);
}
if session.req_header().uri.path() == self.token_path {
let token = generate_token(&self.key);
let mut builder = Cookie::build((&self.name, &token)).path("/");
if self.ttl > 0 {
builder = builder
.max_age(cookie::time::Duration::seconds(self.ttl as i64));
};
let set_cookie = (
header::SET_COOKIE,
HeaderValue::from_str(&builder.build().to_string())
.map_err(|e| new_internal_error(400, e))?,
);
let resp = HttpResponse {
status: StatusCode::NO_CONTENT,
headers: Some(vec![HTTP_HEADER_NO_STORE.clone(), set_cookie]),
..Default::default()
};
return Ok(RequestPluginResult::Respond(resp));
}
if [Method::GET, Method::HEAD, Method::OPTIONS]
.contains(&session.req_header().method)
{
return Ok(RequestPluginResult::Skipped);
}
let value = session.get_header_bytes(&self.name);
if value.is_empty() {
return Ok(RequestPluginResult::Respond(
self.unauthorized_resp.clone(),
));
}
let value = std::string::String::from_utf8_lossy(value);
if value
!= get_cookie_value(session.req_header(), &self.name)
.unwrap_or_default()
|| !validate_token(&self.key, self.ttl, &value)
{
return Ok(RequestPluginResult::Respond(
self.unauthorized_resp.clone(),
));
}
Ok(RequestPluginResult::Continue)
}
}
#[ctor]
fn init() {
get_plugin_factory()
.register("csrf", |params| Ok(Arc::new(Csrf::new(params)?)));
}
#[cfg(test)]
mod tests {
use super::*;
use cookie::Cookie;
use pingap_config::PluginConf;
use pingap_core::{Ctx, PluginStep};
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use std::str::FromStr;
use tokio_test::io::Builder;
#[test]
fn test_csrf_params() {
let params = Csrf::try_from(
&toml::from_str::<PluginConf>(
r###"
token_path = "/csrf-token"
key = "WjrXUG47wu"
ttl = "1h"
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("/csrf-token", params.token_path);
assert_eq!("WjrXUG47wu", params.key);
assert_eq!(3600, params.ttl);
let result = Csrf::try_from(
&toml::from_str::<PluginConf>(
r###"
token_path = "/csrf-token"
key = "WjrXUG47wu"
ttl = "1a"
"###,
)
.unwrap(),
);
assert_eq!(
r#"Plugin csrf invalid, message: unknown time unit "a", supported units: ns, us/µs, ms, sec, min, hours, days, weeks, months, years (and few variations)"#,
result.err().unwrap().to_string()
);
let result = Csrf::try_from(
&toml::from_str::<PluginConf>(
r###"
key = "WjrXUG47wu"
"###,
)
.unwrap(),
);
assert_eq!(
"Plugin csrf invalid, message: Token path is not allowed empty",
result.err().unwrap().to_string()
);
let result = Csrf::try_from(
&toml::from_str::<PluginConf>(
r###"
token_path = "/csrf-token"
"###,
)
.unwrap(),
);
assert_eq!(
"Plugin csrf invalid, message: Key is not allowed empty",
result.err().unwrap().to_string()
);
}
#[test]
fn test_generate_token() {
let key = "123";
let value = generate_token(key);
assert_eq!(true, validate_token(key, 10, &value));
assert_eq!(false, validate_token(key, 10, &format!("{value}:1")));
}
#[tokio::test]
async fn test_csrf() {
let csrf = Csrf::new(
&toml::from_str::<PluginConf>(
r###"
token_path = "/csrf-token"
key = "WjrXUG47wu"
ttl = "1h"
"###,
)
.unwrap(),
)
.unwrap();
let headers = [""].join("\r\n");
let input_header =
format!("GET /csrf-token HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = csrf
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
let binding = resp.headers.unwrap();
let cookie = binding[1].1.to_str().unwrap();
let c = Cookie::from_str(cookie).unwrap();
assert_eq!("x-csrf-token", c.name());
assert_eq!(66, c.value().len());
let headers = [format!("x-csrf-token:{}", "123")].join("\r\n");
let input_header = format!("POST / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = csrf
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(401, resp.status.as_u16());
let headers = [
format!("x-csrf-token: {}", c.value()),
format!("Cookie: x-csrf-token={}", c.value()),
]
.join("\r\n");
let input_header = format!("POST / HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = csrf
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
}
}