1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LogLevel, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86#[non_exhaustive]
87pub enum ServerNotification {
88 Progress(ProgressParams),
90 LogMessage(LoggingMessageParams),
92 ResourceUpdated {
94 uri: String,
96 },
97 ResourcesListChanged,
99 ToolsListChanged,
101 PromptsListChanged,
103 TaskStatusChanged(crate::protocol::TaskStatusParams),
105}
106
107pub type NotificationSender = mpsc::Sender<ServerNotification>;
109
110pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
112
113pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
115 mpsc::channel(buffer)
116}
117
118#[async_trait]
128pub trait ClientRequester: Send + Sync {
129 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
133
134 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
141}
142
143pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
145
146#[derive(Debug)]
148pub struct OutgoingRequest {
149 pub id: RequestId,
151 pub method: String,
153 pub params: serde_json::Value,
155 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
157}
158
159pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
161
162pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
164
165pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
167 mpsc::channel(buffer)
168}
169
170#[derive(Clone)]
172pub struct ChannelClientRequester {
173 request_tx: OutgoingRequestSender,
174 next_id: Arc<AtomicI64>,
175}
176
177impl ChannelClientRequester {
178 pub fn new(request_tx: OutgoingRequestSender) -> Self {
180 Self {
181 request_tx,
182 next_id: Arc::new(AtomicI64::new(1)),
183 }
184 }
185
186 fn next_request_id(&self) -> RequestId {
187 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
188 RequestId::Number(id)
189 }
190}
191
192#[async_trait]
193impl ClientRequester for ChannelClientRequester {
194 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
195 let id = self.next_request_id();
196 let params_json = serde_json::to_value(¶ms)
197 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
198
199 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
200
201 let request = OutgoingRequest {
202 id: id.clone(),
203 method: "sampling/createMessage".to_string(),
204 params: params_json,
205 response_tx,
206 };
207
208 self.request_tx
209 .send(request)
210 .await
211 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
212
213 let response = response_rx.await.map_err(|_| {
214 Error::Internal("Failed to receive response: channel closed".to_string())
215 })??;
216
217 serde_json::from_value(response)
218 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
219 }
220
221 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
222 let id = self.next_request_id();
223 let params_json = serde_json::to_value(¶ms)
224 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
225
226 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
227
228 let request = OutgoingRequest {
229 id: id.clone(),
230 method: "elicitation/create".to_string(),
231 params: params_json,
232 response_tx,
233 };
234
235 self.request_tx
236 .send(request)
237 .await
238 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
239
240 let response = response_rx.await.map_err(|_| {
241 Error::Internal("Failed to receive response: channel closed".to_string())
242 })??;
243
244 serde_json::from_value(response)
245 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
246 }
247}
248
249#[derive(Clone)]
251pub struct RequestContext {
252 request_id: RequestId,
254 progress_token: Option<ProgressToken>,
256 cancelled: Arc<AtomicBool>,
258 notification_tx: Option<NotificationSender>,
260 client_requester: Option<ClientRequesterHandle>,
262 extensions: Arc<Extensions>,
264 min_log_level: Option<Arc<RwLock<LogLevel>>>,
266}
267
268#[derive(Clone, Default)]
273pub struct Extensions {
274 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
275}
276
277impl Extensions {
278 pub fn new() -> Self {
280 Self::default()
281 }
282
283 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
287 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
288 }
289
290 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
294 self.map
295 .get(&std::any::TypeId::of::<T>())
296 .and_then(|val| val.downcast_ref::<T>())
297 }
298
299 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
301 self.map.contains_key(&std::any::TypeId::of::<T>())
302 }
303
304 pub fn merge(&mut self, other: &Extensions) {
308 for (k, v) in &other.map {
309 self.map.insert(*k, v.clone());
310 }
311 }
312
313 pub fn len(&self) -> usize {
315 self.map.len()
316 }
317
318 pub fn is_empty(&self) -> bool {
320 self.map.is_empty()
321 }
322}
323
324impl std::fmt::Debug for Extensions {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 f.debug_struct("Extensions")
327 .field("len", &self.map.len())
328 .finish()
329 }
330}
331
332impl std::fmt::Debug for RequestContext {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 f.debug_struct("RequestContext")
335 .field("request_id", &self.request_id)
336 .field("progress_token", &self.progress_token)
337 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
338 .finish()
339 }
340}
341
342impl RequestContext {
343 pub fn new(request_id: RequestId) -> Self {
345 Self {
346 request_id,
347 progress_token: None,
348 cancelled: Arc::new(AtomicBool::new(false)),
349 notification_tx: None,
350 client_requester: None,
351 extensions: Arc::new(Extensions::new()),
352 min_log_level: None,
353 }
354 }
355
356 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
358 self.progress_token = Some(token);
359 self
360 }
361
362 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
364 self.notification_tx = Some(tx);
365 self
366 }
367
368 pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
373 self.min_log_level = Some(level);
374 self
375 }
376
377 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
379 self.client_requester = Some(requester);
380 self
381 }
382
383 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
387 self.extensions = extensions;
388 self
389 }
390
391 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
407 self.extensions.get::<T>()
408 }
409
410 pub fn extensions_mut(&mut self) -> &mut Extensions {
415 Arc::make_mut(&mut self.extensions)
416 }
417
418 pub fn extensions(&self) -> &Extensions {
420 &self.extensions
421 }
422
423 pub fn request_id(&self) -> &RequestId {
425 &self.request_id
426 }
427
428 pub fn progress_token(&self) -> Option<&ProgressToken> {
430 self.progress_token.as_ref()
431 }
432
433 pub fn is_cancelled(&self) -> bool {
435 self.cancelled.load(Ordering::Relaxed)
436 }
437
438 pub fn cancel(&self) {
440 self.cancelled.store(true, Ordering::Relaxed);
441 }
442
443 pub fn cancellation_token(&self) -> CancellationToken {
445 CancellationToken {
446 cancelled: self.cancelled.clone(),
447 }
448 }
449
450 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
454 let Some(token) = &self.progress_token else {
455 return;
456 };
457 let Some(tx) = &self.notification_tx else {
458 return;
459 };
460
461 let params = ProgressParams {
462 progress_token: token.clone(),
463 progress,
464 total,
465 message: message.map(|s| s.to_string()),
466 meta: None,
467 };
468
469 let _ = tx.try_send(ServerNotification::Progress(params));
471 }
472
473 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
477 let Some(token) = &self.progress_token else {
478 return;
479 };
480 let Some(tx) = &self.notification_tx else {
481 return;
482 };
483
484 let params = ProgressParams {
485 progress_token: token.clone(),
486 progress,
487 total,
488 message: message.map(|s| s.to_string()),
489 meta: None,
490 };
491
492 let _ = tx.try_send(ServerNotification::Progress(params));
493 }
494
495 pub fn send_log(&self, params: LoggingMessageParams) {
512 let Some(tx) = &self.notification_tx else {
513 return;
514 };
515
516 if let Some(min_level) = &self.min_log_level
521 && let Ok(min) = min_level.read()
522 && params.level > *min
523 {
524 return;
525 }
526
527 let _ = tx.try_send(ServerNotification::LogMessage(params));
528 }
529
530 pub fn can_sample(&self) -> bool {
535 self.client_requester.is_some()
536 }
537
538 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
562 let requester = self.client_requester.as_ref().ok_or_else(|| {
563 Error::Internal("Sampling not available: no client requester configured".to_string())
564 })?;
565
566 requester.sample(params).await
567 }
568
569 pub fn can_elicit(&self) -> bool {
575 self.client_requester.is_some()
576 }
577
578 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
610 let requester = self.client_requester.as_ref().ok_or_else(|| {
611 Error::Internal("Elicitation not available: no client requester configured".to_string())
612 })?;
613
614 requester.elicit(ElicitRequestParams::Form(params)).await
615 }
616
617 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
647 let requester = self.client_requester.as_ref().ok_or_else(|| {
648 Error::Internal("Elicitation not available: no client requester configured".to_string())
649 })?;
650
651 requester.elicit(ElicitRequestParams::Url(params)).await
652 }
653
654 pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
678 use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
679
680 let params = ElicitFormParams {
681 mode: Some(ElicitMode::Form),
682 message: message.into(),
683 requested_schema: ElicitFormSchema::new().boolean_field_with_default(
684 "confirm",
685 Some("Confirm this action"),
686 true,
687 false,
688 ),
689 meta: None,
690 };
691
692 let result = self.elicit_form(params).await?;
693 Ok(result.action == ElicitAction::Accept)
694 }
695}
696
697#[derive(Clone, Debug)]
699pub struct CancellationToken {
700 cancelled: Arc<AtomicBool>,
701}
702
703impl CancellationToken {
704 pub fn is_cancelled(&self) -> bool {
706 self.cancelled.load(Ordering::Relaxed)
707 }
708
709 pub fn cancel(&self) {
711 self.cancelled.store(true, Ordering::Relaxed);
712 }
713}
714
715#[derive(Default)]
717pub struct RequestContextBuilder {
718 request_id: Option<RequestId>,
719 progress_token: Option<ProgressToken>,
720 notification_tx: Option<NotificationSender>,
721 client_requester: Option<ClientRequesterHandle>,
722 min_log_level: Option<Arc<RwLock<LogLevel>>>,
723}
724
725impl RequestContextBuilder {
726 pub fn new() -> Self {
728 Self::default()
729 }
730
731 pub fn request_id(mut self, id: RequestId) -> Self {
733 self.request_id = Some(id);
734 self
735 }
736
737 pub fn progress_token(mut self, token: ProgressToken) -> Self {
739 self.progress_token = Some(token);
740 self
741 }
742
743 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
745 self.notification_tx = Some(tx);
746 self
747 }
748
749 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
751 self.client_requester = Some(requester);
752 self
753 }
754
755 pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
757 self.min_log_level = Some(level);
758 self
759 }
760
761 pub fn build(self) -> RequestContext {
765 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
766 if let Some(token) = self.progress_token {
767 ctx = ctx.with_progress_token(token);
768 }
769 if let Some(tx) = self.notification_tx {
770 ctx = ctx.with_notification_sender(tx);
771 }
772 if let Some(requester) = self.client_requester {
773 ctx = ctx.with_client_requester(requester);
774 }
775 if let Some(level) = self.min_log_level {
776 ctx = ctx.with_min_log_level(level);
777 }
778 ctx
779 }
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785
786 #[test]
787 fn test_cancellation() {
788 let ctx = RequestContext::new(RequestId::Number(1));
789 assert!(!ctx.is_cancelled());
790
791 let token = ctx.cancellation_token();
792 assert!(!token.is_cancelled());
793
794 ctx.cancel();
795 assert!(ctx.is_cancelled());
796 assert!(token.is_cancelled());
797 }
798
799 #[tokio::test]
800 async fn test_progress_reporting() {
801 let (tx, mut rx) = notification_channel(10);
802
803 let ctx = RequestContext::new(RequestId::Number(1))
804 .with_progress_token(ProgressToken::Number(42))
805 .with_notification_sender(tx);
806
807 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
808 .await;
809
810 let notification = rx.recv().await.unwrap();
811 match notification {
812 ServerNotification::Progress(params) => {
813 assert_eq!(params.progress, 50.0);
814 assert_eq!(params.total, Some(100.0));
815 assert_eq!(params.message.as_deref(), Some("Halfway"));
816 }
817 _ => panic!("Expected Progress notification"),
818 }
819 }
820
821 #[tokio::test]
822 async fn test_progress_no_token() {
823 let (tx, mut rx) = notification_channel(10);
824
825 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
827
828 ctx.report_progress(50.0, Some(100.0), None).await;
829
830 assert!(rx.try_recv().is_err());
832 }
833
834 #[test]
835 fn test_builder() {
836 let (tx, _rx) = notification_channel(10);
837
838 let ctx = RequestContextBuilder::new()
839 .request_id(RequestId::String("req-1".to_string()))
840 .progress_token(ProgressToken::String("prog-1".to_string()))
841 .notification_sender(tx)
842 .build();
843
844 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
845 assert!(ctx.progress_token().is_some());
846 }
847
848 #[test]
849 fn test_can_sample_without_requester() {
850 let ctx = RequestContext::new(RequestId::Number(1));
851 assert!(!ctx.can_sample());
852 }
853
854 #[test]
855 fn test_can_sample_with_requester() {
856 let (request_tx, _rx) = outgoing_request_channel(10);
857 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
858
859 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
860 assert!(ctx.can_sample());
861 }
862
863 #[tokio::test]
864 async fn test_sample_without_requester_fails() {
865 use crate::protocol::{CreateMessageParams, SamplingMessage};
866
867 let ctx = RequestContext::new(RequestId::Number(1));
868 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
869
870 let result = ctx.sample(params).await;
871 assert!(result.is_err());
872 assert!(
873 result
874 .unwrap_err()
875 .to_string()
876 .contains("Sampling not available")
877 );
878 }
879
880 #[test]
881 fn test_builder_with_client_requester() {
882 let (request_tx, _rx) = outgoing_request_channel(10);
883 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
884
885 let ctx = RequestContextBuilder::new()
886 .request_id(RequestId::Number(1))
887 .client_requester(requester)
888 .build();
889
890 assert!(ctx.can_sample());
891 }
892
893 #[test]
894 fn test_can_elicit_without_requester() {
895 let ctx = RequestContext::new(RequestId::Number(1));
896 assert!(!ctx.can_elicit());
897 }
898
899 #[test]
900 fn test_can_elicit_with_requester() {
901 let (request_tx, _rx) = outgoing_request_channel(10);
902 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
903
904 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
905 assert!(ctx.can_elicit());
906 }
907
908 #[tokio::test]
909 async fn test_elicit_form_without_requester_fails() {
910 use crate::protocol::{ElicitFormSchema, ElicitMode};
911
912 let ctx = RequestContext::new(RequestId::Number(1));
913 let params = ElicitFormParams {
914 mode: Some(ElicitMode::Form),
915 message: "Enter details".to_string(),
916 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
917 meta: None,
918 };
919
920 let result = ctx.elicit_form(params).await;
921 assert!(result.is_err());
922 assert!(
923 result
924 .unwrap_err()
925 .to_string()
926 .contains("Elicitation not available")
927 );
928 }
929
930 #[tokio::test]
931 async fn test_elicit_url_without_requester_fails() {
932 use crate::protocol::ElicitMode;
933
934 let ctx = RequestContext::new(RequestId::Number(1));
935 let params = ElicitUrlParams {
936 mode: Some(ElicitMode::Url),
937 elicitation_id: "test-123".to_string(),
938 message: "Please authorize".to_string(),
939 url: "https://example.com/auth".to_string(),
940 meta: None,
941 };
942
943 let result = ctx.elicit_url(params).await;
944 assert!(result.is_err());
945 assert!(
946 result
947 .unwrap_err()
948 .to_string()
949 .contains("Elicitation not available")
950 );
951 }
952
953 #[tokio::test]
954 async fn test_confirm_without_requester_fails() {
955 let ctx = RequestContext::new(RequestId::Number(1));
956
957 let result = ctx.confirm("Are you sure?").await;
958 assert!(result.is_err());
959 assert!(
960 result
961 .unwrap_err()
962 .to_string()
963 .contains("Elicitation not available")
964 );
965 }
966
967 #[tokio::test]
968 async fn test_send_log_filtered_by_level() {
969 let (tx, mut rx) = notification_channel(10);
970 let min_level = Arc::new(RwLock::new(LogLevel::Warning));
971
972 let ctx = RequestContext::new(RequestId::Number(1))
973 .with_notification_sender(tx)
974 .with_min_log_level(min_level.clone());
975
976 ctx.send_log(LoggingMessageParams::new(
978 LogLevel::Error,
979 serde_json::Value::Null,
980 ));
981 let msg = rx.try_recv();
982 assert!(msg.is_ok(), "Error should pass through Warning filter");
983
984 ctx.send_log(LoggingMessageParams::new(
986 LogLevel::Warning,
987 serde_json::Value::Null,
988 ));
989 let msg = rx.try_recv();
990 assert!(msg.is_ok(), "Warning should pass through Warning filter");
991
992 ctx.send_log(LoggingMessageParams::new(
994 LogLevel::Info,
995 serde_json::Value::Null,
996 ));
997 let msg = rx.try_recv();
998 assert!(msg.is_err(), "Info should be filtered by Warning filter");
999
1000 ctx.send_log(LoggingMessageParams::new(
1002 LogLevel::Debug,
1003 serde_json::Value::Null,
1004 ));
1005 let msg = rx.try_recv();
1006 assert!(msg.is_err(), "Debug should be filtered by Warning filter");
1007 }
1008
1009 #[tokio::test]
1010 async fn test_send_log_level_updates_dynamically() {
1011 let (tx, mut rx) = notification_channel(10);
1012 let min_level = Arc::new(RwLock::new(LogLevel::Error));
1013
1014 let ctx = RequestContext::new(RequestId::Number(1))
1015 .with_notification_sender(tx)
1016 .with_min_log_level(min_level.clone());
1017
1018 ctx.send_log(LoggingMessageParams::new(
1020 LogLevel::Info,
1021 serde_json::Value::Null,
1022 ));
1023 assert!(
1024 rx.try_recv().is_err(),
1025 "Info should be filtered at Error level"
1026 );
1027
1028 *min_level.write().unwrap() = LogLevel::Debug;
1030
1031 ctx.send_log(LoggingMessageParams::new(
1033 LogLevel::Info,
1034 serde_json::Value::Null,
1035 ));
1036 assert!(
1037 rx.try_recv().is_ok(),
1038 "Info should pass through after level changed to Debug"
1039 );
1040 }
1041
1042 #[tokio::test]
1043 async fn test_send_log_no_min_level_sends_all() {
1044 let (tx, mut rx) = notification_channel(10);
1045
1046 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1048
1049 ctx.send_log(LoggingMessageParams::new(
1050 LogLevel::Debug,
1051 serde_json::Value::Null,
1052 ));
1053 assert!(
1054 rx.try_recv().is_ok(),
1055 "Debug should pass when no min level is set"
1056 );
1057 }
1058}