use bytes::Bytes;
use http::Response;
use http_body_util::{BodyExt, Full};
use pingora::http::{RequestHeader, ResponseHeader};
use pingora::prelude::*;
use pingora::proxy::Session;
use std::collections::HashMap;
use crate::routing::RequestInfo;
use crate::trace_id::{generate_for_format, TraceIdFormat};
#[derive(Debug, Clone)]
pub struct OwnedRequestInfo {
pub method: String,
pub path: String,
pub host: String,
pub headers: HashMap<String, String>,
pub query_params: HashMap<String, String>,
}
pub fn extract_request_host(req_header: &RequestHeader) -> &str {
if let Some(host) = req_header.uri.host() {
return host;
}
req_header
.headers
.get("host")
.and_then(|h| h.to_str().ok())
.unwrap_or("")
}
pub fn extract_request_info(session: &Session) -> OwnedRequestInfo {
let req_header = session.req_header();
let headers = RequestInfo::build_headers(req_header.headers.iter());
let host = extract_request_host(req_header).to_string();
let path = req_header.uri.path().to_string();
let method = req_header.method.as_str().to_string();
OwnedRequestInfo {
method,
path: path.clone(),
host,
headers,
query_params: RequestInfo::parse_query_params(&path),
}
}
pub fn get_or_create_trace_id(session: &Session, format: TraceIdFormat) -> String {
let req_header = session.req_header();
const TRACE_HEADERS: [&str; 3] = ["x-trace-id", "x-correlation-id", "x-request-id"];
for header_name in &TRACE_HEADERS {
if let Some(value) = req_header.headers.get(*header_name) {
if let Ok(id) = value.to_str() {
if !id.is_empty() {
return id.to_string();
}
}
}
}
generate_for_format(format)
}
#[inline]
pub fn get_or_create_trace_id_default(session: &Session) -> String {
get_or_create_trace_id(session, TraceIdFormat::default())
}
pub async fn write_response(
session: &mut Session,
response: Response<Full<Bytes>>,
keepalive_secs: Option<u64>,
) -> Result<(), Box<Error>> {
let status = response.status().as_u16();
let headers_owned: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let full_body = response.into_body();
let body_bytes: Bytes = BodyExt::collect(full_body)
.await
.map(|collected| collected.to_bytes())
.unwrap_or_default();
let mut resp_header = ResponseHeader::build(status, None)?;
for (key, value) in headers_owned {
resp_header.insert_header(key, &value)?;
}
session.set_keepalive(keepalive_secs);
session
.write_response_header(Box::new(resp_header), false)
.await?;
session.write_response_body(Some(body_bytes), true).await?;
Ok(())
}
pub async fn write_error(
session: &mut Session,
status: u16,
body: &str,
content_type: &str,
) -> Result<(), Box<Error>> {
let mut resp_header = ResponseHeader::build(status, None)?;
resp_header.insert_header("Content-Type", content_type)?;
resp_header.insert_header("Content-Length", body.len().to_string())?;
session.set_keepalive(None);
session
.write_response_header(Box::new(resp_header), false)
.await?;
session
.write_response_body(Some(Bytes::copy_from_slice(body.as_bytes())), true)
.await?;
Ok(())
}
pub async fn write_text_error(
session: &mut Session,
status: u16,
message: &str,
) -> Result<(), Box<Error>> {
write_error(session, status, message, "text/plain; charset=utf-8").await
}
pub async fn write_json_error(
session: &mut Session,
status: u16,
error: &str,
message: Option<&str>,
) -> Result<(), Box<Error>> {
let body = match message {
Some(msg) => format!(r#"{{"error":"{}","message":"{}"}}"#, error, msg),
None => format!(r#"{{"error":"{}"}}"#, error),
};
write_error(session, status, &body, "application/json").await
}
pub async fn write_rate_limit_error(
session: &mut Session,
status: u16,
body: &str,
limit: u32,
remaining: u32,
reset_at: u64,
retry_after: u64,
) -> Result<(), Box<Error>> {
let mut resp_header = ResponseHeader::build(status, None)?;
resp_header.insert_header("Content-Type", "text/plain; charset=utf-8")?;
resp_header.insert_header("Content-Length", body.len().to_string())?;
resp_header.insert_header("X-RateLimit-Limit", limit.to_string())?;
resp_header.insert_header("X-RateLimit-Remaining", remaining.to_string())?;
resp_header.insert_header("X-RateLimit-Reset", reset_at.to_string())?;
if retry_after > 0 {
resp_header.insert_header("Retry-After", retry_after.to_string())?;
}
session.set_keepalive(None);
session
.write_response_header(Box::new(resp_header), false)
.await?;
session
.write_response_body(Some(Bytes::copy_from_slice(body.as_bytes())), true)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn req(uri: &str, host_header: Option<&str>) -> RequestHeader {
let mut h = RequestHeader::build("GET", b"/", None).unwrap();
h.set_uri(uri.parse().unwrap());
if let Some(v) = host_header {
h.insert_header("host", v).unwrap();
}
h
}
#[test]
fn extract_host_prefers_uri_host_for_absolute_uri() {
let h = req("http://example.com/path", Some("other.example.org"));
assert_eq!(extract_request_host(&h), "example.com");
}
#[test]
fn extract_host_falls_back_to_header_for_relative_uri() {
let h = req(
"/_matrix/federation/v1/send/123",
Some("im.example.com:443"),
);
assert_eq!(extract_request_host(&h), "im.example.com:443");
}
#[test]
fn extract_host_returns_empty_when_no_host_anywhere() {
let h = req("/path", None);
assert_eq!(extract_request_host(&h), "");
}
#[test]
fn extract_host_uses_uri_when_header_missing() {
let h = req("http://api.example.com/v1", None);
assert_eq!(extract_request_host(&h), "api.example.com");
}
}