use futures_util::Future;
use http::Response;
use hyper::client::{HttpConnector, ResponseFuture};
use hyper::service::Service;
use hyper::{Body, Client, Request, Server};
use serde::Serialize;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::challenge::config::{ProtocolVersion, ProxyConfig};
use crate::cli_stderr_printline;
use crate::cli_stdout_printline;
use crate::http_utils::is_hop_by_hop_header;
use crate::otoroshi::protocol::OtoroshiProtocol;
#[derive(Serialize)]
struct ErrorResponse<'a> {
error: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<&'a str>,
}
fn json_error_response(
status: http::StatusCode,
error: &str,
details: Option<&str>,
) -> Response<Body> {
let body = ErrorResponse { error, details };
let json = serde_json::to_string(&body).unwrap_or_else(|e| {
error!("Failed to serialize error response: {}", e);
r#"{"error":"Internal error"}"#.to_string()
});
Response::builder()
.status(status)
.header("Content-Type", "application/json")
.body(Body::from(json))
.expect("Failed to build error response")
}
#[derive(Clone, Debug)]
struct ProxyClient {
config: Arc<ProxyConfig>,
http_client: Client<HttpConnector>,
}
impl ProxyClient {
fn new(config: Arc<ProxyConfig>) -> Self {
let http_client = Client::new();
ProxyClient {
config,
http_client,
}
}
fn forward_request(
&self,
mut req: Request<Body>,
) -> Result<ResponseFuture, http::uri::InvalidUri> {
let uri_string = format!(
"{}{}",
self.config.backend_url,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/")
);
let uri = uri_string.parse()?;
*req.uri_mut() = uri;
Ok(self.http_client.request(req))
}
}
struct ProxySvc {
client: ProxyClient,
config: Arc<ProxyConfig>,
}
impl Service<Request<Body>> for ProxySvc {
type Response = Response<Body>;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let config = self.config.clone();
let client = self.client.clone();
Box::pin(async move {
let state_value = req
.headers()
.get(&config.state_header)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let response_value = match config.version {
ProtocolVersion::V1 => {
match state_value {
Some(value) => Some(value),
None => {
return Ok(json_error_response(
http::StatusCode::UNAUTHORIZED,
"Missing Otoroshi-State header",
None,
));
}
}
}
ProtocolVersion::V2 => {
let verify_key = if config.algorithm.is_asymmetric() {
match &config.public_key {
Some(pk) => pk.as_slice(),
None => {
return Ok(json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Public key is required for asymmetric V2 protocol verification",
None,
));
}
}
} else {
match &config.secret {
Some(s) => s.as_slice(),
None => {
return Ok(json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Secret is required for V2 protocol",
None,
));
}
}
};
let sign_key = if config.response_algorithm.is_asymmetric() {
match config.response_secret.as_ref().or(config.secret.as_ref()) {
Some(sk) => sk.as_slice(),
None => {
return Ok(json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Private key (--response-secret or --secret) is required for asymmetric response signing",
None,
));
}
}
} else {
match config.response_secret.as_ref().or(config.secret.as_ref()) {
Some(s) => s.as_slice(),
None => {
return Ok(json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Secret (--response-secret or --secret) is required for response signing",
None,
));
}
}
};
match state_value {
Some(token) => {
let protocol = OtoroshiProtocol::new_asymmetric_with_ttl(
verify_key,
config.algorithm,
sign_key,
config.response_algorithm,
config.token_ttl,
);
match protocol.process_v2(&token) {
Ok(resp_token) => Some(resp_token),
Err(e) => {
return Ok(json_error_response(
http::StatusCode::UNAUTHORIZED,
"Invalid Otoroshi challenge",
Some(&e.to_string()),
));
}
}
}
None => {
return Ok(json_error_response(
http::StatusCode::UNAUTHORIZED,
"Missing Otoroshi-State header",
None,
));
}
}
}
};
let consumer_info_decoded: Option<String> =
if let Some(ci_config) = &config.consumer_info {
let token_opt = req
.headers()
.get(&ci_config.in_header)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
match token_opt {
None => {
if ci_config.strict {
return Ok(json_error_response(
http::StatusCode::UNAUTHORIZED,
"Missing Consumer Info header",
None,
));
}
None
}
Some(token) => match ci_config.verifier.verify_and_decode(&token) {
Ok(claims) => match serde_json::to_string(&claims) {
Ok(json) => Some(json),
Err(e) => {
warn!("Failed to serialize Consumer Info claims: {}", e);
None
}
},
Err(e) => {
if ci_config.strict {
return Ok(json_error_response(
http::StatusCode::UNAUTHORIZED,
"Invalid Consumer Info token",
None,
));
}
warn!("Consumer Info token verification failed: {}", e);
None
}
},
}
} else {
None
};
let (parts, body) = req.into_parts();
let mut backend_req_builder = Request::builder().method(parts.method).uri(parts.uri);
for (name, value) in parts.headers.iter() {
if is_hop_by_hop_header(name) {
continue;
}
if config.strip_otoroshi_headers && name == config.state_header {
continue;
}
if let Some(ci_config) = &config.consumer_info {
let skip = if config.strip_otoroshi_headers {
name == ci_config.in_header
} else {
ci_config.in_header == ci_config.out_header
&& consumer_info_decoded.is_some()
&& name == ci_config.in_header
};
if skip {
continue;
}
}
backend_req_builder = backend_req_builder.header(name, value);
}
if let Some(ci_config) = &config.consumer_info {
if let Some(ref json) = consumer_info_decoded {
match json.parse::<http::header::HeaderValue>() {
Ok(header_value) => {
backend_req_builder = backend_req_builder
.header(ci_config.out_header.clone(), header_value);
}
Err(e) => {
warn!("Failed to set Consumer Info output header: {}", e);
}
}
}
}
let backend_req = match backend_req_builder.body(body) {
Ok(req) => req,
Err(e) => {
return Ok(json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Failed to build request",
Some(&e.to_string()),
));
}
};
let backend_future = match client.forward_request(backend_req) {
Ok(future) => future,
Err(e) => {
return Ok(json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Invalid backend URI",
Some(&e.to_string()),
));
}
};
let result = tokio::time::timeout(config.request_timeout, backend_future).await;
match result {
Ok(Ok(backend_resp)) => {
let (mut resp_parts, resp_body) = backend_resp.into_parts();
if let Some(value) = response_value
&& let Ok(header_value) = value.parse()
{
resp_parts
.headers
.insert(config.state_resp_header.clone(), header_value);
}
Ok(Response::from_parts(resp_parts, resp_body))
}
Ok(Err(e)) => Ok(json_error_response(
http::StatusCode::BAD_GATEWAY,
"Backend unavailable",
Some(&e.to_string()),
)),
Err(_) => Ok(json_error_response(
http::StatusCode::GATEWAY_TIMEOUT,
"Backend request timed out",
None,
)),
}
})
}
}
struct MakeSvc {
client: ProxyClient,
config: Arc<ProxyConfig>,
}
impl<T> Service<T> for MakeSvc {
type Response = ProxySvc;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: T) -> Self::Future {
let client = self.client.clone();
let config = self.config.clone();
Box::pin(async move { Ok(ProxySvc { client, config }) })
}
}
#[allow(clippy::too_many_arguments)]
pub async fn run(
port: u16,
backend_host: String,
backend_port: u16,
secret: Option<String>,
secret_base64: bool,
state_header: String,
state_resp_header: String,
timeout: u64,
token_ttl: i64,
alg: String,
public_key: Option<String>,
response_secret: Option<String>,
response_secret_base64: bool,
response_alg: Option<String>,
use_v1: bool,
consumer_info_enabled: bool,
consumer_info_header: String,
consumer_info_out_header: Option<String>,
consumer_info_alg: String,
consumer_info_secret: Option<String>,
consumer_info_secret_base64: bool,
consumer_info_public_key: Option<String>,
consumer_info_strict: bool,
strip_otoroshi_headers: bool,
) {
if !use_v1 && secret.is_none() && public_key.is_none() {
cli_stderr_printline!(
"Error: --secret or --public-key is required for V2 protocol (or use --v1 for simple echo mode)"
);
std::process::exit(1);
}
let config = match ProxyConfig::new(
port,
backend_host,
backend_port,
secret,
secret_base64,
state_header,
state_resp_header,
timeout,
token_ttl,
alg,
public_key,
response_secret,
response_secret_base64,
response_alg,
use_v1,
consumer_info_enabled,
consumer_info_header,
consumer_info_out_header,
consumer_info_alg,
consumer_info_secret,
consumer_info_secret_base64,
consumer_info_public_key,
consumer_info_strict,
strip_otoroshi_headers,
) {
Ok(config) => Arc::new(config),
Err(e) => {
cli_stderr_printline!("Configuration error: {}", e);
std::process::exit(1);
}
};
let addr: SocketAddr = config.listen_addr;
let client = ProxyClient::new(config.clone());
let make_svc = MakeSvc {
client,
config: config.clone(),
};
let version_str = if use_v1 { "V1 (echo)" } else { "V2 (JWT)" };
cli_stdout_printline!(
"Otoroshi {} Challenge Proxy listening on http://{}",
version_str,
addr
);
cli_stdout_printline!("Forwarding requests to {}", config.backend_url);
let server = Server::bind(&addr).serve(make_svc);
let graceful = server.with_graceful_shutdown(async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler");
cli_stdout_printline!("Shutdown signal received, stopping server...");
});
if let Err(e) = graceful.await {
cli_stderr_printline!("Server error: {}", e);
std::process::exit(1);
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::body::to_bytes;
async fn body_to_string(response: Response<Body>) -> String {
let bytes = to_bytes(response.into_body()).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[tokio::test]
async fn test_json_error_response_without_details() {
let response = json_error_response(http::StatusCode::BAD_REQUEST, "Test error", None);
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
let body = body_to_string(response).await;
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(json["error"], "Test error");
assert!(json.get("details").is_none());
}
#[tokio::test]
async fn test_json_error_response_with_details() {
let response = json_error_response(
http::StatusCode::INTERNAL_SERVER_ERROR,
"Error message",
Some("Detailed info"),
);
assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
let body = body_to_string(response).await;
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(json["error"], "Error message");
assert_eq!(json["details"], "Detailed info");
}
#[test]
fn test_json_error_response_content_type() {
let response = json_error_response(http::StatusCode::OK, "test", None);
let content_type = response.headers().get("content-type").unwrap();
assert_eq!(content_type, "application/json");
}
#[tokio::test]
async fn test_json_error_response_is_valid_json() {
let response = json_error_response(
http::StatusCode::UNAUTHORIZED,
"Unauthorized",
Some("Token expired"),
);
let body = body_to_string(response).await;
let result: Result<serde_json::Value, _> = serde_json::from_str(&body);
assert!(result.is_ok(), "Response body should be valid JSON");
}
#[tokio::test]
async fn test_json_error_response_special_characters() {
let response = json_error_response(
http::StatusCode::BAD_REQUEST,
"Error with \"quotes\" and \\backslash",
Some("Details with <html> & special chars"),
);
let body = body_to_string(response).await;
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(json["error"], "Error with \"quotes\" and \\backslash");
assert_eq!(json["details"], "Details with <html> & special chars");
}
#[tokio::test]
async fn test_json_error_response_gateway_timeout() {
let response = json_error_response(
http::StatusCode::GATEWAY_TIMEOUT,
"Backend request timed out",
None,
);
assert_eq!(response.status(), http::StatusCode::GATEWAY_TIMEOUT);
let body = body_to_string(response).await;
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(json["error"], "Backend request timed out");
assert!(json.get("details").is_none());
}
#[tokio::test]
async fn test_json_error_response_bad_gateway() {
let response = json_error_response(
http::StatusCode::BAD_GATEWAY,
"Backend unavailable",
Some("Connection refused"),
);
assert_eq!(response.status(), http::StatusCode::BAD_GATEWAY);
let body = body_to_string(response).await;
let json: serde_json::Value = serde_json::from_str(&body).unwrap();
assert_eq!(json["error"], "Backend unavailable");
assert_eq!(json["details"], "Connection refused");
}
}