1use agent_client_protocol as acp;
2use std::path::Path;
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6
7#[derive(Debug)]
9pub enum PromptCommand {
10 Prompt {
11 session_id: acp::SessionId,
12 text: String,
13 content: Option<Vec<acp::ContentBlock>>,
14 },
15 Cancel {
16 session_id: acp::SessionId,
17 },
18 SetConfigOption {
19 session_id: acp::SessionId,
20 config_id: String,
21 value: String,
22 },
23 AuthenticateMcpServer {
24 session_id: acp::SessionId,
25 server_name: String,
26 },
27 Authenticate {
28 #[allow(dead_code)]
29 session_id: acp::SessionId,
30 method_id: String,
31 },
32 ListSessions,
33 LoadSession {
34 session_id: acp::SessionId,
35 cwd: std::path::PathBuf,
36 },
37 NewSession {
38 cwd: std::path::PathBuf,
39 },
40}
41
42#[derive(Clone)]
44pub struct AcpPromptHandle {
45 pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
46}
47
48impl AcpPromptHandle {
49 pub fn noop() -> Self {
52 let (cmd_tx, rx) = mpsc::unbounded_channel();
53 std::mem::forget(rx);
54 Self { cmd_tx }
55 }
56
57 pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
60 let (cmd_tx, rx) = mpsc::unbounded_channel();
61 (Self { cmd_tx }, rx)
62 }
63
64 pub fn prompt(
65 &self,
66 session_id: &acp::SessionId,
67 text: &str,
68 content: Option<Vec<acp::ContentBlock>>,
69 ) -> Result<(), AcpClientError> {
70 self.send(PromptCommand::Prompt {
71 session_id: session_id.clone(),
72 text: text.to_string(),
73 content,
74 })
75 }
76
77 pub fn cancel(&self, session_id: &acp::SessionId) -> Result<(), AcpClientError> {
78 self.send(PromptCommand::Cancel {
79 session_id: session_id.clone(),
80 })
81 }
82
83 pub fn set_config_option(
84 &self,
85 session_id: &acp::SessionId,
86 config_id: &str,
87 value: &str,
88 ) -> Result<(), AcpClientError> {
89 self.send(PromptCommand::SetConfigOption {
90 session_id: session_id.clone(),
91 config_id: config_id.to_string(),
92 value: value.to_string(),
93 })
94 }
95
96 pub fn authenticate_mcp_server(
97 &self,
98 session_id: &acp::SessionId,
99 server_name: &str,
100 ) -> Result<(), AcpClientError> {
101 self.send(PromptCommand::AuthenticateMcpServer {
102 session_id: session_id.clone(),
103 server_name: server_name.to_string(),
104 })
105 }
106
107 pub fn authenticate(
108 &self,
109 session_id: &acp::SessionId,
110 method_id: &str,
111 ) -> Result<(), AcpClientError> {
112 self.send(PromptCommand::Authenticate {
113 session_id: session_id.clone(),
114 method_id: method_id.to_string(),
115 })
116 }
117
118 pub fn list_sessions(&self) -> Result<(), AcpClientError> {
119 self.send(PromptCommand::ListSessions)
120 }
121
122 pub fn load_session(
123 &self,
124 session_id: &acp::SessionId,
125 cwd: &Path,
126 ) -> Result<(), AcpClientError> {
127 self.send(PromptCommand::LoadSession {
128 session_id: session_id.clone(),
129 cwd: cwd.to_path_buf(),
130 })
131 }
132
133 pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
134 self.send(PromptCommand::NewSession {
135 cwd: cwd.to_path_buf(),
136 })
137 }
138
139 fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
140 self.cmd_tx
141 .send(cmd)
142 .map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_noop_handle_succeeds_silently() {
152 let handle = AcpPromptHandle::noop();
153 let session_id = acp::SessionId::new("test");
154
155 assert!(handle.prompt(&session_id, "hello", None).is_ok());
156 assert!(handle.cancel(&session_id).is_ok());
157 }
158
159 #[test]
160 fn test_prompt_sends_command() {
161 let (tx, mut rx) = mpsc::unbounded_channel();
162 let handle = AcpPromptHandle { cmd_tx: tx };
163 let session_id = acp::SessionId::new("sess-1");
164
165 handle.prompt(&session_id, "hello", None).unwrap();
166
167 let cmd = rx.try_recv().unwrap();
168 match cmd {
169 PromptCommand::Prompt {
170 session_id, text, ..
171 } => {
172 assert_eq!(session_id.0.as_ref(), "sess-1");
173 assert_eq!(text, "hello");
174 }
175 _ => panic!("Expected Prompt command"),
176 }
177 }
178
179 #[test]
180 fn test_cancel_sends_command() {
181 let (tx, mut rx) = mpsc::unbounded_channel();
182 let handle = AcpPromptHandle { cmd_tx: tx };
183 let session_id = acp::SessionId::new("sess-1");
184
185 handle.cancel(&session_id).unwrap();
186
187 let cmd = rx.try_recv().unwrap();
188 assert!(matches!(cmd, PromptCommand::Cancel { .. }));
189 }
190
191 #[test]
192 fn test_set_config_option_sends_command() {
193 let (tx, mut rx) = mpsc::unbounded_channel();
194 let handle = AcpPromptHandle { cmd_tx: tx };
195 let session_id = acp::SessionId::new("sess-1");
196
197 handle
198 .set_config_option(&session_id, "model", "gpt-4o")
199 .unwrap();
200
201 let cmd = rx.try_recv().unwrap();
202 match cmd {
203 PromptCommand::SetConfigOption {
204 session_id,
205 config_id,
206 value,
207 } => {
208 assert_eq!(session_id.0.as_ref(), "sess-1");
209 assert_eq!(config_id, "model");
210 assert_eq!(value, "gpt-4o");
211 }
212 _ => panic!("Expected SetConfigOption command"),
213 }
214 }
215
216 #[test]
217 fn test_prompt_with_content_sends_blocks() {
218 let (tx, mut rx) = mpsc::unbounded_channel();
219 let handle = AcpPromptHandle { cmd_tx: tx };
220 let session_id = acp::SessionId::new("sess-1");
221 let content = vec![acp::ContentBlock::Text(acp::TextContent::new("attached"))];
222
223 handle
224 .prompt(&session_id, "hello", Some(content.clone()))
225 .unwrap();
226
227 let cmd = rx.try_recv().unwrap();
228 match cmd {
229 PromptCommand::Prompt {
230 session_id,
231 text,
232 content: Some(extra),
233 } => {
234 assert_eq!(session_id.0.as_ref(), "sess-1");
235 assert_eq!(text, "hello");
236 assert_eq!(extra, content);
237 }
238 _ => panic!("Expected Prompt command with content"),
239 }
240 }
241
242 #[test]
243 fn test_list_sessions_sends_command() {
244 let (tx, mut rx) = mpsc::unbounded_channel();
245 let handle = AcpPromptHandle { cmd_tx: tx };
246
247 handle.list_sessions().unwrap();
248
249 let cmd = rx.try_recv().unwrap();
250 assert!(matches!(cmd, PromptCommand::ListSessions));
251 }
252
253 #[test]
254 fn test_load_session_sends_command() {
255 let (tx, mut rx) = mpsc::unbounded_channel();
256 let handle = AcpPromptHandle { cmd_tx: tx };
257 let session_id = acp::SessionId::new("sess-restore");
258 let cwd = std::path::Path::new("/tmp/project");
259
260 handle.load_session(&session_id, cwd).unwrap();
261
262 let cmd = rx.try_recv().unwrap();
263 match cmd {
264 PromptCommand::LoadSession { session_id, cwd } => {
265 assert_eq!(session_id.0.as_ref(), "sess-restore");
266 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
267 }
268 _ => panic!("Expected LoadSession command"),
269 }
270 }
271
272 #[test]
273 fn test_new_session_sends_command() {
274 let (tx, mut rx) = mpsc::unbounded_channel();
275 let handle = AcpPromptHandle { cmd_tx: tx };
276 let cwd = std::path::Path::new("/tmp/project");
277
278 handle.new_session(cwd).unwrap();
279
280 let cmd = rx.try_recv().unwrap();
281 match cmd {
282 PromptCommand::NewSession { cwd } => {
283 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
284 }
285 _ => panic!("Expected NewSession command"),
286 }
287 }
288}