Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: ElicitMode::Form,
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81    ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84/// A notification to be sent to the client
85#[derive(Debug, Clone)]
86pub enum ServerNotification {
87    /// Progress update for a request
88    Progress(ProgressParams),
89    /// Log message notification
90    LogMessage(LoggingMessageParams),
91    /// A subscribed resource has been updated
92    ResourceUpdated {
93        /// The URI of the updated resource
94        uri: String,
95    },
96    /// The list of available resources has changed
97    ResourcesListChanged,
98}
99
100/// Sender for server notifications
101pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103/// Receiver for server notifications
104pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106/// Create a new notification channel
107pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108    mpsc::channel(buffer)
109}
110
111// =============================================================================
112// Client Requests (Server -> Client)
113// =============================================================================
114
115/// Trait for sending requests from server to client
116///
117/// This enables bidirectional communication where the server can request
118/// actions from the client, such as sampling (LLM requests) and elicitation
119/// (user input requests).
120#[async_trait]
121pub trait ClientRequester: Send + Sync {
122    /// Send a sampling request to the client
123    ///
124    /// Returns the LLM completion result from the client.
125    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127    /// Send an elicitation request to the client
128    ///
129    /// This requests user input from the client. The request can be either
130    /// form-based (structured input) or URL-based (redirect to external URL).
131    ///
132    /// Returns the elicitation result with the user's action and any submitted data.
133    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136/// A clonable handle to a client requester
137pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139/// Outgoing request to be sent to the client
140#[derive(Debug)]
141pub struct OutgoingRequest {
142    /// The JSON-RPC request ID
143    pub id: RequestId,
144    /// The method name
145    pub method: String,
146    /// The request parameters as JSON
147    pub params: serde_json::Value,
148    /// Channel to send the response back
149    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152/// Sender for outgoing requests to the client
153pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155/// Receiver for outgoing requests (used by transport)
156pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158/// Create a new outgoing request channel
159pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160    mpsc::channel(buffer)
161}
162
163/// A client requester implementation that sends requests through a channel
164#[derive(Clone)]
165pub struct ChannelClientRequester {
166    request_tx: OutgoingRequestSender,
167    next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171    /// Create a new channel-based client requester
172    pub fn new(request_tx: OutgoingRequestSender) -> Self {
173        Self {
174            request_tx,
175            next_id: Arc::new(AtomicI64::new(1)),
176        }
177    }
178
179    fn next_request_id(&self) -> RequestId {
180        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181        RequestId::Number(id)
182    }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188        let id = self.next_request_id();
189        let params_json = serde_json::to_value(&params)
190            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194        let request = OutgoingRequest {
195            id: id.clone(),
196            method: "sampling/createMessage".to_string(),
197            params: params_json,
198            response_tx,
199        };
200
201        self.request_tx
202            .send(request)
203            .await
204            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206        let response = response_rx.await.map_err(|_| {
207            Error::Internal("Failed to receive response: channel closed".to_string())
208        })??;
209
210        serde_json::from_value(response)
211            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212    }
213
214    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215        let id = self.next_request_id();
216        let params_json = serde_json::to_value(&params)
217            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221        let request = OutgoingRequest {
222            id: id.clone(),
223            method: "elicitation/create".to_string(),
224            params: params_json,
225            response_tx,
226        };
227
228        self.request_tx
229            .send(request)
230            .await
231            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233        let response = response_rx.await.map_err(|_| {
234            Error::Internal("Failed to receive response: channel closed".to_string())
235        })??;
236
237        serde_json::from_value(response)
238            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239    }
240}
241
242/// Context for a request, providing progress, cancellation, and client request support
243#[derive(Clone)]
244pub struct RequestContext {
245    /// The request ID
246    request_id: RequestId,
247    /// Progress token (if provided by client)
248    progress_token: Option<ProgressToken>,
249    /// Cancellation flag
250    cancelled: Arc<AtomicBool>,
251    /// Channel for sending notifications
252    notification_tx: Option<NotificationSender>,
253    /// Handle for sending requests to the client (for sampling, etc.)
254    client_requester: Option<ClientRequesterHandle>,
255    /// Extensions for passing data from router/middleware to handlers
256    extensions: Arc<Extensions>,
257}
258
259/// Type-erased extensions map for passing data to handlers.
260///
261/// Extensions allow router-level state and middleware-injected data to flow
262/// to tool handlers via the `Extension<T>` extractor.
263#[derive(Clone, Default)]
264pub struct Extensions {
265    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
266}
267
268impl Extensions {
269    /// Create an empty extensions map.
270    pub fn new() -> Self {
271        Self::default()
272    }
273
274    /// Insert a value into the extensions map.
275    ///
276    /// If a value of the same type already exists, it is replaced.
277    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
278        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
279    }
280
281    /// Get a reference to a value in the extensions map.
282    ///
283    /// Returns `None` if no value of the given type has been inserted.
284    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
285        self.map
286            .get(&std::any::TypeId::of::<T>())
287            .and_then(|val| val.downcast_ref::<T>())
288    }
289
290    /// Check if the extensions map contains a value of the given type.
291    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
292        self.map.contains_key(&std::any::TypeId::of::<T>())
293    }
294
295    /// Merge another extensions map into this one.
296    ///
297    /// Values from `other` will overwrite existing values of the same type.
298    pub fn merge(&mut self, other: &Extensions) {
299        for (k, v) in &other.map {
300            self.map.insert(*k, v.clone());
301        }
302    }
303}
304
305impl std::fmt::Debug for Extensions {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        f.debug_struct("Extensions")
308            .field("len", &self.map.len())
309            .finish()
310    }
311}
312
313impl std::fmt::Debug for RequestContext {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("RequestContext")
316            .field("request_id", &self.request_id)
317            .field("progress_token", &self.progress_token)
318            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
319            .finish()
320    }
321}
322
323impl RequestContext {
324    /// Create a new request context
325    pub fn new(request_id: RequestId) -> Self {
326        Self {
327            request_id,
328            progress_token: None,
329            cancelled: Arc::new(AtomicBool::new(false)),
330            notification_tx: None,
331            client_requester: None,
332            extensions: Arc::new(Extensions::new()),
333        }
334    }
335
336    /// Set the progress token
337    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
338        self.progress_token = Some(token);
339        self
340    }
341
342    /// Set the notification sender
343    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
344        self.notification_tx = Some(tx);
345        self
346    }
347
348    /// Set the client requester for server-to-client requests
349    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
350        self.client_requester = Some(requester);
351        self
352    }
353
354    /// Set the extensions for this request context.
355    ///
356    /// Extensions allow router-level state and middleware data to flow to handlers.
357    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
358        self.extensions = extensions;
359        self
360    }
361
362    /// Get a reference to a value from the extensions map.
363    ///
364    /// Returns `None` if no value of the given type has been inserted.
365    ///
366    /// # Example
367    ///
368    /// ```rust,ignore
369    /// #[derive(Clone)]
370    /// struct CurrentUser { id: String }
371    ///
372    /// // In a handler:
373    /// if let Some(user) = ctx.extension::<CurrentUser>() {
374    ///     println!("User: {}", user.id);
375    /// }
376    /// ```
377    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
378        self.extensions.get::<T>()
379    }
380
381    /// Get a mutable reference to the extensions.
382    ///
383    /// This allows middleware to insert data that handlers can access via
384    /// the `Extension<T>` extractor.
385    pub fn extensions_mut(&mut self) -> &mut Extensions {
386        Arc::make_mut(&mut self.extensions)
387    }
388
389    /// Get a reference to the extensions.
390    pub fn extensions(&self) -> &Extensions {
391        &self.extensions
392    }
393
394    /// Get the request ID
395    pub fn request_id(&self) -> &RequestId {
396        &self.request_id
397    }
398
399    /// Get the progress token (if any)
400    pub fn progress_token(&self) -> Option<&ProgressToken> {
401        self.progress_token.as_ref()
402    }
403
404    /// Check if the request has been cancelled
405    pub fn is_cancelled(&self) -> bool {
406        self.cancelled.load(Ordering::Relaxed)
407    }
408
409    /// Mark the request as cancelled
410    pub fn cancel(&self) {
411        self.cancelled.store(true, Ordering::Relaxed);
412    }
413
414    /// Get a cancellation token that can be shared
415    pub fn cancellation_token(&self) -> CancellationToken {
416        CancellationToken {
417            cancelled: self.cancelled.clone(),
418        }
419    }
420
421    /// Report progress to the client
422    ///
423    /// This is a no-op if no progress token was provided or no notification sender is configured.
424    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
425        let Some(token) = &self.progress_token else {
426            return;
427        };
428        let Some(tx) = &self.notification_tx else {
429            return;
430        };
431
432        let params = ProgressParams {
433            progress_token: token.clone(),
434            progress,
435            total,
436            message: message.map(|s| s.to_string()),
437        };
438
439        // Best effort - don't block if channel is full
440        let _ = tx.try_send(ServerNotification::Progress(params));
441    }
442
443    /// Report progress synchronously (non-async version)
444    ///
445    /// This is a no-op if no progress token was provided or no notification sender is configured.
446    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
447        let Some(token) = &self.progress_token else {
448            return;
449        };
450        let Some(tx) = &self.notification_tx else {
451            return;
452        };
453
454        let params = ProgressParams {
455            progress_token: token.clone(),
456            progress,
457            total,
458            message: message.map(|s| s.to_string()),
459        };
460
461        let _ = tx.try_send(ServerNotification::Progress(params));
462    }
463
464    /// Send a log message notification to the client
465    ///
466    /// This is a no-op if no notification sender is configured.
467    ///
468    /// # Example
469    ///
470    /// ```rust,ignore
471    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
472    ///
473    /// async fn my_tool(ctx: RequestContext) {
474    ///     ctx.send_log(
475    ///         LoggingMessageParams::new(LogLevel::Info)
476    ///             .with_logger("my-tool")
477    ///             .with_data(serde_json::json!("Processing..."))
478    ///     );
479    /// }
480    /// ```
481    pub fn send_log(&self, params: LoggingMessageParams) {
482        let Some(tx) = &self.notification_tx else {
483            return;
484        };
485
486        let _ = tx.try_send(ServerNotification::LogMessage(params));
487    }
488
489    /// Check if sampling is available
490    ///
491    /// Returns true if a client requester is configured and the transport
492    /// supports bidirectional communication.
493    pub fn can_sample(&self) -> bool {
494        self.client_requester.is_some()
495    }
496
497    /// Request an LLM completion from the client
498    ///
499    /// This sends a `sampling/createMessage` request to the client and waits
500    /// for the response. The client is expected to forward this to an LLM
501    /// and return the result.
502    ///
503    /// Returns an error if sampling is not available (no client requester configured).
504    ///
505    /// # Example
506    ///
507    /// ```rust,ignore
508    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
509    ///
510    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
511    ///     let params = CreateMessageParams::new(
512    ///         vec![SamplingMessage::user("Summarize: ...")],
513    ///         500,
514    ///     );
515    ///
516    ///     let result = ctx.sample(params).await?;
517    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
518    /// }
519    /// ```
520    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
521        let requester = self.client_requester.as_ref().ok_or_else(|| {
522            Error::Internal("Sampling not available: no client requester configured".to_string())
523        })?;
524
525        requester.sample(params).await
526    }
527
528    /// Check if elicitation is available
529    ///
530    /// Returns true if a client requester is configured and the transport
531    /// supports bidirectional communication. Note that this only checks if
532    /// the mechanism is available, not whether the client supports elicitation.
533    pub fn can_elicit(&self) -> bool {
534        self.client_requester.is_some()
535    }
536
537    /// Request user input via a form from the client
538    ///
539    /// This sends an `elicitation/create` request to the client with a form schema.
540    /// The client renders the form to the user and returns their response.
541    ///
542    /// Returns an error if elicitation is not available (no client requester configured).
543    ///
544    /// # Example
545    ///
546    /// ```rust,ignore
547    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
548    ///
549    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
550    ///     let params = ElicitFormParams {
551    ///         mode: ElicitMode::Form,
552    ///         message: "Please enter your details".to_string(),
553    ///         requested_schema: ElicitFormSchema::new()
554    ///             .string_field("name", Some("Your name"), true),
555    ///         meta: None,
556    ///     };
557    ///
558    ///     let result = ctx.elicit_form(params).await?;
559    ///     match result.action {
560    ///         ElicitAction::Accept => {
561    ///             // Use result.content
562    ///             Ok(CallToolResult::text("Got your input!"))
563    ///         }
564    ///         _ => Ok(CallToolResult::text("User declined"))
565    ///     }
566    /// }
567    /// ```
568    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
569        let requester = self.client_requester.as_ref().ok_or_else(|| {
570            Error::Internal("Elicitation not available: no client requester configured".to_string())
571        })?;
572
573        requester.elicit(ElicitRequestParams::Form(params)).await
574    }
575
576    /// Request user input via URL redirect from the client
577    ///
578    /// This sends an `elicitation/create` request to the client with a URL.
579    /// The client directs the user to the URL for out-of-band input collection.
580    /// The server receives the result via a callback notification.
581    ///
582    /// Returns an error if elicitation is not available (no client requester configured).
583    ///
584    /// # Example
585    ///
586    /// ```rust,ignore
587    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
588    ///
589    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
590    ///     let params = ElicitUrlParams {
591    ///         mode: ElicitMode::Url,
592    ///         elicitation_id: "unique-id-123".to_string(),
593    ///         message: "Please authorize via the link".to_string(),
594    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
595    ///         meta: None,
596    ///     };
597    ///
598    ///     let result = ctx.elicit_url(params).await?;
599    ///     match result.action {
600    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
601    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
602    ///     }
603    /// }
604    /// ```
605    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
606        let requester = self.client_requester.as_ref().ok_or_else(|| {
607            Error::Internal("Elicitation not available: no client requester configured".to_string())
608        })?;
609
610        requester.elicit(ElicitRequestParams::Url(params)).await
611    }
612}
613
614/// A token that can be used to check for cancellation
615#[derive(Clone, Debug)]
616pub struct CancellationToken {
617    cancelled: Arc<AtomicBool>,
618}
619
620impl CancellationToken {
621    /// Check if cancellation has been requested
622    pub fn is_cancelled(&self) -> bool {
623        self.cancelled.load(Ordering::Relaxed)
624    }
625
626    /// Request cancellation
627    pub fn cancel(&self) {
628        self.cancelled.store(true, Ordering::Relaxed);
629    }
630}
631
632/// Builder for creating request contexts
633#[derive(Default)]
634pub struct RequestContextBuilder {
635    request_id: Option<RequestId>,
636    progress_token: Option<ProgressToken>,
637    notification_tx: Option<NotificationSender>,
638    client_requester: Option<ClientRequesterHandle>,
639}
640
641impl RequestContextBuilder {
642    /// Create a new builder
643    pub fn new() -> Self {
644        Self::default()
645    }
646
647    /// Set the request ID
648    pub fn request_id(mut self, id: RequestId) -> Self {
649        self.request_id = Some(id);
650        self
651    }
652
653    /// Set the progress token
654    pub fn progress_token(mut self, token: ProgressToken) -> Self {
655        self.progress_token = Some(token);
656        self
657    }
658
659    /// Set the notification sender
660    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
661        self.notification_tx = Some(tx);
662        self
663    }
664
665    /// Set the client requester for server-to-client requests
666    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
667        self.client_requester = Some(requester);
668        self
669    }
670
671    /// Build the request context
672    ///
673    /// Panics if request_id is not set.
674    pub fn build(self) -> RequestContext {
675        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
676        if let Some(token) = self.progress_token {
677            ctx = ctx.with_progress_token(token);
678        }
679        if let Some(tx) = self.notification_tx {
680            ctx = ctx.with_notification_sender(tx);
681        }
682        if let Some(requester) = self.client_requester {
683            ctx = ctx.with_client_requester(requester);
684        }
685        ctx
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[test]
694    fn test_cancellation() {
695        let ctx = RequestContext::new(RequestId::Number(1));
696        assert!(!ctx.is_cancelled());
697
698        let token = ctx.cancellation_token();
699        assert!(!token.is_cancelled());
700
701        ctx.cancel();
702        assert!(ctx.is_cancelled());
703        assert!(token.is_cancelled());
704    }
705
706    #[tokio::test]
707    async fn test_progress_reporting() {
708        let (tx, mut rx) = notification_channel(10);
709
710        let ctx = RequestContext::new(RequestId::Number(1))
711            .with_progress_token(ProgressToken::Number(42))
712            .with_notification_sender(tx);
713
714        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
715            .await;
716
717        let notification = rx.recv().await.unwrap();
718        match notification {
719            ServerNotification::Progress(params) => {
720                assert_eq!(params.progress, 50.0);
721                assert_eq!(params.total, Some(100.0));
722                assert_eq!(params.message.as_deref(), Some("Halfway"));
723            }
724            _ => panic!("Expected Progress notification"),
725        }
726    }
727
728    #[tokio::test]
729    async fn test_progress_no_token() {
730        let (tx, mut rx) = notification_channel(10);
731
732        // No progress token - should be a no-op
733        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
734
735        ctx.report_progress(50.0, Some(100.0), None).await;
736
737        // Channel should be empty
738        assert!(rx.try_recv().is_err());
739    }
740
741    #[test]
742    fn test_builder() {
743        let (tx, _rx) = notification_channel(10);
744
745        let ctx = RequestContextBuilder::new()
746            .request_id(RequestId::String("req-1".to_string()))
747            .progress_token(ProgressToken::String("prog-1".to_string()))
748            .notification_sender(tx)
749            .build();
750
751        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
752        assert!(ctx.progress_token().is_some());
753    }
754
755    #[test]
756    fn test_can_sample_without_requester() {
757        let ctx = RequestContext::new(RequestId::Number(1));
758        assert!(!ctx.can_sample());
759    }
760
761    #[test]
762    fn test_can_sample_with_requester() {
763        let (request_tx, _rx) = outgoing_request_channel(10);
764        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
765
766        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
767        assert!(ctx.can_sample());
768    }
769
770    #[tokio::test]
771    async fn test_sample_without_requester_fails() {
772        use crate::protocol::{CreateMessageParams, SamplingMessage};
773
774        let ctx = RequestContext::new(RequestId::Number(1));
775        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
776
777        let result = ctx.sample(params).await;
778        assert!(result.is_err());
779        assert!(
780            result
781                .unwrap_err()
782                .to_string()
783                .contains("Sampling not available")
784        );
785    }
786
787    #[test]
788    fn test_builder_with_client_requester() {
789        let (request_tx, _rx) = outgoing_request_channel(10);
790        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
791
792        let ctx = RequestContextBuilder::new()
793            .request_id(RequestId::Number(1))
794            .client_requester(requester)
795            .build();
796
797        assert!(ctx.can_sample());
798    }
799
800    #[test]
801    fn test_can_elicit_without_requester() {
802        let ctx = RequestContext::new(RequestId::Number(1));
803        assert!(!ctx.can_elicit());
804    }
805
806    #[test]
807    fn test_can_elicit_with_requester() {
808        let (request_tx, _rx) = outgoing_request_channel(10);
809        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
810
811        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
812        assert!(ctx.can_elicit());
813    }
814
815    #[tokio::test]
816    async fn test_elicit_form_without_requester_fails() {
817        use crate::protocol::{ElicitFormSchema, ElicitMode};
818
819        let ctx = RequestContext::new(RequestId::Number(1));
820        let params = ElicitFormParams {
821            mode: ElicitMode::Form,
822            message: "Enter details".to_string(),
823            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
824            meta: None,
825        };
826
827        let result = ctx.elicit_form(params).await;
828        assert!(result.is_err());
829        assert!(
830            result
831                .unwrap_err()
832                .to_string()
833                .contains("Elicitation not available")
834        );
835    }
836
837    #[tokio::test]
838    async fn test_elicit_url_without_requester_fails() {
839        use crate::protocol::ElicitMode;
840
841        let ctx = RequestContext::new(RequestId::Number(1));
842        let params = ElicitUrlParams {
843            mode: ElicitMode::Url,
844            elicitation_id: "test-123".to_string(),
845            message: "Please authorize".to_string(),
846            url: "https://example.com/auth".to_string(),
847            meta: None,
848        };
849
850        let result = ctx.elicit_url(params).await;
851        assert!(result.is_err());
852        assert!(
853            result
854                .unwrap_err()
855                .to_string()
856                .contains("Elicitation not available")
857        );
858    }
859}