Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: Some(ElicitMode::Form),
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81    ElicitUrlParams, LogLevel, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84/// A notification to be sent to the client
85#[derive(Debug, Clone)]
86pub enum ServerNotification {
87    /// Progress update for a request
88    Progress(ProgressParams),
89    /// Log message notification
90    LogMessage(LoggingMessageParams),
91    /// A subscribed resource has been updated
92    ResourceUpdated {
93        /// The URI of the updated resource
94        uri: String,
95    },
96    /// The list of available resources has changed
97    ResourcesListChanged,
98    /// The list of available tools has changed
99    ToolsListChanged,
100    /// The list of available prompts has changed
101    PromptsListChanged,
102    /// Task status has changed
103    TaskStatusChanged(crate::protocol::TaskStatusParams),
104}
105
106/// Sender for server notifications
107pub type NotificationSender = mpsc::Sender<ServerNotification>;
108
109/// Receiver for server notifications
110pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
111
112/// Create a new notification channel
113pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
114    mpsc::channel(buffer)
115}
116
117// =============================================================================
118// Client Requests (Server -> Client)
119// =============================================================================
120
121/// Trait for sending requests from server to client
122///
123/// This enables bidirectional communication where the server can request
124/// actions from the client, such as sampling (LLM requests) and elicitation
125/// (user input requests).
126#[async_trait]
127pub trait ClientRequester: Send + Sync {
128    /// Send a sampling request to the client
129    ///
130    /// Returns the LLM completion result from the client.
131    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
132
133    /// Send an elicitation request to the client
134    ///
135    /// This requests user input from the client. The request can be either
136    /// form-based (structured input) or URL-based (redirect to external URL).
137    ///
138    /// Returns the elicitation result with the user's action and any submitted data.
139    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
140}
141
142/// A clonable handle to a client requester
143pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
144
145/// Outgoing request to be sent to the client
146#[derive(Debug)]
147pub struct OutgoingRequest {
148    /// The JSON-RPC request ID
149    pub id: RequestId,
150    /// The method name
151    pub method: String,
152    /// The request parameters as JSON
153    pub params: serde_json::Value,
154    /// Channel to send the response back
155    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
156}
157
158/// Sender for outgoing requests to the client
159pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
160
161/// Receiver for outgoing requests (used by transport)
162pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
163
164/// Create a new outgoing request channel
165pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
166    mpsc::channel(buffer)
167}
168
169/// A client requester implementation that sends requests through a channel
170#[derive(Clone)]
171pub struct ChannelClientRequester {
172    request_tx: OutgoingRequestSender,
173    next_id: Arc<AtomicI64>,
174}
175
176impl ChannelClientRequester {
177    /// Create a new channel-based client requester
178    pub fn new(request_tx: OutgoingRequestSender) -> Self {
179        Self {
180            request_tx,
181            next_id: Arc::new(AtomicI64::new(1)),
182        }
183    }
184
185    fn next_request_id(&self) -> RequestId {
186        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
187        RequestId::Number(id)
188    }
189}
190
191#[async_trait]
192impl ClientRequester for ChannelClientRequester {
193    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
194        let id = self.next_request_id();
195        let params_json = serde_json::to_value(&params)
196            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
197
198        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
199
200        let request = OutgoingRequest {
201            id: id.clone(),
202            method: "sampling/createMessage".to_string(),
203            params: params_json,
204            response_tx,
205        };
206
207        self.request_tx
208            .send(request)
209            .await
210            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
211
212        let response = response_rx.await.map_err(|_| {
213            Error::Internal("Failed to receive response: channel closed".to_string())
214        })??;
215
216        serde_json::from_value(response)
217            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
218    }
219
220    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
221        let id = self.next_request_id();
222        let params_json = serde_json::to_value(&params)
223            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
224
225        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
226
227        let request = OutgoingRequest {
228            id: id.clone(),
229            method: "elicitation/create".to_string(),
230            params: params_json,
231            response_tx,
232        };
233
234        self.request_tx
235            .send(request)
236            .await
237            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
238
239        let response = response_rx.await.map_err(|_| {
240            Error::Internal("Failed to receive response: channel closed".to_string())
241        })??;
242
243        serde_json::from_value(response)
244            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
245    }
246}
247
248/// Context for a request, providing progress, cancellation, and client request support
249#[derive(Clone)]
250pub struct RequestContext {
251    /// The request ID
252    request_id: RequestId,
253    /// Progress token (if provided by client)
254    progress_token: Option<ProgressToken>,
255    /// Cancellation flag
256    cancelled: Arc<AtomicBool>,
257    /// Channel for sending notifications
258    notification_tx: Option<NotificationSender>,
259    /// Handle for sending requests to the client (for sampling, etc.)
260    client_requester: Option<ClientRequesterHandle>,
261    /// Extensions for passing data from router/middleware to handlers
262    extensions: Arc<Extensions>,
263    /// Minimum log level set by the client (shared with router for dynamic updates)
264    min_log_level: Option<Arc<RwLock<LogLevel>>>,
265}
266
267/// Type-erased extensions map for passing data to handlers.
268///
269/// Extensions allow router-level state and middleware-injected data to flow
270/// to tool handlers via the `Extension<T>` extractor.
271#[derive(Clone, Default)]
272pub struct Extensions {
273    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
274}
275
276impl Extensions {
277    /// Create an empty extensions map.
278    pub fn new() -> Self {
279        Self::default()
280    }
281
282    /// Insert a value into the extensions map.
283    ///
284    /// If a value of the same type already exists, it is replaced.
285    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
286        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
287    }
288
289    /// Get a reference to a value in the extensions map.
290    ///
291    /// Returns `None` if no value of the given type has been inserted.
292    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
293        self.map
294            .get(&std::any::TypeId::of::<T>())
295            .and_then(|val| val.downcast_ref::<T>())
296    }
297
298    /// Check if the extensions map contains a value of the given type.
299    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
300        self.map.contains_key(&std::any::TypeId::of::<T>())
301    }
302
303    /// Merge another extensions map into this one.
304    ///
305    /// Values from `other` will overwrite existing values of the same type.
306    pub fn merge(&mut self, other: &Extensions) {
307        for (k, v) in &other.map {
308            self.map.insert(*k, v.clone());
309        }
310    }
311}
312
313impl std::fmt::Debug for Extensions {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("Extensions")
316            .field("len", &self.map.len())
317            .finish()
318    }
319}
320
321impl std::fmt::Debug for RequestContext {
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        f.debug_struct("RequestContext")
324            .field("request_id", &self.request_id)
325            .field("progress_token", &self.progress_token)
326            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
327            .finish()
328    }
329}
330
331impl RequestContext {
332    /// Create a new request context
333    pub fn new(request_id: RequestId) -> Self {
334        Self {
335            request_id,
336            progress_token: None,
337            cancelled: Arc::new(AtomicBool::new(false)),
338            notification_tx: None,
339            client_requester: None,
340            extensions: Arc::new(Extensions::new()),
341            min_log_level: None,
342        }
343    }
344
345    /// Set the progress token
346    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
347        self.progress_token = Some(token);
348        self
349    }
350
351    /// Set the notification sender
352    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
353        self.notification_tx = Some(tx);
354        self
355    }
356
357    /// Set the minimum log level for filtering outgoing log notifications
358    ///
359    /// This is shared with the router so that `logging/setLevel` updates
360    /// are immediately visible to all request contexts.
361    pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
362        self.min_log_level = Some(level);
363        self
364    }
365
366    /// Set the client requester for server-to-client requests
367    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
368        self.client_requester = Some(requester);
369        self
370    }
371
372    /// Set the extensions for this request context.
373    ///
374    /// Extensions allow router-level state and middleware data to flow to handlers.
375    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
376        self.extensions = extensions;
377        self
378    }
379
380    /// Get a reference to a value from the extensions map.
381    ///
382    /// Returns `None` if no value of the given type has been inserted.
383    ///
384    /// # Example
385    ///
386    /// ```rust,ignore
387    /// #[derive(Clone)]
388    /// struct CurrentUser { id: String }
389    ///
390    /// // In a handler:
391    /// if let Some(user) = ctx.extension::<CurrentUser>() {
392    ///     println!("User: {}", user.id);
393    /// }
394    /// ```
395    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
396        self.extensions.get::<T>()
397    }
398
399    /// Get a mutable reference to the extensions.
400    ///
401    /// This allows middleware to insert data that handlers can access via
402    /// the `Extension<T>` extractor.
403    pub fn extensions_mut(&mut self) -> &mut Extensions {
404        Arc::make_mut(&mut self.extensions)
405    }
406
407    /// Get a reference to the extensions.
408    pub fn extensions(&self) -> &Extensions {
409        &self.extensions
410    }
411
412    /// Get the request ID
413    pub fn request_id(&self) -> &RequestId {
414        &self.request_id
415    }
416
417    /// Get the progress token (if any)
418    pub fn progress_token(&self) -> Option<&ProgressToken> {
419        self.progress_token.as_ref()
420    }
421
422    /// Check if the request has been cancelled
423    pub fn is_cancelled(&self) -> bool {
424        self.cancelled.load(Ordering::Relaxed)
425    }
426
427    /// Mark the request as cancelled
428    pub fn cancel(&self) {
429        self.cancelled.store(true, Ordering::Relaxed);
430    }
431
432    /// Get a cancellation token that can be shared
433    pub fn cancellation_token(&self) -> CancellationToken {
434        CancellationToken {
435            cancelled: self.cancelled.clone(),
436        }
437    }
438
439    /// Report progress to the client
440    ///
441    /// This is a no-op if no progress token was provided or no notification sender is configured.
442    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
443        let Some(token) = &self.progress_token else {
444            return;
445        };
446        let Some(tx) = &self.notification_tx else {
447            return;
448        };
449
450        let params = ProgressParams {
451            progress_token: token.clone(),
452            progress,
453            total,
454            message: message.map(|s| s.to_string()),
455            meta: None,
456        };
457
458        // Best effort - don't block if channel is full
459        let _ = tx.try_send(ServerNotification::Progress(params));
460    }
461
462    /// Report progress synchronously (non-async version)
463    ///
464    /// This is a no-op if no progress token was provided or no notification sender is configured.
465    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
466        let Some(token) = &self.progress_token else {
467            return;
468        };
469        let Some(tx) = &self.notification_tx else {
470            return;
471        };
472
473        let params = ProgressParams {
474            progress_token: token.clone(),
475            progress,
476            total,
477            message: message.map(|s| s.to_string()),
478            meta: None,
479        };
480
481        let _ = tx.try_send(ServerNotification::Progress(params));
482    }
483
484    /// Send a log message notification to the client
485    ///
486    /// This is a no-op if no notification sender is configured.
487    ///
488    /// # Example
489    ///
490    /// ```rust,ignore
491    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
492    ///
493    /// async fn my_tool(ctx: RequestContext) {
494    ///     ctx.send_log(
495    ///         LoggingMessageParams::new(LogLevel::Info, serde_json::json!("Processing..."))
496    ///             .with_logger("my-tool")
497    ///     );
498    /// }
499    /// ```
500    pub fn send_log(&self, params: LoggingMessageParams) {
501        let Some(tx) = &self.notification_tx else {
502            return;
503        };
504
505        // Filter by minimum log level set via logging/setLevel
506        // LogLevel derives Ord with Emergency < Alert < ... < Debug,
507        // so a message passes if its severity is at least the minimum
508        // (i.e., its ordinal is <= the minimum level's ordinal).
509        if let Some(min_level) = &self.min_log_level
510            && let Ok(min) = min_level.read()
511            && params.level > *min
512        {
513            return;
514        }
515
516        let _ = tx.try_send(ServerNotification::LogMessage(params));
517    }
518
519    /// Check if sampling is available
520    ///
521    /// Returns true if a client requester is configured and the transport
522    /// supports bidirectional communication.
523    pub fn can_sample(&self) -> bool {
524        self.client_requester.is_some()
525    }
526
527    /// Request an LLM completion from the client
528    ///
529    /// This sends a `sampling/createMessage` request to the client and waits
530    /// for the response. The client is expected to forward this to an LLM
531    /// and return the result.
532    ///
533    /// Returns an error if sampling is not available (no client requester configured).
534    ///
535    /// # Example
536    ///
537    /// ```rust,ignore
538    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
539    ///
540    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
541    ///     let params = CreateMessageParams::new(
542    ///         vec![SamplingMessage::user("Summarize: ...")],
543    ///         500,
544    ///     );
545    ///
546    ///     let result = ctx.sample(params).await?;
547    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
548    /// }
549    /// ```
550    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
551        let requester = self.client_requester.as_ref().ok_or_else(|| {
552            Error::Internal("Sampling not available: no client requester configured".to_string())
553        })?;
554
555        requester.sample(params).await
556    }
557
558    /// Check if elicitation is available
559    ///
560    /// Returns true if a client requester is configured and the transport
561    /// supports bidirectional communication. Note that this only checks if
562    /// the mechanism is available, not whether the client supports elicitation.
563    pub fn can_elicit(&self) -> bool {
564        self.client_requester.is_some()
565    }
566
567    /// Request user input via a form from the client
568    ///
569    /// This sends an `elicitation/create` request to the client with a form schema.
570    /// The client renders the form to the user and returns their response.
571    ///
572    /// Returns an error if elicitation is not available (no client requester configured).
573    ///
574    /// # Example
575    ///
576    /// ```rust,ignore
577    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
578    ///
579    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
580    ///     let params = ElicitFormParams {
581    ///         mode: Some(ElicitMode::Form),
582    ///         message: "Please enter your details".to_string(),
583    ///         requested_schema: ElicitFormSchema::new()
584    ///             .string_field("name", Some("Your name"), true),
585    ///         meta: None,
586    ///     };
587    ///
588    ///     let result = ctx.elicit_form(params).await?;
589    ///     match result.action {
590    ///         ElicitAction::Accept => {
591    ///             // Use result.content
592    ///             Ok(CallToolResult::text("Got your input!"))
593    ///         }
594    ///         _ => Ok(CallToolResult::text("User declined"))
595    ///     }
596    /// }
597    /// ```
598    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
599        let requester = self.client_requester.as_ref().ok_or_else(|| {
600            Error::Internal("Elicitation not available: no client requester configured".to_string())
601        })?;
602
603        requester.elicit(ElicitRequestParams::Form(params)).await
604    }
605
606    /// Request user input via URL redirect from the client
607    ///
608    /// This sends an `elicitation/create` request to the client with a URL.
609    /// The client directs the user to the URL for out-of-band input collection.
610    /// The server receives the result via a callback notification.
611    ///
612    /// Returns an error if elicitation is not available (no client requester configured).
613    ///
614    /// # Example
615    ///
616    /// ```rust,ignore
617    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
618    ///
619    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
620    ///     let params = ElicitUrlParams {
621    ///         mode: Some(ElicitMode::Url),
622    ///         elicitation_id: "unique-id-123".to_string(),
623    ///         message: "Please authorize via the link".to_string(),
624    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
625    ///         meta: None,
626    ///     };
627    ///
628    ///     let result = ctx.elicit_url(params).await?;
629    ///     match result.action {
630    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
631    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
632    ///     }
633    /// }
634    /// ```
635    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
636        let requester = self.client_requester.as_ref().ok_or_else(|| {
637            Error::Internal("Elicitation not available: no client requester configured".to_string())
638        })?;
639
640        requester.elicit(ElicitRequestParams::Url(params)).await
641    }
642
643    /// Request simple confirmation from the user.
644    ///
645    /// This is a convenience method for simple yes/no confirmation dialogs.
646    /// It creates an elicitation form with a single boolean "confirm" field
647    /// and returns `true` if the user accepts, `false` otherwise.
648    ///
649    /// Returns an error if elicitation is not available (no client requester configured).
650    ///
651    /// # Example
652    ///
653    /// ```rust,ignore
654    /// use tower_mcp::{RequestContext, CallToolResult};
655    ///
656    /// async fn delete_item(ctx: RequestContext) -> Result<CallToolResult> {
657    ///     let confirmed = ctx.confirm("Are you sure you want to delete this item?").await?;
658    ///     if confirmed {
659    ///         // Perform deletion
660    ///         Ok(CallToolResult::text("Item deleted"))
661    ///     } else {
662    ///         Ok(CallToolResult::text("Deletion cancelled"))
663    ///     }
664    /// }
665    /// ```
666    pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
667        use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
668
669        let params = ElicitFormParams {
670            mode: Some(ElicitMode::Form),
671            message: message.into(),
672            requested_schema: ElicitFormSchema::new().boolean_field_with_default(
673                "confirm",
674                Some("Confirm this action"),
675                true,
676                false,
677            ),
678            meta: None,
679        };
680
681        let result = self.elicit_form(params).await?;
682        Ok(result.action == ElicitAction::Accept)
683    }
684}
685
686/// A token that can be used to check for cancellation
687#[derive(Clone, Debug)]
688pub struct CancellationToken {
689    cancelled: Arc<AtomicBool>,
690}
691
692impl CancellationToken {
693    /// Check if cancellation has been requested
694    pub fn is_cancelled(&self) -> bool {
695        self.cancelled.load(Ordering::Relaxed)
696    }
697
698    /// Request cancellation
699    pub fn cancel(&self) {
700        self.cancelled.store(true, Ordering::Relaxed);
701    }
702}
703
704/// Builder for creating request contexts
705#[derive(Default)]
706pub struct RequestContextBuilder {
707    request_id: Option<RequestId>,
708    progress_token: Option<ProgressToken>,
709    notification_tx: Option<NotificationSender>,
710    client_requester: Option<ClientRequesterHandle>,
711    min_log_level: Option<Arc<RwLock<LogLevel>>>,
712}
713
714impl RequestContextBuilder {
715    /// Create a new builder
716    pub fn new() -> Self {
717        Self::default()
718    }
719
720    /// Set the request ID
721    pub fn request_id(mut self, id: RequestId) -> Self {
722        self.request_id = Some(id);
723        self
724    }
725
726    /// Set the progress token
727    pub fn progress_token(mut self, token: ProgressToken) -> Self {
728        self.progress_token = Some(token);
729        self
730    }
731
732    /// Set the notification sender
733    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
734        self.notification_tx = Some(tx);
735        self
736    }
737
738    /// Set the client requester for server-to-client requests
739    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
740        self.client_requester = Some(requester);
741        self
742    }
743
744    /// Set the minimum log level for filtering
745    pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
746        self.min_log_level = Some(level);
747        self
748    }
749
750    /// Build the request context
751    ///
752    /// Panics if request_id is not set.
753    pub fn build(self) -> RequestContext {
754        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
755        if let Some(token) = self.progress_token {
756            ctx = ctx.with_progress_token(token);
757        }
758        if let Some(tx) = self.notification_tx {
759            ctx = ctx.with_notification_sender(tx);
760        }
761        if let Some(requester) = self.client_requester {
762            ctx = ctx.with_client_requester(requester);
763        }
764        if let Some(level) = self.min_log_level {
765            ctx = ctx.with_min_log_level(level);
766        }
767        ctx
768    }
769}
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774
775    #[test]
776    fn test_cancellation() {
777        let ctx = RequestContext::new(RequestId::Number(1));
778        assert!(!ctx.is_cancelled());
779
780        let token = ctx.cancellation_token();
781        assert!(!token.is_cancelled());
782
783        ctx.cancel();
784        assert!(ctx.is_cancelled());
785        assert!(token.is_cancelled());
786    }
787
788    #[tokio::test]
789    async fn test_progress_reporting() {
790        let (tx, mut rx) = notification_channel(10);
791
792        let ctx = RequestContext::new(RequestId::Number(1))
793            .with_progress_token(ProgressToken::Number(42))
794            .with_notification_sender(tx);
795
796        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
797            .await;
798
799        let notification = rx.recv().await.unwrap();
800        match notification {
801            ServerNotification::Progress(params) => {
802                assert_eq!(params.progress, 50.0);
803                assert_eq!(params.total, Some(100.0));
804                assert_eq!(params.message.as_deref(), Some("Halfway"));
805            }
806            _ => panic!("Expected Progress notification"),
807        }
808    }
809
810    #[tokio::test]
811    async fn test_progress_no_token() {
812        let (tx, mut rx) = notification_channel(10);
813
814        // No progress token - should be a no-op
815        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
816
817        ctx.report_progress(50.0, Some(100.0), None).await;
818
819        // Channel should be empty
820        assert!(rx.try_recv().is_err());
821    }
822
823    #[test]
824    fn test_builder() {
825        let (tx, _rx) = notification_channel(10);
826
827        let ctx = RequestContextBuilder::new()
828            .request_id(RequestId::String("req-1".to_string()))
829            .progress_token(ProgressToken::String("prog-1".to_string()))
830            .notification_sender(tx)
831            .build();
832
833        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
834        assert!(ctx.progress_token().is_some());
835    }
836
837    #[test]
838    fn test_can_sample_without_requester() {
839        let ctx = RequestContext::new(RequestId::Number(1));
840        assert!(!ctx.can_sample());
841    }
842
843    #[test]
844    fn test_can_sample_with_requester() {
845        let (request_tx, _rx) = outgoing_request_channel(10);
846        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
847
848        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
849        assert!(ctx.can_sample());
850    }
851
852    #[tokio::test]
853    async fn test_sample_without_requester_fails() {
854        use crate::protocol::{CreateMessageParams, SamplingMessage};
855
856        let ctx = RequestContext::new(RequestId::Number(1));
857        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
858
859        let result = ctx.sample(params).await;
860        assert!(result.is_err());
861        assert!(
862            result
863                .unwrap_err()
864                .to_string()
865                .contains("Sampling not available")
866        );
867    }
868
869    #[test]
870    fn test_builder_with_client_requester() {
871        let (request_tx, _rx) = outgoing_request_channel(10);
872        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
873
874        let ctx = RequestContextBuilder::new()
875            .request_id(RequestId::Number(1))
876            .client_requester(requester)
877            .build();
878
879        assert!(ctx.can_sample());
880    }
881
882    #[test]
883    fn test_can_elicit_without_requester() {
884        let ctx = RequestContext::new(RequestId::Number(1));
885        assert!(!ctx.can_elicit());
886    }
887
888    #[test]
889    fn test_can_elicit_with_requester() {
890        let (request_tx, _rx) = outgoing_request_channel(10);
891        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
892
893        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
894        assert!(ctx.can_elicit());
895    }
896
897    #[tokio::test]
898    async fn test_elicit_form_without_requester_fails() {
899        use crate::protocol::{ElicitFormSchema, ElicitMode};
900
901        let ctx = RequestContext::new(RequestId::Number(1));
902        let params = ElicitFormParams {
903            mode: Some(ElicitMode::Form),
904            message: "Enter details".to_string(),
905            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
906            meta: None,
907        };
908
909        let result = ctx.elicit_form(params).await;
910        assert!(result.is_err());
911        assert!(
912            result
913                .unwrap_err()
914                .to_string()
915                .contains("Elicitation not available")
916        );
917    }
918
919    #[tokio::test]
920    async fn test_elicit_url_without_requester_fails() {
921        use crate::protocol::ElicitMode;
922
923        let ctx = RequestContext::new(RequestId::Number(1));
924        let params = ElicitUrlParams {
925            mode: Some(ElicitMode::Url),
926            elicitation_id: "test-123".to_string(),
927            message: "Please authorize".to_string(),
928            url: "https://example.com/auth".to_string(),
929            meta: None,
930        };
931
932        let result = ctx.elicit_url(params).await;
933        assert!(result.is_err());
934        assert!(
935            result
936                .unwrap_err()
937                .to_string()
938                .contains("Elicitation not available")
939        );
940    }
941
942    #[tokio::test]
943    async fn test_confirm_without_requester_fails() {
944        let ctx = RequestContext::new(RequestId::Number(1));
945
946        let result = ctx.confirm("Are you sure?").await;
947        assert!(result.is_err());
948        assert!(
949            result
950                .unwrap_err()
951                .to_string()
952                .contains("Elicitation not available")
953        );
954    }
955
956    #[tokio::test]
957    async fn test_send_log_filtered_by_level() {
958        let (tx, mut rx) = notification_channel(10);
959        let min_level = Arc::new(RwLock::new(LogLevel::Warning));
960
961        let ctx = RequestContext::new(RequestId::Number(1))
962            .with_notification_sender(tx)
963            .with_min_log_level(min_level.clone());
964
965        // Error is more severe than Warning — should pass through
966        ctx.send_log(LoggingMessageParams::new(
967            LogLevel::Error,
968            serde_json::Value::Null,
969        ));
970        let msg = rx.try_recv();
971        assert!(msg.is_ok(), "Error should pass through Warning filter");
972
973        // Warning is equal to min level — should pass through
974        ctx.send_log(LoggingMessageParams::new(
975            LogLevel::Warning,
976            serde_json::Value::Null,
977        ));
978        let msg = rx.try_recv();
979        assert!(msg.is_ok(), "Warning should pass through Warning filter");
980
981        // Info is less severe than Warning — should be filtered
982        ctx.send_log(LoggingMessageParams::new(
983            LogLevel::Info,
984            serde_json::Value::Null,
985        ));
986        let msg = rx.try_recv();
987        assert!(msg.is_err(), "Info should be filtered by Warning filter");
988
989        // Debug is less severe than Warning — should be filtered
990        ctx.send_log(LoggingMessageParams::new(
991            LogLevel::Debug,
992            serde_json::Value::Null,
993        ));
994        let msg = rx.try_recv();
995        assert!(msg.is_err(), "Debug should be filtered by Warning filter");
996    }
997
998    #[tokio::test]
999    async fn test_send_log_level_updates_dynamically() {
1000        let (tx, mut rx) = notification_channel(10);
1001        let min_level = Arc::new(RwLock::new(LogLevel::Error));
1002
1003        let ctx = RequestContext::new(RequestId::Number(1))
1004            .with_notification_sender(tx)
1005            .with_min_log_level(min_level.clone());
1006
1007        // Info should be filtered at Error level
1008        ctx.send_log(LoggingMessageParams::new(
1009            LogLevel::Info,
1010            serde_json::Value::Null,
1011        ));
1012        assert!(
1013            rx.try_recv().is_err(),
1014            "Info should be filtered at Error level"
1015        );
1016
1017        // Dynamically update to Debug (most permissive)
1018        *min_level.write().unwrap() = LogLevel::Debug;
1019
1020        // Now Info should pass through
1021        ctx.send_log(LoggingMessageParams::new(
1022            LogLevel::Info,
1023            serde_json::Value::Null,
1024        ));
1025        assert!(
1026            rx.try_recv().is_ok(),
1027            "Info should pass through after level changed to Debug"
1028        );
1029    }
1030
1031    #[tokio::test]
1032    async fn test_send_log_no_min_level_sends_all() {
1033        let (tx, mut rx) = notification_channel(10);
1034
1035        // No min_log_level set — all messages should pass through
1036        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1037
1038        ctx.send_log(LoggingMessageParams::new(
1039            LogLevel::Debug,
1040            serde_json::Value::Null,
1041        ));
1042        assert!(
1043            rx.try_recv().is_ok(),
1044            "Debug should pass when no min level is set"
1045        );
1046    }
1047}