1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LogLevel, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86pub enum ServerNotification {
87 Progress(ProgressParams),
89 LogMessage(LoggingMessageParams),
91 ResourceUpdated {
93 uri: String,
95 },
96 ResourcesListChanged,
98 ToolsListChanged,
100 PromptsListChanged,
102 TaskStatusChanged(crate::protocol::TaskStatusParams),
104}
105
106pub type NotificationSender = mpsc::Sender<ServerNotification>;
108
109pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
111
112pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
114 mpsc::channel(buffer)
115}
116
117#[async_trait]
127pub trait ClientRequester: Send + Sync {
128 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
132
133 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
140}
141
142pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
144
145#[derive(Debug)]
147pub struct OutgoingRequest {
148 pub id: RequestId,
150 pub method: String,
152 pub params: serde_json::Value,
154 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
156}
157
158pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
160
161pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
163
164pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
166 mpsc::channel(buffer)
167}
168
169#[derive(Clone)]
171pub struct ChannelClientRequester {
172 request_tx: OutgoingRequestSender,
173 next_id: Arc<AtomicI64>,
174}
175
176impl ChannelClientRequester {
177 pub fn new(request_tx: OutgoingRequestSender) -> Self {
179 Self {
180 request_tx,
181 next_id: Arc::new(AtomicI64::new(1)),
182 }
183 }
184
185 fn next_request_id(&self) -> RequestId {
186 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
187 RequestId::Number(id)
188 }
189}
190
191#[async_trait]
192impl ClientRequester for ChannelClientRequester {
193 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
194 let id = self.next_request_id();
195 let params_json = serde_json::to_value(¶ms)
196 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
197
198 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
199
200 let request = OutgoingRequest {
201 id: id.clone(),
202 method: "sampling/createMessage".to_string(),
203 params: params_json,
204 response_tx,
205 };
206
207 self.request_tx
208 .send(request)
209 .await
210 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
211
212 let response = response_rx.await.map_err(|_| {
213 Error::Internal("Failed to receive response: channel closed".to_string())
214 })??;
215
216 serde_json::from_value(response)
217 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
218 }
219
220 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
221 let id = self.next_request_id();
222 let params_json = serde_json::to_value(¶ms)
223 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
224
225 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
226
227 let request = OutgoingRequest {
228 id: id.clone(),
229 method: "elicitation/create".to_string(),
230 params: params_json,
231 response_tx,
232 };
233
234 self.request_tx
235 .send(request)
236 .await
237 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
238
239 let response = response_rx.await.map_err(|_| {
240 Error::Internal("Failed to receive response: channel closed".to_string())
241 })??;
242
243 serde_json::from_value(response)
244 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
245 }
246}
247
248#[derive(Clone)]
250pub struct RequestContext {
251 request_id: RequestId,
253 progress_token: Option<ProgressToken>,
255 cancelled: Arc<AtomicBool>,
257 notification_tx: Option<NotificationSender>,
259 client_requester: Option<ClientRequesterHandle>,
261 extensions: Arc<Extensions>,
263 min_log_level: Option<Arc<RwLock<LogLevel>>>,
265}
266
267#[derive(Clone, Default)]
272pub struct Extensions {
273 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
274}
275
276impl Extensions {
277 pub fn new() -> Self {
279 Self::default()
280 }
281
282 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
286 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
287 }
288
289 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
293 self.map
294 .get(&std::any::TypeId::of::<T>())
295 .and_then(|val| val.downcast_ref::<T>())
296 }
297
298 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
300 self.map.contains_key(&std::any::TypeId::of::<T>())
301 }
302
303 pub fn merge(&mut self, other: &Extensions) {
307 for (k, v) in &other.map {
308 self.map.insert(*k, v.clone());
309 }
310 }
311}
312
313impl std::fmt::Debug for Extensions {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("Extensions")
316 .field("len", &self.map.len())
317 .finish()
318 }
319}
320
321impl std::fmt::Debug for RequestContext {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 f.debug_struct("RequestContext")
324 .field("request_id", &self.request_id)
325 .field("progress_token", &self.progress_token)
326 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
327 .finish()
328 }
329}
330
331impl RequestContext {
332 pub fn new(request_id: RequestId) -> Self {
334 Self {
335 request_id,
336 progress_token: None,
337 cancelled: Arc::new(AtomicBool::new(false)),
338 notification_tx: None,
339 client_requester: None,
340 extensions: Arc::new(Extensions::new()),
341 min_log_level: None,
342 }
343 }
344
345 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
347 self.progress_token = Some(token);
348 self
349 }
350
351 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
353 self.notification_tx = Some(tx);
354 self
355 }
356
357 pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
362 self.min_log_level = Some(level);
363 self
364 }
365
366 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
368 self.client_requester = Some(requester);
369 self
370 }
371
372 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
376 self.extensions = extensions;
377 self
378 }
379
380 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
396 self.extensions.get::<T>()
397 }
398
399 pub fn extensions_mut(&mut self) -> &mut Extensions {
404 Arc::make_mut(&mut self.extensions)
405 }
406
407 pub fn extensions(&self) -> &Extensions {
409 &self.extensions
410 }
411
412 pub fn request_id(&self) -> &RequestId {
414 &self.request_id
415 }
416
417 pub fn progress_token(&self) -> Option<&ProgressToken> {
419 self.progress_token.as_ref()
420 }
421
422 pub fn is_cancelled(&self) -> bool {
424 self.cancelled.load(Ordering::Relaxed)
425 }
426
427 pub fn cancel(&self) {
429 self.cancelled.store(true, Ordering::Relaxed);
430 }
431
432 pub fn cancellation_token(&self) -> CancellationToken {
434 CancellationToken {
435 cancelled: self.cancelled.clone(),
436 }
437 }
438
439 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
443 let Some(token) = &self.progress_token else {
444 return;
445 };
446 let Some(tx) = &self.notification_tx else {
447 return;
448 };
449
450 let params = ProgressParams {
451 progress_token: token.clone(),
452 progress,
453 total,
454 message: message.map(|s| s.to_string()),
455 meta: None,
456 };
457
458 let _ = tx.try_send(ServerNotification::Progress(params));
460 }
461
462 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
466 let Some(token) = &self.progress_token else {
467 return;
468 };
469 let Some(tx) = &self.notification_tx else {
470 return;
471 };
472
473 let params = ProgressParams {
474 progress_token: token.clone(),
475 progress,
476 total,
477 message: message.map(|s| s.to_string()),
478 meta: None,
479 };
480
481 let _ = tx.try_send(ServerNotification::Progress(params));
482 }
483
484 pub fn send_log(&self, params: LoggingMessageParams) {
501 let Some(tx) = &self.notification_tx else {
502 return;
503 };
504
505 if let Some(min_level) = &self.min_log_level
510 && let Ok(min) = min_level.read()
511 && params.level > *min
512 {
513 return;
514 }
515
516 let _ = tx.try_send(ServerNotification::LogMessage(params));
517 }
518
519 pub fn can_sample(&self) -> bool {
524 self.client_requester.is_some()
525 }
526
527 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
551 let requester = self.client_requester.as_ref().ok_or_else(|| {
552 Error::Internal("Sampling not available: no client requester configured".to_string())
553 })?;
554
555 requester.sample(params).await
556 }
557
558 pub fn can_elicit(&self) -> bool {
564 self.client_requester.is_some()
565 }
566
567 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
599 let requester = self.client_requester.as_ref().ok_or_else(|| {
600 Error::Internal("Elicitation not available: no client requester configured".to_string())
601 })?;
602
603 requester.elicit(ElicitRequestParams::Form(params)).await
604 }
605
606 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
636 let requester = self.client_requester.as_ref().ok_or_else(|| {
637 Error::Internal("Elicitation not available: no client requester configured".to_string())
638 })?;
639
640 requester.elicit(ElicitRequestParams::Url(params)).await
641 }
642
643 pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
667 use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
668
669 let params = ElicitFormParams {
670 mode: Some(ElicitMode::Form),
671 message: message.into(),
672 requested_schema: ElicitFormSchema::new().boolean_field_with_default(
673 "confirm",
674 Some("Confirm this action"),
675 true,
676 false,
677 ),
678 meta: None,
679 };
680
681 let result = self.elicit_form(params).await?;
682 Ok(result.action == ElicitAction::Accept)
683 }
684}
685
686#[derive(Clone, Debug)]
688pub struct CancellationToken {
689 cancelled: Arc<AtomicBool>,
690}
691
692impl CancellationToken {
693 pub fn is_cancelled(&self) -> bool {
695 self.cancelled.load(Ordering::Relaxed)
696 }
697
698 pub fn cancel(&self) {
700 self.cancelled.store(true, Ordering::Relaxed);
701 }
702}
703
704#[derive(Default)]
706pub struct RequestContextBuilder {
707 request_id: Option<RequestId>,
708 progress_token: Option<ProgressToken>,
709 notification_tx: Option<NotificationSender>,
710 client_requester: Option<ClientRequesterHandle>,
711 min_log_level: Option<Arc<RwLock<LogLevel>>>,
712}
713
714impl RequestContextBuilder {
715 pub fn new() -> Self {
717 Self::default()
718 }
719
720 pub fn request_id(mut self, id: RequestId) -> Self {
722 self.request_id = Some(id);
723 self
724 }
725
726 pub fn progress_token(mut self, token: ProgressToken) -> Self {
728 self.progress_token = Some(token);
729 self
730 }
731
732 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
734 self.notification_tx = Some(tx);
735 self
736 }
737
738 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
740 self.client_requester = Some(requester);
741 self
742 }
743
744 pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
746 self.min_log_level = Some(level);
747 self
748 }
749
750 pub fn build(self) -> RequestContext {
754 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
755 if let Some(token) = self.progress_token {
756 ctx = ctx.with_progress_token(token);
757 }
758 if let Some(tx) = self.notification_tx {
759 ctx = ctx.with_notification_sender(tx);
760 }
761 if let Some(requester) = self.client_requester {
762 ctx = ctx.with_client_requester(requester);
763 }
764 if let Some(level) = self.min_log_level {
765 ctx = ctx.with_min_log_level(level);
766 }
767 ctx
768 }
769}
770
771#[cfg(test)]
772mod tests {
773 use super::*;
774
775 #[test]
776 fn test_cancellation() {
777 let ctx = RequestContext::new(RequestId::Number(1));
778 assert!(!ctx.is_cancelled());
779
780 let token = ctx.cancellation_token();
781 assert!(!token.is_cancelled());
782
783 ctx.cancel();
784 assert!(ctx.is_cancelled());
785 assert!(token.is_cancelled());
786 }
787
788 #[tokio::test]
789 async fn test_progress_reporting() {
790 let (tx, mut rx) = notification_channel(10);
791
792 let ctx = RequestContext::new(RequestId::Number(1))
793 .with_progress_token(ProgressToken::Number(42))
794 .with_notification_sender(tx);
795
796 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
797 .await;
798
799 let notification = rx.recv().await.unwrap();
800 match notification {
801 ServerNotification::Progress(params) => {
802 assert_eq!(params.progress, 50.0);
803 assert_eq!(params.total, Some(100.0));
804 assert_eq!(params.message.as_deref(), Some("Halfway"));
805 }
806 _ => panic!("Expected Progress notification"),
807 }
808 }
809
810 #[tokio::test]
811 async fn test_progress_no_token() {
812 let (tx, mut rx) = notification_channel(10);
813
814 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
816
817 ctx.report_progress(50.0, Some(100.0), None).await;
818
819 assert!(rx.try_recv().is_err());
821 }
822
823 #[test]
824 fn test_builder() {
825 let (tx, _rx) = notification_channel(10);
826
827 let ctx = RequestContextBuilder::new()
828 .request_id(RequestId::String("req-1".to_string()))
829 .progress_token(ProgressToken::String("prog-1".to_string()))
830 .notification_sender(tx)
831 .build();
832
833 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
834 assert!(ctx.progress_token().is_some());
835 }
836
837 #[test]
838 fn test_can_sample_without_requester() {
839 let ctx = RequestContext::new(RequestId::Number(1));
840 assert!(!ctx.can_sample());
841 }
842
843 #[test]
844 fn test_can_sample_with_requester() {
845 let (request_tx, _rx) = outgoing_request_channel(10);
846 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
847
848 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
849 assert!(ctx.can_sample());
850 }
851
852 #[tokio::test]
853 async fn test_sample_without_requester_fails() {
854 use crate::protocol::{CreateMessageParams, SamplingMessage};
855
856 let ctx = RequestContext::new(RequestId::Number(1));
857 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
858
859 let result = ctx.sample(params).await;
860 assert!(result.is_err());
861 assert!(
862 result
863 .unwrap_err()
864 .to_string()
865 .contains("Sampling not available")
866 );
867 }
868
869 #[test]
870 fn test_builder_with_client_requester() {
871 let (request_tx, _rx) = outgoing_request_channel(10);
872 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
873
874 let ctx = RequestContextBuilder::new()
875 .request_id(RequestId::Number(1))
876 .client_requester(requester)
877 .build();
878
879 assert!(ctx.can_sample());
880 }
881
882 #[test]
883 fn test_can_elicit_without_requester() {
884 let ctx = RequestContext::new(RequestId::Number(1));
885 assert!(!ctx.can_elicit());
886 }
887
888 #[test]
889 fn test_can_elicit_with_requester() {
890 let (request_tx, _rx) = outgoing_request_channel(10);
891 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
892
893 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
894 assert!(ctx.can_elicit());
895 }
896
897 #[tokio::test]
898 async fn test_elicit_form_without_requester_fails() {
899 use crate::protocol::{ElicitFormSchema, ElicitMode};
900
901 let ctx = RequestContext::new(RequestId::Number(1));
902 let params = ElicitFormParams {
903 mode: Some(ElicitMode::Form),
904 message: "Enter details".to_string(),
905 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
906 meta: None,
907 };
908
909 let result = ctx.elicit_form(params).await;
910 assert!(result.is_err());
911 assert!(
912 result
913 .unwrap_err()
914 .to_string()
915 .contains("Elicitation not available")
916 );
917 }
918
919 #[tokio::test]
920 async fn test_elicit_url_without_requester_fails() {
921 use crate::protocol::ElicitMode;
922
923 let ctx = RequestContext::new(RequestId::Number(1));
924 let params = ElicitUrlParams {
925 mode: Some(ElicitMode::Url),
926 elicitation_id: "test-123".to_string(),
927 message: "Please authorize".to_string(),
928 url: "https://example.com/auth".to_string(),
929 meta: None,
930 };
931
932 let result = ctx.elicit_url(params).await;
933 assert!(result.is_err());
934 assert!(
935 result
936 .unwrap_err()
937 .to_string()
938 .contains("Elicitation not available")
939 );
940 }
941
942 #[tokio::test]
943 async fn test_confirm_without_requester_fails() {
944 let ctx = RequestContext::new(RequestId::Number(1));
945
946 let result = ctx.confirm("Are you sure?").await;
947 assert!(result.is_err());
948 assert!(
949 result
950 .unwrap_err()
951 .to_string()
952 .contains("Elicitation not available")
953 );
954 }
955
956 #[tokio::test]
957 async fn test_send_log_filtered_by_level() {
958 let (tx, mut rx) = notification_channel(10);
959 let min_level = Arc::new(RwLock::new(LogLevel::Warning));
960
961 let ctx = RequestContext::new(RequestId::Number(1))
962 .with_notification_sender(tx)
963 .with_min_log_level(min_level.clone());
964
965 ctx.send_log(LoggingMessageParams::new(
967 LogLevel::Error,
968 serde_json::Value::Null,
969 ));
970 let msg = rx.try_recv();
971 assert!(msg.is_ok(), "Error should pass through Warning filter");
972
973 ctx.send_log(LoggingMessageParams::new(
975 LogLevel::Warning,
976 serde_json::Value::Null,
977 ));
978 let msg = rx.try_recv();
979 assert!(msg.is_ok(), "Warning should pass through Warning filter");
980
981 ctx.send_log(LoggingMessageParams::new(
983 LogLevel::Info,
984 serde_json::Value::Null,
985 ));
986 let msg = rx.try_recv();
987 assert!(msg.is_err(), "Info should be filtered by Warning filter");
988
989 ctx.send_log(LoggingMessageParams::new(
991 LogLevel::Debug,
992 serde_json::Value::Null,
993 ));
994 let msg = rx.try_recv();
995 assert!(msg.is_err(), "Debug should be filtered by Warning filter");
996 }
997
998 #[tokio::test]
999 async fn test_send_log_level_updates_dynamically() {
1000 let (tx, mut rx) = notification_channel(10);
1001 let min_level = Arc::new(RwLock::new(LogLevel::Error));
1002
1003 let ctx = RequestContext::new(RequestId::Number(1))
1004 .with_notification_sender(tx)
1005 .with_min_log_level(min_level.clone());
1006
1007 ctx.send_log(LoggingMessageParams::new(
1009 LogLevel::Info,
1010 serde_json::Value::Null,
1011 ));
1012 assert!(
1013 rx.try_recv().is_err(),
1014 "Info should be filtered at Error level"
1015 );
1016
1017 *min_level.write().unwrap() = LogLevel::Debug;
1019
1020 ctx.send_log(LoggingMessageParams::new(
1022 LogLevel::Info,
1023 serde_json::Value::Null,
1024 ));
1025 assert!(
1026 rx.try_recv().is_ok(),
1027 "Info should pass through after level changed to Debug"
1028 );
1029 }
1030
1031 #[tokio::test]
1032 async fn test_send_log_no_min_level_sends_all() {
1033 let (tx, mut rx) = notification_channel(10);
1034
1035 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1037
1038 ctx.send_log(LoggingMessageParams::new(
1039 LogLevel::Debug,
1040 serde_json::Value::Null,
1041 ));
1042 assert!(
1043 rx.try_recv().is_ok(),
1044 "Debug should pass when no min level is set"
1045 );
1046 }
1047}