use std::sync::Arc;
use http_body_util::{BodyExt, Limited};
use hyper::body::Incoming;
use hyper::{Method, Request, Response};
use osproxy_spi::{HttpMethod, Protocol};
use crate::admission::{Admission, IngressLimits, Reservation};
use crate::classify::classify;
use crate::handler::IngressHandler;
use crate::request::{
buffered_response, IngressRequest, IngressResponse, ResponseBody, StreamingResponse,
};
#[derive(Clone, Debug, Default)]
pub(crate) struct ConnInfo {
pub(crate) client_cert_subject: Option<String>,
pub(crate) secure: bool,
}
pub(crate) async fn serve_request<H: IngressHandler>(
handler: &H,
req: Request<Incoming>,
conn_info: &ConnInfo,
limits: IngressLimits,
admission: &Arc<Admission>,
) -> Response<ResponseBody> {
let Some(method) = map_method(req.method()) else {
return render(IngressResponse::json(405, error_body("method not allowed")));
};
let path = req.uri().path().to_owned();
let query = req.uri().query().map(str::to_owned);
let protocol = map_protocol(req.version());
let headers: Vec<(String, String)> = req
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_owned(), v.to_str().unwrap_or("").to_owned()))
.collect();
let declared = content_length(&headers);
let c = classify(method, &path);
let head = IngressRequest {
method,
protocol,
path,
endpoint: c.endpoint,
logical_index: c.logical_index,
doc_id: c.doc_id,
headers,
body: Vec::new(),
query,
client_cert_subject: conn_info.client_cert_subject.clone(),
secure: conn_info.secure,
};
if handler.forward_plan(&head.path, &head.logical_index) {
return render_forward(handler.handle_forward(head, req.into_body()).await);
}
if declared.is_some_and(|n| n > limits.max_body_bytes) {
return render(IngressResponse::json(
413,
error_body("request body too large"),
));
}
if handler.wants_bulk_stream(head.endpoint, &head.headers) {
return render(handler.handle_bulk_stream(head, req.into_body()).await);
}
serve_buffered(handler, req.into_body(), head, declared, limits, admission).await
}
async fn serve_buffered<H: IngressHandler>(
handler: &H,
body: Incoming,
mut head: IngressRequest,
declared: Option<usize>,
limits: IngressLimits,
admission: &Arc<Admission>,
) -> Response<ResponseBody> {
let reserve = declared.unwrap_or(limits.max_body_bytes);
let Some(_reservation): Option<Reservation> = admission.try_reserve(reserve) else {
return render(overloaded_response());
};
let collected = match Limited::new(body, limits.max_body_bytes).collect().await {
Ok(collected) => collected.to_bytes().to_vec(),
Err(_) => {
return render(IngressResponse::json(
413,
error_body("request body too large"),
))
}
};
head.body = collected;
if handler.wants_search_stream(head.endpoint, head.query.as_deref()) {
return render_forward(handler.handle_search_stream(head).await);
}
render(handler.handle(head).await)
}
fn content_length(headers: &[(String, String)]) -> Option<usize> {
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.trim().parse().ok())
}
fn overloaded_response() -> IngressResponse {
IngressResponse::json(429, error_body("ingress overloaded, retry later"))
.with_header("retry-after", "1")
}
fn map_method(method: &Method) -> Option<HttpMethod> {
match *method {
Method::GET => Some(HttpMethod::Get),
Method::PUT => Some(HttpMethod::Put),
Method::POST => Some(HttpMethod::Post),
Method::DELETE => Some(HttpMethod::Delete),
Method::HEAD => Some(HttpMethod::Head),
_ => None,
}
}
fn map_protocol(version: hyper::Version) -> Protocol {
if version == hyper::Version::HTTP_2 {
Protocol::Http2
} else {
Protocol::Http1
}
}
fn render(out: IngressResponse) -> Response<ResponseBody> {
let mut builder = Response::builder().status(out.status);
if !has_content_type(&out.headers) {
builder = builder.header("content-type", "application/json");
}
for (name, value) in out.headers {
builder = builder.header(name, value);
}
builder
.body(buffered_response(out.body))
.unwrap_or_else(|_| {
Response::new(buffered_response(b"{\"error\":\"internal\"}".to_vec()))
})
}
fn render_forward(out: StreamingResponse) -> Response<ResponseBody> {
let mut builder = Response::builder().status(out.status);
if !has_content_type(&out.headers) {
builder = builder.header("content-type", "application/json");
}
for (name, value) in out.headers {
builder = builder.header(name, value);
}
builder
.body(out.body)
.unwrap_or_else(|_| Response::new(buffered_response(b"{\"error\":\"internal\"}".to_vec())))
}
fn has_content_type(headers: &[(String, String)]) -> bool {
headers
.iter()
.any(|(name, _)| name.eq_ignore_ascii_case("content-type"))
}
fn error_body(message: &str) -> Vec<u8> {
format!(r#"{{"error":"{message}"}}"#).into_bytes()
}