use actix_web::{
Error, HttpRequest, HttpResponse,
http::{Method, StatusCode, header},
web,
};
use futures::TryStreamExt;
use std::time::Duration;
use tracing::{debug, error, instrument};
pub const TARGET_SERVICE_ENV: &str = "RUNEGATE_TARGET_SERVICE";
const DEFAULT_TARGET_SERVICE: &str = "http://127.0.0.1:7860";
#[instrument(skip(payload), fields(method = %req.method(), path = %req.uri().path(), query = %req.uri().query().unwrap_or(""), client_ip = %req.connection_info().realip_remote_addr().unwrap_or("unknown")))]
pub async fn proxy_request(
req: HttpRequest,
payload: web::Payload,
identity_email: Option<String>,
) -> Result<HttpResponse, Error> {
let target_url = get_target_service_url();
let session_cookie_name =
std::env::var("RUNEGATE_SESSION_COOKIE_NAME").unwrap_or_else(|_| "runegate_id".to_string());
let identity_headers_enabled = std::env::var("RUNEGATE_IDENTITY_HEADERS")
.map(|v| matches!(v.as_str(), "true" | "1" | "yes" | "on"))
.unwrap_or(true);
let original_path = req.uri().path();
let stripped = original_path
.strip_prefix("/proxy")
.unwrap_or(original_path);
let forwarded_path = if stripped.is_empty() { "/" } else { stripped };
let query = req
.uri()
.query()
.map_or_else(String::new, |q| format!("?{}", q));
let forwarded_url = format!("{}{}{}", target_url, forwarded_path, query);
debug!(target_url = %target_url, forwarded_url = %forwarded_url, "Proxying request");
let connector = awc::Connector::new()
.timeout(Duration::from_secs(10))
.conn_keep_alive(Duration::from_secs(15))
.disconnect_timeout(Duration::from_secs(2));
let client = awc::ClientBuilder::new()
.timeout(Duration::from_secs(600))
.connector(connector)
.finish();
let mut forwarded_req = client
.request(req.method().clone(), forwarded_url)
.no_decompress();
for (header_name, header_value) in req.headers().iter().filter(|(h, _)| {
*h != header::HOST &&
*h != header::CONNECTION &&
*h != header::CONTENT_LENGTH &&
*h != header::COOKIE &&
!h.as_str().eq_ignore_ascii_case("X-Forwarded-User") &&
!h.as_str().eq_ignore_ascii_case("X-Forwarded-Email") &&
!h.as_str().eq_ignore_ascii_case("X-Runegate-Authenticated") &&
!h.as_str().eq_ignore_ascii_case("X-Runegate-User")
}) {
forwarded_req = forwarded_req.insert_header((header_name.clone(), header_value.clone()));
}
if let Some(cookie_val) = req
.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
{
let filtered = cookie_val
.split(';')
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
let name = parts.next()?.trim();
let val = parts.next().unwrap_or("");
if !name.eq_ignore_ascii_case(&session_cookie_name) {
Some(format!("{}={}", name, val))
} else {
None
}
})
.collect::<Vec<_>>()
.join("; ");
if !filtered.is_empty() {
forwarded_req = forwarded_req.insert_header((header::COOKIE, filtered));
}
}
if let Some(host_val) = req.headers().get(header::HOST).cloned() {
forwarded_req = forwarded_req.insert_header((header::HOST, host_val));
} else {
let host = req.connection_info().host().to_string();
forwarded_req = forwarded_req.insert_header((header::HOST, host));
}
if let Some(xfp) = req.headers().get("X-Forwarded-Proto").cloned() {
forwarded_req = forwarded_req.insert_header(("X-Forwarded-Proto", xfp));
}
if let Some(client_ip) = req.connection_info().realip_remote_addr() {
forwarded_req =
forwarded_req.insert_header((header::FORWARDED, format!("for={}", client_ip)));
}
if identity_headers_enabled {
if let Some(email) = identity_email {
forwarded_req = forwarded_req.insert_header(("X-Runegate-Authenticated", "true"));
forwarded_req = forwarded_req.insert_header(("X-Runegate-User", email.clone()));
forwarded_req = forwarded_req.insert_header(("X-Forwarded-User", email.clone()));
forwarded_req = forwarded_req.insert_header(("X-Forwarded-Email", email));
} else {
forwarded_req = forwarded_req.insert_header(("X-Runegate-Authenticated", "false"));
}
}
let forwarded_req = match *req.method() {
Method::POST | Method::PUT | Method::PATCH => forwarded_req.send_stream(payload),
_ => forwarded_req.send(),
};
let mut forwarded_res = forwarded_req.await.map_err(|e| {
error!(error = %e, "Forwarding error to target service");
actix_web::error::ErrorBadGateway(e)
})?;
debug!(status = %forwarded_res.status(), "Received response from target service");
let mut client_res = HttpResponse::build(
StatusCode::from_u16(forwarded_res.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY),
);
for (header_name, header_value) in forwarded_res.headers().iter().filter(|(h, _)| {
*h != header::CONNECTION && *h != header::CONTENT_LENGTH && *h != header::TRANSFER_ENCODING
}) {
client_res.insert_header((header_name.clone(), header_value.clone()));
}
let stream_responses = !matches!(
std::env::var("RUNEGATE_STREAM_RESPONSES"),
Ok(v) if matches!(v.as_str(), "false" | "0" | "no" | "off")
);
if stream_responses {
let stream = forwarded_res.map_err(|e| {
error!(error = %e, "Upstream body stream error");
actix_web::error::ErrorBadGateway(e)
});
Ok(client_res.streaming(stream))
} else {
let body = forwarded_res.body().await.map_err(|e| {
error!(error = %e, "Failed to read response body from target service");
actix_web::error::ErrorBadGateway(e)
})?;
Ok(client_res.body(body))
}
}
#[instrument]
fn get_target_service_url() -> String {
std::env::var(TARGET_SERVICE_ENV).unwrap_or_else(|_| DEFAULT_TARGET_SERVICE.to_string())
}