mcp_streamable_proxy/
session_manager.rs

1//! Session Manager with backend version tracking
2//!
3//! This module implements ProxyAwareSessionManager that integrates with
4//! ProxyHandler's version control mechanism to automatically invalidate
5//! sessions when the backend reconnects.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ProxyAwareSessionManager
11//! ├── LocalSessionManager (rmcp 提供的基础实现)
12//! ├── ProxyHandler (Arc, 访问 backend_version)
13//! └── DashMap<SessionId, SessionMetadata> (跟踪 session 创建时的版本)
14//!
15//! 工作流程:
16//! 1. create_session: 记录当前 backend_version
17//! 2. resume: 检查版本是否匹配
18//!    - 匹配 → 正常 resume
19//!    - 不匹配 → 返回 NotFound,客户端重新创建 session
20//! ```
21
22use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25    model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26    transport::{
27        WorkerTransport,
28        common::server_side_http::ServerSseMessage,
29        streamable_http_server::session::{
30            SessionId, SessionManager,
31            local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32        },
33    },
34};
35use std::sync::Arc;
36use tracing::{debug, info};
37
38use super::proxy_handler::ProxyHandler;
39
40/// Session 元数据:跟踪 session 创建时的后端版本
41#[derive(Debug, Clone)]
42struct SessionMetadata {
43    backend_version: u64,
44}
45
46/// 感知代理状态的 SessionManager
47///
48/// 职责:
49/// 1. 委托 LocalSessionManager 处理核心 session 逻辑
50/// 2. 维护 session → backend_version 映射
51/// 3. 在 resume 时检查版本一致性
52/// 4. 版本不匹配时使 session 失效
53pub struct ProxyAwareSessionManager {
54    inner: LocalSessionManager,
55    handler: Arc<ProxyHandler>,
56    session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60    pub fn new(handler: Arc<ProxyHandler>) -> Self {
61        info!("创建 ProxyAwareSessionManager");
62        Self {
63            inner: LocalSessionManager::default(),
64            handler,
65            session_versions: DashMap::new(),
66        }
67    }
68
69    fn check_backend_version(&self, session_id: &SessionId) -> bool {
70        if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
71            let current_version = self.handler.get_backend_version();
72            if meta.backend_version != current_version {
73                debug!(
74                    "Session {} version mismatch: {} != {}",
75                    session_id, meta.backend_version, current_version
76                );
77                return false;
78            }
79        }
80        true
81    }
82}
83
84// Implement SessionManager trait
85impl SessionManager for ProxyAwareSessionManager {
86    type Error = LocalSessionManagerError;
87    type Transport = WorkerTransport<LocalSessionWorker>;
88
89    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
90        let (session_id, transport) = self.inner.create_session().await?;
91
92        let version = self.handler.get_backend_version();
93        self.session_versions.insert(
94            session_id.to_string(),
95            SessionMetadata {
96                backend_version: version,
97            },
98        );
99
100        debug!(
101            "Created session {} with backend version {}",
102            session_id, version
103        );
104
105        Ok((session_id, transport))
106    }
107
108    async fn initialize_session(
109        &self,
110        id: &SessionId,
111        message: ClientJsonRpcMessage,
112    ) -> Result<ServerJsonRpcMessage, Self::Error> {
113        if !self.handler.is_backend_available() {
114            info!(
115                "Rejecting session initialization {}: backend not available",
116                id
117            );
118            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
119        }
120
121        if !self.check_backend_version(id) {
122            info!("Rejecting session initialization {}: version mismatch", id);
123            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
124        }
125
126        self.inner.initialize_session(id, message).await
127    }
128
129    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
130        if !self.check_backend_version(id) {
131            return Ok(false);
132        }
133        self.inner.has_session(id).await
134    }
135
136    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
137        self.session_versions.remove(id.as_ref());
138        self.inner.close_session(id).await
139    }
140
141    async fn create_stream(
142        &self,
143        id: &SessionId,
144        message: ClientJsonRpcMessage,
145    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
146        if !self.handler.is_backend_available() {
147            info!("Rejecting stream creation {}: backend not available", id);
148            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
149        }
150
151        if !self.check_backend_version(id) {
152            info!("Rejecting stream creation {}: version mismatch", id);
153            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
154        }
155
156        self.inner.create_stream(id, message).await
157    }
158
159    async fn accept_message(
160        &self,
161        id: &SessionId,
162        message: ClientJsonRpcMessage,
163    ) -> Result<(), Self::Error> {
164        if !self.handler.is_backend_available() {
165            info!(
166                "Rejecting message for session {}: backend not available",
167                id
168            );
169            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
170        }
171
172        if !self.check_backend_version(id) {
173            info!("Rejecting message for session {}: version mismatch", id);
174            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
175        }
176
177        self.inner.accept_message(id, message).await
178    }
179
180    async fn create_standalone_stream(
181        &self,
182        id: &SessionId,
183    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
184        self.inner.create_standalone_stream(id).await
185    }
186
187    async fn resume(
188        &self,
189        id: &SessionId,
190        last_event_id: String,
191    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
192        // 关键:检查后端版本
193        if let Some(meta) = self.session_versions.get(id.as_ref()) {
194            let current_version = self.handler.get_backend_version();
195            if meta.backend_version != current_version {
196                info!(
197                    "Session {} invalidated: backend version changed ({} -> {})",
198                    id, meta.backend_version, current_version
199                );
200
201                // 清理失效 session
202                drop(meta); // 释放 DashMap 的读锁
203                self.session_versions.remove(id.as_ref());
204                let _ = self.inner.close_session(id).await;
205
206                return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
207            }
208        }
209
210        if !self.handler.is_backend_available() {
211            info!("Cannot resume session {}: backend not available", id);
212            return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
213        }
214
215        self.inner.resume(id, last_event_id).await
216    }
217}