use std::time::Instant;
use axum::body::{Body, Bytes};
use axum::http::{HeaderMap, Method};
use axum::response::Response;
use serde_json::Value;
use crate::protocol::{self as jsonrpc, McpMethod};
use crate::proxy::ProxyState;
use crate::proxy::forwarding::{build_response, read_body_capped};
use crate::proxy::pipeline::context::RequestContext;
use crate::proxy::pipeline::emit::{ResponseSummary, emit_request_event};
use crate::proxy::pipeline::steps::{health, rewrite, schema, session, widget};
use crate::proxy::sse::{extract_json_from_sse, wrap_as_sse};
use super::{
Stage, StageTimer, capture_session_id, emit_upstream_error, forward_or_502,
populate_client_info,
};
pub async fn forward_and_buffer(
state: &ProxyState,
ctx: &mut RequestContext,
method: &McpMethod,
headers: &HeaderMap,
body: &Bytes,
) -> Response {
let upstream_url = state.mcp_upstream.trim_end_matches('/').to_string();
let upstream_start = Instant::now();
let resp = match forward_or_502(
&state.upstream,
&upstream_url,
Method::POST,
headers,
body,
false,
)
.await
{
Ok(r) => r,
Err(e) => return emit_upstream_error(state, ctx, upstream_start, e),
};
let status = resp.status().as_u16();
let upstream_headers = resp.headers().clone();
capture_session_id(ctx, &upstream_headers);
let raw = match read_body_capped(resp, state.max_response_body).await {
Ok(b) => b,
Err(err_resp) => return err_resp,
};
let upstream_us = upstream_start.elapsed().as_micros() as u64;
let mut timer = StageTimer::new();
let (json_bytes, was_sse): (Vec<u8>, bool) = match extract_json_from_sse(&raw) {
Some(v) => (v, true),
None => (raw.to_vec(), false),
};
timer.mark(Stage::SseUnwrap);
let mut parsed: Option<Value> = serde_json::from_slice(&json_bytes).ok();
let rpc_error = parsed
.as_ref()
.and_then(|v| jsonrpc::extract_error_code(v).map(|(c, m)| (c, m.to_string())));
timer.mark(Stage::JsonParse);
let mut mutated = false;
if let Some(json) = parsed.as_mut() {
schema::spawn_ingest(state, ctx, json);
schema::mark_stale_if_listchanged(state, json);
timer.mark(Stage::Schema);
if widget::maybe_overlay(state, ctx, json).await {
mutated = true;
}
timer.mark(Stage::WidgetOverlay);
let markers_present = ctx.mcp_method_str.is_some() && rewrite::has_markers(&json_bytes);
timer.mark(Stage::MarkerScan);
if markers_present
&& let Some(method_str) = ctx.mcp_method_str.as_deref()
&& rewrite::rewrite_in_place(&state.rewrite_config, method_str, json)
{
mutated = true;
}
timer.mark(Stage::Rewrite);
}
health::track_post_response(state, method, status);
session::maybe_record_start(state, ctx, method, status).await;
populate_client_info(state, ctx).await;
timer.mark(Stage::SideEffects);
let final_body: Vec<u8> = if mutated {
let bytes = match parsed.as_ref().and_then(|v| serde_json::to_vec(v).ok()) {
Some(serialized) if was_sse => wrap_as_sse(&serialized),
Some(serialized) => serialized,
None => raw.to_vec(),
};
timer.mark(Stage::Reserialize);
bytes
} else {
raw.to_vec()
};
if parsed.is_some() {
ctx.tags.push("rewritten");
if was_sse {
ctx.tags.push("sse");
}
} else {
ctx.tags.push("passthrough");
}
let mut summary = ResponseSummary {
status,
response_size: Some(final_body.len() as u64),
upstream_us: Some(upstream_us),
error_code: None,
error_msg: None,
stage_timings: timer.finish(),
};
if let Some((code, msg)) = rpc_error {
summary = summary.with_rpc_error(code, msg);
}
emit_request_event(state, ctx, &summary);
build_response(status, &upstream_headers, Body::from(final_body))
}