use axum::http::{HeaderMap, Method, header};
use super::context::RequestContext;
use crate::protocol::McpMethod;
#[derive(Debug, Clone, PartialEq)]
pub enum RequestKind {
WidgetHtml(String),
WidgetList,
WidgetAsset,
McpPostBuffer(McpMethod),
McpPostStream(McpMethod),
McpSseStream,
Passthrough,
}
pub fn classify_request(
ctx: &RequestContext,
headers: &HeaderMap,
has_widgets: bool,
) -> RequestKind {
let path = ctx.path.as_str();
if ctx.http_method == Method::GET {
if let Some(name) = path
.strip_prefix("/widgets/")
.and_then(|s| s.strip_suffix(".html"))
{
return RequestKind::WidgetHtml(name.to_string());
}
if path == "/widgets" || path == "/widgets/" {
return RequestKind::WidgetList;
}
}
if ctx.http_method == Method::GET && has_widgets && is_widget_asset(path, headers) {
return RequestKind::WidgetAsset;
}
if ctx.http_method == Method::POST && ctx.jsonrpc.is_some() {
let method = ctx
.mcp_method
.clone()
.unwrap_or_else(|| McpMethod::Unknown(String::new()));
return if method.needs_response_buffering() {
RequestKind::McpPostBuffer(method)
} else {
RequestKind::McpPostStream(method)
};
}
if ctx.http_method == Method::GET && ctx.wants_sse {
return RequestKind::McpSseStream;
}
RequestKind::Passthrough
}
fn is_widget_asset(path: &str, headers: &HeaderMap) -> bool {
let ext = path.rsplit('.').next().unwrap_or("");
if matches!(
ext,
"js" | "mjs"
| "css"
| "html"
| "svg"
| "png"
| "jpg"
| "jpeg"
| "gif"
| "ico"
| "woff"
| "woff2"
| "ttf"
| "eot"
| "map"
| "webp"
) {
return true;
}
if let Some(accept) = headers.get(header::ACCEPT).and_then(|v| v.to_str().ok())
&& (accept.contains("text/html")
|| accept.contains("text/css")
|| accept.contains("image/")
|| accept.contains("font/")
|| accept.contains("application/javascript"))
{
return true;
}
false
}
#[cfg(test)]
#[allow(non_snake_case)]
mod tests {
use std::time::Instant;
use axum::body::Bytes;
use super::*;
use crate::proxy::pipeline::parser::build_request_context;
fn mk_ctx(method: Method, path: &str, headers: &HeaderMap, body: &[u8]) -> RequestContext {
build_request_context(
method,
path,
headers,
&Bytes::copy_from_slice(body),
Instant::now(),
)
}
#[test]
fn classify__tools_call_needs_buffer() {
let body = br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"echo"}}"#;
let ctx = mk_ctx(Method::POST, "/mcp", &HeaderMap::new(), body);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::McpPostBuffer(McpMethod::ToolsCall)
);
}
#[test]
fn classify__initialize_buffers_for_schema_capture() {
let body = br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#;
let ctx = mk_ctx(Method::POST, "/mcp", &HeaderMap::new(), body);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::McpPostBuffer(McpMethod::Initialize)
);
}
#[test]
fn classify__ping_streams() {
let body = br#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;
let ctx = mk_ctx(Method::POST, "/mcp", &HeaderMap::new(), body);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::McpPostStream(McpMethod::Ping)
);
}
#[test]
fn classify__all_buffered_methods() {
for (method_str, expected) in [
("initialize", McpMethod::Initialize),
("tools/list", McpMethod::ToolsList),
("tools/call", McpMethod::ToolsCall),
("resources/list", McpMethod::ResourcesList),
(
"resources/templates/list",
McpMethod::ResourcesTemplatesList,
),
("resources/read", McpMethod::ResourcesRead),
("prompts/list", McpMethod::PromptsList),
] {
let body =
format!(r#"{{"jsonrpc":"2.0","id":1,"method":"{method_str}"}}"#).into_bytes();
let ctx = mk_ctx(Method::POST, "/mcp", &HeaderMap::new(), &body);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::McpPostBuffer(expected),
"method {method_str} should route to McpPostBuffer"
);
}
}
#[test]
fn classify__non_mcp_post_is_passthrough() {
let body = br#"{"client_name":"My App"}"#;
let ctx = mk_ctx(Method::POST, "/register", &HeaderMap::new(), body);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::Passthrough
);
let ctx = mk_ctx(
Method::POST,
"/token",
&HeaderMap::new(),
b"grant_type=x&client_id=y",
);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::Passthrough
);
}
#[test]
fn classify__sse_stream() {
let mut headers = HeaderMap::new();
headers.insert(header::ACCEPT, "text/event-stream".parse().unwrap());
let ctx = mk_ctx(Method::GET, "/mcp", &headers, b"");
assert_eq!(
classify_request(&ctx, &headers, false),
RequestKind::McpSseStream
);
}
#[test]
fn classify__get_html_is_passthrough() {
let mut headers = HeaderMap::new();
headers.insert(header::ACCEPT, "text/html".parse().unwrap());
let ctx = mk_ctx(Method::GET, "/mcp", &headers, b"");
assert_eq!(
classify_request(&ctx, &headers, false),
RequestKind::Passthrough
);
}
#[test]
fn classify__sse_accept_wins_over_widgets() {
let mut headers = HeaderMap::new();
headers.insert(header::ACCEPT, "text/event-stream".parse().unwrap());
let ctx = mk_ctx(Method::GET, "/mcp", &headers, b"");
assert_eq!(
classify_request(&ctx, &headers, true),
RequestKind::McpSseStream
);
}
#[test]
fn classify__widget_html_matches_prefix() {
let ctx = mk_ctx(Method::GET, "/widgets/foo.html", &HeaderMap::new(), b"");
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::WidgetHtml("foo".to_string())
);
}
#[test]
fn classify__widget_list_at_widgets_root() {
let ctx = mk_ctx(Method::GET, "/widgets", &HeaderMap::new(), b"");
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::WidgetList
);
}
#[test]
fn classify__widget_asset_js_with_widgets() {
let ctx = mk_ctx(Method::GET, "/assets/main.js", &HeaderMap::new(), b"");
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), true),
RequestKind::WidgetAsset
);
}
#[test]
fn classify__widget_asset_image_accept_with_widgets() {
let mut headers = HeaderMap::new();
headers.insert(header::ACCEPT, "image/png".parse().unwrap());
let ctx = mk_ctx(Method::GET, "/logo", &headers, b"");
assert_eq!(
classify_request(&ctx, &headers, true),
RequestKind::WidgetAsset
);
}
#[test]
fn classify__widget_asset_gated_by_has_widgets() {
let ctx = mk_ctx(Method::GET, "/assets/main.js", &HeaderMap::new(), b"");
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), false),
RequestKind::Passthrough
);
}
#[test]
fn classify__well_known_not_widget_asset() {
let ctx = mk_ctx(
Method::GET,
"/.well-known/oauth-authorization-server",
&HeaderMap::new(),
b"",
);
assert_eq!(
classify_request(&ctx, &HeaderMap::new(), true),
RequestKind::Passthrough
);
}
}