use serde::de::DeserializeOwned;
use tokio::sync::mpsc;
use crate::{
error::{Error, Result},
request_handler::{RequestHandler, TransportSink},
schema::{self},
};
#[derive(Clone)]
pub struct ClientCtx {
pub(crate) notification_tx: mpsc::UnboundedSender<schema::ClientNotification>,
pub(crate) request_id: Option<schema::RequestId>,
}
impl ClientCtx {
pub(crate) fn new(notification_tx: mpsc::UnboundedSender<schema::ClientNotification>) -> Self {
Self {
notification_tx,
request_id: None,
}
}
pub fn notify(&self, notification: schema::ClientNotification) -> Result<()> {
self.notification_tx
.send(notification)
.map_err(|_| Error::InternalError("Failed to send notification".into()))?;
Ok(())
}
pub(crate) fn with_request_id(&self, request_id: schema::RequestId) -> Self {
let mut ctx = self.clone();
ctx.request_id = Some(request_id);
ctx
}
pub fn cancel(&self, reason: Option<String>) -> Result<()> {
if let Some(request_id) = &self.request_id {
self.notify(schema::ClientNotification::cancelled(
Some(request_id.clone()),
reason,
))
} else {
Err(Error::InternalError(
"No request ID available to cancel".into(),
))
}
}
}
#[derive(Clone)]
pub struct ServerCtx {
pub(crate) notification_tx: mpsc::UnboundedSender<schema::ServerNotification>,
request_handler: RequestHandler,
pub(crate) request_id: Option<schema::RequestId>,
}
impl ServerCtx {
pub(crate) fn new(
notification_tx: mpsc::UnboundedSender<schema::ServerNotification>,
transport_tx: Option<TransportSink>,
) -> Self {
Self {
notification_tx,
request_handler: RequestHandler::new(transport_tx, "srv-req".to_string()),
request_id: None,
}
}
pub fn notify(&self, notification: schema::ServerNotification) -> Result<()> {
self.notification_tx
.send(notification)
.map_err(|_| Error::InternalError("Failed to send notification".into()))?;
Ok(())
}
pub(crate) fn with_request_id(&self, request_id: schema::RequestId) -> Self {
let mut ctx = self.clone();
ctx.request_id = Some(request_id);
ctx
}
async fn request<T>(&self, request: schema::ServerRequest) -> Result<T>
where
T: DeserializeOwned,
{
self.request_handler.request(request).await
}
pub(crate) async fn handle_client_response(&self, response: schema::JSONRPCResponse) {
let handler = self.request_handler.clone();
handler.handle_response(response).await
}
pub fn cancel(&self, reason: Option<String>) -> Result<()> {
if let Some(request_id) = &self.request_id {
self.notify(schema::ServerNotification::cancelled(
Some(request_id.clone()),
reason,
))
} else {
Err(Error::InternalError(
"No request ID available to cancel".into(),
))
}
}
pub async fn ping(&self) -> Result<()> {
let _: schema::EmptyResult = self.request(schema::ServerRequest::ping()).await?;
Ok(())
}
pub async fn create_message(
&self,
params: schema::CreateMessageParams,
) -> Result<schema::CreateMessageResult> {
self.request(schema::ServerRequest::create_message(params))
.await
}
pub async fn list_roots(&self) -> Result<schema::ListRootsResult> {
self.request(schema::ServerRequest::list_roots()).await
}
pub async fn elicit(
&self,
params: schema::ElicitRequestParams,
) -> Result<schema::ElicitResult> {
self.request(schema::ServerRequest::elicit(params)).await
}
pub async fn get_task(
&self,
task_id: impl Into<String> + Send,
) -> Result<schema::GetTaskResult> {
self.request(schema::ServerRequest::get_task(task_id)).await
}
pub async fn get_task_payload(
&self,
task_id: impl Into<String> + Send,
) -> Result<schema::GetTaskPayloadResult> {
self.request(schema::ServerRequest::get_task_payload(task_id))
.await
}
pub async fn list_tasks(
&self,
cursor: impl Into<Option<schema::Cursor>> + Send,
) -> Result<schema::ListTasksResult> {
self.request(schema::ServerRequest::list_tasks(cursor.into()))
.await
}
pub async fn cancel_task(
&self,
task_id: impl Into<String> + Send,
) -> Result<schema::CancelTaskResult> {
self.request(schema::ServerRequest::cancel_task(task_id))
.await
}
}