1use agent_client_protocol::schema::{ContentBlock, SessionId};
2use std::path::{Path, PathBuf};
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6
7#[derive(Debug)]
9pub enum PromptCommand {
10 Prompt { session_id: SessionId, text: String, content: Option<Vec<ContentBlock>> },
11 Cancel { session_id: SessionId },
12 SetConfigOption { session_id: SessionId, config_id: String, value: String },
13 AuthenticateMcpServer { session_id: SessionId, server_name: String },
14 Authenticate { method_id: String },
15 ListSessions,
16 LoadSession { session_id: SessionId, cwd: PathBuf },
17 NewSession { cwd: std::path::PathBuf },
18}
19
20#[derive(Clone)]
22pub struct AcpPromptHandle {
23 pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
24}
25
26impl AcpPromptHandle {
27 pub fn noop() -> Self {
30 let (cmd_tx, rx) = mpsc::unbounded_channel();
31 std::mem::forget(rx);
32 Self { cmd_tx }
33 }
34
35 pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
38 let (cmd_tx, rx) = mpsc::unbounded_channel();
39 (Self { cmd_tx }, rx)
40 }
41
42 pub fn prompt(
43 &self,
44 session_id: &SessionId,
45 text: &str,
46 content: Option<Vec<ContentBlock>>,
47 ) -> Result<(), AcpClientError> {
48 self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
49 }
50
51 pub fn cancel(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
52 self.send(PromptCommand::Cancel { session_id: session_id.clone() })
53 }
54
55 pub fn set_config_option(
56 &self,
57 session_id: &SessionId,
58 config_id: &str,
59 value: &str,
60 ) -> Result<(), AcpClientError> {
61 self.send(PromptCommand::SetConfigOption {
62 session_id: session_id.clone(),
63 config_id: config_id.to_string(),
64 value: value.to_string(),
65 })
66 }
67
68 pub fn authenticate_mcp_server(&self, session_id: &SessionId, server_name: &str) -> Result<(), AcpClientError> {
69 self.send(PromptCommand::AuthenticateMcpServer {
70 session_id: session_id.clone(),
71 server_name: server_name.to_string(),
72 })
73 }
74
75 pub fn authenticate(&self, method_id: &str) -> Result<(), AcpClientError> {
76 self.send(PromptCommand::Authenticate { method_id: method_id.to_string() })
77 }
78
79 pub fn list_sessions(&self) -> Result<(), AcpClientError> {
80 self.send(PromptCommand::ListSessions)
81 }
82
83 pub fn load_session(&self, session_id: &SessionId, cwd: &Path) -> Result<(), AcpClientError> {
84 self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
85 }
86
87 pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
88 self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
89 }
90
91 fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
92 self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use agent_client_protocol::schema::TextContent;
99
100 use super::*;
101
102 #[test]
103 fn test_noop_handle_succeeds_silently() {
104 let handle = AcpPromptHandle::noop();
105 let session_id = SessionId::new("test");
106
107 assert!(handle.prompt(&session_id, "hello", None).is_ok());
108 assert!(handle.cancel(&session_id).is_ok());
109 }
110
111 #[test]
112 fn test_prompt_sends_command() {
113 let (tx, mut rx) = mpsc::unbounded_channel();
114 let handle = AcpPromptHandle { cmd_tx: tx };
115 let session_id = SessionId::new("sess-1");
116
117 handle.prompt(&session_id, "hello", None).unwrap();
118
119 let cmd = rx.try_recv().unwrap();
120 match cmd {
121 PromptCommand::Prompt { session_id, text, .. } => {
122 assert_eq!(session_id.0.as_ref(), "sess-1");
123 assert_eq!(text, "hello");
124 }
125 _ => panic!("Expected Prompt command"),
126 }
127 }
128
129 #[test]
130 fn test_cancel_sends_command() {
131 let (tx, mut rx) = mpsc::unbounded_channel();
132 let handle = AcpPromptHandle { cmd_tx: tx };
133 let session_id = SessionId::new("sess-1");
134
135 handle.cancel(&session_id).unwrap();
136
137 let cmd = rx.try_recv().unwrap();
138 assert!(matches!(cmd, PromptCommand::Cancel { .. }));
139 }
140
141 #[test]
142 fn test_set_config_option_sends_command() {
143 let (tx, mut rx) = mpsc::unbounded_channel();
144 let handle = AcpPromptHandle { cmd_tx: tx };
145 let session_id = SessionId::new("sess-1");
146
147 handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
148
149 let cmd = rx.try_recv().unwrap();
150 match cmd {
151 PromptCommand::SetConfigOption { session_id, config_id, value } => {
152 assert_eq!(session_id.0.as_ref(), "sess-1");
153 assert_eq!(config_id, "model");
154 assert_eq!(value, "gpt-4o");
155 }
156 _ => panic!("Expected SetConfigOption command"),
157 }
158 }
159
160 #[test]
161 fn test_prompt_with_content_sends_blocks() {
162 let (tx, mut rx) = mpsc::unbounded_channel();
163 let handle = AcpPromptHandle { cmd_tx: tx };
164 let session_id = SessionId::new("sess-1");
165 let content = vec![ContentBlock::Text(TextContent::new("attached"))];
166
167 handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
168
169 let cmd = rx.try_recv().unwrap();
170 match cmd {
171 PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
172 assert_eq!(session_id.0.as_ref(), "sess-1");
173 assert_eq!(text, "hello");
174 assert_eq!(extra, content);
175 }
176 _ => panic!("Expected Prompt command with content"),
177 }
178 }
179
180 #[test]
181 fn test_list_sessions_sends_command() {
182 let (tx, mut rx) = mpsc::unbounded_channel();
183 let handle = AcpPromptHandle { cmd_tx: tx };
184
185 handle.list_sessions().unwrap();
186
187 let cmd = rx.try_recv().unwrap();
188 assert!(matches!(cmd, PromptCommand::ListSessions));
189 }
190
191 #[test]
192 fn test_load_session_sends_command() {
193 let (tx, mut rx) = mpsc::unbounded_channel();
194 let handle = AcpPromptHandle { cmd_tx: tx };
195 let session_id = SessionId::new("sess-restore");
196 let cwd = Path::new("/tmp/project");
197
198 handle.load_session(&session_id, cwd).unwrap();
199
200 let cmd = rx.try_recv().unwrap();
201 match cmd {
202 PromptCommand::LoadSession { session_id, cwd } => {
203 assert_eq!(session_id.0.as_ref(), "sess-restore");
204 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
205 }
206 _ => panic!("Expected LoadSession command"),
207 }
208 }
209
210 #[test]
211 fn test_new_session_sends_command() {
212 let (tx, mut rx) = mpsc::unbounded_channel();
213 let handle = AcpPromptHandle { cmd_tx: tx };
214 let cwd = std::path::Path::new("/tmp/project");
215
216 handle.new_session(cwd).unwrap();
217
218 let cmd = rx.try_recv().unwrap();
219 match cmd {
220 PromptCommand::NewSession { cwd } => {
221 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
222 }
223 _ => panic!("Expected NewSession command"),
224 }
225 }
226}