mcp_streamable_proxy/
session_manager.rs

1//! Session Manager with backend version tracking
2//!
3//! This module implements ProxyAwareSessionManager that integrates with
4//! ProxyHandler's version control mechanism to automatically invalidate
5//! sessions when the backend reconnects.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ProxyAwareSessionManager
11//! ├── LocalSessionManager (rmcp 提供的基础实现)
12//! ├── ProxyHandler (Arc, 访问 backend_version)
13//! └── DashMap<SessionId, SessionMetadata> (跟踪 session 创建时的版本)
14//!
15//! 工作流程:
16//! 1. create_session: 记录当前 backend_version
17//! 2. resume: 检查版本是否匹配
18//!    - 匹配 → 正常 resume
19//!    - 不匹配 → 返回 NotFound,客户端重新创建 session
20//! ```
21
22use std::sync::Arc;
23use dashmap::DashMap;
24use futures::Stream;
25use rmcp::{
26    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
27    transport::{
28        streamable_http_server::session::{
29            SessionManager, SessionId,
30            local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
31        },
32        common::server_side_http::ServerSseMessage,
33        WorkerTransport,
34    },
35};
36use tracing::{debug, info};
37
38use super::proxy_handler::ProxyHandler;
39
40/// Session 元数据:跟踪 session 创建时的后端版本
41#[derive(Debug, Clone)]
42struct SessionMetadata {
43    backend_version: u64,
44}
45
46/// 感知代理状态的 SessionManager
47///
48/// 职责:
49/// 1. 委托 LocalSessionManager 处理核心 session 逻辑
50/// 2. 维护 session → backend_version 映射
51/// 3. 在 resume 时检查版本一致性
52/// 4. 版本不匹配时使 session 失效
53pub struct ProxyAwareSessionManager {
54    inner: LocalSessionManager,
55    handler: Arc<ProxyHandler>,
56    session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60    pub fn new(handler: Arc<ProxyHandler>) -> Self {
61        info!("创建 ProxyAwareSessionManager");
62        Self {
63            inner: LocalSessionManager::default(),
64            handler,
65            session_versions: DashMap::new(),
66        }
67    }
68
69    fn check_backend_version(&self, session_id: &SessionId) -> bool {
70        if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
71            let current_version = self.handler.get_backend_version();
72            if meta.backend_version != current_version {
73                debug!(
74                    "Session {} version mismatch: {} != {}",
75                    session_id,
76                    meta.backend_version,
77                    current_version
78                );
79                return false;
80            }
81        }
82        true
83    }
84}
85
86// Implement SessionManager trait
87impl SessionManager for ProxyAwareSessionManager {
88    type Error = LocalSessionManagerError;
89    type Transport = WorkerTransport<LocalSessionWorker>;
90
91    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
92        let (session_id, transport) = self.inner.create_session().await?;
93
94        let version = self.handler.get_backend_version();
95        self.session_versions.insert(
96            session_id.to_string(),
97            SessionMetadata {
98                backend_version: version,
99            },
100        );
101
102        debug!(
103            "Created session {} with backend version {}",
104            session_id, version
105        );
106
107        Ok((session_id, transport))
108    }
109
110    async fn initialize_session(
111        &self,
112        id: &SessionId,
113        message: ClientJsonRpcMessage,
114    ) -> Result<ServerJsonRpcMessage, Self::Error> {
115        if !self.handler.is_backend_available() {
116            info!(
117                "Rejecting session initialization {}: backend not available",
118                id
119            );
120            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
121        }
122
123        if !self.check_backend_version(id) {
124            info!(
125                "Rejecting session initialization {}: version mismatch",
126                id
127            );
128            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
129        }
130
131        self.inner.initialize_session(id, message).await
132    }
133
134    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
135        if !self.check_backend_version(id) {
136            return Ok(false);
137        }
138        self.inner.has_session(id).await
139    }
140
141    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
142        self.session_versions.remove(id.as_ref());
143        self.inner.close_session(id).await
144    }
145
146    async fn create_stream(
147        &self,
148        id: &SessionId,
149        message: ClientJsonRpcMessage,
150    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
151        if !self.handler.is_backend_available() {
152            info!(
153                "Rejecting stream creation {}: backend not available",
154                id
155            );
156            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
157        }
158
159        if !self.check_backend_version(id) {
160            info!("Rejecting stream creation {}: version mismatch", id);
161            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
162        }
163
164        self.inner.create_stream(id, message).await
165    }
166
167    async fn accept_message(
168        &self,
169        id: &SessionId,
170        message: ClientJsonRpcMessage,
171    ) -> Result<(), Self::Error> {
172        if !self.handler.is_backend_available() {
173            info!(
174                "Rejecting message for session {}: backend not available",
175                id
176            );
177            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
178        }
179
180        if !self.check_backend_version(id) {
181            info!(
182                "Rejecting message for session {}: version mismatch",
183                id
184            );
185            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
186        }
187
188        self.inner.accept_message(id, message).await
189    }
190
191    async fn create_standalone_stream(
192        &self,
193        id: &SessionId,
194    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
195        self.inner.create_standalone_stream(id).await
196    }
197
198    async fn resume(
199        &self,
200        id: &SessionId,
201        last_event_id: String,
202    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
203        // 关键:检查后端版本
204        if let Some(meta) = self.session_versions.get(id.as_ref()) {
205            let current_version = self.handler.get_backend_version();
206            if meta.backend_version != current_version {
207                info!(
208                    "Session {} invalidated: backend version changed ({} -> {})",
209                    id, meta.backend_version, current_version
210                );
211
212                // 清理失效 session
213                drop(meta); // 释放 DashMap 的读锁
214                self.session_versions.remove(id.as_ref());
215                let _ = self.inner.close_session(id).await;
216
217                return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
218            }
219        }
220
221        if !self.handler.is_backend_available() {
222            info!("Cannot resume session {}: backend not available", id);
223            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
224        }
225
226        self.inner.resume(id, last_event_id).await
227    }
228}