Skip to main content

mcpr_core/proxy/
forwarding.rs

1use std::time::Duration;
2
3use axum::{
4    body::{Body, Bytes},
5    http::{HeaderMap, Method, StatusCode, header},
6    response::Response,
7};
8use futures_util::StreamExt;
9
10/// Shared upstream connection config for forwarding requests.
11///
12/// Concurrency and buffered-path timeout are enforced at the axum edge
13/// by `ConcurrencyLimitLayer` / `TimeoutLayer`. `request_timeout` here
14/// applies only to streaming calls, which the tower-http timeout can't
15/// cover (it cancels at response start, not mid-stream).
16#[derive(Clone)]
17pub struct UpstreamClient {
18    pub http_client: reqwest::Client,
19    pub request_timeout: Duration,
20}
21
22/// Reason the upstream body could not be read.
23///
24/// `ProxyTransport` maps this to `Response::Upstream502 { reason }` so
25/// the failure flows through the response middleware chain.
26#[derive(Debug)]
27pub enum ReadBodyError {
28    /// `Content-Length` or streamed bytes exceeded `max_bytes`.
29    TooLarge,
30    /// Underlying reqwest stream error.
31    Stream(reqwest::Error),
32}
33
34impl std::fmt::Display for ReadBodyError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            ReadBodyError::TooLarge => write!(f, "upstream response too large"),
38            ReadBodyError::Stream(e) => write!(f, "upstream read error: {e}"),
39        }
40    }
41}
42
43/// Read a response body with a size cap. Returns a typed error on
44/// overflow or stream failure — callers produce the appropriate
45/// `Response` variant (`Upstream502`) so the failure flows through the
46/// response chain like any other upstream problem.
47pub async fn read_body_capped(
48    resp: reqwest::Response,
49    max_bytes: usize,
50) -> Result<Bytes, ReadBodyError> {
51    if let Some(len) = resp.content_length()
52        && len as usize > max_bytes
53    {
54        return Err(ReadBodyError::TooLarge);
55    }
56
57    let mut body =
58        Vec::with_capacity(resp.content_length().unwrap_or(0).min(max_bytes as u64) as usize);
59    let mut stream = resp.bytes_stream();
60    while let Some(chunk) = stream.next().await {
61        let chunk = chunk.map_err(ReadBodyError::Stream)?;
62        if body.len() + chunk.len() > max_bytes {
63            return Err(ReadBodyError::TooLarge);
64        }
65        body.extend_from_slice(&chunk);
66    }
67    Ok(Bytes::from(body))
68}
69
70/// Send a request to the upstream server, forwarding relevant headers.
71///
72/// For streaming calls we still apply the per-request reqwest timeout:
73/// the axum-edge `TimeoutLayer` (tower-http) cancels at response start
74/// and can't see mid-stream stalls. For non-streaming calls the tower
75/// layer handles timeout budget end-to-end, so no reqwest timeout here.
76pub async fn forward_request(
77    upstream: &UpstreamClient,
78    url: &str,
79    method: Method,
80    headers: &HeaderMap,
81    body: &Bytes,
82    is_streaming: bool,
83) -> Result<reqwest::Response, reqwest::Error> {
84    let mut req = upstream.http_client.request(method, url);
85
86    if is_streaming {
87        req = req.timeout(upstream.request_timeout);
88    }
89
90    for key in [header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT] {
91        if let Some(val) = headers.get(&key) {
92            req = req.header(key.as_str(), val.as_bytes());
93        }
94    }
95
96    if let Some(session_id) = headers.get("mcp-session-id") {
97        req = req.header("mcp-session-id", session_id.as_bytes());
98    }
99
100    if let Some(last_event) = headers.get("last-event-id") {
101        req = req.header("last-event-id", last_event.as_bytes());
102    }
103
104    if !body.is_empty() {
105        req = req.body(body.clone());
106    }
107
108    req.send().await
109}
110
111/// Build an axum Response from status, headers, and body.
112pub fn build_response(status: u16, upstream_headers: &HeaderMap, body: Body) -> Response {
113    let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_GATEWAY);
114    let mut builder = Response::builder().status(status_code);
115
116    for key in [header::CONTENT_TYPE, header::CACHE_CONTROL] {
117        if let Some(val) = upstream_headers.get(&key) {
118            builder = builder.header(key.as_str(), val);
119        }
120    }
121
122    if let Some(val) = upstream_headers.get("mcp-session-id") {
123        builder = builder.header("mcp-session-id", val);
124    }
125
126    if let Some(val) = upstream_headers.get(header::WWW_AUTHENTICATE) {
127        builder = builder.header(header::WWW_AUTHENTICATE, val);
128    }
129
130    builder.body(body).unwrap_or_else(|_| {
131        Response::builder()
132            .status(StatusCode::INTERNAL_SERVER_ERROR)
133            .body(Body::from("Failed to build response"))
134            .unwrap()
135    })
136}