mcpkit_server/
context.rs

1//! Request context for MCP handlers.
2//!
3//! The context provides access to the current request state and allows
4//! handlers to interact with the connection (sending notifications,
5//! progress updates, etc.).
6//!
7//! # Key Features
8//!
9//! - **Borrowing-friendly**: Uses lifetime references, NO `'static` requirement
10//! - **Progress reporting**: Send progress updates for long-running operations
11//! - **Cancellation**: Check if the request has been cancelled
12//! - **Notifications**: Send notifications back to the client via Peer trait
13//!
14//! # Example
15//!
16//! ```rust
17//! use mcpkit_server::{Context, NoOpPeer, ContextData};
18//! use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
19//! use mcpkit_core::protocol::RequestId;
20//!
21//! // Create test context data
22//! let data = ContextData::new(
23//!     RequestId::Number(1),
24//!     ClientCapabilities::default(),
25//!     ServerCapabilities::default(),
26//! );
27//! let peer = NoOpPeer;
28//!
29//! // Create a context from the data
30//! let ctx = Context::new(
31//!     &data.request_id,
32//!     data.progress_token.as_ref(),
33//!     &data.client_caps,
34//!     &data.server_caps,
35//!     &peer,
36//! );
37//!
38//! // Check for cancellation
39//! assert!(!ctx.is_cancelled());
40//! ```
41
42use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
43use mcpkit_core::error::McpError;
44use mcpkit_core::protocol::{Notification, ProgressToken, RequestId};
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::sync::Arc;
49
50/// Trait for sending messages to the peer (client or server).
51///
52/// This trait abstracts over the transport layer, allowing the context
53/// to send notifications without knowing the underlying transport.
54pub trait Peer: Send + Sync {
55    /// Send a notification to the peer.
56    fn notify(
57        &self,
58        notification: Notification,
59    ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>>;
60}
61
62/// A cancellation token for tracking request cancellation.
63///
64/// This is a simple wrapper around an atomic boolean that can be
65/// shared across threads and checked for cancellation.
66#[derive(Debug, Clone)]
67pub struct CancellationToken {
68    cancelled: Arc<AtomicBool>,
69}
70
71impl CancellationToken {
72    /// Create a new cancellation token.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            cancelled: Arc::new(AtomicBool::new(false)),
77        }
78    }
79
80    /// Check if cancellation has been requested.
81    #[must_use]
82    pub fn is_cancelled(&self) -> bool {
83        self.cancelled.load(Ordering::SeqCst)
84    }
85
86    /// Request cancellation.
87    pub fn cancel(&self) {
88        self.cancelled.store(true, Ordering::SeqCst);
89    }
90
91    /// Wait for cancellation.
92    ///
93    /// Returns a future that completes when cancellation is requested.
94    ///
95    /// Note: In a production implementation, this would integrate with the
96    /// runtime's notification system. This simple implementation polls
97    /// the atomic flag.
98    #[must_use]
99    pub fn cancelled(&self) -> CancelledFuture {
100        CancelledFuture {
101            cancelled: self.cancelled.clone(),
102        }
103    }
104}
105
106impl Default for CancellationToken {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112/// A future that completes when cancellation is requested.
113pub struct CancelledFuture {
114    cancelled: Arc<AtomicBool>,
115}
116
117impl Future for CancelledFuture {
118    type Output = ();
119
120    fn poll(
121        self: Pin<&mut Self>,
122        cx: &mut std::task::Context<'_>,
123    ) -> std::task::Poll<Self::Output> {
124        if self.cancelled.load(Ordering::SeqCst) {
125            std::task::Poll::Ready(())
126        } else {
127            // Wake up later to check again
128            // In production, this would register with a proper notification system
129            cx.waker().wake_by_ref();
130            std::task::Poll::Pending
131        }
132    }
133}
134
135/// Request context passed to handler methods.
136///
137/// The context uses lifetime references to avoid `'static` requirements
138/// and Arc overhead. This enables:
139/// - Single-threaded async without Arc overhead
140/// - `!Send` types in handlers (important for some runtimes)
141/// - Users who need spawning can wrap in Arc themselves
142///
143/// Per the plan: "Request context - passed by reference, NO 'static requirement"
144pub struct Context<'a> {
145    /// The request ID for this operation.
146    pub request_id: &'a RequestId,
147    /// Optional progress token for reporting progress.
148    pub progress_token: Option<&'a ProgressToken>,
149    /// Client capabilities negotiated during initialization.
150    pub client_caps: &'a ClientCapabilities,
151    /// Server capabilities advertised during initialization.
152    pub server_caps: &'a ServerCapabilities,
153    /// Peer for sending notifications.
154    peer: &'a dyn Peer,
155    /// Cancellation token for this request.
156    cancel: CancellationToken,
157}
158
159impl<'a> Context<'a> {
160    /// Create a new context with all required references.
161    #[must_use]
162    pub fn new(
163        request_id: &'a RequestId,
164        progress_token: Option<&'a ProgressToken>,
165        client_caps: &'a ClientCapabilities,
166        server_caps: &'a ServerCapabilities,
167        peer: &'a dyn Peer,
168    ) -> Self {
169        Self {
170            request_id,
171            progress_token,
172            client_caps,
173            server_caps,
174            peer,
175            cancel: CancellationToken::new(),
176        }
177    }
178
179    /// Create a new context with a custom cancellation token.
180    #[must_use]
181    pub fn with_cancellation(
182        request_id: &'a RequestId,
183        progress_token: Option<&'a ProgressToken>,
184        client_caps: &'a ClientCapabilities,
185        server_caps: &'a ServerCapabilities,
186        peer: &'a dyn Peer,
187        cancel: CancellationToken,
188    ) -> Self {
189        Self {
190            request_id,
191            progress_token,
192            client_caps,
193            server_caps,
194            peer,
195            cancel,
196        }
197    }
198
199    /// Check if the request has been cancelled.
200    #[must_use]
201    pub fn is_cancelled(&self) -> bool {
202        self.cancel.is_cancelled()
203    }
204
205    /// Get a future that completes when the request is cancelled.
206    pub fn cancelled(&self) -> impl Future<Output = ()> + '_ {
207        self.cancel.cancelled()
208    }
209
210    /// Get the cancellation token for this context.
211    #[must_use]
212    pub const fn cancellation_token(&self) -> &CancellationToken {
213        &self.cancel
214    }
215
216    /// Send a notification to the client.
217    ///
218    /// # Arguments
219    ///
220    /// * `method` - The notification method name
221    /// * `params` - Optional notification parameters
222    ///
223    /// # Errors
224    ///
225    /// Returns an error if the notification could not be sent.
226    pub async fn notify(
227        &self,
228        method: &str,
229        params: Option<serde_json::Value>,
230    ) -> Result<(), McpError> {
231        let notification = if let Some(p) = params {
232            Notification::with_params(method.to_string(), p)
233        } else {
234            Notification::new(method.to_string())
235        };
236        self.peer.notify(notification).await
237    }
238
239    /// Report progress for this operation.
240    ///
241    /// This sends a progress notification to the client if a progress token
242    /// was provided with the request.
243    ///
244    /// # Arguments
245    ///
246    /// * `current` - Current progress value
247    /// * `total` - Total progress value (if known)
248    /// * `message` - Optional progress message
249    ///
250    /// # Errors
251    ///
252    /// Returns an error if the notification could not be sent.
253    pub async fn progress(
254        &self,
255        current: u64,
256        total: Option<u64>,
257        message: Option<&str>,
258    ) -> Result<(), McpError> {
259        let Some(token) = self.progress_token else {
260            // No progress token, silently succeed
261            return Ok(());
262        };
263
264        let params = serde_json::json!({
265            "progressToken": token,
266            "progress": current,
267            "total": total,
268            "message": message,
269        });
270
271        self.notify("notifications/progress", Some(params)).await
272    }
273}
274
275impl std::fmt::Debug for Context<'_> {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        f.debug_struct("Context")
278            .field("request_id", &self.request_id)
279            .field("progress_token", &self.progress_token)
280            .field("client_caps", &self.client_caps)
281            .field("server_caps", &self.server_caps)
282            .field("is_cancelled", &self.is_cancelled())
283            .finish()
284    }
285}
286
287/// A no-op peer implementation for testing.
288///
289/// This peer silently accepts all notifications without sending them anywhere.
290#[derive(Debug, Clone, Copy)]
291pub struct NoOpPeer;
292
293impl Peer for NoOpPeer {
294    fn notify(
295        &self,
296        _notification: Notification,
297    ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
298        Box::pin(async { Ok(()) })
299    }
300}
301
302/// Owned data for creating contexts.
303///
304/// This struct holds owned copies of all the data needed to create a Context.
305/// It's useful when you need to create contexts from owned data.
306pub struct ContextData {
307    /// The request ID.
308    pub request_id: RequestId,
309    /// Optional progress token.
310    pub progress_token: Option<ProgressToken>,
311    /// Client capabilities.
312    pub client_caps: ClientCapabilities,
313    /// Server capabilities.
314    pub server_caps: ServerCapabilities,
315}
316
317impl ContextData {
318    /// Create a new context data struct.
319    #[must_use]
320    pub const fn new(
321        request_id: RequestId,
322        client_caps: ClientCapabilities,
323        server_caps: ServerCapabilities,
324    ) -> Self {
325        Self {
326            request_id,
327            progress_token: None,
328            client_caps,
329            server_caps,
330        }
331    }
332
333    /// Set the progress token.
334    #[must_use]
335    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
336        self.progress_token = Some(token);
337        self
338    }
339
340    /// Create a context from this data with the given peer.
341    #[must_use]
342    pub fn to_context<'a>(&'a self, peer: &'a dyn Peer) -> Context<'a> {
343        Context::new(
344            &self.request_id,
345            self.progress_token.as_ref(),
346            &self.client_caps,
347            &self.server_caps,
348            peer,
349        )
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_cancellation_token() {
359        let token = CancellationToken::new();
360        assert!(!token.is_cancelled());
361        token.cancel();
362        assert!(token.is_cancelled());
363    }
364
365    #[test]
366    fn test_context_creation() {
367        let request_id = RequestId::Number(1);
368        let client_caps = ClientCapabilities::default();
369        let server_caps = ServerCapabilities::default();
370        let peer = NoOpPeer;
371
372        let ctx = Context::new(&request_id, None, &client_caps, &server_caps, &peer);
373
374        assert!(!ctx.is_cancelled());
375        assert!(ctx.progress_token.is_none());
376    }
377
378    #[test]
379    fn test_context_with_progress_token() {
380        let request_id = RequestId::Number(1);
381        let progress_token = ProgressToken::String("token".to_string());
382        let client_caps = ClientCapabilities::default();
383        let server_caps = ServerCapabilities::default();
384        let peer = NoOpPeer;
385
386        let ctx = Context::new(
387            &request_id,
388            Some(&progress_token),
389            &client_caps,
390            &server_caps,
391            &peer,
392        );
393
394        assert!(ctx.progress_token.is_some());
395    }
396
397    #[test]
398    fn test_context_data() {
399        let data = ContextData::new(
400            RequestId::Number(42),
401            ClientCapabilities::default(),
402            ServerCapabilities::default(),
403        )
404        .with_progress_token(ProgressToken::String("test".to_string()));
405
406        let peer = NoOpPeer;
407        let ctx = data.to_context(&peer);
408
409        assert!(ctx.progress_token.is_some());
410    }
411}