Skip to main content

mcp_streamable_proxy/
session_manager.rs

1//! Session Manager with backend version tracking
2//!
3//! This module implements ProxyAwareSessionManager that integrates with
4//! ProxyHandler's version control mechanism to automatically invalidate
5//! sessions when the backend reconnects.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ProxyAwareSessionManager
11//! ├── LocalSessionManager (rmcp 提供的基础实现)
12//! ├── ProxyHandler (Arc, 访问 backend_version)
13//! └── DashMap<SessionId, SessionMetadata> (跟踪 session 创建时的版本)
14//!
15//! 工作流程:
16//! 1. create_session: 记录当前 backend_version
17//! 2. resume: 检查版本是否匹配
18//!    - 匹配 → 正常 resume
19//!    - 不匹配 → 返回 NotFound,客户端重新创建 session
20//! ```
21
22use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26    transport::{
27        WorkerTransport,
28        common::server_side_http::ServerSseMessage,
29        streamable_http_server::session::{
30            SessionId, SessionManager,
31            local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32        },
33    },
34};
35use std::sync::Arc;
36use std::time::Duration;
37use tracing::{debug, info, warn};
38
39use super::proxy_handler::ProxyHandler;
40
41/// Session 元数据:跟踪 session 创建时的后端版本
42#[derive(Debug, Clone)]
43struct SessionMetadata {
44    backend_version: u64,
45}
46
47/// 感知代理状态的 SessionManager
48///
49/// 职责:
50/// 1. 委托 LocalSessionManager 处理核心 session 逻辑
51/// 2. 维护 session → backend_version 映射
52/// 3. 在 resume 时检查版本一致性
53/// 4. 版本不匹配时使 session 失效
54pub struct ProxyAwareSessionManager {
55    inner: LocalSessionManager,
56    handler: Arc<ProxyHandler>,
57    session_versions: DashMap<String, SessionMetadata>,
58}
59
60impl ProxyAwareSessionManager {
61    /// 默认 session keep_alive 超时: 30 分钟
62    /// rmcp 默认 5 分钟太短, 对于代理场景, agent 可能长时间不发消息但仍需要保持 session
63    const DEFAULT_SESSION_KEEP_ALIVE_SECS: u64 = 30 * 60;
64
65    pub fn new(handler: Arc<ProxyHandler>) -> Self {
66        Self::with_keep_alive(handler, Duration::from_secs(Self::DEFAULT_SESSION_KEEP_ALIVE_SECS))
67    }
68
69    pub fn with_keep_alive(handler: Arc<ProxyHandler>, keep_alive: Duration) -> Self {
70        info!(
71            "[Session Manager] Create ProxyAwareSessionManager - MCP ID: {}, keep_alive: {}s",
72            handler.mcp_id(),
73            keep_alive.as_secs()
74        );
75        let mut inner = LocalSessionManager::default();
76        inner.session_config.keep_alive = Some(keep_alive);
77        Self {
78            inner,
79            handler,
80            session_versions: DashMap::new(),
81        }
82    }
83
84    fn check_backend_version(&self, session_id: &SessionId) -> bool {
85        if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
86            let current_version = self.handler.get_backend_version();
87            if meta.backend_version != current_version {
88                warn!(
89                    "[Session version mismatch] session_id={}, creation version={}, current version={}, MCP ID: {}",
90                    session_id,
91                    meta.backend_version,
92                    current_version,
93                    self.handler.mcp_id()
94                );
95                return false;
96            }
97        }
98        true
99    }
100}
101
102// Implement SessionManager trait
103impl SessionManager for ProxyAwareSessionManager {
104    type Error = LocalSessionManagerError;
105    type Transport = WorkerTransport<LocalSessionWorker>;
106
107    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
108        let (session_id, transport) = self.inner.create_session().await?;
109
110        let version = self.handler.get_backend_version();
111        self.session_versions.insert(
112            session_id.to_string(),
113            SessionMetadata {
114                backend_version: version,
115            },
116        );
117
118        info!(
119            "[SessionCreated] session_id={}, backend_version={}, MCP ID: {}",
120            session_id,
121            version,
122            self.handler.mcp_id()
123        );
124
125        Ok((session_id, transport))
126    }
127
128    async fn initialize_session(
129        &self,
130        id: &SessionId,
131        message: ClientJsonRpcMessage,
132    ) -> Result<ServerJsonRpcMessage, Self::Error> {
133        if !self.handler.is_backend_available() {
134            warn!(
135                "[Session initialization failed] session_id={}, reason: backend is unavailable, MCP ID: {}",
136                id,
137                self.handler.mcp_id()
138            );
139            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
140        }
141
142        if !self.check_backend_version(id) {
143            warn!(
144                "[Session initialization failed] session_id={}, reason: version mismatch, MCP ID: {}",
145                id,
146                self.handler.mcp_id()
147            );
148            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
149        }
150
151        debug!(
152            "[Session initialization] session_id={}, MCP ID: {}",
153            id,
154            self.handler.mcp_id()
155        );
156        self.inner.initialize_session(id, message).await
157    }
158
159    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
160        if !self.check_backend_version(id) {
161            return Ok(false);
162        }
163        self.inner.has_session(id).await
164    }
165
166    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
167        info!(
168            "[SessionClosed] session_id={}, MCP ID: {}",
169            id,
170            self.handler.mcp_id()
171        );
172        self.session_versions.remove(id.as_ref());
173        self.inner.close_session(id).await
174    }
175
176    async fn create_stream(
177        &self,
178        id: &SessionId,
179        message: ClientJsonRpcMessage,
180    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
181        if !self.handler.is_backend_available() {
182            warn!(
183                "[Stream creation failed] session_id={}, reason: backend is unavailable, MCP ID: {}",
184                id,
185                self.handler.mcp_id()
186            );
187            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
188        }
189
190        if !self.check_backend_version(id) {
191            warn!(
192                "[Stream creation failed] session_id={}, reason: version mismatch, MCP ID: {}",
193                id,
194                self.handler.mcp_id()
195            );
196            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
197        }
198
199        debug!(
200            "[Stream creation] session_id={}, MCP ID: {}",
201            id,
202            self.handler.mcp_id()
203        );
204        self.inner.create_stream(id, message).await
205    }
206
207    async fn accept_message(
208        &self,
209        id: &SessionId,
210        message: ClientJsonRpcMessage,
211    ) -> Result<(), Self::Error> {
212        if !self.handler.is_backend_available() {
213            warn!(
214                "[Message rejected] session_id={}, reason: backend unavailable, MCP ID: {}",
215                id,
216                self.handler.mcp_id()
217            );
218            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
219        }
220
221        if !self.check_backend_version(id) {
222            warn!(
223                "[Message rejected] session_id={}, reason: version mismatch, MCP ID: {}",
224                id,
225                self.handler.mcp_id()
226            );
227            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
228        }
229
230        self.inner.accept_message(id, message).await
231    }
232
233    async fn create_standalone_stream(
234        &self,
235        id: &SessionId,
236    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
237        self.inner.create_standalone_stream(id).await
238    }
239
240    async fn resume(
241        &self,
242        id: &SessionId,
243        last_event_id: String,
244    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
245        // 关键:检查后端版本
246        if let Some(meta) = self.session_versions.get(id.as_ref()) {
247            let current_version = self.handler.get_backend_version();
248            if meta.backend_version != current_version {
249                warn!(
250                    "[Session recovery failed] session_id={}, reason: backend version change ({} -> {}), MCP ID: {}",
251                    id,
252                    meta.backend_version,
253                    current_version,
254                    self.handler.mcp_id()
255                );
256
257                // 清理失效 session
258                drop(meta); // 释放 DashMap 的读锁
259                self.session_versions.remove(id.as_ref());
260                let _ = self.inner.close_session(id).await;
261
262                return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
263            }
264        }
265
266        if !self.handler.is_backend_available() {
267            warn!(
268                "[Session recovery failed] session_id={}, reason: backend is unavailable, MCP ID: {}",
269                id,
270                self.handler.mcp_id()
271            );
272            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
273        }
274
275        debug!(
276            "[SessionResumed] session_id={}, last_event_id={}, MCP ID: {}",
277            id,
278            last_event_id,
279            self.handler.mcp_id()
280        );
281        self.inner.resume(id, last_event_id).await
282    }
283}