pub mod buffered;
pub mod header_phase;
pub mod passthrough;
pub mod sse;
pub mod streamed;
use std::time::Instant;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::{IntoResponse, Response};
use crate::protocol::session::SessionStore;
use crate::proxy::ProxyState;
use crate::proxy::forwarding::{UpstreamClient, forward_request};
use crate::proxy::pipeline::context::RequestContext;
use crate::proxy::pipeline::emit::{ResponseSummary, emit_request_event};
pub(super) use crate::timing::{Stage, StageTimer};
pub(super) async fn forward_or_502(
upstream: &UpstreamClient,
url: &str,
method: Method,
headers: &HeaderMap,
body: &axum::body::Bytes,
is_streaming: bool,
) -> Result<reqwest::Response, ForwardError> {
forward_request(upstream, url, method, headers, body, is_streaming)
.await
.map_err(|e| ForwardError {
err_string: format!("{e}"),
})
}
pub(super) struct ForwardError {
pub err_string: String,
}
pub(super) fn emit_upstream_error(
state: &ProxyState,
ctx: &mut RequestContext,
upstream_start: Instant,
err: ForwardError,
) -> Response {
let upstream_us = upstream_start.elapsed().as_micros() as u64;
ctx.tags.push("upstream error");
emit_request_event(
state,
ctx,
&ResponseSummary {
status: 502,
response_size: None,
upstream_us: Some(upstream_us),
error_code: None,
error_msg: Some(err.err_string.clone()),
stage_timings: None,
},
);
(
StatusCode::BAD_GATEWAY,
format!("Upstream error: {}", err.err_string),
)
.into_response()
}
pub(super) async fn populate_client_info(state: &ProxyState, ctx: &mut RequestContext) {
if let Some(ref sid) = ctx.session_id
&& let Some(info) = state.sessions.get(sid).await
&& let Some(ci) = info.client_info
{
ctx.client_name = Some(ci.name);
ctx.client_version = ci.version;
}
}
pub(super) fn capture_session_id(ctx: &mut RequestContext, upstream_headers: &HeaderMap) {
if let Some(sid) = upstream_headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
{
ctx.session_id = Some(sid.to_string());
}
}