Skip to main content

mcpr_core/proxy/
forwarding.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    body::{Body, Bytes},
6    http::{HeaderMap, Method, StatusCode, header},
7    response::{IntoResponse, Response},
8};
9use futures_util::StreamExt;
10use tokio::sync::Semaphore;
11
12/// Shared upstream connection config for forwarding requests.
13#[derive(Clone)]
14pub struct UpstreamClient {
15    pub http_client: reqwest::Client,
16    pub semaphore: Arc<Semaphore>,
17    pub request_timeout: Duration,
18}
19
20/// Read a response body with a size cap. Returns 502 if the upstream response exceeds `max_bytes`.
21pub async fn read_body_capped(
22    resp: reqwest::Response,
23    max_bytes: usize,
24) -> Result<Bytes, Response> {
25    if let Some(len) = resp.content_length()
26        && len as usize > max_bytes
27    {
28        return Err((StatusCode::BAD_GATEWAY, "upstream response too large").into_response());
29    }
30
31    let mut body =
32        Vec::with_capacity(resp.content_length().unwrap_or(0).min(max_bytes as u64) as usize);
33    let mut stream = resp.bytes_stream();
34    while let Some(chunk) = stream.next().await {
35        let chunk = chunk.map_err(|e| {
36            (StatusCode::BAD_GATEWAY, format!("upstream read error: {e}")).into_response()
37        })?;
38        if body.len() + chunk.len() > max_bytes {
39            return Err((StatusCode::BAD_GATEWAY, "upstream response too large").into_response());
40        }
41        body.extend_from_slice(&chunk);
42    }
43    Ok(Bytes::from(body))
44}
45
46/// Send a request to the upstream server, forwarding relevant headers.
47/// When `is_streaming` is false, applies the configured request timeout.
48pub async fn forward_request(
49    upstream: &UpstreamClient,
50    url: &str,
51    method: Method,
52    headers: &HeaderMap,
53    body: &Bytes,
54    is_streaming: bool,
55) -> Result<reqwest::Response, reqwest::Error> {
56    let _permit = upstream
57        .semaphore
58        .acquire()
59        .await
60        .expect("upstream semaphore closed");
61
62    let mut req = upstream.http_client.request(method, url);
63
64    if !is_streaming {
65        req = req.timeout(upstream.request_timeout);
66    }
67
68    for key in [header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT] {
69        if let Some(val) = headers.get(&key) {
70            req = req.header(key.as_str(), val.as_bytes());
71        }
72    }
73
74    if let Some(session_id) = headers.get("mcp-session-id") {
75        req = req.header("mcp-session-id", session_id.as_bytes());
76    }
77
78    if let Some(last_event) = headers.get("last-event-id") {
79        req = req.header("last-event-id", last_event.as_bytes());
80    }
81
82    if !body.is_empty() {
83        req = req.body(body.clone());
84    }
85
86    req.send().await
87}
88
89/// Build an axum Response from status, headers, and body.
90pub fn build_response(status: u16, upstream_headers: &HeaderMap, body: Body) -> Response {
91    let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_GATEWAY);
92    let mut builder = Response::builder().status(status_code);
93
94    for key in [header::CONTENT_TYPE, header::CACHE_CONTROL] {
95        if let Some(val) = upstream_headers.get(&key) {
96            builder = builder.header(key.as_str(), val);
97        }
98    }
99
100    if let Some(val) = upstream_headers.get("mcp-session-id") {
101        builder = builder.header("mcp-session-id", val);
102    }
103
104    if let Some(val) = upstream_headers.get(header::WWW_AUTHENTICATE) {
105        builder = builder.header(header::WWW_AUTHENTICATE, val);
106    }
107
108    builder.body(body).unwrap_or_else(|_| {
109        Response::builder()
110            .status(StatusCode::INTERNAL_SERVER_ERROR)
111            .body(Body::from("Failed to build response"))
112            .unwrap()
113    })
114}