1use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86pub enum ServerNotification {
87 Progress(ProgressParams),
89 LogMessage(LoggingMessageParams),
91 ResourceUpdated {
93 uri: String,
95 },
96 ResourcesListChanged,
98}
99
100pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108 mpsc::channel(buffer)
109}
110
111#[async_trait]
121pub trait ClientRequester: Send + Sync {
122 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139#[derive(Debug)]
141pub struct OutgoingRequest {
142 pub id: RequestId,
144 pub method: String,
146 pub params: serde_json::Value,
148 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160 mpsc::channel(buffer)
161}
162
163#[derive(Clone)]
165pub struct ChannelClientRequester {
166 request_tx: OutgoingRequestSender,
167 next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171 pub fn new(request_tx: OutgoingRequestSender) -> Self {
173 Self {
174 request_tx,
175 next_id: Arc::new(AtomicI64::new(1)),
176 }
177 }
178
179 fn next_request_id(&self) -> RequestId {
180 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181 RequestId::Number(id)
182 }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188 let id = self.next_request_id();
189 let params_json = serde_json::to_value(¶ms)
190 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194 let request = OutgoingRequest {
195 id: id.clone(),
196 method: "sampling/createMessage".to_string(),
197 params: params_json,
198 response_tx,
199 };
200
201 self.request_tx
202 .send(request)
203 .await
204 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206 let response = response_rx.await.map_err(|_| {
207 Error::Internal("Failed to receive response: channel closed".to_string())
208 })??;
209
210 serde_json::from_value(response)
211 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212 }
213
214 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215 let id = self.next_request_id();
216 let params_json = serde_json::to_value(¶ms)
217 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221 let request = OutgoingRequest {
222 id: id.clone(),
223 method: "elicitation/create".to_string(),
224 params: params_json,
225 response_tx,
226 };
227
228 self.request_tx
229 .send(request)
230 .await
231 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233 let response = response_rx.await.map_err(|_| {
234 Error::Internal("Failed to receive response: channel closed".to_string())
235 })??;
236
237 serde_json::from_value(response)
238 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239 }
240}
241
242#[derive(Clone)]
244pub struct RequestContext {
245 request_id: RequestId,
247 progress_token: Option<ProgressToken>,
249 cancelled: Arc<AtomicBool>,
251 notification_tx: Option<NotificationSender>,
253 client_requester: Option<ClientRequesterHandle>,
255}
256
257impl std::fmt::Debug for RequestContext {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("RequestContext")
260 .field("request_id", &self.request_id)
261 .field("progress_token", &self.progress_token)
262 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
263 .finish()
264 }
265}
266
267impl RequestContext {
268 pub fn new(request_id: RequestId) -> Self {
270 Self {
271 request_id,
272 progress_token: None,
273 cancelled: Arc::new(AtomicBool::new(false)),
274 notification_tx: None,
275 client_requester: None,
276 }
277 }
278
279 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
281 self.progress_token = Some(token);
282 self
283 }
284
285 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
287 self.notification_tx = Some(tx);
288 self
289 }
290
291 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
293 self.client_requester = Some(requester);
294 self
295 }
296
297 pub fn request_id(&self) -> &RequestId {
299 &self.request_id
300 }
301
302 pub fn progress_token(&self) -> Option<&ProgressToken> {
304 self.progress_token.as_ref()
305 }
306
307 pub fn is_cancelled(&self) -> bool {
309 self.cancelled.load(Ordering::Relaxed)
310 }
311
312 pub fn cancel(&self) {
314 self.cancelled.store(true, Ordering::Relaxed);
315 }
316
317 pub fn cancellation_token(&self) -> CancellationToken {
319 CancellationToken {
320 cancelled: self.cancelled.clone(),
321 }
322 }
323
324 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
328 let Some(token) = &self.progress_token else {
329 return;
330 };
331 let Some(tx) = &self.notification_tx else {
332 return;
333 };
334
335 let params = ProgressParams {
336 progress_token: token.clone(),
337 progress,
338 total,
339 message: message.map(|s| s.to_string()),
340 };
341
342 let _ = tx.try_send(ServerNotification::Progress(params));
344 }
345
346 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
350 let Some(token) = &self.progress_token else {
351 return;
352 };
353 let Some(tx) = &self.notification_tx else {
354 return;
355 };
356
357 let params = ProgressParams {
358 progress_token: token.clone(),
359 progress,
360 total,
361 message: message.map(|s| s.to_string()),
362 };
363
364 let _ = tx.try_send(ServerNotification::Progress(params));
365 }
366
367 pub fn send_log(&self, params: LoggingMessageParams) {
385 let Some(tx) = &self.notification_tx else {
386 return;
387 };
388
389 let _ = tx.try_send(ServerNotification::LogMessage(params));
390 }
391
392 pub fn can_sample(&self) -> bool {
397 self.client_requester.is_some()
398 }
399
400 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
424 let requester = self.client_requester.as_ref().ok_or_else(|| {
425 Error::Internal("Sampling not available: no client requester configured".to_string())
426 })?;
427
428 requester.sample(params).await
429 }
430
431 pub fn can_elicit(&self) -> bool {
437 self.client_requester.is_some()
438 }
439
440 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
472 let requester = self.client_requester.as_ref().ok_or_else(|| {
473 Error::Internal("Elicitation not available: no client requester configured".to_string())
474 })?;
475
476 requester.elicit(ElicitRequestParams::Form(params)).await
477 }
478
479 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
509 let requester = self.client_requester.as_ref().ok_or_else(|| {
510 Error::Internal("Elicitation not available: no client requester configured".to_string())
511 })?;
512
513 requester.elicit(ElicitRequestParams::Url(params)).await
514 }
515}
516
517#[derive(Clone, Debug)]
519pub struct CancellationToken {
520 cancelled: Arc<AtomicBool>,
521}
522
523impl CancellationToken {
524 pub fn is_cancelled(&self) -> bool {
526 self.cancelled.load(Ordering::Relaxed)
527 }
528
529 pub fn cancel(&self) {
531 self.cancelled.store(true, Ordering::Relaxed);
532 }
533}
534
535#[derive(Default)]
537pub struct RequestContextBuilder {
538 request_id: Option<RequestId>,
539 progress_token: Option<ProgressToken>,
540 notification_tx: Option<NotificationSender>,
541 client_requester: Option<ClientRequesterHandle>,
542}
543
544impl RequestContextBuilder {
545 pub fn new() -> Self {
547 Self::default()
548 }
549
550 pub fn request_id(mut self, id: RequestId) -> Self {
552 self.request_id = Some(id);
553 self
554 }
555
556 pub fn progress_token(mut self, token: ProgressToken) -> Self {
558 self.progress_token = Some(token);
559 self
560 }
561
562 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
564 self.notification_tx = Some(tx);
565 self
566 }
567
568 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
570 self.client_requester = Some(requester);
571 self
572 }
573
574 pub fn build(self) -> RequestContext {
578 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
579 if let Some(token) = self.progress_token {
580 ctx = ctx.with_progress_token(token);
581 }
582 if let Some(tx) = self.notification_tx {
583 ctx = ctx.with_notification_sender(tx);
584 }
585 if let Some(requester) = self.client_requester {
586 ctx = ctx.with_client_requester(requester);
587 }
588 ctx
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_cancellation() {
598 let ctx = RequestContext::new(RequestId::Number(1));
599 assert!(!ctx.is_cancelled());
600
601 let token = ctx.cancellation_token();
602 assert!(!token.is_cancelled());
603
604 ctx.cancel();
605 assert!(ctx.is_cancelled());
606 assert!(token.is_cancelled());
607 }
608
609 #[tokio::test]
610 async fn test_progress_reporting() {
611 let (tx, mut rx) = notification_channel(10);
612
613 let ctx = RequestContext::new(RequestId::Number(1))
614 .with_progress_token(ProgressToken::Number(42))
615 .with_notification_sender(tx);
616
617 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
618 .await;
619
620 let notification = rx.recv().await.unwrap();
621 match notification {
622 ServerNotification::Progress(params) => {
623 assert_eq!(params.progress, 50.0);
624 assert_eq!(params.total, Some(100.0));
625 assert_eq!(params.message.as_deref(), Some("Halfway"));
626 }
627 _ => panic!("Expected Progress notification"),
628 }
629 }
630
631 #[tokio::test]
632 async fn test_progress_no_token() {
633 let (tx, mut rx) = notification_channel(10);
634
635 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
637
638 ctx.report_progress(50.0, Some(100.0), None).await;
639
640 assert!(rx.try_recv().is_err());
642 }
643
644 #[test]
645 fn test_builder() {
646 let (tx, _rx) = notification_channel(10);
647
648 let ctx = RequestContextBuilder::new()
649 .request_id(RequestId::String("req-1".to_string()))
650 .progress_token(ProgressToken::String("prog-1".to_string()))
651 .notification_sender(tx)
652 .build();
653
654 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
655 assert!(ctx.progress_token().is_some());
656 }
657
658 #[test]
659 fn test_can_sample_without_requester() {
660 let ctx = RequestContext::new(RequestId::Number(1));
661 assert!(!ctx.can_sample());
662 }
663
664 #[test]
665 fn test_can_sample_with_requester() {
666 let (request_tx, _rx) = outgoing_request_channel(10);
667 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
668
669 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
670 assert!(ctx.can_sample());
671 }
672
673 #[tokio::test]
674 async fn test_sample_without_requester_fails() {
675 use crate::protocol::{CreateMessageParams, SamplingMessage};
676
677 let ctx = RequestContext::new(RequestId::Number(1));
678 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
679
680 let result = ctx.sample(params).await;
681 assert!(result.is_err());
682 assert!(
683 result
684 .unwrap_err()
685 .to_string()
686 .contains("Sampling not available")
687 );
688 }
689
690 #[test]
691 fn test_builder_with_client_requester() {
692 let (request_tx, _rx) = outgoing_request_channel(10);
693 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
694
695 let ctx = RequestContextBuilder::new()
696 .request_id(RequestId::Number(1))
697 .client_requester(requester)
698 .build();
699
700 assert!(ctx.can_sample());
701 }
702
703 #[test]
704 fn test_can_elicit_without_requester() {
705 let ctx = RequestContext::new(RequestId::Number(1));
706 assert!(!ctx.can_elicit());
707 }
708
709 #[test]
710 fn test_can_elicit_with_requester() {
711 let (request_tx, _rx) = outgoing_request_channel(10);
712 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
713
714 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
715 assert!(ctx.can_elicit());
716 }
717
718 #[tokio::test]
719 async fn test_elicit_form_without_requester_fails() {
720 use crate::protocol::{ElicitFormSchema, ElicitMode};
721
722 let ctx = RequestContext::new(RequestId::Number(1));
723 let params = ElicitFormParams {
724 mode: ElicitMode::Form,
725 message: "Enter details".to_string(),
726 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
727 meta: None,
728 };
729
730 let result = ctx.elicit_form(params).await;
731 assert!(result.is_err());
732 assert!(
733 result
734 .unwrap_err()
735 .to_string()
736 .contains("Elicitation not available")
737 );
738 }
739
740 #[tokio::test]
741 async fn test_elicit_url_without_requester_fails() {
742 use crate::protocol::ElicitMode;
743
744 let ctx = RequestContext::new(RequestId::Number(1));
745 let params = ElicitUrlParams {
746 mode: ElicitMode::Url,
747 elicitation_id: "test-123".to_string(),
748 message: "Please authorize".to_string(),
749 url: "https://example.com/auth".to_string(),
750 meta: None,
751 };
752
753 let result = ctx.elicit_url(params).await;
754 assert!(result.is_err());
755 assert!(
756 result
757 .unwrap_err()
758 .to_string()
759 .contains("Elicitation not available")
760 );
761 }
762}