1use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26 transport::{
27 WorkerTransport,
28 common::server_side_http::ServerSseMessage,
29 streamable_http_server::session::{
30 SessionId, SessionManager,
31 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32 },
33 },
34};
35use std::sync::Arc;
36use tracing::{debug, info, warn};
37
38use super::proxy_handler::ProxyHandler;
39
40#[derive(Debug, Clone)]
42struct SessionMetadata {
43 backend_version: u64,
44}
45
46pub struct ProxyAwareSessionManager {
54 inner: LocalSessionManager,
55 handler: Arc<ProxyHandler>,
56 session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60 pub fn new(handler: Arc<ProxyHandler>) -> Self {
61 info!(
62 "[Session管理器] 创建 ProxyAwareSessionManager - MCP ID: {}",
63 handler.mcp_id()
64 );
65 Self {
66 inner: LocalSessionManager::default(),
67 handler,
68 session_versions: DashMap::new(),
69 }
70 }
71
72 fn check_backend_version(&self, session_id: &SessionId) -> bool {
73 if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
74 let current_version = self.handler.get_backend_version();
75 if meta.backend_version != current_version {
76 warn!(
77 "[Session版本不匹配] session_id={}, 创建时版本={}, 当前版本={}, MCP ID: {}",
78 session_id, meta.backend_version, current_version, self.handler.mcp_id()
79 );
80 return false;
81 }
82 }
83 true
84 }
85}
86
87impl SessionManager for ProxyAwareSessionManager {
89 type Error = LocalSessionManagerError;
90 type Transport = WorkerTransport<LocalSessionWorker>;
91
92 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
93 let (session_id, transport) = self.inner.create_session().await?;
94
95 let version = self.handler.get_backend_version();
96 self.session_versions.insert(
97 session_id.to_string(),
98 SessionMetadata {
99 backend_version: version,
100 },
101 );
102
103 info!(
104 "[Session创建] session_id={}, backend_version={}, MCP ID: {}",
105 session_id, version, self.handler.mcp_id()
106 );
107
108 Ok((session_id, transport))
109 }
110
111 async fn initialize_session(
112 &self,
113 id: &SessionId,
114 message: ClientJsonRpcMessage,
115 ) -> Result<ServerJsonRpcMessage, Self::Error> {
116 if !self.handler.is_backend_available() {
117 warn!(
118 "[Session初始化失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
119 id, self.handler.mcp_id()
120 );
121 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
122 }
123
124 if !self.check_backend_version(id) {
125 warn!(
126 "[Session初始化失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
127 id, self.handler.mcp_id()
128 );
129 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
130 }
131
132 debug!(
133 "[Session初始化] session_id={}, MCP ID: {}",
134 id, self.handler.mcp_id()
135 );
136 self.inner.initialize_session(id, message).await
137 }
138
139 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
140 if !self.check_backend_version(id) {
141 return Ok(false);
142 }
143 self.inner.has_session(id).await
144 }
145
146 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
147 info!(
148 "[Session关闭] session_id={}, MCP ID: {}",
149 id, self.handler.mcp_id()
150 );
151 self.session_versions.remove(id.as_ref());
152 self.inner.close_session(id).await
153 }
154
155 async fn create_stream(
156 &self,
157 id: &SessionId,
158 message: ClientJsonRpcMessage,
159 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
160 if !self.handler.is_backend_available() {
161 warn!(
162 "[Stream创建失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
163 id, self.handler.mcp_id()
164 );
165 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
166 }
167
168 if !self.check_backend_version(id) {
169 warn!(
170 "[Stream创建失败] session_id={}, 原因: 版本不匹配, MCP ID: {}",
171 id, self.handler.mcp_id()
172 );
173 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
174 }
175
176 debug!(
177 "[Stream创建] session_id={}, MCP ID: {}",
178 id, self.handler.mcp_id()
179 );
180 self.inner.create_stream(id, message).await
181 }
182
183 async fn accept_message(
184 &self,
185 id: &SessionId,
186 message: ClientJsonRpcMessage,
187 ) -> Result<(), Self::Error> {
188 if !self.handler.is_backend_available() {
189 warn!(
190 "[消息拒绝] session_id={}, 原因: 后端不可用, MCP ID: {}",
191 id, self.handler.mcp_id()
192 );
193 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
194 }
195
196 if !self.check_backend_version(id) {
197 warn!(
198 "[消息拒绝] session_id={}, 原因: 版本不匹配, MCP ID: {}",
199 id, self.handler.mcp_id()
200 );
201 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
202 }
203
204 self.inner.accept_message(id, message).await
205 }
206
207 async fn create_standalone_stream(
208 &self,
209 id: &SessionId,
210 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
211 self.inner.create_standalone_stream(id).await
212 }
213
214 async fn resume(
215 &self,
216 id: &SessionId,
217 last_event_id: String,
218 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
219 if let Some(meta) = self.session_versions.get(id.as_ref()) {
221 let current_version = self.handler.get_backend_version();
222 if meta.backend_version != current_version {
223 warn!(
224 "[Session恢复失败] session_id={}, 原因: 后端版本变化 ({} -> {}), MCP ID: {}",
225 id, meta.backend_version, current_version, self.handler.mcp_id()
226 );
227
228 drop(meta); self.session_versions.remove(id.as_ref());
231 let _ = self.inner.close_session(id).await;
232
233 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
234 }
235 }
236
237 if !self.handler.is_backend_available() {
238 warn!(
239 "[Session恢复失败] session_id={}, 原因: 后端不可用, MCP ID: {}",
240 id, self.handler.mcp_id()
241 );
242 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
243 }
244
245 debug!(
246 "[Session恢复] session_id={}, last_event_id={}, MCP ID: {}",
247 id, last_event_id, self.handler.mcp_id()
248 );
249 self.inner.resume(id, last_event_id).await
250 }
251}