Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: ElicitMode::Form,
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81    ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84/// A notification to be sent to the client
85#[derive(Debug, Clone)]
86pub enum ServerNotification {
87    /// Progress update for a request
88    Progress(ProgressParams),
89    /// Log message notification
90    LogMessage(LoggingMessageParams),
91    /// A subscribed resource has been updated
92    ResourceUpdated {
93        /// The URI of the updated resource
94        uri: String,
95    },
96    /// The list of available resources has changed
97    ResourcesListChanged,
98}
99
100/// Sender for server notifications
101pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103/// Receiver for server notifications
104pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106/// Create a new notification channel
107pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108    mpsc::channel(buffer)
109}
110
111// =============================================================================
112// Client Requests (Server -> Client)
113// =============================================================================
114
115/// Trait for sending requests from server to client
116///
117/// This enables bidirectional communication where the server can request
118/// actions from the client, such as sampling (LLM requests) and elicitation
119/// (user input requests).
120#[async_trait]
121pub trait ClientRequester: Send + Sync {
122    /// Send a sampling request to the client
123    ///
124    /// Returns the LLM completion result from the client.
125    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127    /// Send an elicitation request to the client
128    ///
129    /// This requests user input from the client. The request can be either
130    /// form-based (structured input) or URL-based (redirect to external URL).
131    ///
132    /// Returns the elicitation result with the user's action and any submitted data.
133    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136/// A clonable handle to a client requester
137pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139/// Outgoing request to be sent to the client
140#[derive(Debug)]
141pub struct OutgoingRequest {
142    /// The JSON-RPC request ID
143    pub id: RequestId,
144    /// The method name
145    pub method: String,
146    /// The request parameters as JSON
147    pub params: serde_json::Value,
148    /// Channel to send the response back
149    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152/// Sender for outgoing requests to the client
153pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155/// Receiver for outgoing requests (used by transport)
156pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158/// Create a new outgoing request channel
159pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160    mpsc::channel(buffer)
161}
162
163/// A client requester implementation that sends requests through a channel
164#[derive(Clone)]
165pub struct ChannelClientRequester {
166    request_tx: OutgoingRequestSender,
167    next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171    /// Create a new channel-based client requester
172    pub fn new(request_tx: OutgoingRequestSender) -> Self {
173        Self {
174            request_tx,
175            next_id: Arc::new(AtomicI64::new(1)),
176        }
177    }
178
179    fn next_request_id(&self) -> RequestId {
180        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181        RequestId::Number(id)
182    }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188        let id = self.next_request_id();
189        let params_json = serde_json::to_value(&params)
190            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194        let request = OutgoingRequest {
195            id: id.clone(),
196            method: "sampling/createMessage".to_string(),
197            params: params_json,
198            response_tx,
199        };
200
201        self.request_tx
202            .send(request)
203            .await
204            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206        let response = response_rx.await.map_err(|_| {
207            Error::Internal("Failed to receive response: channel closed".to_string())
208        })??;
209
210        serde_json::from_value(response)
211            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212    }
213
214    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215        let id = self.next_request_id();
216        let params_json = serde_json::to_value(&params)
217            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221        let request = OutgoingRequest {
222            id: id.clone(),
223            method: "elicitation/create".to_string(),
224            params: params_json,
225            response_tx,
226        };
227
228        self.request_tx
229            .send(request)
230            .await
231            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233        let response = response_rx.await.map_err(|_| {
234            Error::Internal("Failed to receive response: channel closed".to_string())
235        })??;
236
237        serde_json::from_value(response)
238            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239    }
240}
241
242/// Context for a request, providing progress, cancellation, and client request support
243#[derive(Clone)]
244pub struct RequestContext {
245    /// The request ID
246    request_id: RequestId,
247    /// Progress token (if provided by client)
248    progress_token: Option<ProgressToken>,
249    /// Cancellation flag
250    cancelled: Arc<AtomicBool>,
251    /// Channel for sending notifications
252    notification_tx: Option<NotificationSender>,
253    /// Handle for sending requests to the client (for sampling, etc.)
254    client_requester: Option<ClientRequesterHandle>,
255    /// Extensions for passing data from router/middleware to handlers
256    extensions: Arc<Extensions>,
257}
258
259/// Type-erased extensions map for passing data to handlers.
260///
261/// Extensions allow router-level state and middleware-injected data to flow
262/// to tool handlers via the `Extension<T>` extractor.
263#[derive(Clone, Default)]
264pub struct Extensions {
265    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
266}
267
268impl Extensions {
269    /// Create an empty extensions map.
270    pub fn new() -> Self {
271        Self::default()
272    }
273
274    /// Insert a value into the extensions map.
275    ///
276    /// If a value of the same type already exists, it is replaced.
277    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
278        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
279    }
280
281    /// Get a reference to a value in the extensions map.
282    ///
283    /// Returns `None` if no value of the given type has been inserted.
284    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
285        self.map
286            .get(&std::any::TypeId::of::<T>())
287            .and_then(|val| val.downcast_ref::<T>())
288    }
289
290    /// Check if the extensions map contains a value of the given type.
291    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
292        self.map.contains_key(&std::any::TypeId::of::<T>())
293    }
294
295    /// Merge another extensions map into this one.
296    ///
297    /// Values from `other` will overwrite existing values of the same type.
298    pub fn merge(&mut self, other: &Extensions) {
299        for (k, v) in &other.map {
300            self.map.insert(*k, v.clone());
301        }
302    }
303}
304
305impl std::fmt::Debug for Extensions {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        f.debug_struct("Extensions")
308            .field("len", &self.map.len())
309            .finish()
310    }
311}
312
313impl std::fmt::Debug for RequestContext {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("RequestContext")
316            .field("request_id", &self.request_id)
317            .field("progress_token", &self.progress_token)
318            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
319            .finish()
320    }
321}
322
323impl RequestContext {
324    /// Create a new request context
325    pub fn new(request_id: RequestId) -> Self {
326        Self {
327            request_id,
328            progress_token: None,
329            cancelled: Arc::new(AtomicBool::new(false)),
330            notification_tx: None,
331            client_requester: None,
332            extensions: Arc::new(Extensions::new()),
333        }
334    }
335
336    /// Set the progress token
337    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
338        self.progress_token = Some(token);
339        self
340    }
341
342    /// Set the notification sender
343    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
344        self.notification_tx = Some(tx);
345        self
346    }
347
348    /// Set the client requester for server-to-client requests
349    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
350        self.client_requester = Some(requester);
351        self
352    }
353
354    /// Set the extensions for this request context.
355    ///
356    /// Extensions allow router-level state and middleware data to flow to handlers.
357    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
358        self.extensions = extensions;
359        self
360    }
361
362    /// Get a reference to a value from the extensions map.
363    ///
364    /// Returns `None` if no value of the given type has been inserted.
365    ///
366    /// # Example
367    ///
368    /// ```rust,ignore
369    /// #[derive(Clone)]
370    /// struct CurrentUser { id: String }
371    ///
372    /// // In a handler:
373    /// if let Some(user) = ctx.extension::<CurrentUser>() {
374    ///     println!("User: {}", user.id);
375    /// }
376    /// ```
377    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
378        self.extensions.get::<T>()
379    }
380
381    /// Get a mutable reference to the extensions.
382    ///
383    /// This allows middleware to insert data that handlers can access via
384    /// the `Extension<T>` extractor.
385    pub fn extensions_mut(&mut self) -> &mut Extensions {
386        Arc::make_mut(&mut self.extensions)
387    }
388
389    /// Get a reference to the extensions.
390    pub fn extensions(&self) -> &Extensions {
391        &self.extensions
392    }
393
394    /// Get the request ID
395    pub fn request_id(&self) -> &RequestId {
396        &self.request_id
397    }
398
399    /// Get the progress token (if any)
400    pub fn progress_token(&self) -> Option<&ProgressToken> {
401        self.progress_token.as_ref()
402    }
403
404    /// Check if the request has been cancelled
405    pub fn is_cancelled(&self) -> bool {
406        self.cancelled.load(Ordering::Relaxed)
407    }
408
409    /// Mark the request as cancelled
410    pub fn cancel(&self) {
411        self.cancelled.store(true, Ordering::Relaxed);
412    }
413
414    /// Get a cancellation token that can be shared
415    pub fn cancellation_token(&self) -> CancellationToken {
416        CancellationToken {
417            cancelled: self.cancelled.clone(),
418        }
419    }
420
421    /// Report progress to the client
422    ///
423    /// This is a no-op if no progress token was provided or no notification sender is configured.
424    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
425        let Some(token) = &self.progress_token else {
426            return;
427        };
428        let Some(tx) = &self.notification_tx else {
429            return;
430        };
431
432        let params = ProgressParams {
433            progress_token: token.clone(),
434            progress,
435            total,
436            message: message.map(|s| s.to_string()),
437        };
438
439        // Best effort - don't block if channel is full
440        let _ = tx.try_send(ServerNotification::Progress(params));
441    }
442
443    /// Report progress synchronously (non-async version)
444    ///
445    /// This is a no-op if no progress token was provided or no notification sender is configured.
446    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
447        let Some(token) = &self.progress_token else {
448            return;
449        };
450        let Some(tx) = &self.notification_tx else {
451            return;
452        };
453
454        let params = ProgressParams {
455            progress_token: token.clone(),
456            progress,
457            total,
458            message: message.map(|s| s.to_string()),
459        };
460
461        let _ = tx.try_send(ServerNotification::Progress(params));
462    }
463
464    /// Send a log message notification to the client
465    ///
466    /// This is a no-op if no notification sender is configured.
467    ///
468    /// # Example
469    ///
470    /// ```rust,ignore
471    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
472    ///
473    /// async fn my_tool(ctx: RequestContext) {
474    ///     ctx.send_log(
475    ///         LoggingMessageParams::new(LogLevel::Info)
476    ///             .with_logger("my-tool")
477    ///             .with_data(serde_json::json!("Processing..."))
478    ///     );
479    /// }
480    /// ```
481    pub fn send_log(&self, params: LoggingMessageParams) {
482        let Some(tx) = &self.notification_tx else {
483            return;
484        };
485
486        let _ = tx.try_send(ServerNotification::LogMessage(params));
487    }
488
489    /// Check if sampling is available
490    ///
491    /// Returns true if a client requester is configured and the transport
492    /// supports bidirectional communication.
493    pub fn can_sample(&self) -> bool {
494        self.client_requester.is_some()
495    }
496
497    /// Request an LLM completion from the client
498    ///
499    /// This sends a `sampling/createMessage` request to the client and waits
500    /// for the response. The client is expected to forward this to an LLM
501    /// and return the result.
502    ///
503    /// Returns an error if sampling is not available (no client requester configured).
504    ///
505    /// # Example
506    ///
507    /// ```rust,ignore
508    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
509    ///
510    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
511    ///     let params = CreateMessageParams::new(
512    ///         vec![SamplingMessage::user("Summarize: ...")],
513    ///         500,
514    ///     );
515    ///
516    ///     let result = ctx.sample(params).await?;
517    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
518    /// }
519    /// ```
520    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
521        let requester = self.client_requester.as_ref().ok_or_else(|| {
522            Error::Internal("Sampling not available: no client requester configured".to_string())
523        })?;
524
525        requester.sample(params).await
526    }
527
528    /// Check if elicitation is available
529    ///
530    /// Returns true if a client requester is configured and the transport
531    /// supports bidirectional communication. Note that this only checks if
532    /// the mechanism is available, not whether the client supports elicitation.
533    pub fn can_elicit(&self) -> bool {
534        self.client_requester.is_some()
535    }
536
537    /// Request user input via a form from the client
538    ///
539    /// This sends an `elicitation/create` request to the client with a form schema.
540    /// The client renders the form to the user and returns their response.
541    ///
542    /// Returns an error if elicitation is not available (no client requester configured).
543    ///
544    /// # Example
545    ///
546    /// ```rust,ignore
547    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
548    ///
549    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
550    ///     let params = ElicitFormParams {
551    ///         mode: ElicitMode::Form,
552    ///         message: "Please enter your details".to_string(),
553    ///         requested_schema: ElicitFormSchema::new()
554    ///             .string_field("name", Some("Your name"), true),
555    ///         meta: None,
556    ///     };
557    ///
558    ///     let result = ctx.elicit_form(params).await?;
559    ///     match result.action {
560    ///         ElicitAction::Accept => {
561    ///             // Use result.content
562    ///             Ok(CallToolResult::text("Got your input!"))
563    ///         }
564    ///         _ => Ok(CallToolResult::text("User declined"))
565    ///     }
566    /// }
567    /// ```
568    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
569        let requester = self.client_requester.as_ref().ok_or_else(|| {
570            Error::Internal("Elicitation not available: no client requester configured".to_string())
571        })?;
572
573        requester.elicit(ElicitRequestParams::Form(params)).await
574    }
575
576    /// Request user input via URL redirect from the client
577    ///
578    /// This sends an `elicitation/create` request to the client with a URL.
579    /// The client directs the user to the URL for out-of-band input collection.
580    /// The server receives the result via a callback notification.
581    ///
582    /// Returns an error if elicitation is not available (no client requester configured).
583    ///
584    /// # Example
585    ///
586    /// ```rust,ignore
587    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
588    ///
589    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
590    ///     let params = ElicitUrlParams {
591    ///         mode: ElicitMode::Url,
592    ///         elicitation_id: "unique-id-123".to_string(),
593    ///         message: "Please authorize via the link".to_string(),
594    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
595    ///         meta: None,
596    ///     };
597    ///
598    ///     let result = ctx.elicit_url(params).await?;
599    ///     match result.action {
600    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
601    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
602    ///     }
603    /// }
604    /// ```
605    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
606        let requester = self.client_requester.as_ref().ok_or_else(|| {
607            Error::Internal("Elicitation not available: no client requester configured".to_string())
608        })?;
609
610        requester.elicit(ElicitRequestParams::Url(params)).await
611    }
612
613    /// Request simple confirmation from the user.
614    ///
615    /// This is a convenience method for simple yes/no confirmation dialogs.
616    /// It creates an elicitation form with a single boolean "confirm" field
617    /// and returns `true` if the user accepts, `false` otherwise.
618    ///
619    /// Returns an error if elicitation is not available (no client requester configured).
620    ///
621    /// # Example
622    ///
623    /// ```rust,ignore
624    /// use tower_mcp::{RequestContext, CallToolResult};
625    ///
626    /// async fn delete_item(ctx: RequestContext) -> Result<CallToolResult> {
627    ///     let confirmed = ctx.confirm("Are you sure you want to delete this item?").await?;
628    ///     if confirmed {
629    ///         // Perform deletion
630    ///         Ok(CallToolResult::text("Item deleted"))
631    ///     } else {
632    ///         Ok(CallToolResult::text("Deletion cancelled"))
633    ///     }
634    /// }
635    /// ```
636    pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
637        use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
638
639        let params = ElicitFormParams {
640            mode: ElicitMode::Form,
641            message: message.into(),
642            requested_schema: ElicitFormSchema::new().boolean_field_with_default(
643                "confirm",
644                Some("Confirm this action"),
645                true,
646                false,
647            ),
648            meta: None,
649        };
650
651        let result = self.elicit_form(params).await?;
652        Ok(result.action == ElicitAction::Accept)
653    }
654}
655
656/// A token that can be used to check for cancellation
657#[derive(Clone, Debug)]
658pub struct CancellationToken {
659    cancelled: Arc<AtomicBool>,
660}
661
662impl CancellationToken {
663    /// Check if cancellation has been requested
664    pub fn is_cancelled(&self) -> bool {
665        self.cancelled.load(Ordering::Relaxed)
666    }
667
668    /// Request cancellation
669    pub fn cancel(&self) {
670        self.cancelled.store(true, Ordering::Relaxed);
671    }
672}
673
674/// Builder for creating request contexts
675#[derive(Default)]
676pub struct RequestContextBuilder {
677    request_id: Option<RequestId>,
678    progress_token: Option<ProgressToken>,
679    notification_tx: Option<NotificationSender>,
680    client_requester: Option<ClientRequesterHandle>,
681}
682
683impl RequestContextBuilder {
684    /// Create a new builder
685    pub fn new() -> Self {
686        Self::default()
687    }
688
689    /// Set the request ID
690    pub fn request_id(mut self, id: RequestId) -> Self {
691        self.request_id = Some(id);
692        self
693    }
694
695    /// Set the progress token
696    pub fn progress_token(mut self, token: ProgressToken) -> Self {
697        self.progress_token = Some(token);
698        self
699    }
700
701    /// Set the notification sender
702    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
703        self.notification_tx = Some(tx);
704        self
705    }
706
707    /// Set the client requester for server-to-client requests
708    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
709        self.client_requester = Some(requester);
710        self
711    }
712
713    /// Build the request context
714    ///
715    /// Panics if request_id is not set.
716    pub fn build(self) -> RequestContext {
717        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
718        if let Some(token) = self.progress_token {
719            ctx = ctx.with_progress_token(token);
720        }
721        if let Some(tx) = self.notification_tx {
722            ctx = ctx.with_notification_sender(tx);
723        }
724        if let Some(requester) = self.client_requester {
725            ctx = ctx.with_client_requester(requester);
726        }
727        ctx
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734
735    #[test]
736    fn test_cancellation() {
737        let ctx = RequestContext::new(RequestId::Number(1));
738        assert!(!ctx.is_cancelled());
739
740        let token = ctx.cancellation_token();
741        assert!(!token.is_cancelled());
742
743        ctx.cancel();
744        assert!(ctx.is_cancelled());
745        assert!(token.is_cancelled());
746    }
747
748    #[tokio::test]
749    async fn test_progress_reporting() {
750        let (tx, mut rx) = notification_channel(10);
751
752        let ctx = RequestContext::new(RequestId::Number(1))
753            .with_progress_token(ProgressToken::Number(42))
754            .with_notification_sender(tx);
755
756        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
757            .await;
758
759        let notification = rx.recv().await.unwrap();
760        match notification {
761            ServerNotification::Progress(params) => {
762                assert_eq!(params.progress, 50.0);
763                assert_eq!(params.total, Some(100.0));
764                assert_eq!(params.message.as_deref(), Some("Halfway"));
765            }
766            _ => panic!("Expected Progress notification"),
767        }
768    }
769
770    #[tokio::test]
771    async fn test_progress_no_token() {
772        let (tx, mut rx) = notification_channel(10);
773
774        // No progress token - should be a no-op
775        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
776
777        ctx.report_progress(50.0, Some(100.0), None).await;
778
779        // Channel should be empty
780        assert!(rx.try_recv().is_err());
781    }
782
783    #[test]
784    fn test_builder() {
785        let (tx, _rx) = notification_channel(10);
786
787        let ctx = RequestContextBuilder::new()
788            .request_id(RequestId::String("req-1".to_string()))
789            .progress_token(ProgressToken::String("prog-1".to_string()))
790            .notification_sender(tx)
791            .build();
792
793        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
794        assert!(ctx.progress_token().is_some());
795    }
796
797    #[test]
798    fn test_can_sample_without_requester() {
799        let ctx = RequestContext::new(RequestId::Number(1));
800        assert!(!ctx.can_sample());
801    }
802
803    #[test]
804    fn test_can_sample_with_requester() {
805        let (request_tx, _rx) = outgoing_request_channel(10);
806        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
807
808        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
809        assert!(ctx.can_sample());
810    }
811
812    #[tokio::test]
813    async fn test_sample_without_requester_fails() {
814        use crate::protocol::{CreateMessageParams, SamplingMessage};
815
816        let ctx = RequestContext::new(RequestId::Number(1));
817        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
818
819        let result = ctx.sample(params).await;
820        assert!(result.is_err());
821        assert!(
822            result
823                .unwrap_err()
824                .to_string()
825                .contains("Sampling not available")
826        );
827    }
828
829    #[test]
830    fn test_builder_with_client_requester() {
831        let (request_tx, _rx) = outgoing_request_channel(10);
832        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
833
834        let ctx = RequestContextBuilder::new()
835            .request_id(RequestId::Number(1))
836            .client_requester(requester)
837            .build();
838
839        assert!(ctx.can_sample());
840    }
841
842    #[test]
843    fn test_can_elicit_without_requester() {
844        let ctx = RequestContext::new(RequestId::Number(1));
845        assert!(!ctx.can_elicit());
846    }
847
848    #[test]
849    fn test_can_elicit_with_requester() {
850        let (request_tx, _rx) = outgoing_request_channel(10);
851        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
852
853        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
854        assert!(ctx.can_elicit());
855    }
856
857    #[tokio::test]
858    async fn test_elicit_form_without_requester_fails() {
859        use crate::protocol::{ElicitFormSchema, ElicitMode};
860
861        let ctx = RequestContext::new(RequestId::Number(1));
862        let params = ElicitFormParams {
863            mode: ElicitMode::Form,
864            message: "Enter details".to_string(),
865            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
866            meta: None,
867        };
868
869        let result = ctx.elicit_form(params).await;
870        assert!(result.is_err());
871        assert!(
872            result
873                .unwrap_err()
874                .to_string()
875                .contains("Elicitation not available")
876        );
877    }
878
879    #[tokio::test]
880    async fn test_elicit_url_without_requester_fails() {
881        use crate::protocol::ElicitMode;
882
883        let ctx = RequestContext::new(RequestId::Number(1));
884        let params = ElicitUrlParams {
885            mode: ElicitMode::Url,
886            elicitation_id: "test-123".to_string(),
887            message: "Please authorize".to_string(),
888            url: "https://example.com/auth".to_string(),
889            meta: None,
890        };
891
892        let result = ctx.elicit_url(params).await;
893        assert!(result.is_err());
894        assert!(
895            result
896                .unwrap_err()
897                .to_string()
898                .contains("Elicitation not available")
899        );
900    }
901
902    #[tokio::test]
903    async fn test_confirm_without_requester_fails() {
904        let ctx = RequestContext::new(RequestId::Number(1));
905
906        let result = ctx.confirm("Are you sure?").await;
907        assert!(result.is_err());
908        assert!(
909            result
910                .unwrap_err()
911                .to_string()
912                .contains("Elicitation not available")
913        );
914    }
915}