use client_util::{get_client, Client, ClientConfig};
use futures_util::future::BoxFuture;
use http::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use tower_http::auth::async_require_authorization::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
use super::ApplyLayer;
use crate::automatic_body::{add_automatic_body, AutomaticBody};
use crate::client_ip::GetClientIp;
use crate::target::ReqBody;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct AuthRequestConfig {
target: String,
#[serde(default)]
client: ClientConfig,
}
#[derive(Clone, Debug)]
struct AuthRequest {
uri: hyper::Uri,
client: Client<ReqBody>,
}
impl<B: Send + 'static> AsyncAuthorizeRequest<B> for AuthRequest {
type RequestBody = B;
type ResponseBody = AutomaticBody;
type Future =
BoxFuture<'static, Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, req: Request<B>) -> Self::Future {
let uri = self.uri.clone();
let client = self.client.clone();
Box::pin(async move {
let client_ip = req.client_ip();
let mut new_req = Request::new(ReqBody::default());
let target = format!("{}://{}", uri.scheme_str().unwrap(), uri.authority().unwrap());
*new_req.uri_mut() = uri;
*new_req.headers_mut() = req.headers().clone();
new_req.headers_mut().remove(hyper::header::CONTENT_LENGTH);
match crate::reverse_proxy::call(client_ip, &target, new_req, &client).await {
Ok(auth_resp) if auth_resp.status().is_success() => Ok(req),
Ok(auth_resp)
if auth_resp.status() == StatusCode::FORBIDDEN
|| auth_resp.status() == StatusCode::UNAUTHORIZED =>
{
let mut res = Response::new(Self::ResponseBody::default());
*res.status_mut() = auth_resp.status();
Err(res)
}
Ok(auth_resp) => {
log::warn!("auth request returned status {}", auth_resp.status());
let mut res = Response::new(Self::ResponseBody::default());
*res.status_mut() = StatusCode::BAD_GATEWAY;
Err(res)
}
Err(e) => {
log::warn!("auth request failed: {}", e);
let mut res = Response::new(Self::ResponseBody::default());
*res.status_mut() = StatusCode::BAD_GATEWAY;
Err(res)
}
}
})
}
}
pub fn from_config(config: AuthRequestConfig) -> anyhow::Result<impl ApplyLayer> {
let uri = config.target.parse()?;
Ok(add_automatic_body(AsyncRequireAuthorizationLayer::new(
AuthRequest {
uri,
client: get_client(config.client)?,
},
)))
}