Skip to main content

mcpr_core/proxy/pipeline/middlewares/
session_touch.rs

1//! Request-side middleware: touch the session store, stash the
2//! originating `ClientMethod` on `Working` for response middlewares.
3
4use async_trait::async_trait;
5
6use crate::protocol::mcp::{ClientKind, ClientNotifMethod};
7use crate::protocol::session::{SessionState, SessionStore};
8use crate::proxy::pipeline::middleware::{Flow, RequestMiddleware};
9use crate::proxy::pipeline::values::{Context, Request};
10
11pub struct SessionTouchMiddleware;
12
13#[async_trait]
14impl RequestMiddleware for SessionTouchMiddleware {
15    fn name(&self) -> &'static str {
16        "session_touch"
17    }
18
19    async fn on_request(&self, req: Request, cx: &mut Context) -> Flow {
20        let Request::Mcp(ref mcp) = req else {
21            return Flow::Continue(req);
22        };
23
24        if let ClientKind::Request(m) = &mcp.kind {
25            cx.working.request_method = Some(m.clone());
26        }
27
28        if let Some(sid) = mcp.session_hint.as_ref() {
29            let store = &cx.intake.proxy.sessions;
30            store.touch(sid.as_str()).await;
31            if matches!(
32                mcp.kind,
33                ClientKind::Notification(ClientNotifMethod::Initialized)
34            ) {
35                store.update_state(sid.as_str(), SessionState::Active).await;
36            }
37            cx.working.session = store.get(sid.as_str()).await;
38        }
39
40        Flow::Continue(req)
41    }
42}
43
44#[cfg(test)]
45#[allow(non_snake_case)]
46mod tests {
47    use super::*;
48
49    use axum::body::Body;
50    use axum::http::{HeaderMap, Method};
51    use serde_json::Value;
52
53    use crate::protocol::mcp::{ClientMethod, ToolsMethod};
54    use crate::proxy::pipeline::middlewares::test_support::{
55        mcp_notification, mcp_request, test_context, test_proxy_state,
56    };
57    use crate::proxy::pipeline::values::{RawRequest, Request};
58
59    #[tokio::test]
60    async fn on_request__non_mcp_passthrough() {
61        let proxy = test_proxy_state();
62        let mut cx = test_context(proxy);
63        let req = Request::Raw(RawRequest {
64            method: Method::GET,
65            path: "/health".into(),
66            body: Body::empty(),
67            headers: HeaderMap::new(),
68        });
69
70        let flow = SessionTouchMiddleware.on_request(req, &mut cx).await;
71        assert!(matches!(flow, Flow::Continue(Request::Raw(_))));
72        assert!(cx.working.session.is_none());
73        assert!(cx.working.request_method.is_none());
74    }
75
76    #[tokio::test]
77    async fn on_request__mcp_no_session_hint_still_stashes_method() {
78        let proxy = test_proxy_state();
79        let mut cx = test_context(proxy);
80        let req = mcp_request("tools/list", Value::Null, None);
81
82        SessionTouchMiddleware.on_request(req, &mut cx).await;
83        assert_eq!(
84            cx.working.request_method,
85            Some(ClientMethod::Tools(ToolsMethod::List))
86        );
87        assert!(cx.working.session.is_none());
88    }
89
90    #[tokio::test]
91    async fn on_request__known_session_bumps_request_count() {
92        let proxy = test_proxy_state();
93        proxy.sessions.create("sess-1").await;
94        let mut cx = test_context(proxy.clone());
95        let req = mcp_request("tools/list", Value::Null, Some("sess-1"));
96
97        SessionTouchMiddleware.on_request(req, &mut cx).await;
98        let info = proxy.sessions.get("sess-1").await.unwrap();
99        assert_eq!(info.request_count, 1);
100        assert_eq!(cx.working.session.as_ref().unwrap().id, "sess-1");
101    }
102
103    #[tokio::test]
104    async fn on_request__initialized_notification_flips_state_to_active() {
105        let proxy = test_proxy_state();
106        proxy.sessions.create("sess-2").await;
107        proxy
108            .sessions
109            .update_state("sess-2", SessionState::Initialized)
110            .await;
111        let mut cx = test_context(proxy.clone());
112        let req = mcp_notification("notifications/initialized", Some("sess-2"));
113
114        SessionTouchMiddleware.on_request(req, &mut cx).await;
115        let info = proxy.sessions.get("sess-2").await.unwrap();
116        assert_eq!(info.state, SessionState::Active);
117        assert!(cx.working.request_method.is_none());
118    }
119
120    #[tokio::test]
121    async fn on_request__unknown_session_id_is_noop() {
122        let proxy = test_proxy_state();
123        let mut cx = test_context(proxy.clone());
124        let req = mcp_request("tools/list", Value::Null, Some("missing"));
125
126        SessionTouchMiddleware.on_request(req, &mut cx).await;
127        assert!(cx.working.session.is_none());
128        assert!(proxy.sessions.get("missing").await.is_none());
129    }
130}