Skip to main content

acp_utils/server/
actor_handle.rs

1use agent_client_protocol as acp;
2use tokio::sync::{mpsc, oneshot};
3
4use super::{AcpRequest, AcpServerError};
5
6/// Send-safe handle to communicate with an [`AcpActor`](super::AcpActor).
7#[derive(Clone, Debug)]
8pub struct AcpActorHandle {
9    request_tx: mpsc::UnboundedSender<AcpRequest>,
10}
11
12impl AcpActorHandle {
13    pub fn new(request_tx: mpsc::UnboundedSender<AcpRequest>) -> Self {
14        Self { request_tx }
15    }
16
17    pub async fn send_session_notification(
18        &self,
19        notification: acp::SessionNotification,
20    ) -> Result<(), AcpServerError> {
21        self.send_request(|tx| AcpRequest::SessionNotification {
22            notification: Box::new(notification),
23            response_tx: tx,
24        })
25        .await
26    }
27
28    pub async fn send_ext_notification(&self, notification: acp::ExtNotification) -> Result<(), AcpServerError> {
29        self.send_request(|tx| AcpRequest::ExtNotification { notification, response_tx: tx }).await
30    }
31
32    pub async fn request_permission(
33        &self,
34        request: acp::RequestPermissionRequest,
35    ) -> Result<acp::RequestPermissionResponse, AcpServerError> {
36        self.send_request(|tx| AcpRequest::RequestPermission { request: Box::new(request), response_tx: tx }).await
37    }
38
39    pub async fn ext_method(&self, request: acp::ExtRequest) -> Result<acp::ExtResponse, AcpServerError> {
40        self.send_request(|tx| AcpRequest::ExtMethod { request, response_tx: tx }).await
41    }
42
43    async fn send_request<R>(
44        &self,
45        make_request: impl FnOnce(oneshot::Sender<Result<R, AcpServerError>>) -> AcpRequest,
46    ) -> Result<R, AcpServerError> {
47        let (tx, rx) = oneshot::channel();
48        self.request_tx.send(make_request(tx)).map_err(|_| AcpServerError::ActorStopped)?;
49        rx.await.map_err(|_| AcpServerError::ActorStopped)?
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[tokio::test]
58    async fn test_handle_returns_error_when_actor_stopped() {
59        let (tx, rx) = mpsc::unbounded_channel();
60        let handle = AcpActorHandle::new(tx);
61
62        // Drop the receiver to simulate a stopped actor
63        drop(rx);
64
65        let session_id = acp::SessionId::new("test");
66        let notification = acp::SessionNotification::new(
67            session_id,
68            acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text(
69                acp::TextContent::new("test"),
70            ))),
71        );
72
73        let result = handle.send_session_notification(notification).await;
74        assert!(matches!(result, Err(AcpServerError::ActorStopped)));
75    }
76
77    #[tokio::test]
78    async fn test_ext_handle_returns_error_when_actor_stopped() {
79        let (tx, rx) = mpsc::unbounded_channel();
80        let handle = AcpActorHandle::new(tx);
81        drop(rx);
82
83        let null_value: std::sync::Arc<serde_json::value::RawValue> =
84            serde_json::from_str("null").expect("null is valid JSON");
85        let notification = acp::ExtNotification::new("test/method", null_value);
86        let result = handle.send_ext_notification(notification).await;
87        assert!(matches!(result, Err(AcpServerError::ActorStopped)));
88    }
89
90    #[tokio::test]
91    async fn test_ext_method_returns_error_when_actor_stopped() {
92        let (tx, rx) = mpsc::unbounded_channel();
93        let handle = AcpActorHandle::new(tx);
94        drop(rx);
95
96        let null_value: std::sync::Arc<serde_json::value::RawValue> =
97            serde_json::from_str("null").expect("null is valid JSON");
98        let request = acp::ExtRequest::new("test/method", null_value);
99        let result = handle.ext_method(request).await;
100        assert!(matches!(result, Err(AcpServerError::ActorStopped)));
101    }
102
103    #[tokio::test]
104    async fn test_request_permission_returns_error_when_actor_stopped() {
105        let (tx, rx) = mpsc::unbounded_channel();
106        let handle = AcpActorHandle::new(tx);
107        drop(rx);
108
109        let request = acp::RequestPermissionRequest::new(
110            acp::SessionId::new("test"),
111            acp::ToolCallUpdate::new(acp::ToolCallId::new("tool_1"), acp::ToolCallUpdateFields::new()),
112            vec![acp::PermissionOption::new(
113                acp::PermissionOptionId::new("allow-once"),
114                "Allow once",
115                acp::PermissionOptionKind::AllowOnce,
116            )],
117        );
118
119        let result = handle.request_permission(request).await;
120        assert!(matches!(result, Err(AcpServerError::ActorStopped)));
121    }
122}