1use agent_client_protocol as acp;
2use std::path::Path;
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6
7#[derive(Debug)]
9pub enum PromptCommand {
10 Prompt {
11 session_id: acp::SessionId,
12 text: String,
13 content: Option<Vec<acp::ContentBlock>>,
14 },
15 Cancel {
16 session_id: acp::SessionId,
17 },
18 SetConfigOption {
19 session_id: acp::SessionId,
20 config_id: String,
21 value: String,
22 },
23 AuthenticateMcpServer {
24 session_id: acp::SessionId,
25 server_name: String,
26 },
27 Authenticate {
28 #[allow(dead_code)]
29 session_id: acp::SessionId,
30 method_id: String,
31 },
32 ListSessions,
33 LoadSession {
34 session_id: acp::SessionId,
35 cwd: std::path::PathBuf,
36 },
37 NewSession {
38 cwd: std::path::PathBuf,
39 },
40}
41
42#[derive(Clone)]
44pub struct AcpPromptHandle {
45 pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
46}
47
48impl AcpPromptHandle {
49 pub fn noop() -> Self {
52 let (cmd_tx, rx) = mpsc::unbounded_channel();
53 std::mem::forget(rx);
54 Self { cmd_tx }
55 }
56
57 pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
60 let (cmd_tx, rx) = mpsc::unbounded_channel();
61 (Self { cmd_tx }, rx)
62 }
63
64 pub fn prompt(
65 &self,
66 session_id: &acp::SessionId,
67 text: &str,
68 content: Option<Vec<acp::ContentBlock>>,
69 ) -> Result<(), AcpClientError> {
70 self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
71 }
72
73 pub fn cancel(&self, session_id: &acp::SessionId) -> Result<(), AcpClientError> {
74 self.send(PromptCommand::Cancel { session_id: session_id.clone() })
75 }
76
77 pub fn set_config_option(
78 &self,
79 session_id: &acp::SessionId,
80 config_id: &str,
81 value: &str,
82 ) -> Result<(), AcpClientError> {
83 self.send(PromptCommand::SetConfigOption {
84 session_id: session_id.clone(),
85 config_id: config_id.to_string(),
86 value: value.to_string(),
87 })
88 }
89
90 pub fn authenticate_mcp_server(
91 &self,
92 session_id: &acp::SessionId,
93 server_name: &str,
94 ) -> Result<(), AcpClientError> {
95 self.send(PromptCommand::AuthenticateMcpServer {
96 session_id: session_id.clone(),
97 server_name: server_name.to_string(),
98 })
99 }
100
101 pub fn authenticate(&self, session_id: &acp::SessionId, method_id: &str) -> Result<(), AcpClientError> {
102 self.send(PromptCommand::Authenticate { session_id: session_id.clone(), method_id: method_id.to_string() })
103 }
104
105 pub fn list_sessions(&self) -> Result<(), AcpClientError> {
106 self.send(PromptCommand::ListSessions)
107 }
108
109 pub fn load_session(&self, session_id: &acp::SessionId, cwd: &Path) -> Result<(), AcpClientError> {
110 self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
111 }
112
113 pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
114 self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
115 }
116
117 fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
118 self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_noop_handle_succeeds_silently() {
128 let handle = AcpPromptHandle::noop();
129 let session_id = acp::SessionId::new("test");
130
131 assert!(handle.prompt(&session_id, "hello", None).is_ok());
132 assert!(handle.cancel(&session_id).is_ok());
133 }
134
135 #[test]
136 fn test_prompt_sends_command() {
137 let (tx, mut rx) = mpsc::unbounded_channel();
138 let handle = AcpPromptHandle { cmd_tx: tx };
139 let session_id = acp::SessionId::new("sess-1");
140
141 handle.prompt(&session_id, "hello", None).unwrap();
142
143 let cmd = rx.try_recv().unwrap();
144 match cmd {
145 PromptCommand::Prompt { session_id, text, .. } => {
146 assert_eq!(session_id.0.as_ref(), "sess-1");
147 assert_eq!(text, "hello");
148 }
149 _ => panic!("Expected Prompt command"),
150 }
151 }
152
153 #[test]
154 fn test_cancel_sends_command() {
155 let (tx, mut rx) = mpsc::unbounded_channel();
156 let handle = AcpPromptHandle { cmd_tx: tx };
157 let session_id = acp::SessionId::new("sess-1");
158
159 handle.cancel(&session_id).unwrap();
160
161 let cmd = rx.try_recv().unwrap();
162 assert!(matches!(cmd, PromptCommand::Cancel { .. }));
163 }
164
165 #[test]
166 fn test_set_config_option_sends_command() {
167 let (tx, mut rx) = mpsc::unbounded_channel();
168 let handle = AcpPromptHandle { cmd_tx: tx };
169 let session_id = acp::SessionId::new("sess-1");
170
171 handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
172
173 let cmd = rx.try_recv().unwrap();
174 match cmd {
175 PromptCommand::SetConfigOption { session_id, config_id, value } => {
176 assert_eq!(session_id.0.as_ref(), "sess-1");
177 assert_eq!(config_id, "model");
178 assert_eq!(value, "gpt-4o");
179 }
180 _ => panic!("Expected SetConfigOption command"),
181 }
182 }
183
184 #[test]
185 fn test_prompt_with_content_sends_blocks() {
186 let (tx, mut rx) = mpsc::unbounded_channel();
187 let handle = AcpPromptHandle { cmd_tx: tx };
188 let session_id = acp::SessionId::new("sess-1");
189 let content = vec![acp::ContentBlock::Text(acp::TextContent::new("attached"))];
190
191 handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
192
193 let cmd = rx.try_recv().unwrap();
194 match cmd {
195 PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
196 assert_eq!(session_id.0.as_ref(), "sess-1");
197 assert_eq!(text, "hello");
198 assert_eq!(extra, content);
199 }
200 _ => panic!("Expected Prompt command with content"),
201 }
202 }
203
204 #[test]
205 fn test_list_sessions_sends_command() {
206 let (tx, mut rx) = mpsc::unbounded_channel();
207 let handle = AcpPromptHandle { cmd_tx: tx };
208
209 handle.list_sessions().unwrap();
210
211 let cmd = rx.try_recv().unwrap();
212 assert!(matches!(cmd, PromptCommand::ListSessions));
213 }
214
215 #[test]
216 fn test_load_session_sends_command() {
217 let (tx, mut rx) = mpsc::unbounded_channel();
218 let handle = AcpPromptHandle { cmd_tx: tx };
219 let session_id = acp::SessionId::new("sess-restore");
220 let cwd = std::path::Path::new("/tmp/project");
221
222 handle.load_session(&session_id, cwd).unwrap();
223
224 let cmd = rx.try_recv().unwrap();
225 match cmd {
226 PromptCommand::LoadSession { session_id, cwd } => {
227 assert_eq!(session_id.0.as_ref(), "sess-restore");
228 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
229 }
230 _ => panic!("Expected LoadSession command"),
231 }
232 }
233
234 #[test]
235 fn test_new_session_sends_command() {
236 let (tx, mut rx) = mpsc::unbounded_channel();
237 let handle = AcpPromptHandle { cmd_tx: tx };
238 let cwd = std::path::Path::new("/tmp/project");
239
240 handle.new_session(cwd).unwrap();
241
242 let cmd = rx.try_recv().unwrap();
243 match cmd {
244 PromptCommand::NewSession { cwd } => {
245 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
246 }
247 _ => panic!("Expected NewSession command"),
248 }
249 }
250}