Skip to main content

acp_utils/client/
prompt_handle.rs

1use agent_client_protocol as acp;
2use std::path::Path;
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6
7/// Commands sent from the main thread to the ACP `LocalSet` thread.
8#[derive(Debug)]
9pub enum PromptCommand {
10    Prompt {
11        session_id: acp::SessionId,
12        text: String,
13        content: Option<Vec<acp::ContentBlock>>,
14    },
15    Cancel {
16        session_id: acp::SessionId,
17    },
18    SetConfigOption {
19        session_id: acp::SessionId,
20        config_id: String,
21        value: String,
22    },
23    AuthenticateMcpServer {
24        session_id: acp::SessionId,
25        server_name: String,
26    },
27    Authenticate {
28        #[allow(dead_code)]
29        session_id: acp::SessionId,
30        method_id: String,
31    },
32    ListSessions,
33    LoadSession {
34        session_id: acp::SessionId,
35        cwd: std::path::PathBuf,
36    },
37    NewSession {
38        cwd: std::path::PathBuf,
39    },
40}
41
42/// Send-safe handle for issuing prompt commands to the !Send ACP connection.
43#[derive(Clone)]
44pub struct AcpPromptHandle {
45    pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
46}
47
48impl AcpPromptHandle {
49    /// Create a handle whose sends always succeed but are never read.
50    /// Useful for tests that don't care about prompt delivery.
51    pub fn noop() -> Self {
52        let (cmd_tx, rx) = mpsc::unbounded_channel();
53        std::mem::forget(rx);
54        Self { cmd_tx }
55    }
56
57    /// Create a handle paired with a receiver for inspecting sent commands.
58    /// Useful for tests that need to verify which commands were dispatched.
59    pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
60        let (cmd_tx, rx) = mpsc::unbounded_channel();
61        (Self { cmd_tx }, rx)
62    }
63
64    pub fn prompt(
65        &self,
66        session_id: &acp::SessionId,
67        text: &str,
68        content: Option<Vec<acp::ContentBlock>>,
69    ) -> Result<(), AcpClientError> {
70        self.send(PromptCommand::Prompt {
71            session_id: session_id.clone(),
72            text: text.to_string(),
73            content,
74        })
75    }
76
77    pub fn cancel(&self, session_id: &acp::SessionId) -> Result<(), AcpClientError> {
78        self.send(PromptCommand::Cancel {
79            session_id: session_id.clone(),
80        })
81    }
82
83    pub fn set_config_option(
84        &self,
85        session_id: &acp::SessionId,
86        config_id: &str,
87        value: &str,
88    ) -> Result<(), AcpClientError> {
89        self.send(PromptCommand::SetConfigOption {
90            session_id: session_id.clone(),
91            config_id: config_id.to_string(),
92            value: value.to_string(),
93        })
94    }
95
96    pub fn authenticate_mcp_server(
97        &self,
98        session_id: &acp::SessionId,
99        server_name: &str,
100    ) -> Result<(), AcpClientError> {
101        self.send(PromptCommand::AuthenticateMcpServer {
102            session_id: session_id.clone(),
103            server_name: server_name.to_string(),
104        })
105    }
106
107    pub fn authenticate(
108        &self,
109        session_id: &acp::SessionId,
110        method_id: &str,
111    ) -> Result<(), AcpClientError> {
112        self.send(PromptCommand::Authenticate {
113            session_id: session_id.clone(),
114            method_id: method_id.to_string(),
115        })
116    }
117
118    pub fn list_sessions(&self) -> Result<(), AcpClientError> {
119        self.send(PromptCommand::ListSessions)
120    }
121
122    pub fn load_session(
123        &self,
124        session_id: &acp::SessionId,
125        cwd: &Path,
126    ) -> Result<(), AcpClientError> {
127        self.send(PromptCommand::LoadSession {
128            session_id: session_id.clone(),
129            cwd: cwd.to_path_buf(),
130        })
131    }
132
133    pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
134        self.send(PromptCommand::NewSession {
135            cwd: cwd.to_path_buf(),
136        })
137    }
138
139    fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
140        self.cmd_tx
141            .send(cmd)
142            .map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_noop_handle_succeeds_silently() {
152        let handle = AcpPromptHandle::noop();
153        let session_id = acp::SessionId::new("test");
154
155        assert!(handle.prompt(&session_id, "hello", None).is_ok());
156        assert!(handle.cancel(&session_id).is_ok());
157    }
158
159    #[test]
160    fn test_prompt_sends_command() {
161        let (tx, mut rx) = mpsc::unbounded_channel();
162        let handle = AcpPromptHandle { cmd_tx: tx };
163        let session_id = acp::SessionId::new("sess-1");
164
165        handle.prompt(&session_id, "hello", None).unwrap();
166
167        let cmd = rx.try_recv().unwrap();
168        match cmd {
169            PromptCommand::Prompt {
170                session_id, text, ..
171            } => {
172                assert_eq!(session_id.0.as_ref(), "sess-1");
173                assert_eq!(text, "hello");
174            }
175            _ => panic!("Expected Prompt command"),
176        }
177    }
178
179    #[test]
180    fn test_cancel_sends_command() {
181        let (tx, mut rx) = mpsc::unbounded_channel();
182        let handle = AcpPromptHandle { cmd_tx: tx };
183        let session_id = acp::SessionId::new("sess-1");
184
185        handle.cancel(&session_id).unwrap();
186
187        let cmd = rx.try_recv().unwrap();
188        assert!(matches!(cmd, PromptCommand::Cancel { .. }));
189    }
190
191    #[test]
192    fn test_set_config_option_sends_command() {
193        let (tx, mut rx) = mpsc::unbounded_channel();
194        let handle = AcpPromptHandle { cmd_tx: tx };
195        let session_id = acp::SessionId::new("sess-1");
196
197        handle
198            .set_config_option(&session_id, "model", "gpt-4o")
199            .unwrap();
200
201        let cmd = rx.try_recv().unwrap();
202        match cmd {
203            PromptCommand::SetConfigOption {
204                session_id,
205                config_id,
206                value,
207            } => {
208                assert_eq!(session_id.0.as_ref(), "sess-1");
209                assert_eq!(config_id, "model");
210                assert_eq!(value, "gpt-4o");
211            }
212            _ => panic!("Expected SetConfigOption command"),
213        }
214    }
215
216    #[test]
217    fn test_prompt_with_content_sends_blocks() {
218        let (tx, mut rx) = mpsc::unbounded_channel();
219        let handle = AcpPromptHandle { cmd_tx: tx };
220        let session_id = acp::SessionId::new("sess-1");
221        let content = vec![acp::ContentBlock::Text(acp::TextContent::new("attached"))];
222
223        handle
224            .prompt(&session_id, "hello", Some(content.clone()))
225            .unwrap();
226
227        let cmd = rx.try_recv().unwrap();
228        match cmd {
229            PromptCommand::Prompt {
230                session_id,
231                text,
232                content: Some(extra),
233            } => {
234                assert_eq!(session_id.0.as_ref(), "sess-1");
235                assert_eq!(text, "hello");
236                assert_eq!(extra, content);
237            }
238            _ => panic!("Expected Prompt command with content"),
239        }
240    }
241
242    #[test]
243    fn test_list_sessions_sends_command() {
244        let (tx, mut rx) = mpsc::unbounded_channel();
245        let handle = AcpPromptHandle { cmd_tx: tx };
246
247        handle.list_sessions().unwrap();
248
249        let cmd = rx.try_recv().unwrap();
250        assert!(matches!(cmd, PromptCommand::ListSessions));
251    }
252
253    #[test]
254    fn test_load_session_sends_command() {
255        let (tx, mut rx) = mpsc::unbounded_channel();
256        let handle = AcpPromptHandle { cmd_tx: tx };
257        let session_id = acp::SessionId::new("sess-restore");
258        let cwd = std::path::Path::new("/tmp/project");
259
260        handle.load_session(&session_id, cwd).unwrap();
261
262        let cmd = rx.try_recv().unwrap();
263        match cmd {
264            PromptCommand::LoadSession { session_id, cwd } => {
265                assert_eq!(session_id.0.as_ref(), "sess-restore");
266                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
267            }
268            _ => panic!("Expected LoadSession command"),
269        }
270    }
271
272    #[test]
273    fn test_new_session_sends_command() {
274        let (tx, mut rx) = mpsc::unbounded_channel();
275        let handle = AcpPromptHandle { cmd_tx: tx };
276        let cwd = std::path::Path::new("/tmp/project");
277
278        handle.new_session(cwd).unwrap();
279
280        let cmd = rx.try_recv().unwrap();
281        match cmd {
282            PromptCommand::NewSession { cwd } => {
283                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
284            }
285            _ => panic!("Expected NewSession command"),
286        }
287    }
288}