1use std::collections::HashSet;
4
5use fastmcp_core::SessionState;
6use fastmcp_core::logging::{debug, targets, warn};
7use fastmcp_protocol::{
8 ClientCapabilities, ClientInfo, JsonRpcRequest, LogLevel, ResourceUpdatedNotificationParams,
9 ServerCapabilities, ServerInfo,
10};
11
12use crate::NotificationSender;
13
14#[derive(Debug)]
18pub struct Session {
19 initialized: bool,
21 client_info: Option<ClientInfo>,
23 client_capabilities: Option<ClientCapabilities>,
25 server_info: ServerInfo,
27 server_capabilities: ServerCapabilities,
29 protocol_version: Option<String>,
31 resource_subscriptions: HashSet<String>,
33 log_level: Option<LogLevel>,
35 state: SessionState,
37}
38
39impl Session {
40 #[must_use]
42 pub fn new(server_info: ServerInfo, server_capabilities: ServerCapabilities) -> Self {
43 Self {
44 initialized: false,
45 client_info: None,
46 client_capabilities: None,
47 server_info,
48 server_capabilities,
49 protocol_version: None,
50 resource_subscriptions: HashSet::new(),
51 log_level: None,
52 state: SessionState::new(),
53 }
54 }
55
56 #[must_use]
61 pub fn state(&self) -> &SessionState {
62 &self.state
63 }
64
65 #[must_use]
67 pub fn is_initialized(&self) -> bool {
68 self.initialized
69 }
70
71 pub fn initialize(
73 &mut self,
74 client_info: ClientInfo,
75 client_capabilities: ClientCapabilities,
76 protocol_version: String,
77 ) {
78 self.client_info = Some(client_info);
79 self.client_capabilities = Some(client_capabilities);
80 self.protocol_version = Some(protocol_version);
81 self.initialized = true;
82 }
83
84 #[must_use]
86 pub fn client_info(&self) -> Option<&ClientInfo> {
87 self.client_info.as_ref()
88 }
89
90 #[must_use]
92 pub fn client_capabilities(&self) -> Option<&ClientCapabilities> {
93 self.client_capabilities.as_ref()
94 }
95
96 #[must_use]
98 pub fn server_info(&self) -> &ServerInfo {
99 &self.server_info
100 }
101
102 #[must_use]
104 pub fn server_capabilities(&self) -> &ServerCapabilities {
105 &self.server_capabilities
106 }
107
108 #[must_use]
110 pub fn protocol_version(&self) -> Option<&str> {
111 self.protocol_version.as_deref()
112 }
113
114 pub fn subscribe_resource(&mut self, uri: String) {
116 self.resource_subscriptions.insert(uri);
117 }
118
119 pub fn unsubscribe_resource(&mut self, uri: &str) {
121 self.resource_subscriptions.remove(uri);
122 }
123
124 #[must_use]
126 pub fn is_resource_subscribed(&self, uri: &str) -> bool {
127 self.resource_subscriptions.contains(uri)
128 }
129
130 pub fn set_log_level(&mut self, level: LogLevel) {
132 self.log_level = Some(level);
133 }
134
135 #[must_use]
137 pub fn log_level(&self) -> Option<LogLevel> {
138 self.log_level
139 }
140
141 #[must_use]
143 pub fn supports_sampling(&self) -> bool {
144 self.client_capabilities
145 .as_ref()
146 .is_some_and(|caps| caps.sampling.is_some())
147 }
148
149 #[must_use]
151 pub fn supports_elicitation(&self) -> bool {
152 self.client_capabilities
153 .as_ref()
154 .is_some_and(|caps| caps.elicitation.is_some())
155 }
156
157 #[must_use]
159 pub fn supports_roots(&self) -> bool {
160 self.client_capabilities
161 .as_ref()
162 .is_some_and(|caps| caps.roots.is_some())
163 }
164
165 pub fn notify_resource_updated(&self, uri: &str, sender: &NotificationSender) -> bool {
169 if !self.is_resource_subscribed(uri) {
170 return false;
171 }
172
173 let params = ResourceUpdatedNotificationParams {
174 uri: uri.to_string(),
175 };
176 let payload = match serde_json::to_value(params) {
177 Ok(value) => value,
178 Err(err) => {
179 warn!(
180 target: targets::SESSION,
181 "failed to serialize resource update for {}: {}",
182 uri,
183 err
184 );
185 return false;
186 }
187 };
188
189 debug!(
190 target: targets::SESSION,
191 "sending resource update notification for {}",
192 uri
193 );
194 sender(JsonRpcRequest::notification(
195 "notifications/resources/updated",
196 Some(payload),
197 ));
198 true
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use fastmcp_protocol::{ElicitationCapability, RootsCapability, SamplingCapability};
206
207 #[test]
208 fn test_session_supports_sampling() {
209 let mut session = Session::new(
210 ServerInfo {
211 name: "test".to_string(),
212 version: "1.0".to_string(),
213 },
214 ServerCapabilities::default(),
215 );
216
217 assert!(!session.supports_sampling());
219
220 session.initialize(
222 ClientInfo {
223 name: "test-client".to_string(),
224 version: "1.0".to_string(),
225 },
226 ClientCapabilities {
227 sampling: Some(SamplingCapability {}),
228 elicitation: None,
229 roots: None,
230 },
231 "2024-11-05".to_string(),
232 );
233
234 assert!(session.supports_sampling());
235 assert!(!session.supports_elicitation());
236 assert!(!session.supports_roots());
237 }
238
239 #[test]
240 fn test_session_supports_elicitation() {
241 let mut session = Session::new(
242 ServerInfo {
243 name: "test".to_string(),
244 version: "1.0".to_string(),
245 },
246 ServerCapabilities::default(),
247 );
248
249 session.initialize(
250 ClientInfo {
251 name: "test-client".to_string(),
252 version: "1.0".to_string(),
253 },
254 ClientCapabilities {
255 sampling: None,
256 elicitation: Some(ElicitationCapability::form()),
257 roots: None,
258 },
259 "2024-11-05".to_string(),
260 );
261
262 assert!(!session.supports_sampling());
263 assert!(session.supports_elicitation());
264 assert!(!session.supports_roots());
265 }
266
267 #[test]
268 fn test_session_supports_roots() {
269 let mut session = Session::new(
270 ServerInfo {
271 name: "test".to_string(),
272 version: "1.0".to_string(),
273 },
274 ServerCapabilities::default(),
275 );
276
277 session.initialize(
278 ClientInfo {
279 name: "test-client".to_string(),
280 version: "1.0".to_string(),
281 },
282 ClientCapabilities {
283 sampling: None,
284 elicitation: None,
285 roots: Some(RootsCapability { list_changed: true }),
286 },
287 "2024-11-05".to_string(),
288 );
289
290 assert!(!session.supports_sampling());
291 assert!(!session.supports_elicitation());
292 assert!(session.supports_roots());
293 }
294
295 #[test]
296 fn test_session_supports_all_capabilities() {
297 let mut session = Session::new(
298 ServerInfo {
299 name: "test".to_string(),
300 version: "1.0".to_string(),
301 },
302 ServerCapabilities::default(),
303 );
304
305 session.initialize(
306 ClientInfo {
307 name: "test-client".to_string(),
308 version: "1.0".to_string(),
309 },
310 ClientCapabilities {
311 sampling: Some(SamplingCapability {}),
312 elicitation: Some(ElicitationCapability::both()),
313 roots: Some(RootsCapability {
314 list_changed: false,
315 }),
316 },
317 "2024-11-05".to_string(),
318 );
319
320 assert!(session.supports_sampling());
321 assert!(session.supports_elicitation());
322 assert!(session.supports_roots());
323 }
324
325 #[test]
326 fn test_session_no_capabilities() {
327 let mut session = Session::new(
328 ServerInfo {
329 name: "test".to_string(),
330 version: "1.0".to_string(),
331 },
332 ServerCapabilities::default(),
333 );
334
335 session.initialize(
336 ClientInfo {
337 name: "test-client".to_string(),
338 version: "1.0".to_string(),
339 },
340 ClientCapabilities::default(),
341 "2024-11-05".to_string(),
342 );
343
344 assert!(!session.supports_sampling());
345 assert!(!session.supports_elicitation());
346 assert!(!session.supports_roots());
347 }
348}