mcpkit_server/
context.rs

1//! Request context for MCP handlers.
2//!
3//! The context provides access to the current request state and allows
4//! handlers to interact with the connection (sending notifications,
5//! progress updates, etc.).
6//!
7//! # Key Features
8//!
9//! - **Borrowing-friendly**: Uses lifetime references, NO `'static` requirement
10//! - **Progress reporting**: Send progress updates for long-running operations
11//! - **Cancellation**: Check if the request has been cancelled
12//! - **Notifications**: Send notifications back to the client via Peer trait
13//!
14//! # Example
15//!
16//! ```rust
17//! use mcpkit_server::{Context, NoOpPeer, ContextData};
18//! use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
19//! use mcpkit_core::protocol::RequestId;
20//! use mcpkit_core::protocol_version::ProtocolVersion;
21//!
22//! // Create test context data
23//! let data = ContextData::new(
24//!     RequestId::Number(1),
25//!     ClientCapabilities::default(),
26//!     ServerCapabilities::default(),
27//!     ProtocolVersion::LATEST,
28//! );
29//! let peer = NoOpPeer;
30//!
31//! // Create a context from the data
32//! let ctx = Context::new(
33//!     &data.request_id,
34//!     data.progress_token.as_ref(),
35//!     &data.client_caps,
36//!     &data.server_caps,
37//!     data.protocol_version,
38//!     &peer,
39//! );
40//!
41//! // Check for cancellation and protocol version
42//! assert!(!ctx.is_cancelled());
43//! assert!(ctx.protocol_version.supports_tasks());
44//! ```
45
46use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
47use mcpkit_core::error::McpError;
48use mcpkit_core::protocol::{Notification, ProgressToken, RequestId};
49use mcpkit_core::protocol_version::ProtocolVersion;
50use std::future::Future;
51use std::pin::Pin;
52use std::sync::atomic::{AtomicBool, Ordering};
53use std::sync::Arc;
54
55/// Trait for sending messages to the peer (client or server).
56///
57/// This trait abstracts over the transport layer, allowing the context
58/// to send notifications without knowing the underlying transport.
59pub trait Peer: Send + Sync {
60    /// Send a notification to the peer.
61    fn notify(
62        &self,
63        notification: Notification,
64    ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>>;
65}
66
67/// A cancellation token for tracking request cancellation.
68///
69/// This is a simple wrapper around an atomic boolean that can be
70/// shared across threads and checked for cancellation.
71#[derive(Debug, Clone)]
72pub struct CancellationToken {
73    cancelled: Arc<AtomicBool>,
74}
75
76impl CancellationToken {
77    /// Create a new cancellation token.
78    #[must_use]
79    pub fn new() -> Self {
80        Self {
81            cancelled: Arc::new(AtomicBool::new(false)),
82        }
83    }
84
85    /// Check if cancellation has been requested.
86    #[must_use]
87    pub fn is_cancelled(&self) -> bool {
88        self.cancelled.load(Ordering::SeqCst)
89    }
90
91    /// Request cancellation.
92    pub fn cancel(&self) {
93        self.cancelled.store(true, Ordering::SeqCst);
94    }
95
96    /// Wait for cancellation.
97    ///
98    /// Returns a future that completes when cancellation is requested.
99    ///
100    /// Note: In a production implementation, this would integrate with the
101    /// runtime's notification system. This simple implementation polls
102    /// the atomic flag.
103    #[must_use]
104    pub fn cancelled(&self) -> CancelledFuture {
105        CancelledFuture {
106            cancelled: self.cancelled.clone(),
107        }
108    }
109}
110
111impl Default for CancellationToken {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117/// A future that completes when cancellation is requested.
118pub struct CancelledFuture {
119    cancelled: Arc<AtomicBool>,
120}
121
122impl Future for CancelledFuture {
123    type Output = ();
124
125    fn poll(
126        self: Pin<&mut Self>,
127        cx: &mut std::task::Context<'_>,
128    ) -> std::task::Poll<Self::Output> {
129        if self.cancelled.load(Ordering::SeqCst) {
130            std::task::Poll::Ready(())
131        } else {
132            // Wake up later to check again
133            // In production, this would register with a proper notification system
134            cx.waker().wake_by_ref();
135            std::task::Poll::Pending
136        }
137    }
138}
139
140/// Request context passed to handler methods.
141///
142/// The context uses lifetime references to avoid `'static` requirements
143/// and Arc overhead. This enables:
144/// - Single-threaded async without Arc overhead
145/// - `!Send` types in handlers (important for some runtimes)
146/// - Users who need spawning can wrap in Arc themselves
147///
148/// Per the plan: "Request context - passed by reference, NO 'static requirement"
149pub struct Context<'a> {
150    /// The request ID for this operation.
151    pub request_id: &'a RequestId,
152    /// Optional progress token for reporting progress.
153    pub progress_token: Option<&'a ProgressToken>,
154    /// Client capabilities negotiated during initialization.
155    pub client_caps: &'a ClientCapabilities,
156    /// Server capabilities advertised during initialization.
157    pub server_caps: &'a ServerCapabilities,
158    /// The negotiated protocol version.
159    ///
160    /// Use this to check version-specific feature availability via
161    /// methods like `supports_tasks()`, `supports_elicitation()`, etc.
162    pub protocol_version: ProtocolVersion,
163    /// Peer for sending notifications.
164    peer: &'a dyn Peer,
165    /// Cancellation token for this request.
166    cancel: CancellationToken,
167}
168
169impl<'a> Context<'a> {
170    /// Create a new context with all required references.
171    #[must_use]
172    pub fn new(
173        request_id: &'a RequestId,
174        progress_token: Option<&'a ProgressToken>,
175        client_caps: &'a ClientCapabilities,
176        server_caps: &'a ServerCapabilities,
177        protocol_version: ProtocolVersion,
178        peer: &'a dyn Peer,
179    ) -> Self {
180        Self {
181            request_id,
182            progress_token,
183            client_caps,
184            server_caps,
185            protocol_version,
186            peer,
187            cancel: CancellationToken::new(),
188        }
189    }
190
191    /// Create a new context with a custom cancellation token.
192    #[must_use]
193    pub fn with_cancellation(
194        request_id: &'a RequestId,
195        progress_token: Option<&'a ProgressToken>,
196        client_caps: &'a ClientCapabilities,
197        server_caps: &'a ServerCapabilities,
198        protocol_version: ProtocolVersion,
199        peer: &'a dyn Peer,
200        cancel: CancellationToken,
201    ) -> Self {
202        Self {
203            request_id,
204            progress_token,
205            client_caps,
206            server_caps,
207            protocol_version,
208            peer,
209            cancel,
210        }
211    }
212
213    /// Check if the request has been cancelled.
214    #[must_use]
215    pub fn is_cancelled(&self) -> bool {
216        self.cancel.is_cancelled()
217    }
218
219    /// Get a future that completes when the request is cancelled.
220    pub fn cancelled(&self) -> impl Future<Output = ()> + '_ {
221        self.cancel.cancelled()
222    }
223
224    /// Get the cancellation token for this context.
225    #[must_use]
226    pub const fn cancellation_token(&self) -> &CancellationToken {
227        &self.cancel
228    }
229
230    /// Send a notification to the client.
231    ///
232    /// # Arguments
233    ///
234    /// * `method` - The notification method name
235    /// * `params` - Optional notification parameters
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if the notification could not be sent.
240    pub async fn notify(
241        &self,
242        method: &str,
243        params: Option<serde_json::Value>,
244    ) -> Result<(), McpError> {
245        let notification = if let Some(p) = params {
246            Notification::with_params(method.to_string(), p)
247        } else {
248            Notification::new(method.to_string())
249        };
250        self.peer.notify(notification).await
251    }
252
253    /// Report progress for this operation.
254    ///
255    /// This sends a progress notification to the client if a progress token
256    /// was provided with the request.
257    ///
258    /// # Arguments
259    ///
260    /// * `current` - Current progress value
261    /// * `total` - Total progress value (if known)
262    /// * `message` - Optional progress message
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if the notification could not be sent.
267    pub async fn progress(
268        &self,
269        current: u64,
270        total: Option<u64>,
271        message: Option<&str>,
272    ) -> Result<(), McpError> {
273        let Some(token) = self.progress_token else {
274            // No progress token, silently succeed
275            return Ok(());
276        };
277
278        let params = serde_json::json!({
279            "progressToken": token,
280            "progress": current,
281            "total": total,
282            "message": message,
283        });
284
285        self.notify("notifications/progress", Some(params)).await
286    }
287}
288
289impl std::fmt::Debug for Context<'_> {
290    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291        f.debug_struct("Context")
292            .field("request_id", &self.request_id)
293            .field("progress_token", &self.progress_token)
294            .field("client_caps", &self.client_caps)
295            .field("server_caps", &self.server_caps)
296            .field("protocol_version", &self.protocol_version)
297            .field("is_cancelled", &self.is_cancelled())
298            .finish()
299    }
300}
301
302/// A no-op peer implementation for testing.
303///
304/// This peer silently accepts all notifications without sending them anywhere.
305#[derive(Debug, Clone, Copy)]
306pub struct NoOpPeer;
307
308impl Peer for NoOpPeer {
309    fn notify(
310        &self,
311        _notification: Notification,
312    ) -> Pin<Box<dyn Future<Output = Result<(), McpError>> + Send + '_>> {
313        Box::pin(async { Ok(()) })
314    }
315}
316
317/// Owned data for creating contexts.
318///
319/// This struct holds owned copies of all the data needed to create a Context.
320/// It's useful when you need to create contexts from owned data.
321pub struct ContextData {
322    /// The request ID.
323    pub request_id: RequestId,
324    /// Optional progress token.
325    pub progress_token: Option<ProgressToken>,
326    /// Client capabilities.
327    pub client_caps: ClientCapabilities,
328    /// Server capabilities.
329    pub server_caps: ServerCapabilities,
330    /// The negotiated protocol version.
331    pub protocol_version: ProtocolVersion,
332}
333
334impl ContextData {
335    /// Create a new context data struct.
336    #[must_use]
337    pub const fn new(
338        request_id: RequestId,
339        client_caps: ClientCapabilities,
340        server_caps: ServerCapabilities,
341        protocol_version: ProtocolVersion,
342    ) -> Self {
343        Self {
344            request_id,
345            progress_token: None,
346            client_caps,
347            server_caps,
348            protocol_version,
349        }
350    }
351
352    /// Set the progress token.
353    #[must_use]
354    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
355        self.progress_token = Some(token);
356        self
357    }
358
359    /// Create a context from this data with the given peer.
360    #[must_use]
361    pub fn to_context<'a>(&'a self, peer: &'a dyn Peer) -> Context<'a> {
362        Context::new(
363            &self.request_id,
364            self.progress_token.as_ref(),
365            &self.client_caps,
366            &self.server_caps,
367            self.protocol_version,
368            peer,
369        )
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_cancellation_token() {
379        let token = CancellationToken::new();
380        assert!(!token.is_cancelled());
381        token.cancel();
382        assert!(token.is_cancelled());
383    }
384
385    #[test]
386    fn test_context_creation() {
387        let request_id = RequestId::Number(1);
388        let client_caps = ClientCapabilities::default();
389        let server_caps = ServerCapabilities::default();
390        let peer = NoOpPeer;
391
392        let ctx = Context::new(
393            &request_id,
394            None,
395            &client_caps,
396            &server_caps,
397            ProtocolVersion::LATEST,
398            &peer,
399        );
400
401        assert!(!ctx.is_cancelled());
402        assert!(ctx.progress_token.is_none());
403        assert_eq!(ctx.protocol_version, ProtocolVersion::LATEST);
404    }
405
406    #[test]
407    fn test_context_with_progress_token() {
408        let request_id = RequestId::Number(1);
409        let progress_token = ProgressToken::String("token".to_string());
410        let client_caps = ClientCapabilities::default();
411        let server_caps = ServerCapabilities::default();
412        let peer = NoOpPeer;
413
414        let ctx = Context::new(
415            &request_id,
416            Some(&progress_token),
417            &client_caps,
418            &server_caps,
419            ProtocolVersion::V2025_03_26,
420            &peer,
421        );
422
423        assert!(ctx.progress_token.is_some());
424        assert_eq!(ctx.protocol_version, ProtocolVersion::V2025_03_26);
425    }
426
427    #[test]
428    fn test_context_data() {
429        let data = ContextData::new(
430            RequestId::Number(42),
431            ClientCapabilities::default(),
432            ServerCapabilities::default(),
433            ProtocolVersion::V2025_06_18,
434        )
435        .with_progress_token(ProgressToken::String("test".to_string()));
436
437        let peer = NoOpPeer;
438        let ctx = data.to_context(&peer);
439
440        assert!(ctx.progress_token.is_some());
441        assert_eq!(ctx.protocol_version, ProtocolVersion::V2025_06_18);
442        // Test feature detection via protocol version
443        assert!(ctx.protocol_version.supports_elicitation());
444        assert!(!ctx.protocol_version.supports_tasks()); // Tasks require 2025-11-25
445    }
446}