use agent_client_protocol as acp;
use tokio::sync::{mpsc, oneshot};
use super::{AcpRequest, AcpServerError};
#[derive(Clone, Debug)]
pub struct AcpActorHandle {
request_tx: mpsc::UnboundedSender<AcpRequest>,
}
impl AcpActorHandle {
pub fn new(request_tx: mpsc::UnboundedSender<AcpRequest>) -> Self {
Self { request_tx }
}
pub async fn send_session_notification(
&self,
notification: acp::SessionNotification,
) -> Result<(), AcpServerError> {
self.send_request(|tx| AcpRequest::SessionNotification {
notification: Box::new(notification),
response_tx: tx,
})
.await
}
pub async fn send_ext_notification(&self, notification: acp::ExtNotification) -> Result<(), AcpServerError> {
self.send_request(|tx| AcpRequest::ExtNotification { notification, response_tx: tx }).await
}
pub async fn request_permission(
&self,
request: acp::RequestPermissionRequest,
) -> Result<acp::RequestPermissionResponse, AcpServerError> {
self.send_request(|tx| AcpRequest::RequestPermission { request: Box::new(request), response_tx: tx }).await
}
pub async fn ext_method(&self, request: acp::ExtRequest) -> Result<acp::ExtResponse, AcpServerError> {
self.send_request(|tx| AcpRequest::ExtMethod { request, response_tx: tx }).await
}
async fn send_request<R>(
&self,
make_request: impl FnOnce(oneshot::Sender<Result<R, AcpServerError>>) -> AcpRequest,
) -> Result<R, AcpServerError> {
let (tx, rx) = oneshot::channel();
self.request_tx.send(make_request(tx)).map_err(|_| AcpServerError::ActorStopped)?;
rx.await.map_err(|_| AcpServerError::ActorStopped)?
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_handle_returns_error_when_actor_stopped() {
let (tx, rx) = mpsc::unbounded_channel();
let handle = AcpActorHandle::new(tx);
drop(rx);
let session_id = acp::SessionId::new("test");
let notification = acp::SessionNotification::new(
session_id,
acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text(
acp::TextContent::new("test"),
))),
);
let result = handle.send_session_notification(notification).await;
assert!(matches!(result, Err(AcpServerError::ActorStopped)));
}
#[tokio::test]
async fn test_ext_handle_returns_error_when_actor_stopped() {
let (tx, rx) = mpsc::unbounded_channel();
let handle = AcpActorHandle::new(tx);
drop(rx);
let null_value: std::sync::Arc<serde_json::value::RawValue> =
serde_json::from_str("null").expect("null is valid JSON");
let notification = acp::ExtNotification::new("test/method", null_value);
let result = handle.send_ext_notification(notification).await;
assert!(matches!(result, Err(AcpServerError::ActorStopped)));
}
#[tokio::test]
async fn test_ext_method_returns_error_when_actor_stopped() {
let (tx, rx) = mpsc::unbounded_channel();
let handle = AcpActorHandle::new(tx);
drop(rx);
let null_value: std::sync::Arc<serde_json::value::RawValue> =
serde_json::from_str("null").expect("null is valid JSON");
let request = acp::ExtRequest::new("test/method", null_value);
let result = handle.ext_method(request).await;
assert!(matches!(result, Err(AcpServerError::ActorStopped)));
}
#[tokio::test]
async fn test_request_permission_returns_error_when_actor_stopped() {
let (tx, rx) = mpsc::unbounded_channel();
let handle = AcpActorHandle::new(tx);
drop(rx);
let request = acp::RequestPermissionRequest::new(
acp::SessionId::new("test"),
acp::ToolCallUpdate::new(acp::ToolCallId::new("tool_1"), acp::ToolCallUpdateFields::new()),
vec![acp::PermissionOption::new(
acp::PermissionOptionId::new("allow-once"),
"Allow once",
acp::PermissionOptionKind::AllowOnce,
)],
);
let result = handle.request_permission(request).await;
assert!(matches!(result, Err(AcpServerError::ActorStopped)));
}
}