1use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26 transport::{
27 WorkerTransport,
28 common::server_side_http::ServerSseMessage,
29 streamable_http_server::session::{
30 SessionId, SessionManager,
31 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32 },
33 },
34};
35use std::sync::Arc;
36use tracing::{debug, info, warn};
37
38use super::proxy_handler::ProxyHandler;
39
40#[derive(Debug, Clone)]
42struct SessionMetadata {
43 backend_version: u64,
44}
45
46pub struct ProxyAwareSessionManager {
54 inner: LocalSessionManager,
55 handler: Arc<ProxyHandler>,
56 session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60 pub fn new(handler: Arc<ProxyHandler>) -> Self {
61 info!(
62 "[Session管理器] 创建 ProxyAwareSessionManager - MCP ID: {}",
63 handler.mcp_id()
64 );
65 Self {
66 inner: LocalSessionManager::default(),
67 handler,
68 session_versions: DashMap::new(),
69 }
70 }
71
72 fn check_backend_version(&self, session_id: &SessionId) -> bool {
73 if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
74 let current_version = self.handler.get_backend_version();
75 if meta.backend_version != current_version {
76 warn!(
77 "[Session版本不匹配] session_id={}, 创建时版本={}, 当前版本={}, MCP ID: {}",
78 session_id,
79 meta.backend_version,
80 current_version,
81 self.handler.mcp_id()
82 );
83 return false;
84 }
85 }
86 true
87 }
88}
89
90impl SessionManager for ProxyAwareSessionManager {
92 type Error = LocalSessionManagerError;
93 type Transport = WorkerTransport<LocalSessionWorker>;
94
95 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
96 let (session_id, transport) = self.inner.create_session().await?;
97
98 let version = self.handler.get_backend_version();
99 self.session_versions.insert(
100 session_id.to_string(),
101 SessionMetadata {
102 backend_version: version,
103 },
104 );
105
106 info!(
107 "[Session创建] session_id={}, backend_version={}, MCP ID: {}",
108 session_id,
109 version,
110 self.handler.mcp_id()
111 );
112
113 Ok((session_id, transport))
114 }
115
116 async fn initialize_session(
117 &self,
118 id: &SessionId,
119 message: ClientJsonRpcMessage,
120 ) -> Result<ServerJsonRpcMessage, Self::Error> {
121 if !self.handler.is_backend_available() {
122 warn!(
123 "[Session初始化失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
124 id,
125 self.handler.mcp_id()
126 );
127 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
128 }
129
130 if !self.check_backend_version(id) {
131 warn!(
132 "[Session初始化失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
133 id,
134 self.handler.mcp_id()
135 );
136 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
137 }
138
139 debug!(
140 "[Session初始化] session_id={}, MCP ID: {}",
141 id,
142 self.handler.mcp_id()
143 );
144 self.inner.initialize_session(id, message).await
145 }
146
147 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
148 if !self.check_backend_version(id) {
149 return Ok(false);
150 }
151 self.inner.has_session(id).await
152 }
153
154 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
155 info!(
156 "[Session关闭] session_id={}, MCP ID: {}",
157 id,
158 self.handler.mcp_id()
159 );
160 self.session_versions.remove(id.as_ref());
161 self.inner.close_session(id).await
162 }
163
164 async fn create_stream(
165 &self,
166 id: &SessionId,
167 message: ClientJsonRpcMessage,
168 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
169 if !self.handler.is_backend_available() {
170 warn!(
171 "[Stream创建失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
172 id,
173 self.handler.mcp_id()
174 );
175 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
176 }
177
178 if !self.check_backend_version(id) {
179 warn!(
180 "[Stream创建失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
181 id,
182 self.handler.mcp_id()
183 );
184 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
185 }
186
187 debug!(
188 "[Stream创建] session_id={}, MCP ID: {}",
189 id,
190 self.handler.mcp_id()
191 );
192 self.inner.create_stream(id, message).await
193 }
194
195 async fn accept_message(
196 &self,
197 id: &SessionId,
198 message: ClientJsonRpcMessage,
199 ) -> Result<(), Self::Error> {
200 if !self.handler.is_backend_available() {
201 warn!(
202 "[消息拒绝] session_id={}, 原因: 后端不可用, MCP ID: {}",
203 id,
204 self.handler.mcp_id()
205 );
206 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
207 }
208
209 if !self.check_backend_version(id) {
210 warn!(
211 "[消息拒绝] session_id={}, 原因: 版本不匹配, MCP ID: {}",
212 id,
213 self.handler.mcp_id()
214 );
215 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
216 }
217
218 self.inner.accept_message(id, message).await
219 }
220
221 async fn create_standalone_stream(
222 &self,
223 id: &SessionId,
224 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
225 self.inner.create_standalone_stream(id).await
226 }
227
228 async fn resume(
229 &self,
230 id: &SessionId,
231 last_event_id: String,
232 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
233 if let Some(meta) = self.session_versions.get(id.as_ref()) {
235 let current_version = self.handler.get_backend_version();
236 if meta.backend_version != current_version {
237 warn!(
238 "[Session恢复失败] session_id={}, 原因: 后端版本变化 ({} -> {}), MCP ID: {}",
239 id,
240 meta.backend_version,
241 current_version,
242 self.handler.mcp_id()
243 );
244
245 drop(meta); self.session_versions.remove(id.as_ref());
248 let _ = self.inner.close_session(id).await;
249
250 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
251 }
252 }
253
254 if !self.handler.is_backend_available() {
255 warn!(
256 "[Session恢复失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
257 id,
258 self.handler.mcp_id()
259 );
260 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
261 }
262
263 debug!(
264 "[Session恢复] session_id={}, last_event_id={}, MCP ID: {}",
265 id,
266 last_event_id,
267 self.handler.mcp_id()
268 );
269 self.inner.resume(id, last_event_id).await
270 }
271}