mcpr_core/proxy/pipeline/middlewares/
session_touch.rs1use async_trait::async_trait;
5
6use crate::protocol::mcp::{ClientKind, ClientNotifMethod};
7use crate::protocol::session::{SessionState, SessionStore};
8use crate::proxy::pipeline::middleware::{Flow, RequestMiddleware};
9use crate::proxy::pipeline::values::{Context, Request};
10
11pub struct SessionTouchMiddleware;
12
13#[async_trait]
14impl RequestMiddleware for SessionTouchMiddleware {
15 fn name(&self) -> &'static str {
16 "session_touch"
17 }
18
19 async fn on_request(&self, req: Request, cx: &mut Context) -> Flow {
20 let Request::Mcp(ref mcp) = req else {
21 return Flow::Continue(req);
22 };
23
24 if let ClientKind::Request(m) = &mcp.kind {
25 cx.working.request_method = Some(m.clone());
26 }
27
28 if let Some(sid) = mcp.session_hint.as_ref() {
29 let store = &cx.intake.proxy.sessions;
30 store.touch(sid.as_str()).await;
31 if matches!(
32 mcp.kind,
33 ClientKind::Notification(ClientNotifMethod::Initialized)
34 ) {
35 store.update_state(sid.as_str(), SessionState::Active).await;
36 }
37 cx.working.session = store.get(sid.as_str()).await;
38 }
39
40 Flow::Continue(req)
41 }
42}
43
44#[cfg(test)]
45#[allow(non_snake_case)]
46mod tests {
47 use super::*;
48
49 use axum::body::Body;
50 use axum::http::{HeaderMap, Method};
51 use serde_json::Value;
52
53 use crate::protocol::mcp::{ClientMethod, ToolsMethod};
54 use crate::proxy::pipeline::middlewares::test_support::{
55 mcp_notification, mcp_request, test_context, test_proxy_state,
56 };
57 use crate::proxy::pipeline::values::{RawRequest, Request};
58
59 #[tokio::test]
60 async fn on_request__non_mcp_passthrough() {
61 let proxy = test_proxy_state();
62 let mut cx = test_context(proxy);
63 let req = Request::Raw(RawRequest {
64 method: Method::GET,
65 path: "/health".into(),
66 body: Body::empty(),
67 headers: HeaderMap::new(),
68 });
69
70 let flow = SessionTouchMiddleware.on_request(req, &mut cx).await;
71 assert!(matches!(flow, Flow::Continue(Request::Raw(_))));
72 assert!(cx.working.session.is_none());
73 assert!(cx.working.request_method.is_none());
74 }
75
76 #[tokio::test]
77 async fn on_request__mcp_no_session_hint_still_stashes_method() {
78 let proxy = test_proxy_state();
79 let mut cx = test_context(proxy);
80 let req = mcp_request("tools/list", Value::Null, None);
81
82 SessionTouchMiddleware.on_request(req, &mut cx).await;
83 assert_eq!(
84 cx.working.request_method,
85 Some(ClientMethod::Tools(ToolsMethod::List))
86 );
87 assert!(cx.working.session.is_none());
88 }
89
90 #[tokio::test]
91 async fn on_request__known_session_bumps_request_count() {
92 let proxy = test_proxy_state();
93 proxy.sessions.create("sess-1").await;
94 let mut cx = test_context(proxy.clone());
95 let req = mcp_request("tools/list", Value::Null, Some("sess-1"));
96
97 SessionTouchMiddleware.on_request(req, &mut cx).await;
98 let info = proxy.sessions.get("sess-1").await.unwrap();
99 assert_eq!(info.request_count, 1);
100 assert_eq!(cx.working.session.as_ref().unwrap().id, "sess-1");
101 }
102
103 #[tokio::test]
104 async fn on_request__initialized_notification_flips_state_to_active() {
105 let proxy = test_proxy_state();
106 proxy.sessions.create("sess-2").await;
107 proxy
108 .sessions
109 .update_state("sess-2", SessionState::Initialized)
110 .await;
111 let mut cx = test_context(proxy.clone());
112 let req = mcp_notification("notifications/initialized", Some("sess-2"));
113
114 SessionTouchMiddleware.on_request(req, &mut cx).await;
115 let info = proxy.sessions.get("sess-2").await.unwrap();
116 assert_eq!(info.state, SessionState::Active);
117 assert!(cx.working.request_method.is_none());
118 }
119
120 #[tokio::test]
121 async fn on_request__unknown_session_id_is_noop() {
122 let proxy = test_proxy_state();
123 let mut cx = test_context(proxy.clone());
124 let req = mcp_request("tools/list", Value::Null, Some("missing"));
125
126 SessionTouchMiddleware.on_request(req, &mut cx).await;
127 assert!(cx.working.session.is_none());
128 assert!(proxy.sessions.get("missing").await.is_none());
129 }
130}