use std::{borrow::Cow, collections::HashMap, sync::Arc};
use anyhow::{anyhow, bail, Result};
use async_trait::async_trait;
use cookie::Cookie;
use dashmap::DashMap;
use http::StatusCode;
use once_cell::sync::Lazy;
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_proxy::Session;
use provider::{OauthType, OauthUser, Provider};
use crate::{config::RoutePlugin, proxy_server::https_proxy::RouterContext};
use super::{jwt, MiddlewarePlugin};
mod github;
mod workos;
mod provider;
mod secure_cookie;
mod shared;
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(reqwest::Client::new);
static OAUTH2_STATE: Lazy<Arc<DashMap<String, String>>> = Lazy::new(|| Arc::new(DashMap::new()));
const COOKIE_NAME: &str = "__Secure_Auth_PRK_JWT";
pub struct Oauth2;
impl Oauth2 {
pub fn new() -> Self {
Self {}
}
fn is_authorized(user: &OauthUser, validations: Option<&serde_json::Value>) -> bool {
shared::validate_user_from_provider(user, validations)
}
async fn redirect_to_oauth_callback(
&self,
session: &mut Session,
oauth_provider: &Provider,
) -> Result<bool> {
let current_address = session.req_header().uri.to_string();
let state = uuid::Uuid::new_v4().to_string();
let mut res_headers =
ResponseHeader::build_no_case(StatusCode::TEMPORARY_REDIRECT, Some(1))?;
res_headers.append_header(
http::header::LOCATION,
oauth_provider.get_oauth_callback_url(&state),
)?;
OAUTH2_STATE.insert(state, current_address);
session.write_response_header(Box::new(res_headers)).await?;
Ok(true)
}
async fn unauthorized_response(&self, session: &mut Session) -> Result<bool> {
let res_headers = ResponseHeader::build_no_case(StatusCode::UNAUTHORIZED, Some(1))?;
session.write_response_header(Box::new(res_headers)).await?;
Ok(true)
}
async fn validate_cookie(
&self,
session: &mut Session,
jwt_secret: &str,
validations: Option<&serde_json::Value>,
) -> Result<bool> {
let cookie_header = session.req_header().headers.get("cookie");
if cookie_header.is_none() {
return Ok(false); }
let secure_jwt = Cookie::split_parse(cookie_header.unwrap().to_str()?)
.filter_map(Result::ok)
.find(|c| c.name() == COOKIE_NAME);
if secure_jwt.is_none() {
return Ok(false); }
let decoded = jwt::decode_jwt(secure_jwt.unwrap().value(), jwt_secret.as_bytes());
if decoded.is_err() || !Self::is_authorized(&decoded?.into(), validations) {
if let Ok(true) = self.unauthorized_response(session).await {
return Ok(false);
}
}
Ok(true)
}
fn get_required_config(
plugin_config: &HashMap<Cow<'static, str>, serde_json::Value>,
key: &str,
) -> Result<String> {
plugin_config
.get(key)
.and_then(|v| v.as_str())
.map(ToString::to_string)
.ok_or_else(|| anyhow!("Missing or invalid {}", key))
}
fn parse_provider(
plugin_config: &HashMap<Cow<'static, str>, serde_json::Value>,
) -> Result<OauthType> {
let provider = plugin_config
.get("provider")
.and_then(|v| v.as_str())
.ok_or(anyhow!("Missing or invalid provider"))?;
match provider {
"github" => Ok(OauthType::Github),
"workos" => Ok(OauthType::Workos),
_ => bail!("Provider not found in the plugin configuration"),
}
}
}
#[async_trait]
impl MiddlewarePlugin for Oauth2 {
async fn upstream_request_filter(
&self,
_: &mut Session,
_: &mut RequestHeader,
_: &mut RouterContext,
) -> Result<()> {
Ok(())
}
fn upstream_response_filter(
&self,
_: &mut Session,
_: &mut ResponseHeader,
_: &mut RouterContext,
) -> Result<()> {
Ok(())
}
async fn request_filter(
&self,
session: &mut Session,
ctx: &mut RouterContext,
plugin: &RoutePlugin,
) -> Result<bool> {
if plugin.config.is_none() {
return Ok(false);
}
let plugin_config = plugin.config.as_ref().unwrap();
let provider = Self::parse_provider(plugin_config)?;
let client_id = Self::get_required_config(plugin_config, "client_id")?;
let client_secret = Self::get_required_config(plugin_config, "client_secret")?;
let jwt_secret = Self::get_required_config(plugin_config, "jwt_secret")?;
let validations = plugin_config.get("validations");
let callback_path = format!("/__/oauth/{}/callback", &provider);
let oauth_provider = Provider {
client_id,
client_secret,
typ: provider,
};
let uri = &session.req_header().uri;
if uri.path() == callback_path {
let Some(query) = uri.query() else {
return self.unauthorized_response(session).await;
};
let query_params = shared::from_string_to_query_params(query);
let Some(code) = query_params.get("code") else {
tracing::info!("missing code in the query");
return self.unauthorized_response(session).await;
};
let Some(state) = query_params.get("state") else {
tracing::info!("missing state in the query");
return self.unauthorized_response(session).await;
};
let Some(saved_state) = OAUTH2_STATE.get(&state.to_string()) else {
tracing::info!("state does not exist or was removed");
return self.unauthorized_response(session).await;
};
let redirect_from_state = saved_state.value().to_owned();
drop(saved_state);
OAUTH2_STATE.remove(&state.to_string());
let user = match oauth_provider.get_oauth_user(code).await {
Err(err) => {
tracing::error!("Failed to exchange code {code}, state {state}: {err}");
return self.unauthorized_response(session).await;
}
Ok(user) => user,
};
if !Self::is_authorized(&user, validations) {
tracing::info!("user is not authorized {:?}", user);
return self.unauthorized_response(session).await;
}
let jwt_cookie = secure_cookie::create_secure_cookie(&user, &jwt_secret, &ctx.host)?;
let mut res_headers = ResponseHeader::build_no_case(StatusCode::FOUND, Some(1))?;
res_headers.insert_header(http::header::SET_COOKIE, jwt_cookie.to_string())?;
res_headers.insert_header(http::header::LOCATION, redirect_from_state)?;
session.write_response_header(Box::new(res_headers)).await?;
return Ok(true);
}
if self
.validate_cookie(session, &jwt_secret, validations)
.await?
{
return Ok(false);
}
self.redirect_to_oauth_callback(session, &oauth_provider)
.await
}
async fn response_filter(
&self,
_: &mut Session,
_: &mut RouterContext,
_: &RoutePlugin,
) -> Result<bool> {
Ok(false)
}
}