mcp_streamable_proxy/
session_manager.rs1use std::sync::Arc;
23use dashmap::DashMap;
24use futures::Stream;
25use rmcp::{
26 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
27 transport::{
28 streamable_http_server::session::{
29 SessionManager, SessionId,
30 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
31 },
32 common::server_side_http::ServerSseMessage,
33 WorkerTransport,
34 },
35};
36use tracing::{debug, info};
37
38use super::proxy_handler::ProxyHandler;
39
40#[derive(Debug, Clone)]
42struct SessionMetadata {
43 backend_version: u64,
44}
45
46pub struct ProxyAwareSessionManager {
54 inner: LocalSessionManager,
55 handler: Arc<ProxyHandler>,
56 session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60 pub fn new(handler: Arc<ProxyHandler>) -> Self {
61 info!("创建 ProxyAwareSessionManager");
62 Self {
63 inner: LocalSessionManager::default(),
64 handler,
65 session_versions: DashMap::new(),
66 }
67 }
68
69 fn check_backend_version(&self, session_id: &SessionId) -> bool {
70 if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
71 let current_version = self.handler.get_backend_version();
72 if meta.backend_version != current_version {
73 debug!(
74 "Session {} version mismatch: {} != {}",
75 session_id,
76 meta.backend_version,
77 current_version
78 );
79 return false;
80 }
81 }
82 true
83 }
84}
85
86impl SessionManager for ProxyAwareSessionManager {
88 type Error = LocalSessionManagerError;
89 type Transport = WorkerTransport<LocalSessionWorker>;
90
91 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
92 let (session_id, transport) = self.inner.create_session().await?;
93
94 let version = self.handler.get_backend_version();
95 self.session_versions.insert(
96 session_id.to_string(),
97 SessionMetadata {
98 backend_version: version,
99 },
100 );
101
102 debug!(
103 "Created session {} with backend version {}",
104 session_id, version
105 );
106
107 Ok((session_id, transport))
108 }
109
110 async fn initialize_session(
111 &self,
112 id: &SessionId,
113 message: ClientJsonRpcMessage,
114 ) -> Result<ServerJsonRpcMessage, Self::Error> {
115 if !self.handler.is_backend_available() {
116 info!(
117 "Rejecting session initialization {}: backend not available",
118 id
119 );
120 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
121 }
122
123 if !self.check_backend_version(id) {
124 info!(
125 "Rejecting session initialization {}: version mismatch",
126 id
127 );
128 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
129 }
130
131 self.inner.initialize_session(id, message).await
132 }
133
134 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
135 if !self.check_backend_version(id) {
136 return Ok(false);
137 }
138 self.inner.has_session(id).await
139 }
140
141 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
142 self.session_versions.remove(id.as_ref());
143 self.inner.close_session(id).await
144 }
145
146 async fn create_stream(
147 &self,
148 id: &SessionId,
149 message: ClientJsonRpcMessage,
150 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
151 if !self.handler.is_backend_available() {
152 info!(
153 "Rejecting stream creation {}: backend not available",
154 id
155 );
156 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
157 }
158
159 if !self.check_backend_version(id) {
160 info!("Rejecting stream creation {}: version mismatch", id);
161 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
162 }
163
164 self.inner.create_stream(id, message).await
165 }
166
167 async fn accept_message(
168 &self,
169 id: &SessionId,
170 message: ClientJsonRpcMessage,
171 ) -> Result<(), Self::Error> {
172 if !self.handler.is_backend_available() {
173 info!(
174 "Rejecting message for session {}: backend not available",
175 id
176 );
177 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
178 }
179
180 if !self.check_backend_version(id) {
181 info!(
182 "Rejecting message for session {}: version mismatch",
183 id
184 );
185 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
186 }
187
188 self.inner.accept_message(id, message).await
189 }
190
191 async fn create_standalone_stream(
192 &self,
193 id: &SessionId,
194 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
195 self.inner.create_standalone_stream(id).await
196 }
197
198 async fn resume(
199 &self,
200 id: &SessionId,
201 last_event_id: String,
202 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
203 if let Some(meta) = self.session_versions.get(id.as_ref()) {
205 let current_version = self.handler.get_backend_version();
206 if meta.backend_version != current_version {
207 info!(
208 "Session {} invalidated: backend version changed ({} -> {})",
209 id, meta.backend_version, current_version
210 );
211
212 drop(meta); self.session_versions.remove(id.as_ref());
215 let _ = self.inner.close_session(id).await;
216
217 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
218 }
219 }
220
221 if !self.handler.is_backend_available() {
222 info!("Cannot resume session {}: backend not available", id);
223 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
224 }
225
226 self.inner.resume(id, last_event_id).await
227 }
228}