1use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26 transport::{
27 WorkerTransport,
28 common::server_side_http::ServerSseMessage,
29 streamable_http_server::session::{
30 SessionId, SessionManager,
31 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32 },
33 },
34};
35use std::sync::Arc;
36use std::time::Duration;
37use tracing::{debug, info, warn};
38
39use super::proxy_handler::ProxyHandler;
40
41#[derive(Debug, Clone)]
43struct SessionMetadata {
44 backend_version: u64,
45}
46
47pub struct ProxyAwareSessionManager {
55 inner: LocalSessionManager,
56 handler: Arc<ProxyHandler>,
57 session_versions: DashMap<String, SessionMetadata>,
58}
59
60impl ProxyAwareSessionManager {
61 const DEFAULT_SESSION_KEEP_ALIVE_SECS: u64 = 30 * 60;
64
65 pub fn new(handler: Arc<ProxyHandler>) -> Self {
66 Self::with_keep_alive(handler, Duration::from_secs(Self::DEFAULT_SESSION_KEEP_ALIVE_SECS))
67 }
68
69 pub fn with_keep_alive(handler: Arc<ProxyHandler>, keep_alive: Duration) -> Self {
70 info!(
71 "[Session Manager] Create ProxyAwareSessionManager - MCP ID: {}, keep_alive: {}s",
72 handler.mcp_id(),
73 keep_alive.as_secs()
74 );
75 let mut inner = LocalSessionManager::default();
76 inner.session_config.keep_alive = Some(keep_alive);
77 Self {
78 inner,
79 handler,
80 session_versions: DashMap::new(),
81 }
82 }
83
84 fn check_backend_version(&self, session_id: &SessionId) -> bool {
85 if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
86 let current_version = self.handler.get_backend_version();
87 if meta.backend_version != current_version {
88 warn!(
89 "[Session version mismatch] session_id={}, creation version={}, current version={}, MCP ID: {}",
90 session_id,
91 meta.backend_version,
92 current_version,
93 self.handler.mcp_id()
94 );
95 return false;
96 }
97 }
98 true
99 }
100}
101
102impl SessionManager for ProxyAwareSessionManager {
104 type Error = LocalSessionManagerError;
105 type Transport = WorkerTransport<LocalSessionWorker>;
106
107 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
108 let (session_id, transport) = self.inner.create_session().await?;
109
110 let version = self.handler.get_backend_version();
111 self.session_versions.insert(
112 session_id.to_string(),
113 SessionMetadata {
114 backend_version: version,
115 },
116 );
117
118 info!(
119 "[SessionCreated] session_id={}, backend_version={}, MCP ID: {}",
120 session_id,
121 version,
122 self.handler.mcp_id()
123 );
124
125 Ok((session_id, transport))
126 }
127
128 async fn initialize_session(
129 &self,
130 id: &SessionId,
131 message: ClientJsonRpcMessage,
132 ) -> Result<ServerJsonRpcMessage, Self::Error> {
133 if !self.handler.is_backend_available() {
134 warn!(
135 "[Session initialization failed] session_id={}, reason: backend is unavailable, MCP ID: {}",
136 id,
137 self.handler.mcp_id()
138 );
139 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
140 }
141
142 if !self.check_backend_version(id) {
143 warn!(
144 "[Session initialization failed] session_id={}, reason: version mismatch, MCP ID: {}",
145 id,
146 self.handler.mcp_id()
147 );
148 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
149 }
150
151 debug!(
152 "[Session initialization] session_id={}, MCP ID: {}",
153 id,
154 self.handler.mcp_id()
155 );
156 self.inner.initialize_session(id, message).await
157 }
158
159 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
160 if !self.check_backend_version(id) {
161 return Ok(false);
162 }
163 self.inner.has_session(id).await
164 }
165
166 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
167 info!(
168 "[SessionClosed] session_id={}, MCP ID: {}",
169 id,
170 self.handler.mcp_id()
171 );
172 self.session_versions.remove(id.as_ref());
173 self.inner.close_session(id).await
174 }
175
176 async fn create_stream(
177 &self,
178 id: &SessionId,
179 message: ClientJsonRpcMessage,
180 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
181 if !self.handler.is_backend_available() {
182 warn!(
183 "[Stream creation failed] session_id={}, reason: backend is unavailable, MCP ID: {}",
184 id,
185 self.handler.mcp_id()
186 );
187 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
188 }
189
190 if !self.check_backend_version(id) {
191 warn!(
192 "[Stream creation failed] session_id={}, reason: version mismatch, MCP ID: {}",
193 id,
194 self.handler.mcp_id()
195 );
196 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
197 }
198
199 debug!(
200 "[Stream creation] session_id={}, MCP ID: {}",
201 id,
202 self.handler.mcp_id()
203 );
204 self.inner.create_stream(id, message).await
205 }
206
207 async fn accept_message(
208 &self,
209 id: &SessionId,
210 message: ClientJsonRpcMessage,
211 ) -> Result<(), Self::Error> {
212 if !self.handler.is_backend_available() {
213 warn!(
214 "[Message rejected] session_id={}, reason: backend unavailable, MCP ID: {}",
215 id,
216 self.handler.mcp_id()
217 );
218 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
219 }
220
221 if !self.check_backend_version(id) {
222 warn!(
223 "[Message rejected] session_id={}, reason: version mismatch, MCP ID: {}",
224 id,
225 self.handler.mcp_id()
226 );
227 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
228 }
229
230 self.inner.accept_message(id, message).await
231 }
232
233 async fn create_standalone_stream(
234 &self,
235 id: &SessionId,
236 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
237 self.inner.create_standalone_stream(id).await
238 }
239
240 async fn resume(
241 &self,
242 id: &SessionId,
243 last_event_id: String,
244 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
245 if let Some(meta) = self.session_versions.get(id.as_ref()) {
247 let current_version = self.handler.get_backend_version();
248 if meta.backend_version != current_version {
249 warn!(
250 "[Session recovery failed] session_id={}, reason: backend version change ({} -> {}), MCP ID: {}",
251 id,
252 meta.backend_version,
253 current_version,
254 self.handler.mcp_id()
255 );
256
257 drop(meta); self.session_versions.remove(id.as_ref());
260 let _ = self.inner.close_session(id).await;
261
262 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
263 }
264 }
265
266 if !self.handler.is_backend_available() {
267 warn!(
268 "[Session recovery failed] session_id={}, reason: backend is unavailable, MCP ID: {}",
269 id,
270 self.handler.mcp_id()
271 );
272 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
273 }
274
275 debug!(
276 "[SessionResumed] session_id={}, last_event_id={}, MCP ID: {}",
277 id,
278 last_event_id,
279 self.handler.mcp_id()
280 );
281 self.inner.resume(id, last_event_id).await
282 }
283}