use axum::{
extract::Request, http::StatusCode, middleware::Next, response::Response, routing::get, Router,
};
use mockforge_core::proxy::{body_transform::BodyTransformationMiddleware, config::ProxyConfig};
use serde::Serialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
#[derive(Debug)]
pub struct ProxyServer {
config: Arc<RwLock<ProxyConfig>>,
log_requests: bool,
log_responses: bool,
request_counter: Arc<RwLock<u64>>,
start_time: std::time::Instant,
total_response_time_ms: Arc<RwLock<u64>>,
error_counter: Arc<RwLock<u64>>,
}
impl ProxyServer {
pub fn new(config: ProxyConfig, log_requests: bool, log_responses: bool) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
log_requests,
log_responses,
request_counter: Arc::new(RwLock::new(0)),
start_time: std::time::Instant::now(),
total_response_time_ms: Arc::new(RwLock::new(0)),
error_counter: Arc::new(RwLock::new(0)),
}
}
pub fn router(self) -> Router {
let state = Arc::new(self);
let state_for_middleware = state.clone();
Router::new()
.route("/proxy/health", get(health_check))
.fallback(proxy_handler)
.with_state(state)
.layer(axum::middleware::from_fn_with_state(state_for_middleware, logging_middleware))
}
}
async fn health_check() -> Result<Response<String>, StatusCode> {
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(r#"{"status":"healthy","service":"mockforge-proxy"}"#.to_string())
.map_err(|e| {
tracing::error!("Failed to build health check response: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})
}
async fn proxy_handler(
axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
request: http::Request<axum::body::Body>,
) -> Result<Response<String>, StatusCode> {
let client_addr = request
.extensions()
.get::<SocketAddr>()
.copied()
.unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
let method = request.method().clone();
let uri = request.uri().clone();
let headers = request.headers().clone();
let body_bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await {
Ok(bytes) => Some(bytes.to_vec()),
Err(e) => {
error!("Failed to read request body: {}", e);
None
}
};
let config = state.config.read().await;
if !config.enabled {
return Err(StatusCode::SERVICE_UNAVAILABLE);
}
if !config.should_proxy_with_condition(&method, &uri, &headers, body_bytes.as_deref()) {
return Err(StatusCode::NOT_FOUND);
}
let stripped_path = config.strip_prefix(uri.path());
let base_upstream_url = config.get_upstream_url(uri.path());
let full_upstream_url =
if stripped_path.starts_with("http://") || stripped_path.starts_with("https://") {
stripped_path.clone()
} else {
let base = base_upstream_url.trim_end_matches('/');
let path = stripped_path.trim_start_matches('/');
let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
if path.is_empty() || path == "/" {
format!("{}{}", base, query)
} else {
format!("{}/{}", base, path) + &query
}
};
let _modified_uri = full_upstream_url.parse::<http::Uri>().unwrap_or_else(|_| uri.clone());
if state.log_requests {
let mut counter = state.request_counter.write().await;
*counter += 1;
let request_id = *counter;
info!(
request_id = request_id,
method = %method,
path = %uri.path(),
upstream = %full_upstream_url,
client_ip = %client_addr.ip(),
"Proxy request intercepted"
);
}
let mut header_map = std::collections::HashMap::new();
for (key, value) in &headers {
if let Ok(value_str) = value.to_str() {
header_map.insert(key.to_string(), value_str.to_string());
}
}
use mockforge_core::proxy::client::ProxyClient;
let proxy_client = ProxyClient::new();
let reqwest_method = match method.as_str() {
"GET" => reqwest::Method::GET,
"POST" => reqwest::Method::POST,
"PUT" => reqwest::Method::PUT,
"DELETE" => reqwest::Method::DELETE,
"HEAD" => reqwest::Method::HEAD,
"OPTIONS" => reqwest::Method::OPTIONS,
"PATCH" => reqwest::Method::PATCH,
_ => {
error!("Unsupported HTTP method: {}", method);
return Err(StatusCode::METHOD_NOT_ALLOWED);
}
};
for (key, value) in &config.headers {
header_map.insert(key.clone(), value.clone());
}
let mut transformed_request_body = body_bytes.clone();
if !config.request_replacements.is_empty() {
let transform_middleware = BodyTransformationMiddleware::new(
config.request_replacements.clone(),
Vec::new(), );
if let Err(e) =
transform_middleware.transform_request_body(uri.path(), &mut transformed_request_body)
{
warn!("Failed to transform request body: {}", e);
}
}
match proxy_client
.send_request(
reqwest_method,
&full_upstream_url,
&header_map,
transformed_request_body.as_deref(),
)
.await
{
Ok(response) => {
let status = StatusCode::from_u16(response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if state.log_responses {
info!(
method = %method,
path = %uri.path(),
status = status.as_u16(),
"Proxy response sent"
);
}
let mut response_headers = http::HeaderMap::new();
for (name, value) in response.headers() {
if let (Ok(header_name), Ok(header_value)) = (
http::HeaderName::try_from(name.as_str()),
http::HeaderValue::try_from(value.as_bytes()),
) {
response_headers.insert(header_name, header_value);
}
}
let response_body_bytes = response.bytes().await.map_err(|e| {
error!("Failed to read proxy response body: {}", e);
StatusCode::BAD_GATEWAY
})?;
let mut final_body_bytes = response_body_bytes.to_vec();
{
let config_for_response = state.config.read().await;
if !config_for_response.response_replacements.is_empty() {
let transform_middleware = BodyTransformationMiddleware::new(
Vec::new(), config_for_response.response_replacements.clone(),
);
let mut body_option = Some(final_body_bytes.clone());
if let Err(e) = transform_middleware.transform_response_body(
uri.path(),
status.as_u16(),
&mut body_option,
) {
warn!("Failed to transform response body: {}", e);
} else if let Some(transformed_body) = body_option {
final_body_bytes = transformed_body;
}
}
}
let body_string = String::from_utf8_lossy(&final_body_bytes).to_string();
let mut response_builder = Response::builder().status(status);
for (name, value) in response_headers.iter() {
response_builder = response_builder.header(name, value);
}
response_builder
.body(body_string)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Err(e) => {
error!("Proxy request failed: {}", e);
Err(StatusCode::BAD_GATEWAY)
}
}
}
async fn logging_middleware(
axum::extract::State(state): axum::extract::State<Arc<ProxyServer>>,
request: Request,
next: Next,
) -> Response {
let start = std::time::Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let client_addr = request
.extensions()
.get::<SocketAddr>()
.copied()
.unwrap_or_else(|| SocketAddr::from(([0, 0, 0, 0], 0)));
debug!(
method = %method,
uri = %uri,
client_ip = %client_addr.ip(),
"Request received"
);
let response = next.run(request).await;
let duration = start.elapsed();
{
let mut total_time = state.total_response_time_ms.write().await;
*total_time += duration.as_millis() as u64;
}
if response.status().is_server_error() {
let mut errors = state.error_counter.write().await;
*errors += 1;
}
debug!(
method = %method,
uri = %uri,
status = %response.status(),
duration_ms = duration.as_millis(),
"Response sent"
);
response
}
#[derive(Debug, Serialize)]
pub struct ProxyStats {
pub total_requests: u64,
pub requests_per_second: f64,
pub avg_response_time_ms: f64,
pub error_rate_percent: f64,
}
pub async fn get_proxy_stats(state: &ProxyServer) -> ProxyStats {
let total_requests = *state.request_counter.read().await;
let total_response_time_ms = *state.total_response_time_ms.read().await;
let error_count = *state.error_counter.read().await;
let elapsed_secs = state.start_time.elapsed().as_secs_f64();
let requests_per_second = if elapsed_secs > 0.0 {
total_requests as f64 / elapsed_secs
} else {
0.0
};
let avg_response_time_ms = if total_requests > 0 {
total_response_time_ms as f64 / total_requests as f64
} else {
0.0
};
let error_rate_percent = if total_requests > 0 {
(error_count as f64 / total_requests as f64) * 100.0
} else {
0.0
};
ProxyStats {
total_requests,
requests_per_second,
avg_response_time_ms,
error_rate_percent,
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use mockforge_core::proxy::config::ProxyConfig;
#[tokio::test]
async fn test_proxy_server_creation() {
let config = ProxyConfig::default();
let server = ProxyServer::new(config, true, true);
assert!(server.log_requests);
assert!(server.log_responses);
}
#[tokio::test]
async fn test_health_check() {
let response = health_check().await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body();
assert!(body.contains("healthy"));
assert!(body.contains("mockforge-proxy"));
}
#[tokio::test]
async fn test_proxy_stats() {
let config = ProxyConfig::default();
let server = ProxyServer::new(config, false, false);
let stats = get_proxy_stats(&server).await;
assert_eq!(stats.total_requests, 0);
assert_eq!(stats.avg_response_time_ms, 0.0);
assert_eq!(stats.error_rate_percent, 0.0);
}
}