Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: Some(ElicitMode::Form),
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81    ElicitUrlParams, LogLevel, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84/// A notification to be sent to the client
85#[derive(Debug, Clone)]
86#[non_exhaustive]
87pub enum ServerNotification {
88    /// Progress update for a request
89    Progress(ProgressParams),
90    /// Log message notification
91    LogMessage(LoggingMessageParams),
92    /// A subscribed resource has been updated
93    ResourceUpdated {
94        /// The URI of the updated resource
95        uri: String,
96    },
97    /// The list of available resources has changed
98    ResourcesListChanged,
99    /// The list of available tools has changed
100    ToolsListChanged,
101    /// The list of available prompts has changed
102    PromptsListChanged,
103    /// Task status has changed
104    TaskStatusChanged(crate::protocol::TaskStatusParams),
105}
106
107/// Sender for server notifications
108pub type NotificationSender = mpsc::Sender<ServerNotification>;
109
110/// Receiver for server notifications
111pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
112
113/// Create a new notification channel
114pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
115    mpsc::channel(buffer)
116}
117
118// =============================================================================
119// Client Requests (Server -> Client)
120// =============================================================================
121
122/// Trait for sending requests from server to client
123///
124/// This enables bidirectional communication where the server can request
125/// actions from the client, such as sampling (LLM requests) and elicitation
126/// (user input requests).
127#[async_trait]
128pub trait ClientRequester: Send + Sync {
129    /// Send a sampling request to the client
130    ///
131    /// Returns the LLM completion result from the client.
132    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
133
134    /// Send an elicitation request to the client
135    ///
136    /// This requests user input from the client. The request can be either
137    /// form-based (structured input) or URL-based (redirect to external URL).
138    ///
139    /// Returns the elicitation result with the user's action and any submitted data.
140    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
141}
142
143/// A clonable handle to a client requester
144pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
145
146/// Outgoing request to be sent to the client
147#[derive(Debug)]
148pub struct OutgoingRequest {
149    /// The JSON-RPC request ID
150    pub id: RequestId,
151    /// The method name
152    pub method: String,
153    /// The request parameters as JSON
154    pub params: serde_json::Value,
155    /// Channel to send the response back
156    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
157}
158
159/// Sender for outgoing requests to the client
160pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
161
162/// Receiver for outgoing requests (used by transport)
163pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
164
165/// Create a new outgoing request channel
166pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
167    mpsc::channel(buffer)
168}
169
170/// A client requester implementation that sends requests through a channel
171#[derive(Clone)]
172pub struct ChannelClientRequester {
173    request_tx: OutgoingRequestSender,
174    next_id: Arc<AtomicI64>,
175}
176
177impl ChannelClientRequester {
178    /// Create a new channel-based client requester
179    pub fn new(request_tx: OutgoingRequestSender) -> Self {
180        Self {
181            request_tx,
182            next_id: Arc::new(AtomicI64::new(1)),
183        }
184    }
185
186    fn next_request_id(&self) -> RequestId {
187        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
188        RequestId::Number(id)
189    }
190}
191
192#[async_trait]
193impl ClientRequester for ChannelClientRequester {
194    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
195        let id = self.next_request_id();
196        let params_json = serde_json::to_value(&params)
197            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
198
199        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
200
201        let request = OutgoingRequest {
202            id: id.clone(),
203            method: "sampling/createMessage".to_string(),
204            params: params_json,
205            response_tx,
206        };
207
208        self.request_tx
209            .send(request)
210            .await
211            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
212
213        let response = response_rx.await.map_err(|_| {
214            Error::Internal("Failed to receive response: channel closed".to_string())
215        })??;
216
217        serde_json::from_value(response)
218            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
219    }
220
221    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
222        let id = self.next_request_id();
223        let params_json = serde_json::to_value(&params)
224            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
225
226        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
227
228        let request = OutgoingRequest {
229            id: id.clone(),
230            method: "elicitation/create".to_string(),
231            params: params_json,
232            response_tx,
233        };
234
235        self.request_tx
236            .send(request)
237            .await
238            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
239
240        let response = response_rx.await.map_err(|_| {
241            Error::Internal("Failed to receive response: channel closed".to_string())
242        })??;
243
244        serde_json::from_value(response)
245            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
246    }
247}
248
249/// Context for a request, providing progress, cancellation, and client request support
250#[derive(Clone)]
251pub struct RequestContext {
252    /// The request ID
253    request_id: RequestId,
254    /// Progress token (if provided by client)
255    progress_token: Option<ProgressToken>,
256    /// Cancellation flag
257    cancelled: Arc<AtomicBool>,
258    /// Channel for sending notifications
259    notification_tx: Option<NotificationSender>,
260    /// Handle for sending requests to the client (for sampling, etc.)
261    client_requester: Option<ClientRequesterHandle>,
262    /// Extensions for passing data from router/middleware to handlers
263    extensions: Arc<Extensions>,
264    /// Minimum log level set by the client (shared with router for dynamic updates)
265    min_log_level: Option<Arc<RwLock<LogLevel>>>,
266}
267
268/// Type-erased extensions map for passing data to handlers.
269///
270/// Extensions allow router-level state and middleware-injected data to flow
271/// to tool handlers via the `Extension<T>` extractor.
272#[derive(Clone, Default)]
273pub struct Extensions {
274    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
275}
276
277impl Extensions {
278    /// Create an empty extensions map.
279    pub fn new() -> Self {
280        Self::default()
281    }
282
283    /// Insert a value into the extensions map.
284    ///
285    /// If a value of the same type already exists, it is replaced.
286    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
287        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
288    }
289
290    /// Get a reference to a value in the extensions map.
291    ///
292    /// Returns `None` if no value of the given type has been inserted.
293    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
294        self.map
295            .get(&std::any::TypeId::of::<T>())
296            .and_then(|val| val.downcast_ref::<T>())
297    }
298
299    /// Check if the extensions map contains a value of the given type.
300    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
301        self.map.contains_key(&std::any::TypeId::of::<T>())
302    }
303
304    /// Merge another extensions map into this one.
305    ///
306    /// Values from `other` will overwrite existing values of the same type.
307    pub fn merge(&mut self, other: &Extensions) {
308        for (k, v) in &other.map {
309            self.map.insert(*k, v.clone());
310        }
311    }
312
313    /// Returns the number of entries in the extensions map.
314    pub fn len(&self) -> usize {
315        self.map.len()
316    }
317
318    /// Returns `true` if the extensions map contains no entries.
319    pub fn is_empty(&self) -> bool {
320        self.map.is_empty()
321    }
322}
323
324impl std::fmt::Debug for Extensions {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        f.debug_struct("Extensions")
327            .field("len", &self.map.len())
328            .finish()
329    }
330}
331
332impl std::fmt::Debug for RequestContext {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        f.debug_struct("RequestContext")
335            .field("request_id", &self.request_id)
336            .field("progress_token", &self.progress_token)
337            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
338            .finish()
339    }
340}
341
342impl RequestContext {
343    /// Create a new request context
344    pub fn new(request_id: RequestId) -> Self {
345        Self {
346            request_id,
347            progress_token: None,
348            cancelled: Arc::new(AtomicBool::new(false)),
349            notification_tx: None,
350            client_requester: None,
351            extensions: Arc::new(Extensions::new()),
352            min_log_level: None,
353        }
354    }
355
356    /// Set the progress token
357    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
358        self.progress_token = Some(token);
359        self
360    }
361
362    /// Set the notification sender
363    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
364        self.notification_tx = Some(tx);
365        self
366    }
367
368    /// Set the minimum log level for filtering outgoing log notifications
369    ///
370    /// This is shared with the router so that `logging/setLevel` updates
371    /// are immediately visible to all request contexts.
372    pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
373        self.min_log_level = Some(level);
374        self
375    }
376
377    /// Set the client requester for server-to-client requests
378    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
379        self.client_requester = Some(requester);
380        self
381    }
382
383    /// Set the extensions for this request context.
384    ///
385    /// Extensions allow router-level state and middleware data to flow to handlers.
386    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
387        self.extensions = extensions;
388        self
389    }
390
391    /// Get a reference to a value from the extensions map.
392    ///
393    /// Returns `None` if no value of the given type has been inserted.
394    ///
395    /// # Example
396    ///
397    /// ```rust,ignore
398    /// #[derive(Clone)]
399    /// struct CurrentUser { id: String }
400    ///
401    /// // In a handler:
402    /// if let Some(user) = ctx.extension::<CurrentUser>() {
403    ///     println!("User: {}", user.id);
404    /// }
405    /// ```
406    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
407        self.extensions.get::<T>()
408    }
409
410    /// Get a mutable reference to the extensions.
411    ///
412    /// This allows middleware to insert data that handlers can access via
413    /// the `Extension<T>` extractor.
414    pub fn extensions_mut(&mut self) -> &mut Extensions {
415        Arc::make_mut(&mut self.extensions)
416    }
417
418    /// Get a reference to the extensions.
419    pub fn extensions(&self) -> &Extensions {
420        &self.extensions
421    }
422
423    /// Get the request ID
424    pub fn request_id(&self) -> &RequestId {
425        &self.request_id
426    }
427
428    /// Get the progress token (if any)
429    pub fn progress_token(&self) -> Option<&ProgressToken> {
430        self.progress_token.as_ref()
431    }
432
433    /// Check if the request has been cancelled
434    pub fn is_cancelled(&self) -> bool {
435        self.cancelled.load(Ordering::Relaxed)
436    }
437
438    /// Mark the request as cancelled
439    pub fn cancel(&self) {
440        self.cancelled.store(true, Ordering::Relaxed);
441    }
442
443    /// Get a cancellation token that can be shared
444    pub fn cancellation_token(&self) -> CancellationToken {
445        CancellationToken {
446            cancelled: self.cancelled.clone(),
447        }
448    }
449
450    /// Report progress to the client
451    ///
452    /// This is a no-op if no progress token was provided or no notification sender is configured.
453    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
454        let Some(token) = &self.progress_token else {
455            return;
456        };
457        let Some(tx) = &self.notification_tx else {
458            return;
459        };
460
461        let params = ProgressParams {
462            progress_token: token.clone(),
463            progress,
464            total,
465            message: message.map(|s| s.to_string()),
466            meta: None,
467        };
468
469        // Best effort - don't block if channel is full
470        let _ = tx.try_send(ServerNotification::Progress(params));
471    }
472
473    /// Report progress synchronously (non-async version)
474    ///
475    /// This is a no-op if no progress token was provided or no notification sender is configured.
476    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
477        let Some(token) = &self.progress_token else {
478            return;
479        };
480        let Some(tx) = &self.notification_tx else {
481            return;
482        };
483
484        let params = ProgressParams {
485            progress_token: token.clone(),
486            progress,
487            total,
488            message: message.map(|s| s.to_string()),
489            meta: None,
490        };
491
492        let _ = tx.try_send(ServerNotification::Progress(params));
493    }
494
495    /// Send a log message notification to the client
496    ///
497    /// This is a no-op if no notification sender is configured.
498    ///
499    /// # Example
500    ///
501    /// ```rust,ignore
502    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
503    ///
504    /// async fn my_tool(ctx: RequestContext) {
505    ///     ctx.send_log(
506    ///         LoggingMessageParams::new(LogLevel::Info, serde_json::json!("Processing..."))
507    ///             .with_logger("my-tool")
508    ///     );
509    /// }
510    /// ```
511    pub fn send_log(&self, params: LoggingMessageParams) {
512        let Some(tx) = &self.notification_tx else {
513            return;
514        };
515
516        // Filter by minimum log level set via logging/setLevel
517        // LogLevel derives Ord with Emergency < Alert < ... < Debug,
518        // so a message passes if its severity is at least the minimum
519        // (i.e., its ordinal is <= the minimum level's ordinal).
520        if let Some(min_level) = &self.min_log_level
521            && let Ok(min) = min_level.read()
522            && params.level > *min
523        {
524            return;
525        }
526
527        let _ = tx.try_send(ServerNotification::LogMessage(params));
528    }
529
530    /// Check if sampling is available
531    ///
532    /// Returns true if a client requester is configured and the transport
533    /// supports bidirectional communication.
534    pub fn can_sample(&self) -> bool {
535        self.client_requester.is_some()
536    }
537
538    /// Request an LLM completion from the client
539    ///
540    /// This sends a `sampling/createMessage` request to the client and waits
541    /// for the response. The client is expected to forward this to an LLM
542    /// and return the result.
543    ///
544    /// Returns an error if sampling is not available (no client requester configured).
545    ///
546    /// # Example
547    ///
548    /// ```rust,ignore
549    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
550    ///
551    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
552    ///     let params = CreateMessageParams::new(
553    ///         vec![SamplingMessage::user("Summarize: ...")],
554    ///         500,
555    ///     );
556    ///
557    ///     let result = ctx.sample(params).await?;
558    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
559    /// }
560    /// ```
561    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
562        let requester = self.client_requester.as_ref().ok_or_else(|| {
563            Error::Internal("Sampling not available: no client requester configured".to_string())
564        })?;
565
566        requester.sample(params).await
567    }
568
569    /// Check if elicitation is available
570    ///
571    /// Returns true if a client requester is configured and the transport
572    /// supports bidirectional communication. Note that this only checks if
573    /// the mechanism is available, not whether the client supports elicitation.
574    pub fn can_elicit(&self) -> bool {
575        self.client_requester.is_some()
576    }
577
578    /// Request user input via a form from the client
579    ///
580    /// This sends an `elicitation/create` request to the client with a form schema.
581    /// The client renders the form to the user and returns their response.
582    ///
583    /// Returns an error if elicitation is not available (no client requester configured).
584    ///
585    /// # Example
586    ///
587    /// ```rust,ignore
588    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
589    ///
590    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
591    ///     let params = ElicitFormParams {
592    ///         mode: Some(ElicitMode::Form),
593    ///         message: "Please enter your details".to_string(),
594    ///         requested_schema: ElicitFormSchema::new()
595    ///             .string_field("name", Some("Your name"), true),
596    ///         meta: None,
597    ///     };
598    ///
599    ///     let result = ctx.elicit_form(params).await?;
600    ///     match result.action {
601    ///         ElicitAction::Accept => {
602    ///             // Use result.content
603    ///             Ok(CallToolResult::text("Got your input!"))
604    ///         }
605    ///         _ => Ok(CallToolResult::text("User declined"))
606    ///     }
607    /// }
608    /// ```
609    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
610        let requester = self.client_requester.as_ref().ok_or_else(|| {
611            Error::Internal("Elicitation not available: no client requester configured".to_string())
612        })?;
613
614        requester.elicit(ElicitRequestParams::Form(params)).await
615    }
616
617    /// Request user input via URL redirect from the client
618    ///
619    /// This sends an `elicitation/create` request to the client with a URL.
620    /// The client directs the user to the URL for out-of-band input collection.
621    /// The server receives the result via a callback notification.
622    ///
623    /// Returns an error if elicitation is not available (no client requester configured).
624    ///
625    /// # Example
626    ///
627    /// ```rust,ignore
628    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
629    ///
630    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
631    ///     let params = ElicitUrlParams {
632    ///         mode: Some(ElicitMode::Url),
633    ///         elicitation_id: "unique-id-123".to_string(),
634    ///         message: "Please authorize via the link".to_string(),
635    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
636    ///         meta: None,
637    ///     };
638    ///
639    ///     let result = ctx.elicit_url(params).await?;
640    ///     match result.action {
641    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
642    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
643    ///     }
644    /// }
645    /// ```
646    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
647        let requester = self.client_requester.as_ref().ok_or_else(|| {
648            Error::Internal("Elicitation not available: no client requester configured".to_string())
649        })?;
650
651        requester.elicit(ElicitRequestParams::Url(params)).await
652    }
653
654    /// Request simple confirmation from the user.
655    ///
656    /// This is a convenience method for simple yes/no confirmation dialogs.
657    /// It creates an elicitation form with a single boolean "confirm" field
658    /// and returns `true` if the user accepts, `false` otherwise.
659    ///
660    /// Returns an error if elicitation is not available (no client requester configured).
661    ///
662    /// # Example
663    ///
664    /// ```rust,ignore
665    /// use tower_mcp::{RequestContext, CallToolResult};
666    ///
667    /// async fn delete_item(ctx: RequestContext) -> Result<CallToolResult> {
668    ///     let confirmed = ctx.confirm("Are you sure you want to delete this item?").await?;
669    ///     if confirmed {
670    ///         // Perform deletion
671    ///         Ok(CallToolResult::text("Item deleted"))
672    ///     } else {
673    ///         Ok(CallToolResult::text("Deletion cancelled"))
674    ///     }
675    /// }
676    /// ```
677    pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
678        use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
679
680        let params = ElicitFormParams {
681            mode: Some(ElicitMode::Form),
682            message: message.into(),
683            requested_schema: ElicitFormSchema::new().boolean_field_with_default(
684                "confirm",
685                Some("Confirm this action"),
686                true,
687                false,
688            ),
689            meta: None,
690        };
691
692        let result = self.elicit_form(params).await?;
693        Ok(result.action == ElicitAction::Accept)
694    }
695}
696
697/// A token that can be used to check for cancellation
698#[derive(Clone, Debug)]
699pub struct CancellationToken {
700    cancelled: Arc<AtomicBool>,
701}
702
703impl CancellationToken {
704    /// Check if cancellation has been requested
705    pub fn is_cancelled(&self) -> bool {
706        self.cancelled.load(Ordering::Relaxed)
707    }
708
709    /// Request cancellation
710    pub fn cancel(&self) {
711        self.cancelled.store(true, Ordering::Relaxed);
712    }
713}
714
715/// Builder for creating request contexts
716#[derive(Default)]
717pub struct RequestContextBuilder {
718    request_id: Option<RequestId>,
719    progress_token: Option<ProgressToken>,
720    notification_tx: Option<NotificationSender>,
721    client_requester: Option<ClientRequesterHandle>,
722    min_log_level: Option<Arc<RwLock<LogLevel>>>,
723}
724
725impl RequestContextBuilder {
726    /// Create a new builder
727    pub fn new() -> Self {
728        Self::default()
729    }
730
731    /// Set the request ID
732    pub fn request_id(mut self, id: RequestId) -> Self {
733        self.request_id = Some(id);
734        self
735    }
736
737    /// Set the progress token
738    pub fn progress_token(mut self, token: ProgressToken) -> Self {
739        self.progress_token = Some(token);
740        self
741    }
742
743    /// Set the notification sender
744    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
745        self.notification_tx = Some(tx);
746        self
747    }
748
749    /// Set the client requester for server-to-client requests
750    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
751        self.client_requester = Some(requester);
752        self
753    }
754
755    /// Set the minimum log level for filtering
756    pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
757        self.min_log_level = Some(level);
758        self
759    }
760
761    /// Build the request context
762    ///
763    /// Panics if request_id is not set.
764    pub fn build(self) -> RequestContext {
765        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
766        if let Some(token) = self.progress_token {
767            ctx = ctx.with_progress_token(token);
768        }
769        if let Some(tx) = self.notification_tx {
770            ctx = ctx.with_notification_sender(tx);
771        }
772        if let Some(requester) = self.client_requester {
773            ctx = ctx.with_client_requester(requester);
774        }
775        if let Some(level) = self.min_log_level {
776            ctx = ctx.with_min_log_level(level);
777        }
778        ctx
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785
786    #[test]
787    fn test_cancellation() {
788        let ctx = RequestContext::new(RequestId::Number(1));
789        assert!(!ctx.is_cancelled());
790
791        let token = ctx.cancellation_token();
792        assert!(!token.is_cancelled());
793
794        ctx.cancel();
795        assert!(ctx.is_cancelled());
796        assert!(token.is_cancelled());
797    }
798
799    #[tokio::test]
800    async fn test_progress_reporting() {
801        let (tx, mut rx) = notification_channel(10);
802
803        let ctx = RequestContext::new(RequestId::Number(1))
804            .with_progress_token(ProgressToken::Number(42))
805            .with_notification_sender(tx);
806
807        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
808            .await;
809
810        let notification = rx.recv().await.unwrap();
811        match notification {
812            ServerNotification::Progress(params) => {
813                assert_eq!(params.progress, 50.0);
814                assert_eq!(params.total, Some(100.0));
815                assert_eq!(params.message.as_deref(), Some("Halfway"));
816            }
817            _ => panic!("Expected Progress notification"),
818        }
819    }
820
821    #[tokio::test]
822    async fn test_progress_no_token() {
823        let (tx, mut rx) = notification_channel(10);
824
825        // No progress token - should be a no-op
826        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
827
828        ctx.report_progress(50.0, Some(100.0), None).await;
829
830        // Channel should be empty
831        assert!(rx.try_recv().is_err());
832    }
833
834    #[test]
835    fn test_builder() {
836        let (tx, _rx) = notification_channel(10);
837
838        let ctx = RequestContextBuilder::new()
839            .request_id(RequestId::String("req-1".to_string()))
840            .progress_token(ProgressToken::String("prog-1".to_string()))
841            .notification_sender(tx)
842            .build();
843
844        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
845        assert!(ctx.progress_token().is_some());
846    }
847
848    #[test]
849    fn test_can_sample_without_requester() {
850        let ctx = RequestContext::new(RequestId::Number(1));
851        assert!(!ctx.can_sample());
852    }
853
854    #[test]
855    fn test_can_sample_with_requester() {
856        let (request_tx, _rx) = outgoing_request_channel(10);
857        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
858
859        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
860        assert!(ctx.can_sample());
861    }
862
863    #[tokio::test]
864    async fn test_sample_without_requester_fails() {
865        use crate::protocol::{CreateMessageParams, SamplingMessage};
866
867        let ctx = RequestContext::new(RequestId::Number(1));
868        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
869
870        let result = ctx.sample(params).await;
871        assert!(result.is_err());
872        assert!(
873            result
874                .unwrap_err()
875                .to_string()
876                .contains("Sampling not available")
877        );
878    }
879
880    #[test]
881    fn test_builder_with_client_requester() {
882        let (request_tx, _rx) = outgoing_request_channel(10);
883        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
884
885        let ctx = RequestContextBuilder::new()
886            .request_id(RequestId::Number(1))
887            .client_requester(requester)
888            .build();
889
890        assert!(ctx.can_sample());
891    }
892
893    #[test]
894    fn test_can_elicit_without_requester() {
895        let ctx = RequestContext::new(RequestId::Number(1));
896        assert!(!ctx.can_elicit());
897    }
898
899    #[test]
900    fn test_can_elicit_with_requester() {
901        let (request_tx, _rx) = outgoing_request_channel(10);
902        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
903
904        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
905        assert!(ctx.can_elicit());
906    }
907
908    #[tokio::test]
909    async fn test_elicit_form_without_requester_fails() {
910        use crate::protocol::{ElicitFormSchema, ElicitMode};
911
912        let ctx = RequestContext::new(RequestId::Number(1));
913        let params = ElicitFormParams {
914            mode: Some(ElicitMode::Form),
915            message: "Enter details".to_string(),
916            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
917            meta: None,
918        };
919
920        let result = ctx.elicit_form(params).await;
921        assert!(result.is_err());
922        assert!(
923            result
924                .unwrap_err()
925                .to_string()
926                .contains("Elicitation not available")
927        );
928    }
929
930    #[tokio::test]
931    async fn test_elicit_url_without_requester_fails() {
932        use crate::protocol::ElicitMode;
933
934        let ctx = RequestContext::new(RequestId::Number(1));
935        let params = ElicitUrlParams {
936            mode: Some(ElicitMode::Url),
937            elicitation_id: "test-123".to_string(),
938            message: "Please authorize".to_string(),
939            url: "https://example.com/auth".to_string(),
940            meta: None,
941        };
942
943        let result = ctx.elicit_url(params).await;
944        assert!(result.is_err());
945        assert!(
946            result
947                .unwrap_err()
948                .to_string()
949                .contains("Elicitation not available")
950        );
951    }
952
953    #[tokio::test]
954    async fn test_confirm_without_requester_fails() {
955        let ctx = RequestContext::new(RequestId::Number(1));
956
957        let result = ctx.confirm("Are you sure?").await;
958        assert!(result.is_err());
959        assert!(
960            result
961                .unwrap_err()
962                .to_string()
963                .contains("Elicitation not available")
964        );
965    }
966
967    #[tokio::test]
968    async fn test_send_log_filtered_by_level() {
969        let (tx, mut rx) = notification_channel(10);
970        let min_level = Arc::new(RwLock::new(LogLevel::Warning));
971
972        let ctx = RequestContext::new(RequestId::Number(1))
973            .with_notification_sender(tx)
974            .with_min_log_level(min_level.clone());
975
976        // Error is more severe than Warning — should pass through
977        ctx.send_log(LoggingMessageParams::new(
978            LogLevel::Error,
979            serde_json::Value::Null,
980        ));
981        let msg = rx.try_recv();
982        assert!(msg.is_ok(), "Error should pass through Warning filter");
983
984        // Warning is equal to min level — should pass through
985        ctx.send_log(LoggingMessageParams::new(
986            LogLevel::Warning,
987            serde_json::Value::Null,
988        ));
989        let msg = rx.try_recv();
990        assert!(msg.is_ok(), "Warning should pass through Warning filter");
991
992        // Info is less severe than Warning — should be filtered
993        ctx.send_log(LoggingMessageParams::new(
994            LogLevel::Info,
995            serde_json::Value::Null,
996        ));
997        let msg = rx.try_recv();
998        assert!(msg.is_err(), "Info should be filtered by Warning filter");
999
1000        // Debug is less severe than Warning — should be filtered
1001        ctx.send_log(LoggingMessageParams::new(
1002            LogLevel::Debug,
1003            serde_json::Value::Null,
1004        ));
1005        let msg = rx.try_recv();
1006        assert!(msg.is_err(), "Debug should be filtered by Warning filter");
1007    }
1008
1009    #[tokio::test]
1010    async fn test_send_log_level_updates_dynamically() {
1011        let (tx, mut rx) = notification_channel(10);
1012        let min_level = Arc::new(RwLock::new(LogLevel::Error));
1013
1014        let ctx = RequestContext::new(RequestId::Number(1))
1015            .with_notification_sender(tx)
1016            .with_min_log_level(min_level.clone());
1017
1018        // Info should be filtered at Error level
1019        ctx.send_log(LoggingMessageParams::new(
1020            LogLevel::Info,
1021            serde_json::Value::Null,
1022        ));
1023        assert!(
1024            rx.try_recv().is_err(),
1025            "Info should be filtered at Error level"
1026        );
1027
1028        // Dynamically update to Debug (most permissive)
1029        *min_level.write().unwrap() = LogLevel::Debug;
1030
1031        // Now Info should pass through
1032        ctx.send_log(LoggingMessageParams::new(
1033            LogLevel::Info,
1034            serde_json::Value::Null,
1035        ));
1036        assert!(
1037            rx.try_recv().is_ok(),
1038            "Info should pass through after level changed to Debug"
1039        );
1040    }
1041
1042    #[tokio::test]
1043    async fn test_send_log_no_min_level_sends_all() {
1044        let (tx, mut rx) = notification_channel(10);
1045
1046        // No min_log_level set — all messages should pass through
1047        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1048
1049        ctx.send_log(LoggingMessageParams::new(
1050            LogLevel::Debug,
1051            serde_json::Value::Null,
1052        ));
1053        assert!(
1054            rx.try_recv().is_ok(),
1055            "Debug should pass when no min level is set"
1056        );
1057    }
1058}