1use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
43use mcpkit_core::error::McpError;
44use mcpkit_core::protocol::{Notification, ProgressToken, RequestId};
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::sync::Arc;
49
50pub trait Peer: Send + Sync {
55 fn notify(&self, notification: Notification) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>>;
57}
58
59#[derive(Debug, Clone)]
64pub struct CancellationToken {
65 cancelled: Arc<AtomicBool>,
66}
67
68impl CancellationToken {
69 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 cancelled: Arc::new(AtomicBool::new(false)),
74 }
75 }
76
77 #[must_use]
79 pub fn is_cancelled(&self) -> bool {
80 self.cancelled.load(Ordering::SeqCst)
81 }
82
83 pub fn cancel(&self) {
85 self.cancelled.store(true, Ordering::SeqCst);
86 }
87
88 pub fn cancelled(&self) -> CancelledFuture {
96 CancelledFuture {
97 cancelled: self.cancelled.clone(),
98 }
99 }
100}
101
102impl Default for CancellationToken {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108pub struct CancelledFuture {
110 cancelled: Arc<AtomicBool>,
111}
112
113impl Future for CancelledFuture {
114 type Output = ();
115
116 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
117 if self.cancelled.load(Ordering::SeqCst) {
118 std::task::Poll::Ready(())
119 } else {
120 cx.waker().wake_by_ref();
123 std::task::Poll::Pending
124 }
125 }
126}
127
128pub struct Context<'a> {
138 pub request_id: &'a RequestId,
140 pub progress_token: Option<&'a ProgressToken>,
142 pub client_caps: &'a ClientCapabilities,
144 pub server_caps: &'a ServerCapabilities,
146 peer: &'a dyn Peer,
148 cancel: CancellationToken,
150}
151
152impl<'a> Context<'a> {
153 #[must_use]
155 pub fn new(
156 request_id: &'a RequestId,
157 progress_token: Option<&'a ProgressToken>,
158 client_caps: &'a ClientCapabilities,
159 server_caps: &'a ServerCapabilities,
160 peer: &'a dyn Peer,
161 ) -> Self {
162 Self {
163 request_id,
164 progress_token,
165 client_caps,
166 server_caps,
167 peer,
168 cancel: CancellationToken::new(),
169 }
170 }
171
172 #[must_use]
174 pub fn with_cancellation(
175 request_id: &'a RequestId,
176 progress_token: Option<&'a ProgressToken>,
177 client_caps: &'a ClientCapabilities,
178 server_caps: &'a ServerCapabilities,
179 peer: &'a dyn Peer,
180 cancel: CancellationToken,
181 ) -> Self {
182 Self {
183 request_id,
184 progress_token,
185 client_caps,
186 server_caps,
187 peer,
188 cancel,
189 }
190 }
191
192 #[must_use]
194 pub fn is_cancelled(&self) -> bool {
195 self.cancel.is_cancelled()
196 }
197
198 pub fn cancelled(&self) -> impl Future<Output = ()> + '_ {
200 self.cancel.cancelled()
201 }
202
203 #[must_use]
205 pub fn cancellation_token(&self) -> &CancellationToken {
206 &self.cancel
207 }
208
209 pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<(), McpError> {
220 let notification = if let Some(p) = params {
221 Notification::with_params(method.to_string(), p)
222 } else {
223 Notification::new(method.to_string())
224 };
225 self.peer.notify(notification).await
226 }
227
228 pub async fn progress(
243 &self,
244 current: u64,
245 total: Option<u64>,
246 message: Option<&str>,
247 ) -> Result<(), McpError> {
248 let Some(token) = self.progress_token else {
249 return Ok(());
251 };
252
253 let params = serde_json::json!({
254 "progressToken": token,
255 "progress": current,
256 "total": total,
257 "message": message,
258 });
259
260 self.notify("notifications/progress", Some(params)).await
261 }
262}
263
264impl std::fmt::Debug for Context<'_> {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 f.debug_struct("Context")
267 .field("request_id", &self.request_id)
268 .field("progress_token", &self.progress_token)
269 .field("client_caps", &self.client_caps)
270 .field("server_caps", &self.server_caps)
271 .field("is_cancelled", &self.is_cancelled())
272 .finish()
273 }
274}
275
276#[derive(Debug, Clone, Copy)]
280pub struct NoOpPeer;
281
282impl Peer for NoOpPeer {
283 fn notify(&self, _notification: Notification) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
284 Box::pin(async { Ok(()) })
285 }
286}
287
288pub struct ContextData {
293 pub request_id: RequestId,
295 pub progress_token: Option<ProgressToken>,
297 pub client_caps: ClientCapabilities,
299 pub server_caps: ServerCapabilities,
301}
302
303impl ContextData {
304 #[must_use]
306 pub fn new(
307 request_id: RequestId,
308 client_caps: ClientCapabilities,
309 server_caps: ServerCapabilities,
310 ) -> Self {
311 Self {
312 request_id,
313 progress_token: None,
314 client_caps,
315 server_caps,
316 }
317 }
318
319 #[must_use]
321 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
322 self.progress_token = Some(token);
323 self
324 }
325
326 #[must_use]
328 pub fn to_context<'a>(&'a self, peer: &'a dyn Peer) -> Context<'a> {
329 Context::new(
330 &self.request_id,
331 self.progress_token.as_ref(),
332 &self.client_caps,
333 &self.server_caps,
334 peer,
335 )
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_cancellation_token() {
345 let token = CancellationToken::new();
346 assert!(!token.is_cancelled());
347 token.cancel();
348 assert!(token.is_cancelled());
349 }
350
351 #[test]
352 fn test_context_creation() {
353 let request_id = RequestId::Number(1);
354 let client_caps = ClientCapabilities::default();
355 let server_caps = ServerCapabilities::default();
356 let peer = NoOpPeer;
357
358 let ctx = Context::new(
359 &request_id,
360 None,
361 &client_caps,
362 &server_caps,
363 &peer,
364 );
365
366 assert!(!ctx.is_cancelled());
367 assert!(ctx.progress_token.is_none());
368 }
369
370 #[test]
371 fn test_context_with_progress_token() {
372 let request_id = RequestId::Number(1);
373 let progress_token = ProgressToken::String("token".to_string());
374 let client_caps = ClientCapabilities::default();
375 let server_caps = ServerCapabilities::default();
376 let peer = NoOpPeer;
377
378 let ctx = Context::new(
379 &request_id,
380 Some(&progress_token),
381 &client_caps,
382 &server_caps,
383 &peer,
384 );
385
386 assert!(ctx.progress_token.is_some());
387 }
388
389 #[test]
390 fn test_context_data() {
391 let data = ContextData::new(
392 RequestId::Number(42),
393 ClientCapabilities::default(),
394 ServerCapabilities::default(),
395 )
396 .with_progress_token(ProgressToken::String("test".to_string()));
397
398 let peer = NoOpPeer;
399 let ctx = data.to_context(&peer);
400
401 assert!(ctx.progress_token.is_some());
402 }
403}