use anyhow::Context;
use log::{debug, error, info, warn};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use super::{UpstreamHandle, CHANNEL_DEPTH};
pub fn parse_header(raw: &str) -> anyhow::Result<(String, String)> {
let (k, v) = raw
.split_once(':')
.with_context(|| format!("--upstream-header '{}' is not 'Name: value'", raw))?;
Ok((k.trim().to_string(), v.trim().to_string()))
}
pub fn spawn_http_upstream(
url: &str,
extra_headers: Vec<(String, String)>,
) -> anyhow::Result<UpstreamHandle> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.context("building HTTP client for upstream")?;
let (to_tx, mut to_rx) = mpsc::channel::<String>(CHANNEL_DEPTH);
let (from_tx, from_rx) = mpsc::channel::<String>(CHANNEL_DEPTH);
let session: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let url_owned = url.to_string();
{
let client = client.clone();
let url = url_owned.clone();
let from_tx = from_tx.clone();
let session = session.clone();
tokio::spawn(async move {
while let Some(frame) = to_rx.recv().await {
let is_initialize = frame.contains("\"initialize\"");
let mut req = client
.post(&url)
.header("content-type", "application/json")
.header("accept", "application/json, text/event-stream");
for (k, v) in &extra_headers {
req = req.header(k, v);
}
if let Some(sid) = session.lock().await.clone() {
req = req.header("mcp-session-id", sid);
}
let resp = match req.body(frame.clone()).send().await {
Ok(r) => r,
Err(e) => {
error!("[shield] upstream POST failed: {}", e);
if let Some(err) = transport_error_for(&frame, &e.to_string()) {
let _ = from_tx.send(err).await;
}
continue;
}
};
if is_initialize {
if let Some(sid) = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
{
info!("[shield] upstream assigned MCP session id");
*session.lock().await = Some(sid.to_string());
}
}
let status = resp.status();
if status == reqwest::StatusCode::ACCEPTED {
continue;
}
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
warn!(
"[shield] upstream POST returned {}: {}",
status,
body.chars().take(300).collect::<String>()
);
if let Some(err) =
transport_error_for(&frame, &format!("upstream returned {}", status))
{
let _ = from_tx.send(err).await;
}
continue;
}
let ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
if ct.starts_with("text/event-stream") {
if let Err(e) = pump_sse(resp, &from_tx).await {
warn!("[shield] upstream SSE stream ended with error: {}", e);
}
} else {
match resp.text().await {
Ok(body) if !body.trim().is_empty() => {
if from_tx.send(body.trim().to_string()).await.is_err() {
break;
}
}
Ok(_) => {}
Err(e) => error!("[shield] upstream body read error: {}", e),
}
}
}
debug!("[shield] upstream POST pump finished");
});
}
{
let url = url_owned.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
let mut req = client
.get(&url)
.header("accept", "text/event-stream");
if let Some(sid) = session.lock().await.clone() {
req = req.header("mcp-session-id", sid);
}
match req.send().await {
Ok(resp)
if resp.status().is_success()
&& resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|ct| ct.starts_with("text/event-stream"))
.unwrap_or(false) =>
{
info!("[shield] upstream GET SSE stream open (server-initiated messages)");
if let Err(e) = pump_sse(resp, &from_tx).await {
debug!("[shield] upstream GET SSE stream closed: {}", e);
}
}
Ok(resp) => {
debug!(
"[shield] upstream has no GET SSE stream (status {}) -- skipping",
resp.status()
);
}
Err(e) => debug!("[shield] upstream GET SSE probe failed: {}", e),
}
});
}
Ok(UpstreamHandle {
tx: to_tx,
rx: from_rx,
label: url_owned,
child: None,
})
}
async fn pump_sse(
resp: reqwest::Response,
from_tx: &mpsc::Sender<String>,
) -> anyhow::Result<()> {
use futures_util::StreamExt;
let mut stream = resp.bytes_stream();
let mut buf: Vec<u8> = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("SSE chunk")?;
buf.extend_from_slice(&chunk);
loop {
let Some(pos) = find_event_boundary(&buf) else { break };
let event_bytes: Vec<u8> = buf.drain(..pos.end).collect();
let event = String::from_utf8_lossy(&event_bytes[..pos.start]).to_string();
let mut data_lines: Vec<&str> = Vec::new();
for line in event.lines() {
if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.strip_prefix(' ').unwrap_or(rest));
}
}
if data_lines.is_empty() {
continue;
}
let payload = data_lines.join("\n");
if payload.trim().is_empty() {
continue;
}
if from_tx.send(payload).await.is_err() {
return Ok(());
}
}
}
Ok(())
}
struct EventBoundary {
start: usize,
end: usize,
}
fn find_event_boundary(buf: &[u8]) -> Option<EventBoundary> {
let lf = buf.windows(2).position(|w| w == b"\n\n");
let crlf = buf.windows(4).position(|w| w == b"\r\n\r\n");
match (lf, crlf) {
(Some(a), Some(b)) if b < a => Some(EventBoundary { start: b, end: b + 4 }),
(Some(a), _) => Some(EventBoundary { start: a, end: a + 2 }),
(None, Some(b)) => Some(EventBoundary { start: b, end: b + 4 }),
(None, None) => None,
}
}
fn transport_error_for(frame: &str, detail: &str) -> Option<String> {
let parsed: serde_json::Value = serde_json::from_str(frame).ok()?;
let id = parsed.get("id")?.clone();
if id.is_null() {
return None;
}
Some(
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -32000,
"message": "shield_upstream_transport_error",
"data": { "detail": detail }
}
})
.to_string(),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_header_ok() {
let (k, v) = parse_header("Authorization: Bearer abc").unwrap();
assert_eq!(k, "Authorization");
assert_eq!(v, "Bearer abc");
}
#[test]
fn parse_header_rejects_missing_colon() {
assert!(parse_header("not-a-header").is_err());
}
#[test]
fn event_boundary_lf() {
let buf = b"data: {\"a\":1}\n\nrest";
let b = find_event_boundary(buf).unwrap();
assert_eq!(&buf[..b.start], b"data: {\"a\":1}");
assert_eq!(b.end - b.start, 2);
}
#[test]
fn event_boundary_crlf() {
let buf = b"data: x\r\n\r\n";
let b = find_event_boundary(buf).unwrap();
assert_eq!(&buf[..b.start], b"data: x");
assert_eq!(b.end - b.start, 4);
}
#[test]
fn transport_error_only_for_requests() {
assert!(transport_error_for(r#"{"jsonrpc":"2.0","id":1,"method":"x"}"#, "boom").is_some());
assert!(transport_error_for(r#"{"jsonrpc":"2.0","method":"notify"}"#, "boom").is_none());
assert!(transport_error_for("not json", "boom").is_none());
}
}