Skip to main content

acp_utils/client/
prompt_handle.rs

1use agent_client_protocol as acp;
2use std::path::Path;
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6
7/// Commands sent from the main thread to the ACP `LocalSet` thread.
8#[derive(Debug)]
9pub enum PromptCommand {
10    Prompt {
11        session_id: acp::SessionId,
12        text: String,
13        content: Option<Vec<acp::ContentBlock>>,
14    },
15    Cancel {
16        session_id: acp::SessionId,
17    },
18    SetConfigOption {
19        session_id: acp::SessionId,
20        config_id: String,
21        value: String,
22    },
23    AuthenticateMcpServer {
24        session_id: acp::SessionId,
25        server_name: String,
26    },
27    Authenticate {
28        #[allow(dead_code)]
29        session_id: acp::SessionId,
30        method_id: String,
31    },
32    ListSessions,
33    LoadSession {
34        session_id: acp::SessionId,
35        cwd: std::path::PathBuf,
36    },
37    NewSession {
38        cwd: std::path::PathBuf,
39    },
40}
41
42/// Send-safe handle for issuing prompt commands to the !Send ACP connection.
43#[derive(Clone)]
44pub struct AcpPromptHandle {
45    pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
46}
47
48impl AcpPromptHandle {
49    /// Create a handle whose sends always succeed but are never read.
50    /// Useful for tests that don't care about prompt delivery.
51    pub fn noop() -> Self {
52        let (cmd_tx, rx) = mpsc::unbounded_channel();
53        std::mem::forget(rx);
54        Self { cmd_tx }
55    }
56
57    /// Create a handle paired with a receiver for inspecting sent commands.
58    /// Useful for tests that need to verify which commands were dispatched.
59    pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
60        let (cmd_tx, rx) = mpsc::unbounded_channel();
61        (Self { cmd_tx }, rx)
62    }
63
64    pub fn prompt(
65        &self,
66        session_id: &acp::SessionId,
67        text: &str,
68        content: Option<Vec<acp::ContentBlock>>,
69    ) -> Result<(), AcpClientError> {
70        self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
71    }
72
73    pub fn cancel(&self, session_id: &acp::SessionId) -> Result<(), AcpClientError> {
74        self.send(PromptCommand::Cancel { session_id: session_id.clone() })
75    }
76
77    pub fn set_config_option(
78        &self,
79        session_id: &acp::SessionId,
80        config_id: &str,
81        value: &str,
82    ) -> Result<(), AcpClientError> {
83        self.send(PromptCommand::SetConfigOption {
84            session_id: session_id.clone(),
85            config_id: config_id.to_string(),
86            value: value.to_string(),
87        })
88    }
89
90    pub fn authenticate_mcp_server(
91        &self,
92        session_id: &acp::SessionId,
93        server_name: &str,
94    ) -> Result<(), AcpClientError> {
95        self.send(PromptCommand::AuthenticateMcpServer {
96            session_id: session_id.clone(),
97            server_name: server_name.to_string(),
98        })
99    }
100
101    pub fn authenticate(&self, session_id: &acp::SessionId, method_id: &str) -> Result<(), AcpClientError> {
102        self.send(PromptCommand::Authenticate { session_id: session_id.clone(), method_id: method_id.to_string() })
103    }
104
105    pub fn list_sessions(&self) -> Result<(), AcpClientError> {
106        self.send(PromptCommand::ListSessions)
107    }
108
109    pub fn load_session(&self, session_id: &acp::SessionId, cwd: &Path) -> Result<(), AcpClientError> {
110        self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
111    }
112
113    pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
114        self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
115    }
116
117    fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
118        self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_noop_handle_succeeds_silently() {
128        let handle = AcpPromptHandle::noop();
129        let session_id = acp::SessionId::new("test");
130
131        assert!(handle.prompt(&session_id, "hello", None).is_ok());
132        assert!(handle.cancel(&session_id).is_ok());
133    }
134
135    #[test]
136    fn test_prompt_sends_command() {
137        let (tx, mut rx) = mpsc::unbounded_channel();
138        let handle = AcpPromptHandle { cmd_tx: tx };
139        let session_id = acp::SessionId::new("sess-1");
140
141        handle.prompt(&session_id, "hello", None).unwrap();
142
143        let cmd = rx.try_recv().unwrap();
144        match cmd {
145            PromptCommand::Prompt { session_id, text, .. } => {
146                assert_eq!(session_id.0.as_ref(), "sess-1");
147                assert_eq!(text, "hello");
148            }
149            _ => panic!("Expected Prompt command"),
150        }
151    }
152
153    #[test]
154    fn test_cancel_sends_command() {
155        let (tx, mut rx) = mpsc::unbounded_channel();
156        let handle = AcpPromptHandle { cmd_tx: tx };
157        let session_id = acp::SessionId::new("sess-1");
158
159        handle.cancel(&session_id).unwrap();
160
161        let cmd = rx.try_recv().unwrap();
162        assert!(matches!(cmd, PromptCommand::Cancel { .. }));
163    }
164
165    #[test]
166    fn test_set_config_option_sends_command() {
167        let (tx, mut rx) = mpsc::unbounded_channel();
168        let handle = AcpPromptHandle { cmd_tx: tx };
169        let session_id = acp::SessionId::new("sess-1");
170
171        handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
172
173        let cmd = rx.try_recv().unwrap();
174        match cmd {
175            PromptCommand::SetConfigOption { session_id, config_id, value } => {
176                assert_eq!(session_id.0.as_ref(), "sess-1");
177                assert_eq!(config_id, "model");
178                assert_eq!(value, "gpt-4o");
179            }
180            _ => panic!("Expected SetConfigOption command"),
181        }
182    }
183
184    #[test]
185    fn test_prompt_with_content_sends_blocks() {
186        let (tx, mut rx) = mpsc::unbounded_channel();
187        let handle = AcpPromptHandle { cmd_tx: tx };
188        let session_id = acp::SessionId::new("sess-1");
189        let content = vec![acp::ContentBlock::Text(acp::TextContent::new("attached"))];
190
191        handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
192
193        let cmd = rx.try_recv().unwrap();
194        match cmd {
195            PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
196                assert_eq!(session_id.0.as_ref(), "sess-1");
197                assert_eq!(text, "hello");
198                assert_eq!(extra, content);
199            }
200            _ => panic!("Expected Prompt command with content"),
201        }
202    }
203
204    #[test]
205    fn test_list_sessions_sends_command() {
206        let (tx, mut rx) = mpsc::unbounded_channel();
207        let handle = AcpPromptHandle { cmd_tx: tx };
208
209        handle.list_sessions().unwrap();
210
211        let cmd = rx.try_recv().unwrap();
212        assert!(matches!(cmd, PromptCommand::ListSessions));
213    }
214
215    #[test]
216    fn test_load_session_sends_command() {
217        let (tx, mut rx) = mpsc::unbounded_channel();
218        let handle = AcpPromptHandle { cmd_tx: tx };
219        let session_id = acp::SessionId::new("sess-restore");
220        let cwd = std::path::Path::new("/tmp/project");
221
222        handle.load_session(&session_id, cwd).unwrap();
223
224        let cmd = rx.try_recv().unwrap();
225        match cmd {
226            PromptCommand::LoadSession { session_id, cwd } => {
227                assert_eq!(session_id.0.as_ref(), "sess-restore");
228                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
229            }
230            _ => panic!("Expected LoadSession command"),
231        }
232    }
233
234    #[test]
235    fn test_new_session_sends_command() {
236        let (tx, mut rx) = mpsc::unbounded_channel();
237        let handle = AcpPromptHandle { cmd_tx: tx };
238        let cwd = std::path::Path::new("/tmp/project");
239
240        handle.new_session(cwd).unwrap();
241
242        let cmd = rx.try_recv().unwrap();
243        match cmd {
244            PromptCommand::NewSession { cwd } => {
245                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
246            }
247            _ => panic!("Expected NewSession command"),
248        }
249    }
250}