mcpkit_server/
context.rs

1//! Request context for MCP handlers.
2//!
3//! The context provides access to the current request state and allows
4//! handlers to interact with the connection (sending notifications,
5//! progress updates, etc.).
6//!
7//! # Key Features
8//!
9//! - **Borrowing-friendly**: Uses lifetime references, NO `'static` requirement
10//! - **Progress reporting**: Send progress updates for long-running operations
11//! - **Cancellation**: Check if the request has been cancelled
12//! - **Notifications**: Send notifications back to the client via Peer trait
13//!
14//! # Example
15//!
16//! ```rust
17//! use mcpkit_server::{Context, NoOpPeer, ContextData};
18//! use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
19//! use mcpkit_core::protocol::RequestId;
20//!
21//! // Create test context data
22//! let data = ContextData::new(
23//!     RequestId::Number(1),
24//!     ClientCapabilities::default(),
25//!     ServerCapabilities::default(),
26//! );
27//! let peer = NoOpPeer;
28//!
29//! // Create a context from the data
30//! let ctx = Context::new(
31//!     &data.request_id,
32//!     data.progress_token.as_ref(),
33//!     &data.client_caps,
34//!     &data.server_caps,
35//!     &peer,
36//! );
37//!
38//! // Check for cancellation
39//! assert!(!ctx.is_cancelled());
40//! ```
41
42use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
43use mcpkit_core::error::McpError;
44use mcpkit_core::protocol::{Notification, ProgressToken, RequestId};
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::sync::Arc;
49
50/// Trait for sending messages to the peer (client or server).
51///
52/// This trait abstracts over the transport layer, allowing the context
53/// to send notifications without knowing the underlying transport.
54pub trait Peer: Send + Sync {
55    /// Send a notification to the peer.
56    fn notify(&self, notification: Notification) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>>;
57}
58
59/// A cancellation token for tracking request cancellation.
60///
61/// This is a simple wrapper around an atomic boolean that can be
62/// shared across threads and checked for cancellation.
63#[derive(Debug, Clone)]
64pub struct CancellationToken {
65    cancelled: Arc<AtomicBool>,
66}
67
68impl CancellationToken {
69    /// Create a new cancellation token.
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            cancelled: Arc::new(AtomicBool::new(false)),
74        }
75    }
76
77    /// Check if cancellation has been requested.
78    #[must_use]
79    pub fn is_cancelled(&self) -> bool {
80        self.cancelled.load(Ordering::SeqCst)
81    }
82
83    /// Request cancellation.
84    pub fn cancel(&self) {
85        self.cancelled.store(true, Ordering::SeqCst);
86    }
87
88    /// Wait for cancellation.
89    ///
90    /// Returns a future that completes when cancellation is requested.
91    ///
92    /// Note: In a production implementation, this would integrate with the
93    /// runtime's notification system. This simple implementation polls
94    /// the atomic flag.
95    pub fn cancelled(&self) -> CancelledFuture {
96        CancelledFuture {
97            cancelled: self.cancelled.clone(),
98        }
99    }
100}
101
102impl Default for CancellationToken {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108/// A future that completes when cancellation is requested.
109pub struct CancelledFuture {
110    cancelled: Arc<AtomicBool>,
111}
112
113impl Future for CancelledFuture {
114    type Output = ();
115
116    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
117        if self.cancelled.load(Ordering::SeqCst) {
118            std::task::Poll::Ready(())
119        } else {
120            // Wake up later to check again
121            // In production, this would register with a proper notification system
122            cx.waker().wake_by_ref();
123            std::task::Poll::Pending
124        }
125    }
126}
127
128/// Request context passed to handler methods.
129///
130/// The context uses lifetime references to avoid `'static` requirements
131/// and Arc overhead. This enables:
132/// - Single-threaded async without Arc overhead
133/// - `!Send` types in handlers (important for some runtimes)
134/// - Users who need spawning can wrap in Arc themselves
135///
136/// Per the plan: "Request context - passed by reference, NO 'static requirement"
137pub struct Context<'a> {
138    /// The request ID for this operation.
139    pub request_id: &'a RequestId,
140    /// Optional progress token for reporting progress.
141    pub progress_token: Option<&'a ProgressToken>,
142    /// Client capabilities negotiated during initialization.
143    pub client_caps: &'a ClientCapabilities,
144    /// Server capabilities advertised during initialization.
145    pub server_caps: &'a ServerCapabilities,
146    /// Peer for sending notifications.
147    peer: &'a dyn Peer,
148    /// Cancellation token for this request.
149    cancel: CancellationToken,
150}
151
152impl<'a> Context<'a> {
153    /// Create a new context with all required references.
154    #[must_use]
155    pub fn new(
156        request_id: &'a RequestId,
157        progress_token: Option<&'a ProgressToken>,
158        client_caps: &'a ClientCapabilities,
159        server_caps: &'a ServerCapabilities,
160        peer: &'a dyn Peer,
161    ) -> Self {
162        Self {
163            request_id,
164            progress_token,
165            client_caps,
166            server_caps,
167            peer,
168            cancel: CancellationToken::new(),
169        }
170    }
171
172    /// Create a new context with a custom cancellation token.
173    #[must_use]
174    pub fn with_cancellation(
175        request_id: &'a RequestId,
176        progress_token: Option<&'a ProgressToken>,
177        client_caps: &'a ClientCapabilities,
178        server_caps: &'a ServerCapabilities,
179        peer: &'a dyn Peer,
180        cancel: CancellationToken,
181    ) -> Self {
182        Self {
183            request_id,
184            progress_token,
185            client_caps,
186            server_caps,
187            peer,
188            cancel,
189        }
190    }
191
192    /// Check if the request has been cancelled.
193    #[must_use]
194    pub fn is_cancelled(&self) -> bool {
195        self.cancel.is_cancelled()
196    }
197
198    /// Get a future that completes when the request is cancelled.
199    pub fn cancelled(&self) -> impl Future<Output = ()> + '_ {
200        self.cancel.cancelled()
201    }
202
203    /// Get the cancellation token for this context.
204    #[must_use]
205    pub fn cancellation_token(&self) -> &CancellationToken {
206        &self.cancel
207    }
208
209    /// Send a notification to the client.
210    ///
211    /// # Arguments
212    ///
213    /// * `method` - The notification method name
214    /// * `params` - Optional notification parameters
215    ///
216    /// # Errors
217    ///
218    /// Returns an error if the notification could not be sent.
219    pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<(), McpError> {
220        let notification = if let Some(p) = params {
221            Notification::with_params(method.to_string(), p)
222        } else {
223            Notification::new(method.to_string())
224        };
225        self.peer.notify(notification).await
226    }
227
228    /// Report progress for this operation.
229    ///
230    /// This sends a progress notification to the client if a progress token
231    /// was provided with the request.
232    ///
233    /// # Arguments
234    ///
235    /// * `current` - Current progress value
236    /// * `total` - Total progress value (if known)
237    /// * `message` - Optional progress message
238    ///
239    /// # Errors
240    ///
241    /// Returns an error if the notification could not be sent.
242    pub async fn progress(
243        &self,
244        current: u64,
245        total: Option<u64>,
246        message: Option<&str>,
247    ) -> Result<(), McpError> {
248        let Some(token) = self.progress_token else {
249            // No progress token, silently succeed
250            return Ok(());
251        };
252
253        let params = serde_json::json!({
254            "progressToken": token,
255            "progress": current,
256            "total": total,
257            "message": message,
258        });
259
260        self.notify("notifications/progress", Some(params)).await
261    }
262}
263
264impl std::fmt::Debug for Context<'_> {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        f.debug_struct("Context")
267            .field("request_id", &self.request_id)
268            .field("progress_token", &self.progress_token)
269            .field("client_caps", &self.client_caps)
270            .field("server_caps", &self.server_caps)
271            .field("is_cancelled", &self.is_cancelled())
272            .finish()
273    }
274}
275
276/// A no-op peer implementation for testing.
277///
278/// This peer silently accepts all notifications without sending them anywhere.
279#[derive(Debug, Clone, Copy)]
280pub struct NoOpPeer;
281
282impl Peer for NoOpPeer {
283    fn notify(&self, _notification: Notification) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
284        Box::pin(async { Ok(()) })
285    }
286}
287
288/// Owned data for creating contexts.
289///
290/// This struct holds owned copies of all the data needed to create a Context.
291/// It's useful when you need to create contexts from owned data.
292pub struct ContextData {
293    /// The request ID.
294    pub request_id: RequestId,
295    /// Optional progress token.
296    pub progress_token: Option<ProgressToken>,
297    /// Client capabilities.
298    pub client_caps: ClientCapabilities,
299    /// Server capabilities.
300    pub server_caps: ServerCapabilities,
301}
302
303impl ContextData {
304    /// Create a new context data struct.
305    #[must_use]
306    pub fn new(
307        request_id: RequestId,
308        client_caps: ClientCapabilities,
309        server_caps: ServerCapabilities,
310    ) -> Self {
311        Self {
312            request_id,
313            progress_token: None,
314            client_caps,
315            server_caps,
316        }
317    }
318
319    /// Set the progress token.
320    #[must_use]
321    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
322        self.progress_token = Some(token);
323        self
324    }
325
326    /// Create a context from this data with the given peer.
327    #[must_use]
328    pub fn to_context<'a>(&'a self, peer: &'a dyn Peer) -> Context<'a> {
329        Context::new(
330            &self.request_id,
331            self.progress_token.as_ref(),
332            &self.client_caps,
333            &self.server_caps,
334            peer,
335        )
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_cancellation_token() {
345        let token = CancellationToken::new();
346        assert!(!token.is_cancelled());
347        token.cancel();
348        assert!(token.is_cancelled());
349    }
350
351    #[test]
352    fn test_context_creation() {
353        let request_id = RequestId::Number(1);
354        let client_caps = ClientCapabilities::default();
355        let server_caps = ServerCapabilities::default();
356        let peer = NoOpPeer;
357
358        let ctx = Context::new(
359            &request_id,
360            None,
361            &client_caps,
362            &server_caps,
363            &peer,
364        );
365
366        assert!(!ctx.is_cancelled());
367        assert!(ctx.progress_token.is_none());
368    }
369
370    #[test]
371    fn test_context_with_progress_token() {
372        let request_id = RequestId::Number(1);
373        let progress_token = ProgressToken::String("token".to_string());
374        let client_caps = ClientCapabilities::default();
375        let server_caps = ServerCapabilities::default();
376        let peer = NoOpPeer;
377
378        let ctx = Context::new(
379            &request_id,
380            Some(&progress_token),
381            &client_caps,
382            &server_caps,
383            &peer,
384        );
385
386        assert!(ctx.progress_token.is_some());
387    }
388
389    #[test]
390    fn test_context_data() {
391        let data = ContextData::new(
392            RequestId::Number(42),
393            ClientCapabilities::default(),
394            ServerCapabilities::default(),
395        )
396        .with_progress_token(ProgressToken::String("test".to_string()));
397
398        let peer = NoOpPeer;
399        let ctx = data.to_context(&peer);
400
401        assert!(ctx.progress_token.is_some());
402    }
403}