mcpr_core/proxy/
forwarding.rs1use std::time::Duration;
2
3use axum::{
4 body::{Body, Bytes},
5 http::{HeaderMap, Method, StatusCode, header},
6 response::Response,
7};
8use futures_util::StreamExt;
9
10#[derive(Clone)]
17pub struct UpstreamClient {
18 pub http_client: reqwest::Client,
19 pub request_timeout: Duration,
20}
21
22#[derive(Debug)]
27pub enum ReadBodyError {
28 TooLarge,
30 Stream(reqwest::Error),
32}
33
34impl std::fmt::Display for ReadBodyError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 ReadBodyError::TooLarge => write!(f, "upstream response too large"),
38 ReadBodyError::Stream(e) => write!(f, "upstream read error: {e}"),
39 }
40 }
41}
42
43pub async fn read_body_capped(
48 resp: reqwest::Response,
49 max_bytes: usize,
50) -> Result<Bytes, ReadBodyError> {
51 if let Some(len) = resp.content_length()
52 && len as usize > max_bytes
53 {
54 return Err(ReadBodyError::TooLarge);
55 }
56
57 let mut body =
58 Vec::with_capacity(resp.content_length().unwrap_or(0).min(max_bytes as u64) as usize);
59 let mut stream = resp.bytes_stream();
60 while let Some(chunk) = stream.next().await {
61 let chunk = chunk.map_err(ReadBodyError::Stream)?;
62 if body.len() + chunk.len() > max_bytes {
63 return Err(ReadBodyError::TooLarge);
64 }
65 body.extend_from_slice(&chunk);
66 }
67 Ok(Bytes::from(body))
68}
69
70pub async fn forward_request(
77 upstream: &UpstreamClient,
78 url: &str,
79 method: Method,
80 headers: &HeaderMap,
81 body: &Bytes,
82 is_streaming: bool,
83) -> Result<reqwest::Response, reqwest::Error> {
84 let mut req = upstream.http_client.request(method, url);
85
86 if is_streaming {
87 req = req.timeout(upstream.request_timeout);
88 }
89
90 for key in [header::AUTHORIZATION, header::CONTENT_TYPE, header::ACCEPT] {
91 if let Some(val) = headers.get(&key) {
92 req = req.header(key.as_str(), val.as_bytes());
93 }
94 }
95
96 if let Some(session_id) = headers.get("mcp-session-id") {
97 req = req.header("mcp-session-id", session_id.as_bytes());
98 }
99
100 if let Some(last_event) = headers.get("last-event-id") {
101 req = req.header("last-event-id", last_event.as_bytes());
102 }
103
104 if !body.is_empty() {
105 req = req.body(body.clone());
106 }
107
108 req.send().await
109}
110
111pub fn build_response(status: u16, upstream_headers: &HeaderMap, body: Body) -> Response {
113 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_GATEWAY);
114 let mut builder = Response::builder().status(status_code);
115
116 for key in [header::CONTENT_TYPE, header::CACHE_CONTROL] {
117 if let Some(val) = upstream_headers.get(&key) {
118 builder = builder.header(key.as_str(), val);
119 }
120 }
121
122 if let Some(val) = upstream_headers.get("mcp-session-id") {
123 builder = builder.header("mcp-session-id", val);
124 }
125
126 if let Some(val) = upstream_headers.get(header::WWW_AUTHENTICATE) {
127 builder = builder.header(header::WWW_AUTHENTICATE, val);
128 }
129
130 builder.body(body).unwrap_or_else(|_| {
131 Response::builder()
132 .status(StatusCode::INTERNAL_SERVER_ERROR)
133 .body(Body::from("Failed to build response"))
134 .unwrap()
135 })
136}