mcpr_core/proxy/
forwarding.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 body::{Body, Bytes},
6 http::{HeaderMap, Method, StatusCode, header},
7 response::{IntoResponse, Response},
8};
9use futures_util::StreamExt;
10use tokio::sync::Semaphore;
11
12#[derive(Clone)]
14pub struct UpstreamClient {
15 pub http_client: reqwest::Client,
16 pub semaphore: Arc<Semaphore>,
17 pub request_timeout: Duration,
18}
19
20pub async fn read_body_capped(
22 resp: reqwest::Response,
23 max_bytes: usize,
24) -> Result<Bytes, Response> {
25 if let Some(len) = resp.content_length()
26 && len as usize > max_bytes
27 {
28 return Err((StatusCode::BAD_GATEWAY, "upstream response too large").into_response());
29 }
30
31 let mut body =
32 Vec::with_capacity(resp.content_length().unwrap_or(0).min(max_bytes as u64) as usize);
33 let mut stream = resp.bytes_stream();
34 while let Some(chunk) = stream.next().await {
35 let chunk = chunk.map_err(|e| {
36 (StatusCode::BAD_GATEWAY, format!("upstream read error: {e}")).into_response()
37 })?;
38 if body.len() + chunk.len() > max_bytes {
39 return Err((StatusCode::BAD_GATEWAY, "upstream response too large").into_response());
40 }
41 body.extend_from_slice(&chunk);
42 }
43 Ok(Bytes::from(body))
44}
45
46pub async fn forward_request(
49 upstream: &UpstreamClient,
50 url: &str,
51 method: Method,
52 headers: &HeaderMap,
53 body: &Bytes,
54 is_streaming: bool,
55) -> Result<reqwest::Response, reqwest::Error> {
56 let _permit = upstream
57 .semaphore
58 .acquire()
59 .await
60 .expect("upstream semaphore closed");
61
62 let mut req = upstream.http_client.request(method, url);
63
64 if !is_streaming {
65 req = req.timeout(upstream.request_timeout);
66 }
67
68 for key in [header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT] {
69 if let Some(val) = headers.get(&key) {
70 req = req.header(key.as_str(), val.as_bytes());
71 }
72 }
73
74 if let Some(session_id) = headers.get("mcp-session-id") {
75 req = req.header("mcp-session-id", session_id.as_bytes());
76 }
77
78 if let Some(last_event) = headers.get("last-event-id") {
79 req = req.header("last-event-id", last_event.as_bytes());
80 }
81
82 if !body.is_empty() {
83 req = req.body(body.clone());
84 }
85
86 req.send().await
87}
88
89pub fn build_response(status: u16, upstream_headers: &HeaderMap, body: Body) -> Response {
91 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_GATEWAY);
92 let mut builder = Response::builder().status(status_code);
93
94 for key in [header::CONTENT_TYPE, header::CACHE_CONTROL] {
95 if let Some(val) = upstream_headers.get(&key) {
96 builder = builder.header(key.as_str(), val);
97 }
98 }
99
100 if let Some(val) = upstream_headers.get("mcp-session-id") {
101 builder = builder.header("mcp-session-id", val);
102 }
103
104 if let Some(val) = upstream_headers.get(header::WWW_AUTHENTICATE) {
105 builder = builder.header(header::WWW_AUTHENTICATE, val);
106 }
107
108 builder.body(body).unwrap_or_else(|_| {
109 Response::builder()
110 .status(StatusCode::INTERNAL_SERVER_ERROR)
111 .body(Body::from("Failed to build response"))
112 .unwrap()
113 })
114}