#[cfg(test)]
mod tests;
use crate::application::Application;
use crate::core::New;
use crate::header::Header;
use crate::http_client::Client;
use crate::middleware::Middleware;
use crate::mime_type::MimeType;
use crate::range::Range;
use crate::request::Request;
use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
use crate::server::ConnectionInfo;
const EXCLUDED_PASSTHROUGH_HEADERS: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"transfer-encoding",
"upgrade",
"content-type",
"content-length",
];
pub struct ForwardAuthLayer {
auth_url: String,
copy_headers: Vec<String>,
timeout_ms: u64,
}
impl ForwardAuthLayer {
pub fn new(auth_url: impl Into<String>) -> Self {
ForwardAuthLayer { auth_url: auth_url.into(), copy_headers: Vec::new(), timeout_ms: 5000 }
}
pub fn copy_header(mut self, name: impl Into<String>) -> Self {
self.copy_headers.push(name.into());
self
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
}
impl Middleware for ForwardAuthLayer {
fn handle(&self, request: &Request, connection: &ConnectionInfo, next: &dyn Application) -> Result<Response, String> {
let client = Client::new().timeout_ms(self.timeout_ms).max_redirects(0);
let mut builder = client.get(&self.auth_url);
for h in &request.headers {
builder = builder.header(&h.name, &h.value);
}
let auth_response = match builder.send() {
Ok(r) => r,
Err(_) => return Ok(auth_service_unreachable()),
};
if !auth_response.is_success() {
return Ok(passthrough_response(&auth_response));
}
let mut forwarded = request.clone();
for name in &self.copy_headers {
if let Some(value) = auth_response.header(name) {
forwarded.headers.retain(|h| !h.name.eq_ignore_ascii_case(name));
forwarded.headers.push(Header { name: name.clone(), value: value.to_string() });
}
}
next.execute(&forwarded, connection)
}
}
fn auth_service_unreachable() -> Response {
let mut r = Response::new();
r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
r.content_range_list = vec![Range::get_content_range(
b"502 Bad Gateway: auth service unreachable".to_vec(),
MimeType::TEXT_PLAIN.to_string(),
)];
r
}
fn passthrough_response(auth_response: &crate::http_client::Response) -> Response {
let status = auth_response.status();
let mut r = Response::new();
r.status_code = status as i16;
r.reason_phrase = reason_phrase_for_status(status);
for (name, value) in auth_response.headers() {
if EXCLUDED_PASSTHROUGH_HEADERS.contains(&name.to_lowercase().as_str()) {
continue;
}
r.headers.push(Header { name: name.clone(), value: value.clone() });
}
let body = auth_response.bytes();
if !body.is_empty() {
let content_type = auth_response
.header(Header::_CONTENT_TYPE)
.unwrap_or(MimeType::TEXT_PLAIN)
.to_string();
r.content_range_list = vec![Range::get_content_range(body.to_vec(), content_type)];
}
r
}
fn reason_phrase_for_status(status: u16) -> String {
Response::status_code_reason_phrase_list()
.into_iter()
.find(|s| *s.status_code == status as i16)
.map(|s| s.reason_phrase.to_string())
.unwrap_or_default()
}