use super::client::HttpClient;
use super::headers::{
RiftHeadersExt, VALUE_TRUE, X_RIFT_PROXIED, X_RIFT_RECORDED, X_RIFT_REPLAYED,
};
use super::response_ext::ResponseExt;
use crate::recording::{ProxyMode, RecordedResponse, RecordingStore, RequestSignature};
use crate::response::builder::ErrorResponseBuilder;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper::{Request, Response, StatusCode};
use std::convert::Infallible;
use std::sync::Arc;
use tracing::{debug, error};
pub fn error_response(status: u16, message: &str) -> Response<Full<Bytes>> {
let body = format!(r#"{{"error": "{message}"}}"#);
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap_or_else(|_| Response::new(Full::new(Bytes::from(r#"{"error": "internal"}"#))))
}
fn build_upstream_request(
method: hyper::Method,
uri: &hyper::Uri,
headers: &hyper::HeaderMap,
upstream_uri: &str,
) -> hyper::http::request::Builder {
let upstream_path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
let full_uri = format!("{upstream_uri}{upstream_path}");
let mut builder = Request::builder().method(method).uri(full_uri);
for (key, value) in headers.iter() {
if key != "host" {
builder = builder.header(key, value);
}
}
builder
}
pub async fn forward_request_with_body(
http_client: &HttpClient,
method: hyper::Method,
uri: hyper::Uri,
headers: hyper::HeaderMap,
body_bytes: Bytes,
upstream_uri: &str,
) -> Response<Full<Bytes>> {
let builder = build_upstream_request(method, &uri, &headers, upstream_uri);
debug!(
"Forwarding to: {}",
uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/")
);
let upstream_req = builder
.body(BoxBody::new(
Full::new(body_bytes).map_err(|never: Infallible| match never {}),
))
.unwrap();
match http_client.request(upstream_req).await {
Ok(upstream_response) => {
let (parts, body) = upstream_response.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!("Failed to collect upstream response body: {}", e);
return error_response(502, "Failed to read upstream response");
}
};
let mut response = Response::from_parts(parts, Full::new(body_bytes));
response.set_header(&X_RIFT_PROXIED, &VALUE_TRUE);
response
}
Err(e) => {
error!("Failed to forward request to upstream: {}", e);
error_response(502, "Bad Gateway")
}
}
}
pub async fn forward_request_streaming(
http_client: &HttpClient,
req: Request<hyper::body::Incoming>,
upstream_uri: &str,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let method = req.method().clone();
let uri = req.uri().clone();
let headers = req.headers().clone();
let builder = build_upstream_request(method, &uri, &headers, upstream_uri);
debug!(
"Forwarding (streaming) to: {}",
uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/")
);
let upstream_req = builder.body(BoxBody::new(req.into_body())).unwrap();
match http_client.request(upstream_req).await {
Ok(upstream_response) => {
let (mut parts, body) = upstream_response.into_parts();
parts.set_header(&X_RIFT_PROXIED, &VALUE_TRUE);
Response::from_parts(parts, BoxBody::new(body))
}
Err(e) => {
error!("Failed to forward request to upstream: {}", e);
ErrorResponseBuilder::new(StatusCode::BAD_GATEWAY)
.header("content-type", "application/json")
.body(r#"{"error": "Bad Gateway"}"#)
.build_boxed()
}
}
}
pub async fn forward_with_recording(
http_client: &HttpClient,
recording_store: &Arc<RecordingStore>,
signature_headers: &[(String, String)],
req: Request<hyper::body::Incoming>,
upstream_uri: &str,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let method = req.method().clone();
let uri = req.uri().clone();
let headers = req.headers().clone();
let mode = recording_store.mode();
if mode == ProxyMode::ProxyTransparent {
return forward_request_streaming(http_client, req, upstream_uri).await;
}
let body_bytes = match req.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
error!("Failed to collect request body for recording: {}", e);
return ErrorResponseBuilder::new(StatusCode::INTERNAL_SERVER_ERROR)
.header("content-type", "application/json")
.body(r#"{"error": "Failed to read request body"}"#)
.build_boxed();
}
};
let signature =
RequestSignature::new(method.as_str(), uri.path(), uri.query(), signature_headers);
if !recording_store.should_proxy(&signature) {
if let Some(recorded) = recording_store.get_recorded(&signature) {
debug!(
"Replaying recorded response for {} {} (status: {})",
method,
uri.path(),
recorded.status
);
let mut response = Response::builder().status(recorded.status);
for (key, value) in &recorded.headers {
if let Ok(header_value) = value.parse::<hyper::header::HeaderValue>() {
response = response.header(key.as_str(), header_value);
}
}
response = response.header(X_RIFT_REPLAYED.clone(), VALUE_TRUE.clone());
return response
.body(BoxBody::new(
Full::new(Bytes::from(recorded.body.clone()))
.map_err(|never: Infallible| match never {}),
))
.unwrap();
}
}
let start = std::time::Instant::now();
let response = forward_request_with_body(
http_client,
method.clone(),
uri.clone(),
headers,
body_bytes,
upstream_uri,
)
.await;
let latency_ms = start.elapsed().as_millis() as u64;
let status = response.status().as_u16();
let (parts, body) = response.into_parts();
let response_body_bytes: Bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
let mut recorded_headers = Vec::new();
for (key, value) in parts.headers.iter() {
if let Ok(value_str) = value.to_str() {
recorded_headers.push((key.as_str().to_string(), value_str.to_string()));
}
}
let recorded_response = RecordedResponse {
status,
headers: recorded_headers,
body: response_body_bytes.to_vec(),
latency_ms: Some(latency_ms),
timestamp_secs: crate::util::unix_timestamp(),
};
recording_store.record(signature, recorded_response.clone());
debug!(
"Recorded response for {} {} (status: {}, latency: {}ms)",
method,
uri.path(),
status,
latency_ms
);
let mut response = Response::from_parts(parts, Full::new(response_body_bytes));
response.set_header(&X_RIFT_RECORDED, &VALUE_TRUE);
response.into_boxed()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_response_basic() {
let response = error_response(500, "Internal Server Error");
assert_eq!(response.status(), 500);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
}
#[test]
fn test_error_response_400() {
let response = error_response(400, "Bad Request");
assert_eq!(response.status(), 400);
}
#[test]
fn test_error_response_502() {
let response = error_response(502, "Bad Gateway");
assert_eq!(response.status(), 502);
}
#[test]
fn test_error_response_404() {
let response = error_response(404, "Not Found");
assert_eq!(response.status(), 404);
}
#[test]
fn test_error_response_503() {
let response = error_response(503, "Service Unavailable");
assert_eq!(response.status(), 503);
}
}