mcp_streamable_proxy/
session_manager.rs

1//! Session Manager with backend version tracking
2//!
3//! This module implements ProxyAwareSessionManager that integrates with
4//! ProxyHandler's version control mechanism to automatically invalidate
5//! sessions when the backend reconnects.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ProxyAwareSessionManager
11//! ├── LocalSessionManager (rmcp 提供的基础实现)
12//! ├── ProxyHandler (Arc, 访问 backend_version)
13//! └── DashMap<SessionId, SessionMetadata> (跟踪 session 创建时的版本)
14//!
15//! 工作流程:
16//! 1. create_session: 记录当前 backend_version
17//! 2. resume: 检查版本是否匹配
18//!    - 匹配 → 正常 resume
19//!    - 不匹配 → 返回 NotFound,客户端重新创建 session
20//! ```
21
22use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26    transport::{
27        WorkerTransport,
28        common::server_side_http::ServerSseMessage,
29        streamable_http_server::session::{
30            SessionId, SessionManager,
31            local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32        },
33    },
34};
35use std::sync::Arc;
36use tracing::{debug, info, warn};
37
38use super::proxy_handler::ProxyHandler;
39
40/// Session 元数据:跟踪 session 创建时的后端版本
41#[derive(Debug, Clone)]
42struct SessionMetadata {
43    backend_version: u64,
44}
45
46/// 感知代理状态的 SessionManager
47///
48/// 职责:
49/// 1. 委托 LocalSessionManager 处理核心 session 逻辑
50/// 2. 维护 session → backend_version 映射
51/// 3. 在 resume 时检查版本一致性
52/// 4. 版本不匹配时使 session 失效
53pub struct ProxyAwareSessionManager {
54    inner: LocalSessionManager,
55    handler: Arc<ProxyHandler>,
56    session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60    pub fn new(handler: Arc<ProxyHandler>) -> Self {
61        info!(
62            "[Session管理器] 创建 ProxyAwareSessionManager - MCP ID: {}",
63            handler.mcp_id()
64        );
65        Self {
66            inner: LocalSessionManager::default(),
67            handler,
68            session_versions: DashMap::new(),
69        }
70    }
71
72    fn check_backend_version(&self, session_id: &SessionId) -> bool {
73        if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
74            let current_version = self.handler.get_backend_version();
75            if meta.backend_version != current_version {
76                warn!(
77                    "[Session版本不匹配] session_id={}, 创建时版本={}, 当前版本={}, MCP ID: {}",
78                    session_id, meta.backend_version, current_version, self.handler.mcp_id()
79                );
80                return false;
81            }
82        }
83        true
84    }
85}
86
87// Implement SessionManager trait
88impl SessionManager for ProxyAwareSessionManager {
89    type Error = LocalSessionManagerError;
90    type Transport = WorkerTransport<LocalSessionWorker>;
91
92    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
93        let (session_id, transport) = self.inner.create_session().await?;
94
95        let version = self.handler.get_backend_version();
96        self.session_versions.insert(
97            session_id.to_string(),
98            SessionMetadata {
99                backend_version: version,
100            },
101        );
102
103        info!(
104            "[Session创建] session_id={}, backend_version={}, MCP ID: {}",
105            session_id, version, self.handler.mcp_id()
106        );
107
108        Ok((session_id, transport))
109    }
110
111    async fn initialize_session(
112        &self,
113        id: &SessionId,
114        message: ClientJsonRpcMessage,
115    ) -> Result<ServerJsonRpcMessage, Self::Error> {
116        if !self.handler.is_backend_available() {
117            warn!(
118                "[Session初始化失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
119                id, self.handler.mcp_id()
120            );
121            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
122        }
123
124        if !self.check_backend_version(id) {
125            warn!(
126                "[Session初始化失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
127                id, self.handler.mcp_id()
128            );
129            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
130        }
131
132        debug!(
133            "[Session初始化] session_id={}, MCP ID: {}",
134            id, self.handler.mcp_id()
135        );
136        self.inner.initialize_session(id, message).await
137    }
138
139    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
140        if !self.check_backend_version(id) {
141            return Ok(false);
142        }
143        self.inner.has_session(id).await
144    }
145
146    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
147        info!(
148            "[Session关闭] session_id={}, MCP ID: {}",
149            id, self.handler.mcp_id()
150        );
151        self.session_versions.remove(id.as_ref());
152        self.inner.close_session(id).await
153    }
154
155    async fn create_stream(
156        &self,
157        id: &SessionId,
158        message: ClientJsonRpcMessage,
159    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
160        if !self.handler.is_backend_available() {
161            warn!(
162                "[Stream创建失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
163                id, self.handler.mcp_id()
164            );
165            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
166        }
167
168        if !self.check_backend_version(id) {
169            warn!(
170                "[Stream创建失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
171                id, self.handler.mcp_id()
172            );
173            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
174        }
175
176        debug!(
177            "[Stream创建] session_id={}, MCP ID: {}",
178            id, self.handler.mcp_id()
179        );
180        self.inner.create_stream(id, message).await
181    }
182
183    async fn accept_message(
184        &self,
185        id: &SessionId,
186        message: ClientJsonRpcMessage,
187    ) -> Result<(), Self::Error> {
188        if !self.handler.is_backend_available() {
189            warn!(
190                "[消息拒绝] session_id={}, 原因: 后端不可用, MCP ID: {}",
191                id, self.handler.mcp_id()
192            );
193            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
194        }
195
196        if !self.check_backend_version(id) {
197            warn!(
198                "[消息拒绝] session_id={}, 原因: 版本不匹配, MCP ID: {}",
199                id, self.handler.mcp_id()
200            );
201            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
202        }
203
204        self.inner.accept_message(id, message).await
205    }
206
207    async fn create_standalone_stream(
208        &self,
209        id: &SessionId,
210    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
211        self.inner.create_standalone_stream(id).await
212    }
213
214    async fn resume(
215        &self,
216        id: &SessionId,
217        last_event_id: String,
218    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
219        // 关键:检查后端版本
220        if let Some(meta) = self.session_versions.get(id.as_ref()) {
221            let current_version = self.handler.get_backend_version();
222            if meta.backend_version != current_version {
223                warn!(
224                    "[Session恢复失败] session_id={}, 原因: 后端版本变化 ({} -> {}), MCP ID: {}",
225                    id, meta.backend_version, current_version, self.handler.mcp_id()
226                );
227
228                // 清理失效 session
229                drop(meta); // 释放 DashMap 的读锁
230                self.session_versions.remove(id.as_ref());
231                let _ = self.inner.close_session(id).await;
232
233                return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
234            }
235        }
236
237        if !self.handler.is_backend_available() {
238            warn!(
239                "[Session恢复失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
240                id, self.handler.mcp_id()
241            );
242            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
243        }
244
245        debug!(
246            "[Session恢复] session_id={}, last_event_id={}, MCP ID: {}",
247            id, last_event_id, self.handler.mcp_id()
248        );
249        self.inner.resume(id, last_event_id).await
250    }
251}