1use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86pub enum ServerNotification {
87 Progress(ProgressParams),
89 LogMessage(LoggingMessageParams),
91 ResourceUpdated {
93 uri: String,
95 },
96 ResourcesListChanged,
98}
99
100pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108 mpsc::channel(buffer)
109}
110
111#[async_trait]
121pub trait ClientRequester: Send + Sync {
122 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139#[derive(Debug)]
141pub struct OutgoingRequest {
142 pub id: RequestId,
144 pub method: String,
146 pub params: serde_json::Value,
148 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160 mpsc::channel(buffer)
161}
162
163#[derive(Clone)]
165pub struct ChannelClientRequester {
166 request_tx: OutgoingRequestSender,
167 next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171 pub fn new(request_tx: OutgoingRequestSender) -> Self {
173 Self {
174 request_tx,
175 next_id: Arc::new(AtomicI64::new(1)),
176 }
177 }
178
179 fn next_request_id(&self) -> RequestId {
180 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181 RequestId::Number(id)
182 }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188 let id = self.next_request_id();
189 let params_json = serde_json::to_value(¶ms)
190 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194 let request = OutgoingRequest {
195 id: id.clone(),
196 method: "sampling/createMessage".to_string(),
197 params: params_json,
198 response_tx,
199 };
200
201 self.request_tx
202 .send(request)
203 .await
204 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206 let response = response_rx.await.map_err(|_| {
207 Error::Internal("Failed to receive response: channel closed".to_string())
208 })??;
209
210 serde_json::from_value(response)
211 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212 }
213
214 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215 let id = self.next_request_id();
216 let params_json = serde_json::to_value(¶ms)
217 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221 let request = OutgoingRequest {
222 id: id.clone(),
223 method: "elicitation/create".to_string(),
224 params: params_json,
225 response_tx,
226 };
227
228 self.request_tx
229 .send(request)
230 .await
231 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233 let response = response_rx.await.map_err(|_| {
234 Error::Internal("Failed to receive response: channel closed".to_string())
235 })??;
236
237 serde_json::from_value(response)
238 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239 }
240}
241
242#[derive(Clone)]
244pub struct RequestContext {
245 request_id: RequestId,
247 progress_token: Option<ProgressToken>,
249 cancelled: Arc<AtomicBool>,
251 notification_tx: Option<NotificationSender>,
253 client_requester: Option<ClientRequesterHandle>,
255 extensions: Arc<Extensions>,
257}
258
259#[derive(Clone, Default)]
264pub struct Extensions {
265 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
266}
267
268impl Extensions {
269 pub fn new() -> Self {
271 Self::default()
272 }
273
274 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
278 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
279 }
280
281 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
285 self.map
286 .get(&std::any::TypeId::of::<T>())
287 .and_then(|val| val.downcast_ref::<T>())
288 }
289
290 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
292 self.map.contains_key(&std::any::TypeId::of::<T>())
293 }
294
295 pub fn merge(&mut self, other: &Extensions) {
299 for (k, v) in &other.map {
300 self.map.insert(*k, v.clone());
301 }
302 }
303}
304
305impl std::fmt::Debug for Extensions {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 f.debug_struct("Extensions")
308 .field("len", &self.map.len())
309 .finish()
310 }
311}
312
313impl std::fmt::Debug for RequestContext {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("RequestContext")
316 .field("request_id", &self.request_id)
317 .field("progress_token", &self.progress_token)
318 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
319 .finish()
320 }
321}
322
323impl RequestContext {
324 pub fn new(request_id: RequestId) -> Self {
326 Self {
327 request_id,
328 progress_token: None,
329 cancelled: Arc::new(AtomicBool::new(false)),
330 notification_tx: None,
331 client_requester: None,
332 extensions: Arc::new(Extensions::new()),
333 }
334 }
335
336 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
338 self.progress_token = Some(token);
339 self
340 }
341
342 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
344 self.notification_tx = Some(tx);
345 self
346 }
347
348 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
350 self.client_requester = Some(requester);
351 self
352 }
353
354 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
358 self.extensions = extensions;
359 self
360 }
361
362 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
378 self.extensions.get::<T>()
379 }
380
381 pub fn extensions_mut(&mut self) -> &mut Extensions {
386 Arc::make_mut(&mut self.extensions)
387 }
388
389 pub fn extensions(&self) -> &Extensions {
391 &self.extensions
392 }
393
394 pub fn request_id(&self) -> &RequestId {
396 &self.request_id
397 }
398
399 pub fn progress_token(&self) -> Option<&ProgressToken> {
401 self.progress_token.as_ref()
402 }
403
404 pub fn is_cancelled(&self) -> bool {
406 self.cancelled.load(Ordering::Relaxed)
407 }
408
409 pub fn cancel(&self) {
411 self.cancelled.store(true, Ordering::Relaxed);
412 }
413
414 pub fn cancellation_token(&self) -> CancellationToken {
416 CancellationToken {
417 cancelled: self.cancelled.clone(),
418 }
419 }
420
421 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
425 let Some(token) = &self.progress_token else {
426 return;
427 };
428 let Some(tx) = &self.notification_tx else {
429 return;
430 };
431
432 let params = ProgressParams {
433 progress_token: token.clone(),
434 progress,
435 total,
436 message: message.map(|s| s.to_string()),
437 };
438
439 let _ = tx.try_send(ServerNotification::Progress(params));
441 }
442
443 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
447 let Some(token) = &self.progress_token else {
448 return;
449 };
450 let Some(tx) = &self.notification_tx else {
451 return;
452 };
453
454 let params = ProgressParams {
455 progress_token: token.clone(),
456 progress,
457 total,
458 message: message.map(|s| s.to_string()),
459 };
460
461 let _ = tx.try_send(ServerNotification::Progress(params));
462 }
463
464 pub fn send_log(&self, params: LoggingMessageParams) {
482 let Some(tx) = &self.notification_tx else {
483 return;
484 };
485
486 let _ = tx.try_send(ServerNotification::LogMessage(params));
487 }
488
489 pub fn can_sample(&self) -> bool {
494 self.client_requester.is_some()
495 }
496
497 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
521 let requester = self.client_requester.as_ref().ok_or_else(|| {
522 Error::Internal("Sampling not available: no client requester configured".to_string())
523 })?;
524
525 requester.sample(params).await
526 }
527
528 pub fn can_elicit(&self) -> bool {
534 self.client_requester.is_some()
535 }
536
537 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
569 let requester = self.client_requester.as_ref().ok_or_else(|| {
570 Error::Internal("Elicitation not available: no client requester configured".to_string())
571 })?;
572
573 requester.elicit(ElicitRequestParams::Form(params)).await
574 }
575
576 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
606 let requester = self.client_requester.as_ref().ok_or_else(|| {
607 Error::Internal("Elicitation not available: no client requester configured".to_string())
608 })?;
609
610 requester.elicit(ElicitRequestParams::Url(params)).await
611 }
612}
613
614#[derive(Clone, Debug)]
616pub struct CancellationToken {
617 cancelled: Arc<AtomicBool>,
618}
619
620impl CancellationToken {
621 pub fn is_cancelled(&self) -> bool {
623 self.cancelled.load(Ordering::Relaxed)
624 }
625
626 pub fn cancel(&self) {
628 self.cancelled.store(true, Ordering::Relaxed);
629 }
630}
631
632#[derive(Default)]
634pub struct RequestContextBuilder {
635 request_id: Option<RequestId>,
636 progress_token: Option<ProgressToken>,
637 notification_tx: Option<NotificationSender>,
638 client_requester: Option<ClientRequesterHandle>,
639}
640
641impl RequestContextBuilder {
642 pub fn new() -> Self {
644 Self::default()
645 }
646
647 pub fn request_id(mut self, id: RequestId) -> Self {
649 self.request_id = Some(id);
650 self
651 }
652
653 pub fn progress_token(mut self, token: ProgressToken) -> Self {
655 self.progress_token = Some(token);
656 self
657 }
658
659 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
661 self.notification_tx = Some(tx);
662 self
663 }
664
665 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
667 self.client_requester = Some(requester);
668 self
669 }
670
671 pub fn build(self) -> RequestContext {
675 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
676 if let Some(token) = self.progress_token {
677 ctx = ctx.with_progress_token(token);
678 }
679 if let Some(tx) = self.notification_tx {
680 ctx = ctx.with_notification_sender(tx);
681 }
682 if let Some(requester) = self.client_requester {
683 ctx = ctx.with_client_requester(requester);
684 }
685 ctx
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_cancellation() {
695 let ctx = RequestContext::new(RequestId::Number(1));
696 assert!(!ctx.is_cancelled());
697
698 let token = ctx.cancellation_token();
699 assert!(!token.is_cancelled());
700
701 ctx.cancel();
702 assert!(ctx.is_cancelled());
703 assert!(token.is_cancelled());
704 }
705
706 #[tokio::test]
707 async fn test_progress_reporting() {
708 let (tx, mut rx) = notification_channel(10);
709
710 let ctx = RequestContext::new(RequestId::Number(1))
711 .with_progress_token(ProgressToken::Number(42))
712 .with_notification_sender(tx);
713
714 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
715 .await;
716
717 let notification = rx.recv().await.unwrap();
718 match notification {
719 ServerNotification::Progress(params) => {
720 assert_eq!(params.progress, 50.0);
721 assert_eq!(params.total, Some(100.0));
722 assert_eq!(params.message.as_deref(), Some("Halfway"));
723 }
724 _ => panic!("Expected Progress notification"),
725 }
726 }
727
728 #[tokio::test]
729 async fn test_progress_no_token() {
730 let (tx, mut rx) = notification_channel(10);
731
732 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
734
735 ctx.report_progress(50.0, Some(100.0), None).await;
736
737 assert!(rx.try_recv().is_err());
739 }
740
741 #[test]
742 fn test_builder() {
743 let (tx, _rx) = notification_channel(10);
744
745 let ctx = RequestContextBuilder::new()
746 .request_id(RequestId::String("req-1".to_string()))
747 .progress_token(ProgressToken::String("prog-1".to_string()))
748 .notification_sender(tx)
749 .build();
750
751 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
752 assert!(ctx.progress_token().is_some());
753 }
754
755 #[test]
756 fn test_can_sample_without_requester() {
757 let ctx = RequestContext::new(RequestId::Number(1));
758 assert!(!ctx.can_sample());
759 }
760
761 #[test]
762 fn test_can_sample_with_requester() {
763 let (request_tx, _rx) = outgoing_request_channel(10);
764 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
765
766 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
767 assert!(ctx.can_sample());
768 }
769
770 #[tokio::test]
771 async fn test_sample_without_requester_fails() {
772 use crate::protocol::{CreateMessageParams, SamplingMessage};
773
774 let ctx = RequestContext::new(RequestId::Number(1));
775 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
776
777 let result = ctx.sample(params).await;
778 assert!(result.is_err());
779 assert!(
780 result
781 .unwrap_err()
782 .to_string()
783 .contains("Sampling not available")
784 );
785 }
786
787 #[test]
788 fn test_builder_with_client_requester() {
789 let (request_tx, _rx) = outgoing_request_channel(10);
790 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
791
792 let ctx = RequestContextBuilder::new()
793 .request_id(RequestId::Number(1))
794 .client_requester(requester)
795 .build();
796
797 assert!(ctx.can_sample());
798 }
799
800 #[test]
801 fn test_can_elicit_without_requester() {
802 let ctx = RequestContext::new(RequestId::Number(1));
803 assert!(!ctx.can_elicit());
804 }
805
806 #[test]
807 fn test_can_elicit_with_requester() {
808 let (request_tx, _rx) = outgoing_request_channel(10);
809 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
810
811 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
812 assert!(ctx.can_elicit());
813 }
814
815 #[tokio::test]
816 async fn test_elicit_form_without_requester_fails() {
817 use crate::protocol::{ElicitFormSchema, ElicitMode};
818
819 let ctx = RequestContext::new(RequestId::Number(1));
820 let params = ElicitFormParams {
821 mode: ElicitMode::Form,
822 message: "Enter details".to_string(),
823 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
824 meta: None,
825 };
826
827 let result = ctx.elicit_form(params).await;
828 assert!(result.is_err());
829 assert!(
830 result
831 .unwrap_err()
832 .to_string()
833 .contains("Elicitation not available")
834 );
835 }
836
837 #[tokio::test]
838 async fn test_elicit_url_without_requester_fails() {
839 use crate::protocol::ElicitMode;
840
841 let ctx = RequestContext::new(RequestId::Number(1));
842 let params = ElicitUrlParams {
843 mode: ElicitMode::Url,
844 elicitation_id: "test-123".to_string(),
845 message: "Please authorize".to_string(),
846 url: "https://example.com/auth".to_string(),
847 meta: None,
848 };
849
850 let result = ctx.elicit_url(params).await;
851 assert!(result.is_err());
852 assert!(
853 result
854 .unwrap_err()
855 .to_string()
856 .contains("Elicitation not available")
857 );
858 }
859}