mcp_streamable_proxy/
session_manager.rs1use dashmap::DashMap;
23use futures::Stream;
24use rmcp::{
25 model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
26 transport::{
27 WorkerTransport,
28 common::server_side_http::ServerSseMessage,
29 streamable_http_server::session::{
30 SessionId, SessionManager,
31 local::{LocalSessionManager, LocalSessionManagerError, LocalSessionWorker},
32 },
33 },
34};
35use std::sync::Arc;
36use tracing::{debug, info};
37
38use super::proxy_handler::ProxyHandler;
39
40#[derive(Debug, Clone)]
42struct SessionMetadata {
43 backend_version: u64,
44}
45
46pub struct ProxyAwareSessionManager {
54 inner: LocalSessionManager,
55 handler: Arc<ProxyHandler>,
56 session_versions: DashMap<String, SessionMetadata>,
57}
58
59impl ProxyAwareSessionManager {
60 pub fn new(handler: Arc<ProxyHandler>) -> Self {
61 info!("创建 ProxyAwareSessionManager");
62 Self {
63 inner: LocalSessionManager::default(),
64 handler,
65 session_versions: DashMap::new(),
66 }
67 }
68
69 fn check_backend_version(&self, session_id: &SessionId) -> bool {
70 if let Some(meta) = self.session_versions.get(session_id.as_ref()) {
71 let current_version = self.handler.get_backend_version();
72 if meta.backend_version != current_version {
73 debug!(
74 "Session {} version mismatch: {} != {}",
75 session_id, meta.backend_version, current_version
76 );
77 return false;
78 }
79 }
80 true
81 }
82}
83
84impl SessionManager for ProxyAwareSessionManager {
86 type Error = LocalSessionManagerError;
87 type Transport = WorkerTransport<LocalSessionWorker>;
88
89 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
90 let (session_id, transport) = self.inner.create_session().await?;
91
92 let version = self.handler.get_backend_version();
93 self.session_versions.insert(
94 session_id.to_string(),
95 SessionMetadata {
96 backend_version: version,
97 },
98 );
99
100 debug!(
101 "Created session {} with backend version {}",
102 session_id, version
103 );
104
105 Ok((session_id, transport))
106 }
107
108 async fn initialize_session(
109 &self,
110 id: &SessionId,
111 message: ClientJsonRpcMessage,
112 ) -> Result<ServerJsonRpcMessage, Self::Error> {
113 if !self.handler.is_backend_available() {
114 info!(
115 "Rejecting session initialization {}: backend not available",
116 id
117 );
118 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
119 }
120
121 if !self.check_backend_version(id) {
122 info!("Rejecting session initialization {}: version mismatch", id);
123 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
124 }
125
126 self.inner.initialize_session(id, message).await
127 }
128
129 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
130 if !self.check_backend_version(id) {
131 return Ok(false);
132 }
133 self.inner.has_session(id).await
134 }
135
136 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
137 self.session_versions.remove(id.as_ref());
138 self.inner.close_session(id).await
139 }
140
141 async fn create_stream(
142 &self,
143 id: &SessionId,
144 message: ClientJsonRpcMessage,
145 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
146 if !self.handler.is_backend_available() {
147 info!("Rejecting stream creation {}: backend not available", id);
148 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
149 }
150
151 if !self.check_backend_version(id) {
152 info!("Rejecting stream creation {}: version mismatch", id);
153 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
154 }
155
156 self.inner.create_stream(id, message).await
157 }
158
159 async fn accept_message(
160 &self,
161 id: &SessionId,
162 message: ClientJsonRpcMessage,
163 ) -> Result<(), Self::Error> {
164 if !self.handler.is_backend_available() {
165 info!(
166 "Rejecting message for session {}: backend not available",
167 id
168 );
169 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
170 }
171
172 if !self.check_backend_version(id) {
173 info!("Rejecting message for session {}: version mismatch", id);
174 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
175 }
176
177 self.inner.accept_message(id, message).await
178 }
179
180 async fn create_standalone_stream(
181 &self,
182 id: &SessionId,
183 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
184 self.inner.create_standalone_stream(id).await
185 }
186
187 async fn resume(
188 &self,
189 id: &SessionId,
190 last_event_id: String,
191 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
192 if let Some(meta) = self.session_versions.get(id.as_ref()) {
194 let current_version = self.handler.get_backend_version();
195 if meta.backend_version != current_version {
196 info!(
197 "Session {} invalidated: backend version changed ({} -> {})",
198 id, meta.backend_version, current_version
199 );
200
201 drop(meta); self.session_versions.remove(id.as_ref());
204 let _ = self.inner.close_session(id).await;
205
206 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
207 }
208 }
209
210 if !self.handler.is_backend_available() {
211 info!("Cannot resume session {}: backend not available", id);
212 return Err(LocalSessionManagerError::SessionNotFound(id.clone()));
213 }
214
215 self.inner.resume(id, last_event_id).await
216 }
217}