1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
47use mcpkit_core::error::McpError;
48use mcpkit_core::protocol::{Notification, ProgressToken, RequestId};
49use mcpkit_core::protocol_version::ProtocolVersion;
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::Arc;
53use std::sync::atomic::{AtomicBool, Ordering};
54
55pub trait Peer: Send + Sync {
60 fn notify(
62 &self,
63 notification: Notification,
64 ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>>;
65}
66
67#[derive(Debug, Clone)]
72pub struct CancellationToken {
73 cancelled: Arc<AtomicBool>,
74}
75
76impl CancellationToken {
77 #[must_use]
79 pub fn new() -> Self {
80 Self {
81 cancelled: Arc::new(AtomicBool::new(false)),
82 }
83 }
84
85 #[must_use]
87 pub fn is_cancelled(&self) -> bool {
88 self.cancelled.load(Ordering::SeqCst)
89 }
90
91 pub fn cancel(&self) {
93 self.cancelled.store(true, Ordering::SeqCst);
94 }
95
96 #[must_use]
104 pub fn cancelled(&self) -> CancelledFuture {
105 CancelledFuture {
106 cancelled: self.cancelled.clone(),
107 }
108 }
109}
110
111impl Default for CancellationToken {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117pub struct CancelledFuture {
119 cancelled: Arc<AtomicBool>,
120}
121
122impl Future for CancelledFuture {
123 type Output = ();
124
125 fn poll(
126 self: Pin<&mut Self>,
127 cx: &mut std::task::Context<'_>,
128 ) -> std::task::Poll<Self::Output> {
129 if self.cancelled.load(Ordering::SeqCst) {
130 std::task::Poll::Ready(())
131 } else {
132 cx.waker().wake_by_ref();
135 std::task::Poll::Pending
136 }
137 }
138}
139
140pub struct Context<'a> {
150 pub request_id: &'a RequestId,
152 pub progress_token: Option<&'a ProgressToken>,
154 pub client_caps: &'a ClientCapabilities,
156 pub server_caps: &'a ServerCapabilities,
158 pub protocol_version: ProtocolVersion,
163 peer: &'a dyn Peer,
165 cancel: CancellationToken,
167}
168
169impl<'a> Context<'a> {
170 #[must_use]
172 pub fn new(
173 request_id: &'a RequestId,
174 progress_token: Option<&'a ProgressToken>,
175 client_caps: &'a ClientCapabilities,
176 server_caps: &'a ServerCapabilities,
177 protocol_version: ProtocolVersion,
178 peer: &'a dyn Peer,
179 ) -> Self {
180 Self {
181 request_id,
182 progress_token,
183 client_caps,
184 server_caps,
185 protocol_version,
186 peer,
187 cancel: CancellationToken::new(),
188 }
189 }
190
191 #[must_use]
193 pub fn with_cancellation(
194 request_id: &'a RequestId,
195 progress_token: Option<&'a ProgressToken>,
196 client_caps: &'a ClientCapabilities,
197 server_caps: &'a ServerCapabilities,
198 protocol_version: ProtocolVersion,
199 peer: &'a dyn Peer,
200 cancel: CancellationToken,
201 ) -> Self {
202 Self {
203 request_id,
204 progress_token,
205 client_caps,
206 server_caps,
207 protocol_version,
208 peer,
209 cancel,
210 }
211 }
212
213 #[must_use]
215 pub fn is_cancelled(&self) -> bool {
216 self.cancel.is_cancelled()
217 }
218
219 pub fn cancelled(&self) -> impl Future<Output = ()> + '_ {
221 self.cancel.cancelled()
222 }
223
224 #[must_use]
226 pub const fn cancellation_token(&self) -> &CancellationToken {
227 &self.cancel
228 }
229
230 pub async fn notify(
241 &self,
242 method: &str,
243 params: Option<serde_json::Value>,
244 ) -> Result<(), McpError> {
245 let notification = if let Some(p) = params {
246 Notification::with_params(method.to_string(), p)
247 } else {
248 Notification::new(method.to_string())
249 };
250 self.peer.notify(notification).await
251 }
252
253 pub async fn progress(
268 &self,
269 current: u64,
270 total: Option<u64>,
271 message: Option<&str>,
272 ) -> Result<(), McpError> {
273 let Some(token) = self.progress_token else {
274 return Ok(());
276 };
277
278 let params = serde_json::json!({
279 "progressToken": token,
280 "progress": current,
281 "total": total,
282 "message": message,
283 });
284
285 self.notify("notifications/progress", Some(params)).await
286 }
287}
288
289impl std::fmt::Debug for Context<'_> {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("Context")
292 .field("request_id", &self.request_id)
293 .field("progress_token", &self.progress_token)
294 .field("client_caps", &self.client_caps)
295 .field("server_caps", &self.server_caps)
296 .field("protocol_version", &self.protocol_version)
297 .field("is_cancelled", &self.is_cancelled())
298 .finish()
299 }
300}
301
302#[derive(Debug, Clone, Copy)]
306pub struct NoOpPeer;
307
308impl Peer for NoOpPeer {
309 fn notify(
310 &self,
311 _notification: Notification,
312 ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
313 Box::pin(async { Ok(()) })
314 }
315}
316
317pub struct ContextData {
322 pub request_id: RequestId,
324 pub progress_token: Option<ProgressToken>,
326 pub client_caps: ClientCapabilities,
328 pub server_caps: ServerCapabilities,
330 pub protocol_version: ProtocolVersion,
332}
333
334impl ContextData {
335 #[must_use]
337 pub const fn new(
338 request_id: RequestId,
339 client_caps: ClientCapabilities,
340 server_caps: ServerCapabilities,
341 protocol_version: ProtocolVersion,
342 ) -> Self {
343 Self {
344 request_id,
345 progress_token: None,
346 client_caps,
347 server_caps,
348 protocol_version,
349 }
350 }
351
352 #[must_use]
354 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
355 self.progress_token = Some(token);
356 self
357 }
358
359 #[must_use]
361 pub fn to_context<'a>(&'a self, peer: &'a dyn Peer) -> Context<'a> {
362 Context::new(
363 &self.request_id,
364 self.progress_token.as_ref(),
365 &self.client_caps,
366 &self.server_caps,
367 self.protocol_version,
368 peer,
369 )
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_cancellation_token() {
379 let token = CancellationToken::new();
380 assert!(!token.is_cancelled());
381 token.cancel();
382 assert!(token.is_cancelled());
383 }
384
385 #[test]
386 fn test_context_creation() {
387 let request_id = RequestId::Number(1);
388 let client_caps = ClientCapabilities::default();
389 let server_caps = ServerCapabilities::default();
390 let peer = NoOpPeer;
391
392 let ctx = Context::new(
393 &request_id,
394 None,
395 &client_caps,
396 &server_caps,
397 ProtocolVersion::LATEST,
398 &peer,
399 );
400
401 assert!(!ctx.is_cancelled());
402 assert!(ctx.progress_token.is_none());
403 assert_eq!(ctx.protocol_version, ProtocolVersion::LATEST);
404 }
405
406 #[test]
407 fn test_context_with_progress_token() {
408 let request_id = RequestId::Number(1);
409 let progress_token = ProgressToken::String("token".to_string());
410 let client_caps = ClientCapabilities::default();
411 let server_caps = ServerCapabilities::default();
412 let peer = NoOpPeer;
413
414 let ctx = Context::new(
415 &request_id,
416 Some(&progress_token),
417 &client_caps,
418 &server_caps,
419 ProtocolVersion::V2025_03_26,
420 &peer,
421 );
422
423 assert!(ctx.progress_token.is_some());
424 assert_eq!(ctx.protocol_version, ProtocolVersion::V2025_03_26);
425 }
426
427 #[test]
428 fn test_context_data() {
429 let data = ContextData::new(
430 RequestId::Number(42),
431 ClientCapabilities::default(),
432 ServerCapabilities::default(),
433 ProtocolVersion::V2025_06_18,
434 )
435 .with_progress_token(ProgressToken::String("test".to_string()));
436
437 let peer = NoOpPeer;
438 let ctx = data.to_context(&peer);
439
440 assert!(ctx.progress_token.is_some());
441 assert_eq!(ctx.protocol_version, ProtocolVersion::V2025_06_18);
442 assert!(ctx.protocol_version.supports_elicitation());
444 assert!(!ctx.protocol_version.supports_tasks()); }
446}