Skip to main content

mcpr_core/proxy/pipeline/middleware/
session.rs

1//! Session lifecycle middleware — covers both request and response phases.
2//!
3//! * `SessionTouchMiddleware` (request): bumps `last_seen` and flips Initialized→Active
4//!   when the client sends `notifications/initialized`.
5//! * `DeleteSessionEndMiddleware` (request): on `DELETE` with an `mcp-session-id`
6//!   header, emits `SessionEnd` and removes the session.
7//! * `SessionStartMiddleware` (response): on a successful `initialize` response,
8//!   creates the session, stores parsed client info, and emits `SessionStart`.
9
10use crate::event::{ProxyEvent, SessionEndEvent, SessionStartEvent};
11use crate::protocol::McpMethod;
12use crate::protocol::session::{SessionState, SessionStore};
13use async_trait::async_trait;
14use axum::http::Method;
15use axum::response::Response;
16
17use super::{RequestMiddleware, ResponseMiddleware};
18use crate::proxy::pipeline::context::{RequestContext, ResponseContext};
19use crate::proxy::pipeline::emit::normalize_platform;
20use crate::proxy::proxy_state::ProxyState;
21
22#[cfg(test)]
23#[allow(non_snake_case)]
24mod tests {
25    use std::sync::Arc;
26    use std::time::{Duration, Instant};
27    use tokio::sync::RwLock;
28
29    use super::*;
30    use crate::protocol::schema_manager::{MemorySchemaStore, SchemaManager};
31    use crate::protocol::session::{MemorySessionStore, SessionInfo};
32    use crate::proxy::forwarding::UpstreamClient;
33    use crate::proxy::{CspConfig, RewriteConfig, new_shared_health};
34
35    fn test_state() -> ProxyState {
36        ProxyState {
37            name: "t".into(),
38            mcp_upstream: "http://u".into(),
39            upstream: UpstreamClient {
40                http_client: reqwest::Client::builder().build().unwrap(),
41                semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
42                request_timeout: Duration::from_secs(1),
43            },
44            max_request_body: 1024,
45            max_response_body: 1024,
46            rewrite_config: Arc::new(RwLock::new(RewriteConfig {
47                proxy_url: "http://p".into(),
48                proxy_domain: "p".into(),
49                mcp_upstream: "http://u".into(),
50                csp: CspConfig::default(),
51            })),
52            widget_source: None,
53            sessions: MemorySessionStore::new(),
54            schema_manager: Arc::new(SchemaManager::new("t", MemorySchemaStore::new())),
55            health: new_shared_health(),
56            event_bus: crate::event::EventManager::new().start().bus,
57        }
58    }
59
60    fn ctx_with(method: Option<McpMethod>, session_id: Option<&str>) -> RequestContext {
61        RequestContext {
62            start: Instant::now(),
63            http_method: Method::POST,
64            path: "/mcp".into(),
65            request_size: 0,
66            wants_sse: false,
67            session_id: session_id.map(String::from),
68            jsonrpc: None,
69            mcp_method: method,
70            mcp_method_str: None,
71            tool: None,
72            is_batch: false,
73            client_info_from_init: None,
74            client_name: None,
75            client_version: None,
76            tags: Vec::new(),
77        }
78    }
79
80    #[tokio::test]
81    async fn session_touch__initialized_transitions_to_active() {
82        let state = test_state();
83        state.sessions.create("sess-1").await;
84        // Initialize first so the session reaches the `Initialized` state —
85        // that's the one that transitions to Active.
86        state
87            .sessions
88            .update_state("sess-1", SessionState::Initialized)
89            .await;
90
91        let mut ctx = ctx_with(Some(McpMethod::Initialized), Some("sess-1"));
92        assert!(
93            SessionTouchMiddleware
94                .on_request(&state, &mut ctx)
95                .await
96                .is_none()
97        );
98
99        let info: SessionInfo = state.sessions.get("sess-1").await.unwrap();
100        assert_eq!(info.state, SessionState::Active);
101    }
102}
103
104pub struct SessionTouchMiddleware;
105
106#[async_trait]
107impl RequestMiddleware for SessionTouchMiddleware {
108    async fn on_request(&self, state: &ProxyState, ctx: &mut RequestContext) -> Option<Response> {
109        let sid = ctx.session_id.as_deref()?;
110        state.sessions.touch(sid).await;
111        if ctx.mcp_method == Some(McpMethod::Initialized) {
112            state.sessions.update_state(sid, SessionState::Active).await;
113        }
114        None
115    }
116}
117
118pub struct DeleteSessionEndMiddleware;
119
120#[async_trait]
121impl RequestMiddleware for DeleteSessionEndMiddleware {
122    async fn on_request(&self, state: &ProxyState, ctx: &mut RequestContext) -> Option<Response> {
123        if ctx.http_method != Method::DELETE {
124            return None;
125        }
126        let sid = ctx.session_id.as_deref()?;
127        state
128            .event_bus
129            .emit(ProxyEvent::SessionEnd(SessionEndEvent {
130                session_id: sid.to_string(),
131                ts: chrono::Utc::now().timestamp_millis(),
132            }));
133        state.sessions.remove(sid).await;
134        None
135    }
136}
137
138pub struct SessionStartMiddleware;
139
140#[async_trait]
141impl ResponseMiddleware for SessionStartMiddleware {
142    async fn on_response(
143        &self,
144        state: &ProxyState,
145        req: &RequestContext,
146        resp: &mut ResponseContext,
147    ) {
148        if req.mcp_method != Some(McpMethod::Initialize) || resp.status >= 400 {
149            return;
150        }
151        let Some(sid) = req.session_id.as_deref() else {
152            return;
153        };
154
155        state.sessions.create(sid).await;
156        state
157            .sessions
158            .update_state(sid, SessionState::Initialized)
159            .await;
160
161        let (client_name, client_version, client_platform) =
162            if let Some(info) = req.client_info_from_init.clone() {
163                let platform = normalize_platform(&info.name).to_string();
164                let name = info.name.clone();
165                let version = info.version.clone();
166                state.sessions.set_client_info(sid, info).await;
167                (Some(name), version, Some(platform))
168            } else {
169                (None, None, None)
170            };
171
172        state
173            .event_bus
174            .emit(ProxyEvent::SessionStart(SessionStartEvent {
175                session_id: sid.to_string(),
176                proxy: state.name.clone(),
177                ts: chrono::Utc::now().timestamp_millis(),
178                client_name,
179                client_version,
180                client_platform,
181            }));
182    }
183}