Skip to main content

mcp_streamable_proxy/
session_manager.rs

1//! Session Manager with backend version tracking
2//!
3//! This module implements ProxyAwareSessionManager that integrates with
4//! ProxyHandler's version control mechanism to automatically invalidate
5//! sessions when the backend reconnects.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ProxyAwareSessionManager
11//! ├── LocalSessionManager (rmcp 提供的基础实现)
12//! ├── ProxyHandler (Arc, 访问 backend_version)
13//! └── DashMap<SessionId, SessionMetadata> (跟踪 session 创建时的版本)
14//!
15//! 工作流程:
16//! 1. create_session: 记录当前 backend_version
17//! 2. resume: 检查版本是否匹配
18//!    - 匹配 → 正常 resume
19//!    - 不匹配 → 返回 NotFound,客户端重新创建 session
20//! ```
21
22use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26    transport::{
27        WorkerTransport,
28        common::server_side_http::ServerSseMessage,
29        streamable_http_server::session::{
30            SessionId, SessionManager,
31            local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32        },
33    },
34};
35use std::sync::Arc;
36use tracing::{debug, info, warn};
37
38use super::proxy_handler::ProxyHandler;
39
40/// Session 元数据:跟踪 session 创建时的后端版本
41#[derive(Debug, Clone)]
42struct SessionMetadata {
43    backend_version: u64,
44}
45
46/// 感知代理状态的 SessionManager
47///
48/// 职责:
49/// 1. 委托 LocalSessionManager 处理核心 session 逻辑
50/// 2. 维护 session → backend_version 映射
51/// 3. 在 resume 时检查版本一致性
52/// 4. 版本不匹配时使 session 失效
53pub struct ProxyAwareSessionManager {
54    inner: LocalSessionManager,
55    handler: Arc<ProxyHandler>,
56    session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60    pub fn new(handler: Arc<ProxyHandler>) -> Self {
61        info!(
62            "[Session管理器] 创建 ProxyAwareSessionManager - MCP ID: {}",
63            handler.mcp_id()
64        );
65        Self {
66            inner: LocalSessionManager::default(),
67            handler,
68            session_versions: DashMap::new(),
69        }
70    }
71
72    fn check_backend_version(&self, session_id: &SessionId) -> bool {
73        if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
74            let current_version = self.handler.get_backend_version();
75            if meta.backend_version != current_version {
76                warn!(
77                    "[Session版本不匹配] session_id={}, 创建时版本={}, 当前版本={}, MCP ID: {}",
78                    session_id,
79                    meta.backend_version,
80                    current_version,
81                    self.handler.mcp_id()
82                );
83                return false;
84            }
85        }
86        true
87    }
88}
89
90// Implement SessionManager trait
91impl SessionManager for ProxyAwareSessionManager {
92    type Error = LocalSessionManagerError;
93    type Transport = WorkerTransport<LocalSessionWorker>;
94
95    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
96        let (session_id, transport) = self.inner.create_session().await?;
97
98        let version = self.handler.get_backend_version();
99        self.session_versions.insert(
100            session_id.to_string(),
101            SessionMetadata {
102                backend_version: version,
103            },
104        );
105
106        info!(
107            "[Session创建] session_id={}, backend_version={}, MCP ID: {}",
108            session_id,
109            version,
110            self.handler.mcp_id()
111        );
112
113        Ok((session_id, transport))
114    }
115
116    async fn initialize_session(
117        &self,
118        id: &SessionId,
119        message: ClientJsonRpcMessage,
120    ) -> Result<ServerJsonRpcMessage, Self::Error> {
121        if !self.handler.is_backend_available() {
122            warn!(
123                "[Session初始化失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
124                id,
125                self.handler.mcp_id()
126            );
127            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
128        }
129
130        if !self.check_backend_version(id) {
131            warn!(
132                "[Session初始化失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
133                id,
134                self.handler.mcp_id()
135            );
136            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
137        }
138
139        debug!(
140            "[Session初始化] session_id={}, MCP ID: {}",
141            id,
142            self.handler.mcp_id()
143        );
144        self.inner.initialize_session(id, message).await
145    }
146
147    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
148        if !self.check_backend_version(id) {
149            return Ok(false);
150        }
151        self.inner.has_session(id).await
152    }
153
154    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
155        info!(
156            "[Session关闭] session_id={}, MCP ID: {}",
157            id,
158            self.handler.mcp_id()
159        );
160        self.session_versions.remove(id.as_ref());
161        self.inner.close_session(id).await
162    }
163
164    async fn create_stream(
165        &self,
166        id: &SessionId,
167        message: ClientJsonRpcMessage,
168    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
169        if !self.handler.is_backend_available() {
170            warn!(
171                "[Stream创建失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
172                id,
173                self.handler.mcp_id()
174            );
175            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
176        }
177
178        if !self.check_backend_version(id) {
179            warn!(
180                "[Stream创建失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
181                id,
182                self.handler.mcp_id()
183            );
184            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
185        }
186
187        debug!(
188            "[Stream创建] session_id={}, MCP ID: {}",
189            id,
190            self.handler.mcp_id()
191        );
192        self.inner.create_stream(id, message).await
193    }
194
195    async fn accept_message(
196        &self,
197        id: &SessionId,
198        message: ClientJsonRpcMessage,
199    ) -> Result<(), Self::Error> {
200        if !self.handler.is_backend_available() {
201            warn!(
202                "[消息拒绝] session_id={}, 原因: 后端不可用, MCP ID: {}",
203                id,
204                self.handler.mcp_id()
205            );
206            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
207        }
208
209        if !self.check_backend_version(id) {
210            warn!(
211                "[消息拒绝] session_id={}, 原因: 版本不匹配, MCP ID: {}",
212                id,
213                self.handler.mcp_id()
214            );
215            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
216        }
217
218        self.inner.accept_message(id, message).await
219    }
220
221    async fn create_standalone_stream(
222        &self,
223        id: &SessionId,
224    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
225        self.inner.create_standalone_stream(id).await
226    }
227
228    async fn resume(
229        &self,
230        id: &SessionId,
231        last_event_id: String,
232    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
233        // 关键:检查后端版本
234        if let Some(meta) = self.session_versions.get(id.as_ref()) {
235            let current_version = self.handler.get_backend_version();
236            if meta.backend_version != current_version {
237                warn!(
238                    "[Session恢复失败] session_id={}, 原因: 后端版本变化 ({} -> {}), MCP ID: {}",
239                    id,
240                    meta.backend_version,
241                    current_version,
242                    self.handler.mcp_id()
243                );
244
245                // 清理失效 session
246                drop(meta); // 释放 DashMap 的读锁
247                self.session_versions.remove(id.as_ref());
248                let _ = self.inner.close_session(id).await;
249
250                return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
251            }
252        }
253
254        if !self.handler.is_backend_available() {
255            warn!(
256                "[Session恢复失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
257                id,
258                self.handler.mcp_id()
259            );
260            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
261        }
262
263        debug!(
264            "[Session恢复] session_id={}, last_event_id={}, MCP ID: {}",
265            id,
266            last_event_id,
267            self.handler.mcp_id()
268        );
269        self.inner.resume(id, last_event_id).await
270    }
271}