mcpr_core/proxy/pipeline/middleware/
session.rs1use crate::event::{ProxyEvent, SessionEndEvent, SessionStartEvent};
11use crate::protocol::McpMethod;
12use crate::protocol::session::{SessionState, SessionStore};
13use async_trait::async_trait;
14use axum::http::Method;
15use axum::response::Response;
16
17use super::{RequestMiddleware, ResponseMiddleware};
18use crate::proxy::pipeline::context::{RequestContext, ResponseContext};
19use crate::proxy::pipeline::emit::normalize_platform;
20use crate::proxy::proxy_state::ProxyState;
21
22#[cfg(test)]
23#[allow(non_snake_case)]
24mod tests {
25 use std::sync::Arc;
26 use std::time::{Duration, Instant};
27 use tokio::sync::RwLock;
28
29 use super::*;
30 use crate::protocol::schema_manager::{MemorySchemaStore, SchemaManager};
31 use crate::protocol::session::{MemorySessionStore, SessionInfo};
32 use crate::proxy::forwarding::UpstreamClient;
33 use crate::proxy::{CspConfig, RewriteConfig, new_shared_health};
34
35 fn test_state() -> ProxyState {
36 ProxyState {
37 name: "t".into(),
38 mcp_upstream: "http://u".into(),
39 upstream: UpstreamClient {
40 http_client: reqwest::Client::builder().build().unwrap(),
41 semaphore: Arc::new(tokio::sync::Semaphore::new(1)),
42 request_timeout: Duration::from_secs(1),
43 },
44 max_request_body: 1024,
45 max_response_body: 1024,
46 rewrite_config: Arc::new(RwLock::new(RewriteConfig {
47 proxy_url: "http://p".into(),
48 proxy_domain: "p".into(),
49 mcp_upstream: "http://u".into(),
50 csp: CspConfig::default(),
51 })),
52 widget_source: None,
53 sessions: MemorySessionStore::new(),
54 schema_manager: Arc::new(SchemaManager::new("t", MemorySchemaStore::new())),
55 health: new_shared_health(),
56 event_bus: crate::event::EventManager::new().start().bus,
57 }
58 }
59
60 fn ctx_with(method: Option<McpMethod>, session_id: Option<&str>) -> RequestContext {
61 RequestContext {
62 start: Instant::now(),
63 http_method: Method::POST,
64 path: "/mcp".into(),
65 request_size: 0,
66 wants_sse: false,
67 session_id: session_id.map(String::from),
68 jsonrpc: None,
69 mcp_method: method,
70 mcp_method_str: None,
71 tool: None,
72 is_batch: false,
73 client_info_from_init: None,
74 client_name: None,
75 client_version: None,
76 tags: Vec::new(),
77 }
78 }
79
80 #[tokio::test]
81 async fn session_touch__initialized_transitions_to_active() {
82 let state = test_state();
83 state.sessions.create("sess-1").await;
84 state
87 .sessions
88 .update_state("sess-1", SessionState::Initialized)
89 .await;
90
91 let mut ctx = ctx_with(Some(McpMethod::Initialized), Some("sess-1"));
92 assert!(
93 SessionTouchMiddleware
94 .on_request(&state, &mut ctx)
95 .await
96 .is_none()
97 );
98
99 let info: SessionInfo = state.sessions.get("sess-1").await.unwrap();
100 assert_eq!(info.state, SessionState::Active);
101 }
102}
103
104pub struct SessionTouchMiddleware;
105
106#[async_trait]
107impl RequestMiddleware for SessionTouchMiddleware {
108 async fn on_request(&self, state: &ProxyState, ctx: &mut RequestContext) -> Option<Response> {
109 let sid = ctx.session_id.as_deref()?;
110 state.sessions.touch(sid).await;
111 if ctx.mcp_method == Some(McpMethod::Initialized) {
112 state.sessions.update_state(sid, SessionState::Active).await;
113 }
114 None
115 }
116}
117
118pub struct DeleteSessionEndMiddleware;
119
120#[async_trait]
121impl RequestMiddleware for DeleteSessionEndMiddleware {
122 async fn on_request(&self, state: &ProxyState, ctx: &mut RequestContext) -> Option<Response> {
123 if ctx.http_method != Method::DELETE {
124 return None;
125 }
126 let sid = ctx.session_id.as_deref()?;
127 state
128 .event_bus
129 .emit(ProxyEvent::SessionEnd(SessionEndEvent {
130 session_id: sid.to_string(),
131 ts: chrono::Utc::now().timestamp_millis(),
132 }));
133 state.sessions.remove(sid).await;
134 None
135 }
136}
137
138pub struct SessionStartMiddleware;
139
140#[async_trait]
141impl ResponseMiddleware for SessionStartMiddleware {
142 async fn on_response(
143 &self,
144 state: &ProxyState,
145 req: &RequestContext,
146 resp: &mut ResponseContext,
147 ) {
148 if req.mcp_method != Some(McpMethod::Initialize) || resp.status >= 400 {
149 return;
150 }
151 let Some(sid) = req.session_id.as_deref() else {
152 return;
153 };
154
155 state.sessions.create(sid).await;
156 state
157 .sessions
158 .update_state(sid, SessionState::Initialized)
159 .await;
160
161 let (client_name, client_version, client_platform) =
162 if let Some(info) = req.client_info_from_init.clone() {
163 let platform = normalize_platform(&info.name).to_string();
164 let name = info.name.clone();
165 let version = info.version.clone();
166 state.sessions.set_client_info(sid, info).await;
167 (Some(name), version, Some(platform))
168 } else {
169 (None, None, None)
170 };
171
172 state
173 .event_bus
174 .emit(ProxyEvent::SessionStart(SessionStartEvent {
175 session_id: sid.to_string(),
176 proxy: state.name.clone(),
177 ts: chrono::Utc::now().timestamp_millis(),
178 client_name,
179 client_version,
180 client_platform,
181 }));
182 }
183}