1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
43use mcpkit_core::error::McpError;
44use mcpkit_core::protocol::{Notification, ProgressToken, RequestId};
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::sync::Arc;
49
50pub trait Peer: Send + Sync {
55 fn notify(
57 &self,
58 notification: Notification,
59 ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>>;
60}
61
62#[derive(Debug, Clone)]
67pub struct CancellationToken {
68 cancelled: Arc<AtomicBool>,
69}
70
71impl CancellationToken {
72 #[must_use]
74 pub fn new() -> Self {
75 Self {
76 cancelled: Arc::new(AtomicBool::new(false)),
77 }
78 }
79
80 #[must_use]
82 pub fn is_cancelled(&self) -> bool {
83 self.cancelled.load(Ordering::SeqCst)
84 }
85
86 pub fn cancel(&self) {
88 self.cancelled.store(true, Ordering::SeqCst);
89 }
90
91 #[must_use]
99 pub fn cancelled(&self) -> CancelledFuture {
100 CancelledFuture {
101 cancelled: self.cancelled.clone(),
102 }
103 }
104}
105
106impl Default for CancellationToken {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112pub struct CancelledFuture {
114 cancelled: Arc<AtomicBool>,
115}
116
117impl Future for CancelledFuture {
118 type Output = ();
119
120 fn poll(
121 self: Pin<&mut Self>,
122 cx: &mut std::task::Context<'_>,
123 ) -> std::task::Poll<Self::Output> {
124 if self.cancelled.load(Ordering::SeqCst) {
125 std::task::Poll::Ready(())
126 } else {
127 cx.waker().wake_by_ref();
130 std::task::Poll::Pending
131 }
132 }
133}
134
135pub struct Context<'a> {
145 pub request_id: &'a RequestId,
147 pub progress_token: Option<&'a ProgressToken>,
149 pub client_caps: &'a ClientCapabilities,
151 pub server_caps: &'a ServerCapabilities,
153 peer: &'a dyn Peer,
155 cancel: CancellationToken,
157}
158
159impl<'a> Context<'a> {
160 #[must_use]
162 pub fn new(
163 request_id: &'a RequestId,
164 progress_token: Option<&'a ProgressToken>,
165 client_caps: &'a ClientCapabilities,
166 server_caps: &'a ServerCapabilities,
167 peer: &'a dyn Peer,
168 ) -> Self {
169 Self {
170 request_id,
171 progress_token,
172 client_caps,
173 server_caps,
174 peer,
175 cancel: CancellationToken::new(),
176 }
177 }
178
179 #[must_use]
181 pub fn with_cancellation(
182 request_id: &'a RequestId,
183 progress_token: Option<&'a ProgressToken>,
184 client_caps: &'a ClientCapabilities,
185 server_caps: &'a ServerCapabilities,
186 peer: &'a dyn Peer,
187 cancel: CancellationToken,
188 ) -> Self {
189 Self {
190 request_id,
191 progress_token,
192 client_caps,
193 server_caps,
194 peer,
195 cancel,
196 }
197 }
198
199 #[must_use]
201 pub fn is_cancelled(&self) -> bool {
202 self.cancel.is_cancelled()
203 }
204
205 pub fn cancelled(&self) -> impl Future<Output = ()> + '_ {
207 self.cancel.cancelled()
208 }
209
210 #[must_use]
212 pub const fn cancellation_token(&self) -> &CancellationToken {
213 &self.cancel
214 }
215
216 pub async fn notify(
227 &self,
228 method: &str,
229 params: Option<serde_json::Value>,
230 ) -> Result<(), McpError> {
231 let notification = if let Some(p) = params {
232 Notification::with_params(method.to_string(), p)
233 } else {
234 Notification::new(method.to_string())
235 };
236 self.peer.notify(notification).await
237 }
238
239 pub async fn progress(
254 &self,
255 current: u64,
256 total: Option<u64>,
257 message: Option<&str>,
258 ) -> Result<(), McpError> {
259 let Some(token) = self.progress_token else {
260 return Ok(());
262 };
263
264 let params = serde_json::json!({
265 "progressToken": token,
266 "progress": current,
267 "total": total,
268 "message": message,
269 });
270
271 self.notify("notifications/progress", Some(params)).await
272 }
273}
274
275impl std::fmt::Debug for Context<'_> {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 f.debug_struct("Context")
278 .field("request_id", &self.request_id)
279 .field("progress_token", &self.progress_token)
280 .field("client_caps", &self.client_caps)
281 .field("server_caps", &self.server_caps)
282 .field("is_cancelled", &self.is_cancelled())
283 .finish()
284 }
285}
286
287#[derive(Debug, Clone, Copy)]
291pub struct NoOpPeer;
292
293impl Peer for NoOpPeer {
294 fn notify(
295 &self,
296 _notification: Notification,
297 ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
298 Box::pin(async { Ok(()) })
299 }
300}
301
302pub struct ContextData {
307 pub request_id: RequestId,
309 pub progress_token: Option<ProgressToken>,
311 pub client_caps: ClientCapabilities,
313 pub server_caps: ServerCapabilities,
315}
316
317impl ContextData {
318 #[must_use]
320 pub const fn new(
321 request_id: RequestId,
322 client_caps: ClientCapabilities,
323 server_caps: ServerCapabilities,
324 ) -> Self {
325 Self {
326 request_id,
327 progress_token: None,
328 client_caps,
329 server_caps,
330 }
331 }
332
333 #[must_use]
335 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
336 self.progress_token = Some(token);
337 self
338 }
339
340 #[must_use]
342 pub fn to_context<'a>(&'a self, peer: &'a dyn Peer) -> Context<'a> {
343 Context::new(
344 &self.request_id,
345 self.progress_token.as_ref(),
346 &self.client_caps,
347 &self.server_caps,
348 peer,
349 )
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_cancellation_token() {
359 let token = CancellationToken::new();
360 assert!(!token.is_cancelled());
361 token.cancel();
362 assert!(token.is_cancelled());
363 }
364
365 #[test]
366 fn test_context_creation() {
367 let request_id = RequestId::Number(1);
368 let client_caps = ClientCapabilities::default();
369 let server_caps = ServerCapabilities::default();
370 let peer = NoOpPeer;
371
372 let ctx = Context::new(&request_id, None, &client_caps, &server_caps, &peer);
373
374 assert!(!ctx.is_cancelled());
375 assert!(ctx.progress_token.is_none());
376 }
377
378 #[test]
379 fn test_context_with_progress_token() {
380 let request_id = RequestId::Number(1);
381 let progress_token = ProgressToken::String("token".to_string());
382 let client_caps = ClientCapabilities::default();
383 let server_caps = ServerCapabilities::default();
384 let peer = NoOpPeer;
385
386 let ctx = Context::new(
387 &request_id,
388 Some(&progress_token),
389 &client_caps,
390 &server_caps,
391 &peer,
392 );
393
394 assert!(ctx.progress_token.is_some());
395 }
396
397 #[test]
398 fn test_context_data() {
399 let data = ContextData::new(
400 RequestId::Number(42),
401 ClientCapabilities::default(),
402 ServerCapabilities::default(),
403 )
404 .with_progress_token(ProgressToken::String("test".to_string()));
405
406 let peer = NoOpPeer;
407 let ctx = data.to_context(&peer);
408
409 assert!(ctx.progress_token.is_some());
410 }
411}