Skip to main content

acp_utils/client/
prompt_handle.rs

1use agent_client_protocol::schema::{ContentBlock, SessionId};
2use std::path::{Path, PathBuf};
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6use crate::notifications::{
7    PromptSearchParams, SessionPreviewParams, WorkspaceListParams, WorkspaceMoveParams, WorkspaceMoveTarget,
8};
9
10/// Commands sent from the main thread to the ACP client task.
11#[derive(Debug)]
12pub enum PromptCommand {
13    Prompt { session_id: SessionId, text: String, content: Option<Vec<ContentBlock>> },
14    Cancel { session_id: SessionId },
15    SetConfigOption { session_id: SessionId, config_id: String, value: String },
16    AuthenticateMcpServer { session_id: SessionId, server_name: String },
17    Authenticate { method_id: String },
18    ListSessions,
19    LoadSession { session_id: SessionId, cwd: PathBuf },
20    NewSession { cwd: std::path::PathBuf },
21    SearchPrompts(PromptSearchParams),
22    SessionPreview(SessionPreviewParams),
23    ListWorkspaces(WorkspaceListParams),
24    MoveWorkspace(WorkspaceMoveParams),
25}
26
27/// Send-safe handle for issuing prompt commands to the ACP client task.
28#[derive(Clone)]
29pub struct AcpPromptHandle {
30    pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
31}
32
33impl AcpPromptHandle {
34    /// Create a handle whose sends always succeed but are never read.
35    /// Useful for tests that don't care about prompt delivery.
36    pub fn noop() -> Self {
37        let (cmd_tx, rx) = mpsc::unbounded_channel();
38        std::mem::forget(rx);
39        Self { cmd_tx }
40    }
41
42    /// Create a handle paired with a receiver for inspecting sent commands.
43    /// Useful for tests that need to verify which commands were dispatched.
44    pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
45        let (cmd_tx, rx) = mpsc::unbounded_channel();
46        (Self { cmd_tx }, rx)
47    }
48
49    pub fn prompt(
50        &self,
51        session_id: &SessionId,
52        text: &str,
53        content: Option<Vec<ContentBlock>>,
54    ) -> Result<(), AcpClientError> {
55        self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
56    }
57
58    pub fn cancel(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
59        self.send(PromptCommand::Cancel { session_id: session_id.clone() })
60    }
61
62    pub fn set_config_option(
63        &self,
64        session_id: &SessionId,
65        config_id: &str,
66        value: &str,
67    ) -> Result<(), AcpClientError> {
68        self.send(PromptCommand::SetConfigOption {
69            session_id: session_id.clone(),
70            config_id: config_id.to_string(),
71            value: value.to_string(),
72        })
73    }
74
75    pub fn authenticate_mcp_server(&self, session_id: &SessionId, server_name: &str) -> Result<(), AcpClientError> {
76        self.send(PromptCommand::AuthenticateMcpServer {
77            session_id: session_id.clone(),
78            server_name: server_name.to_string(),
79        })
80    }
81
82    pub fn authenticate(&self, method_id: &str) -> Result<(), AcpClientError> {
83        self.send(PromptCommand::Authenticate { method_id: method_id.to_string() })
84    }
85
86    pub fn list_sessions(&self) -> Result<(), AcpClientError> {
87        self.send(PromptCommand::ListSessions)
88    }
89
90    pub fn load_session(&self, session_id: &SessionId, cwd: &Path) -> Result<(), AcpClientError> {
91        self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
92    }
93
94    pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
95        self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
96    }
97
98    pub fn search_prompts(&self, params: PromptSearchParams) -> Result<(), AcpClientError> {
99        self.send(PromptCommand::SearchPrompts(params))
100    }
101
102    pub fn session_preview(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
103        self.send(PromptCommand::SessionPreview(SessionPreviewParams { session_id: session_id.0.to_string() }))
104    }
105
106    pub fn list_workspaces(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
107        self.send(PromptCommand::ListWorkspaces(WorkspaceListParams { session_id: session_id.0.to_string() }))
108    }
109
110    pub fn move_workspace(&self, session_id: &SessionId, target: WorkspaceMoveTarget) -> Result<(), AcpClientError> {
111        self.send(PromptCommand::MoveWorkspace(WorkspaceMoveParams { session_id: session_id.0.to_string(), target }))
112    }
113
114    fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
115        self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use agent_client_protocol::schema::TextContent;
122
123    use super::*;
124
125    #[test]
126    fn test_noop_handle_succeeds_silently() {
127        let handle = AcpPromptHandle::noop();
128        let session_id = SessionId::new("test");
129
130        assert!(handle.prompt(&session_id, "hello", None).is_ok());
131        assert!(handle.cancel(&session_id).is_ok());
132    }
133
134    #[test]
135    fn test_prompt_sends_command() {
136        let (tx, mut rx) = mpsc::unbounded_channel();
137        let handle = AcpPromptHandle { cmd_tx: tx };
138        let session_id = SessionId::new("sess-1");
139
140        handle.prompt(&session_id, "hello", None).unwrap();
141
142        let cmd = rx.try_recv().unwrap();
143        match cmd {
144            PromptCommand::Prompt { session_id, text, .. } => {
145                assert_eq!(session_id.0.as_ref(), "sess-1");
146                assert_eq!(text, "hello");
147            }
148            _ => panic!("Expected Prompt command"),
149        }
150    }
151
152    #[test]
153    fn test_cancel_sends_command() {
154        let (tx, mut rx) = mpsc::unbounded_channel();
155        let handle = AcpPromptHandle { cmd_tx: tx };
156        let session_id = SessionId::new("sess-1");
157
158        handle.cancel(&session_id).unwrap();
159
160        let cmd = rx.try_recv().unwrap();
161        assert!(matches!(cmd, PromptCommand::Cancel { .. }));
162    }
163
164    #[test]
165    fn test_set_config_option_sends_command() {
166        let (tx, mut rx) = mpsc::unbounded_channel();
167        let handle = AcpPromptHandle { cmd_tx: tx };
168        let session_id = SessionId::new("sess-1");
169
170        handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
171
172        let cmd = rx.try_recv().unwrap();
173        match cmd {
174            PromptCommand::SetConfigOption { session_id, config_id, value } => {
175                assert_eq!(session_id.0.as_ref(), "sess-1");
176                assert_eq!(config_id, "model");
177                assert_eq!(value, "gpt-4o");
178            }
179            _ => panic!("Expected SetConfigOption command"),
180        }
181    }
182
183    #[test]
184    fn test_prompt_with_content_sends_blocks() {
185        let (tx, mut rx) = mpsc::unbounded_channel();
186        let handle = AcpPromptHandle { cmd_tx: tx };
187        let session_id = SessionId::new("sess-1");
188        let content = vec![ContentBlock::Text(TextContent::new("attached"))];
189
190        handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
191
192        let cmd = rx.try_recv().unwrap();
193        match cmd {
194            PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
195                assert_eq!(session_id.0.as_ref(), "sess-1");
196                assert_eq!(text, "hello");
197                assert_eq!(extra, content);
198            }
199            _ => panic!("Expected Prompt command with content"),
200        }
201    }
202
203    #[test]
204    fn test_list_sessions_sends_command() {
205        let (tx, mut rx) = mpsc::unbounded_channel();
206        let handle = AcpPromptHandle { cmd_tx: tx };
207
208        handle.list_sessions().unwrap();
209
210        let cmd = rx.try_recv().unwrap();
211        assert!(matches!(cmd, PromptCommand::ListSessions));
212    }
213
214    #[test]
215    fn test_load_session_sends_command() {
216        let (tx, mut rx) = mpsc::unbounded_channel();
217        let handle = AcpPromptHandle { cmd_tx: tx };
218        let session_id = SessionId::new("sess-restore");
219        let cwd = Path::new("/tmp/project");
220
221        handle.load_session(&session_id, cwd).unwrap();
222
223        let cmd = rx.try_recv().unwrap();
224        match cmd {
225            PromptCommand::LoadSession { session_id, cwd } => {
226                assert_eq!(session_id.0.as_ref(), "sess-restore");
227                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
228            }
229            _ => panic!("Expected LoadSession command"),
230        }
231    }
232
233    #[test]
234    fn test_list_workspaces_sends_command() {
235        let (tx, mut rx) = mpsc::unbounded_channel();
236        let handle = AcpPromptHandle { cmd_tx: tx };
237        handle.list_workspaces(&SessionId::new("sess-1")).unwrap();
238
239        let cmd = rx.try_recv().unwrap();
240        match cmd {
241            PromptCommand::ListWorkspaces(params) => assert_eq!(params.session_id, "sess-1"),
242            _ => panic!("Expected ListWorkspaces command"),
243        }
244    }
245
246    #[test]
247    fn test_move_workspace_sends_command() {
248        let (tx, mut rx) = mpsc::unbounded_channel();
249        let handle = AcpPromptHandle { cmd_tx: tx };
250        handle.move_workspace(&SessionId::new("sess-1"), WorkspaceMoveTarget::New { name: "ws".into() }).unwrap();
251
252        let cmd = rx.try_recv().unwrap();
253        match cmd {
254            PromptCommand::MoveWorkspace(params) => {
255                assert_eq!(params.session_id, "sess-1");
256                assert_eq!(params.target, WorkspaceMoveTarget::New { name: "ws".into() });
257            }
258            _ => panic!("Expected MoveWorkspace command"),
259        }
260    }
261
262    #[test]
263    fn test_new_session_sends_command() {
264        let (tx, mut rx) = mpsc::unbounded_channel();
265        let handle = AcpPromptHandle { cmd_tx: tx };
266        let cwd = std::path::Path::new("/tmp/project");
267
268        handle.new_session(cwd).unwrap();
269
270        let cmd = rx.try_recv().unwrap();
271        match cmd {
272            PromptCommand::NewSession { cwd } => {
273                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
274            }
275            _ => panic!("Expected NewSession command"),
276        }
277    }
278}