use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use tokio::sync::mpsc;
use crate::error::{Error, Result};
use crate::protocol::{
CallToolResult, CancelTaskParams, CreateMessageParams, CreateMessageResult, ElicitFormParams,
ElicitRequestParams, ElicitResult, ElicitUrlParams, GetTaskInfoParams, GetTaskResultParams,
ListTasksParams, ListTasksResult, LogLevel, LoggingMessageParams, ProgressParams,
ProgressToken, RequestId, TaskObject, TaskStatus,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ServerNotification {
Progress(ProgressParams),
LogMessage(LoggingMessageParams),
ResourceUpdated {
uri: String,
},
ResourcesListChanged,
ToolsListChanged,
PromptsListChanged,
TaskStatusChanged(crate::protocol::TaskStatusParams),
}
pub type NotificationSender = mpsc::Sender<ServerNotification>;
pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
mpsc::channel(buffer)
}
#[async_trait]
pub trait ClientRequester: Send + Sync {
async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
async fn request(
&self,
method: String,
params: serde_json::Value,
) -> Result<serde_json::Value> {
let _ = (method, params);
Err(Error::Internal(
"ClientRequester does not support arbitrary requests".to_string(),
))
}
}
pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
#[derive(Debug)]
pub struct OutgoingRequest {
pub id: RequestId,
pub method: String,
pub params: serde_json::Value,
pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
}
pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
mpsc::channel(buffer)
}
#[derive(Clone)]
pub struct ChannelClientRequester {
request_tx: OutgoingRequestSender,
next_id: Arc<AtomicI64>,
}
impl ChannelClientRequester {
pub fn new(request_tx: OutgoingRequestSender) -> Self {
Self {
request_tx,
next_id: Arc::new(AtomicI64::new(1)),
}
}
fn next_request_id(&self) -> RequestId {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
RequestId::Number(id)
}
}
impl ChannelClientRequester {
async fn dispatch(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
let id = self.next_request_id();
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let request = OutgoingRequest {
id,
method: method.to_string(),
params,
response_tx,
};
self.request_tx
.send(request)
.await
.map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
response_rx.await.map_err(|_| {
Error::Internal("Failed to receive response: channel closed".to_string())
})?
}
}
#[async_trait]
impl ClientRequester for ChannelClientRequester {
async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
let params_json = serde_json::to_value(¶ms)
.map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
let response = self.dispatch("sampling/createMessage", params_json).await?;
serde_json::from_value(response)
.map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
}
async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
let params_json = serde_json::to_value(¶ms)
.map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
let response = self.dispatch("elicitation/create", params_json).await?;
serde_json::from_value(response)
.map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
}
async fn request(
&self,
method: String,
params: serde_json::Value,
) -> Result<serde_json::Value> {
self.dispatch(&method, params).await
}
}
#[derive(Clone)]
pub struct RequestContext {
request_id: RequestId,
progress_token: Option<ProgressToken>,
cancelled: Arc<AtomicBool>,
notification_tx: Option<NotificationSender>,
client_requester: Option<ClientRequesterHandle>,
extensions: Arc<Extensions>,
min_log_level: Option<Arc<RwLock<LogLevel>>>,
}
#[derive(Clone, Default)]
pub struct Extensions {
map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
}
impl Extensions {
pub fn new() -> Self {
Self::default()
}
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
}
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.map
.get(&std::any::TypeId::of::<T>())
.and_then(|val| val.downcast_ref::<T>())
}
pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
self.map.contains_key(&std::any::TypeId::of::<T>())
}
pub fn merge(&mut self, other: &Extensions) {
for (k, v) in &other.map {
self.map.insert(*k, v.clone());
}
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl std::fmt::Debug for Extensions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Extensions")
.field("len", &self.map.len())
.finish()
}
}
impl std::fmt::Debug for RequestContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RequestContext")
.field("request_id", &self.request_id)
.field("progress_token", &self.progress_token)
.field("cancelled", &self.cancelled.load(Ordering::Relaxed))
.finish()
}
}
impl RequestContext {
pub fn new(request_id: RequestId) -> Self {
Self {
request_id,
progress_token: None,
cancelled: Arc::new(AtomicBool::new(false)),
notification_tx: None,
client_requester: None,
extensions: Arc::new(Extensions::new()),
min_log_level: None,
}
}
pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
self.progress_token = Some(token);
self
}
pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
self.notification_tx = Some(tx);
self
}
pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
self.min_log_level = Some(level);
self
}
pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
self.client_requester = Some(requester);
self
}
pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
self.extensions = extensions;
self
}
pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.extensions.get::<T>()
}
pub fn extensions_mut(&mut self) -> &mut Extensions {
Arc::make_mut(&mut self.extensions)
}
pub fn extensions(&self) -> &Extensions {
&self.extensions
}
pub fn request_id(&self) -> &RequestId {
&self.request_id
}
pub fn progress_token(&self) -> Option<&ProgressToken> {
self.progress_token.as_ref()
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Relaxed)
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
pub fn cancellation_token(&self) -> CancellationToken {
CancellationToken {
cancelled: self.cancelled.clone(),
}
}
pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
let Some(token) = &self.progress_token else {
return;
};
let Some(tx) = &self.notification_tx else {
return;
};
let params = ProgressParams {
progress_token: token.clone(),
progress,
total,
message: message.map(|s| s.to_string()),
meta: None,
};
let _ = tx.try_send(ServerNotification::Progress(params));
}
pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
let Some(token) = &self.progress_token else {
return;
};
let Some(tx) = &self.notification_tx else {
return;
};
let params = ProgressParams {
progress_token: token.clone(),
progress,
total,
message: message.map(|s| s.to_string()),
meta: None,
};
let _ = tx.try_send(ServerNotification::Progress(params));
}
pub fn send_log(&self, params: LoggingMessageParams) {
let Some(tx) = &self.notification_tx else {
return;
};
if let Some(min_level) = &self.min_log_level
&& let Ok(min) = min_level.read()
&& params.level > *min
{
return;
}
let _ = tx.try_send(ServerNotification::LogMessage(params));
}
pub fn can_sample(&self) -> bool {
self.client_requester.is_some()
}
pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
let requester = self.client_requester.as_ref().ok_or_else(|| {
Error::Internal("Sampling not available: no client requester configured".to_string())
})?;
requester.sample(params).await
}
pub fn can_elicit(&self) -> bool {
self.client_requester.is_some()
}
pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
let requester = self.client_requester.as_ref().ok_or_else(|| {
Error::Internal("Elicitation not available: no client requester configured".to_string())
})?;
requester.elicit(ElicitRequestParams::Form(params)).await
}
pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
let requester = self.client_requester.as_ref().ok_or_else(|| {
Error::Internal("Elicitation not available: no client requester configured".to_string())
})?;
requester.elicit(ElicitRequestParams::Url(params)).await
}
pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
let params = ElicitFormParams {
mode: Some(ElicitMode::Form),
message: message.into(),
requested_schema: ElicitFormSchema::new().boolean_field_with_default(
"confirm",
Some("Confirm this action"),
true,
false,
),
meta: None,
};
let result = self.elicit_form(params).await?;
Ok(result.action == ElicitAction::Accept)
}
pub async fn list_tasks(&self, status: Option<TaskStatus>) -> Result<ListTasksResult> {
let params = ListTasksParams {
status,
cursor: None,
meta: None,
};
let value = self
.request_raw("tasks/list", serde_json::to_value(¶ms)?)
.await?;
serde_json::from_value(value)
.map_err(|e| Error::Internal(format!("Failed to deserialize tasks/list: {e}")))
}
pub async fn get_task_info(&self, task_id: impl Into<String>) -> Result<TaskObject> {
let params = GetTaskInfoParams {
task_id: task_id.into(),
meta: None,
};
let value = self
.request_raw("tasks/get", serde_json::to_value(¶ms)?)
.await?;
serde_json::from_value(value)
.map_err(|e| Error::Internal(format!("Failed to deserialize tasks/get: {e}")))
}
pub async fn get_task_result(&self, task_id: impl Into<String>) -> Result<CallToolResult> {
let params = GetTaskResultParams {
task_id: task_id.into(),
meta: None,
};
let value = self
.request_raw("tasks/result", serde_json::to_value(¶ms)?)
.await?;
serde_json::from_value(value)
.map_err(|e| Error::Internal(format!("Failed to deserialize tasks/result: {e}")))
}
pub async fn cancel_task(
&self,
task_id: impl Into<String>,
reason: Option<String>,
) -> Result<TaskObject> {
let params = CancelTaskParams {
task_id: task_id.into(),
reason,
meta: None,
};
let value = self
.request_raw("tasks/cancel", serde_json::to_value(¶ms)?)
.await?;
serde_json::from_value(value)
.map_err(|e| Error::Internal(format!("Failed to deserialize tasks/cancel: {e}")))
}
pub async fn request_raw(
&self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value> {
let requester = self.client_requester.as_ref().ok_or_else(|| {
Error::Internal(
"Client request not available: no client requester configured".to_string(),
)
})?;
requester.request(method.to_string(), params).await
}
}
#[derive(Clone, Debug)]
pub struct CancellationToken {
cancelled: Arc<AtomicBool>,
}
impl CancellationToken {
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Relaxed)
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::Relaxed);
}
}
#[derive(Default)]
pub struct RequestContextBuilder {
request_id: Option<RequestId>,
progress_token: Option<ProgressToken>,
notification_tx: Option<NotificationSender>,
client_requester: Option<ClientRequesterHandle>,
min_log_level: Option<Arc<RwLock<LogLevel>>>,
}
impl RequestContextBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn request_id(mut self, id: RequestId) -> Self {
self.request_id = Some(id);
self
}
pub fn progress_token(mut self, token: ProgressToken) -> Self {
self.progress_token = Some(token);
self
}
pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
self.notification_tx = Some(tx);
self
}
pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
self.client_requester = Some(requester);
self
}
pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
self.min_log_level = Some(level);
self
}
pub fn build(self) -> RequestContext {
let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
if let Some(token) = self.progress_token {
ctx = ctx.with_progress_token(token);
}
if let Some(tx) = self.notification_tx {
ctx = ctx.with_notification_sender(tx);
}
if let Some(requester) = self.client_requester {
ctx = ctx.with_client_requester(requester);
}
if let Some(level) = self.min_log_level {
ctx = ctx.with_min_log_level(level);
}
ctx
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cancellation() {
let ctx = RequestContext::new(RequestId::Number(1));
assert!(!ctx.is_cancelled());
let token = ctx.cancellation_token();
assert!(!token.is_cancelled());
ctx.cancel();
assert!(ctx.is_cancelled());
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_progress_reporting() {
let (tx, mut rx) = notification_channel(10);
let ctx = RequestContext::new(RequestId::Number(1))
.with_progress_token(ProgressToken::Number(42))
.with_notification_sender(tx);
ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
.await;
let notification = rx.recv().await.unwrap();
match notification {
ServerNotification::Progress(params) => {
assert_eq!(params.progress, 50.0);
assert_eq!(params.total, Some(100.0));
assert_eq!(params.message.as_deref(), Some("Halfway"));
}
_ => panic!("Expected Progress notification"),
}
}
#[tokio::test]
async fn test_progress_no_token() {
let (tx, mut rx) = notification_channel(10);
let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
ctx.report_progress(50.0, Some(100.0), None).await;
assert!(rx.try_recv().is_err());
}
#[test]
fn test_builder() {
let (tx, _rx) = notification_channel(10);
let ctx = RequestContextBuilder::new()
.request_id(RequestId::String("req-1".to_string()))
.progress_token(ProgressToken::String("prog-1".to_string()))
.notification_sender(tx)
.build();
assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
assert!(ctx.progress_token().is_some());
}
#[test]
fn test_can_sample_without_requester() {
let ctx = RequestContext::new(RequestId::Number(1));
assert!(!ctx.can_sample());
}
#[test]
fn test_can_sample_with_requester() {
let (request_tx, _rx) = outgoing_request_channel(10);
let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
assert!(ctx.can_sample());
}
#[tokio::test]
async fn test_sample_without_requester_fails() {
use crate::protocol::{CreateMessageParams, SamplingMessage};
let ctx = RequestContext::new(RequestId::Number(1));
let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
let result = ctx.sample(params).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Sampling not available")
);
}
#[test]
fn test_builder_with_client_requester() {
let (request_tx, _rx) = outgoing_request_channel(10);
let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
let ctx = RequestContextBuilder::new()
.request_id(RequestId::Number(1))
.client_requester(requester)
.build();
assert!(ctx.can_sample());
}
#[test]
fn test_can_elicit_without_requester() {
let ctx = RequestContext::new(RequestId::Number(1));
assert!(!ctx.can_elicit());
}
#[test]
fn test_can_elicit_with_requester() {
let (request_tx, _rx) = outgoing_request_channel(10);
let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
assert!(ctx.can_elicit());
}
#[tokio::test]
async fn test_elicit_form_without_requester_fails() {
use crate::protocol::{ElicitFormSchema, ElicitMode};
let ctx = RequestContext::new(RequestId::Number(1));
let params = ElicitFormParams {
mode: Some(ElicitMode::Form),
message: "Enter details".to_string(),
requested_schema: ElicitFormSchema::new().string_field("name", None, true),
meta: None,
};
let result = ctx.elicit_form(params).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Elicitation not available")
);
}
#[tokio::test]
async fn test_elicit_url_without_requester_fails() {
use crate::protocol::ElicitMode;
let ctx = RequestContext::new(RequestId::Number(1));
let params = ElicitUrlParams {
mode: Some(ElicitMode::Url),
elicitation_id: "test-123".to_string(),
message: "Please authorize".to_string(),
url: "https://example.com/auth".to_string(),
meta: None,
};
let result = ctx.elicit_url(params).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Elicitation not available")
);
}
#[tokio::test]
async fn test_confirm_without_requester_fails() {
let ctx = RequestContext::new(RequestId::Number(1));
let result = ctx.confirm("Are you sure?").await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Elicitation not available")
);
}
#[tokio::test]
async fn test_send_log_filtered_by_level() {
let (tx, mut rx) = notification_channel(10);
let min_level = Arc::new(RwLock::new(LogLevel::Warning));
let ctx = RequestContext::new(RequestId::Number(1))
.with_notification_sender(tx)
.with_min_log_level(min_level.clone());
ctx.send_log(LoggingMessageParams::new(
LogLevel::Error,
serde_json::Value::Null,
));
let msg = rx.try_recv();
assert!(msg.is_ok(), "Error should pass through Warning filter");
ctx.send_log(LoggingMessageParams::new(
LogLevel::Warning,
serde_json::Value::Null,
));
let msg = rx.try_recv();
assert!(msg.is_ok(), "Warning should pass through Warning filter");
ctx.send_log(LoggingMessageParams::new(
LogLevel::Info,
serde_json::Value::Null,
));
let msg = rx.try_recv();
assert!(msg.is_err(), "Info should be filtered by Warning filter");
ctx.send_log(LoggingMessageParams::new(
LogLevel::Debug,
serde_json::Value::Null,
));
let msg = rx.try_recv();
assert!(msg.is_err(), "Debug should be filtered by Warning filter");
}
#[tokio::test]
async fn test_send_log_level_updates_dynamically() {
let (tx, mut rx) = notification_channel(10);
let min_level = Arc::new(RwLock::new(LogLevel::Error));
let ctx = RequestContext::new(RequestId::Number(1))
.with_notification_sender(tx)
.with_min_log_level(min_level.clone());
ctx.send_log(LoggingMessageParams::new(
LogLevel::Info,
serde_json::Value::Null,
));
assert!(
rx.try_recv().is_err(),
"Info should be filtered at Error level"
);
*min_level.write().unwrap() = LogLevel::Debug;
ctx.send_log(LoggingMessageParams::new(
LogLevel::Info,
serde_json::Value::Null,
));
assert!(
rx.try_recv().is_ok(),
"Info should pass through after level changed to Debug"
);
}
#[tokio::test]
async fn test_send_log_no_min_level_sends_all() {
let (tx, mut rx) = notification_channel(10);
let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
ctx.send_log(LoggingMessageParams::new(
LogLevel::Debug,
serde_json::Value::Null,
));
assert!(
rx.try_recv().is_ok(),
"Debug should pass when no min level is set"
);
}
fn make_task_object(id: &str, status: TaskStatus) -> serde_json::Value {
serde_json::json!({
"taskId": id,
"status": status,
"createdAt": "2026-04-24T00:00:00Z",
"lastUpdatedAt": "2026-04-24T00:00:00Z",
"ttl": null
})
}
fn spawn_mock_client(
mut rx: OutgoingRequestReceiver,
responder: impl Fn(&str, serde_json::Value) -> serde_json::Value + Send + 'static,
) {
tokio::spawn(async move {
while let Some(req) = rx.recv().await {
let response = responder(&req.method, req.params);
let _ = req.response_tx.send(Ok(response));
}
});
}
#[tokio::test]
async fn test_get_task_info_round_trips() {
let (tx, rx) = outgoing_request_channel(10);
spawn_mock_client(rx, |method, params| {
assert_eq!(method, "tasks/get");
let task_id = params["taskId"].as_str().unwrap().to_string();
make_task_object(&task_id, TaskStatus::Working)
});
let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
let info = ctx.get_task_info("task-123").await.unwrap();
assert_eq!(info.task_id, "task-123");
assert!(matches!(info.status, TaskStatus::Working));
}
#[tokio::test]
async fn test_list_tasks_round_trips() {
let (tx, rx) = outgoing_request_channel(10);
spawn_mock_client(rx, |method, params| {
assert_eq!(method, "tasks/list");
assert_eq!(params["status"], serde_json::json!("working"));
serde_json::json!({
"tasks": [
make_task_object("task-1", TaskStatus::Working),
make_task_object("task-2", TaskStatus::Working),
]
})
});
let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
let result = ctx.list_tasks(Some(TaskStatus::Working)).await.unwrap();
assert_eq!(result.tasks.len(), 2);
assert_eq!(result.tasks[0].task_id, "task-1");
}
#[tokio::test]
async fn test_cancel_task_forwards_reason() {
let (tx, rx) = outgoing_request_channel(10);
spawn_mock_client(rx, |method, params| {
assert_eq!(method, "tasks/cancel");
assert_eq!(params["reason"], serde_json::json!("user requested"));
let task_id = params["taskId"].as_str().unwrap().to_string();
make_task_object(&task_id, TaskStatus::Cancelled)
});
let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
let task = ctx
.cancel_task("task-99", Some("user requested".into()))
.await
.unwrap();
assert_eq!(task.task_id, "task-99");
assert!(matches!(task.status, TaskStatus::Cancelled));
}
#[tokio::test]
async fn test_get_task_info_without_requester_fails() {
let ctx = RequestContext::new(RequestId::Number(1));
let result = ctx.get_task_info("task-1").await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Client request not available")
);
}
#[tokio::test]
async fn test_default_request_impl_errors() {
struct OnlySampleAndElicit;
#[async_trait]
impl ClientRequester for OnlySampleAndElicit {
async fn sample(&self, _: CreateMessageParams) -> Result<CreateMessageResult> {
unreachable!()
}
async fn elicit(&self, _: ElicitRequestParams) -> Result<ElicitResult> {
unreachable!()
}
}
let requester: ClientRequesterHandle = Arc::new(OnlySampleAndElicit);
let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
let err = ctx.get_task_info("x").await.unwrap_err();
assert!(err.to_string().contains("does not support arbitrary"));
}
}