1use agent_client_protocol::schema::{ContentBlock, SessionId};
2use std::path::{Path, PathBuf};
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6use crate::notifications::PromptSearchParams;
7
8#[derive(Debug)]
10pub enum PromptCommand {
11 Prompt { session_id: SessionId, text: String, content: Option<Vec<ContentBlock>> },
12 Cancel { session_id: SessionId },
13 SetConfigOption { session_id: SessionId, config_id: String, value: String },
14 AuthenticateMcpServer { session_id: SessionId, server_name: String },
15 Authenticate { method_id: String },
16 ListSessions,
17 LoadSession { session_id: SessionId, cwd: PathBuf },
18 NewSession { cwd: std::path::PathBuf },
19 SearchPrompts(PromptSearchParams),
20}
21
22#[derive(Clone)]
24pub struct AcpPromptHandle {
25 pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
26}
27
28impl AcpPromptHandle {
29 pub fn noop() -> Self {
32 let (cmd_tx, rx) = mpsc::unbounded_channel();
33 std::mem::forget(rx);
34 Self { cmd_tx }
35 }
36
37 pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
40 let (cmd_tx, rx) = mpsc::unbounded_channel();
41 (Self { cmd_tx }, rx)
42 }
43
44 pub fn prompt(
45 &self,
46 session_id: &SessionId,
47 text: &str,
48 content: Option<Vec<ContentBlock>>,
49 ) -> Result<(), AcpClientError> {
50 self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
51 }
52
53 pub fn cancel(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
54 self.send(PromptCommand::Cancel { session_id: session_id.clone() })
55 }
56
57 pub fn set_config_option(
58 &self,
59 session_id: &SessionId,
60 config_id: &str,
61 value: &str,
62 ) -> Result<(), AcpClientError> {
63 self.send(PromptCommand::SetConfigOption {
64 session_id: session_id.clone(),
65 config_id: config_id.to_string(),
66 value: value.to_string(),
67 })
68 }
69
70 pub fn authenticate_mcp_server(&self, session_id: &SessionId, server_name: &str) -> Result<(), AcpClientError> {
71 self.send(PromptCommand::AuthenticateMcpServer {
72 session_id: session_id.clone(),
73 server_name: server_name.to_string(),
74 })
75 }
76
77 pub fn authenticate(&self, method_id: &str) -> Result<(), AcpClientError> {
78 self.send(PromptCommand::Authenticate { method_id: method_id.to_string() })
79 }
80
81 pub fn list_sessions(&self) -> Result<(), AcpClientError> {
82 self.send(PromptCommand::ListSessions)
83 }
84
85 pub fn load_session(&self, session_id: &SessionId, cwd: &Path) -> Result<(), AcpClientError> {
86 self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
87 }
88
89 pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
90 self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
91 }
92
93 pub fn search_prompts(&self, params: PromptSearchParams) -> Result<(), AcpClientError> {
94 self.send(PromptCommand::SearchPrompts(params))
95 }
96
97 fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
98 self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use agent_client_protocol::schema::TextContent;
105
106 use super::*;
107
108 #[test]
109 fn test_noop_handle_succeeds_silently() {
110 let handle = AcpPromptHandle::noop();
111 let session_id = SessionId::new("test");
112
113 assert!(handle.prompt(&session_id, "hello", None).is_ok());
114 assert!(handle.cancel(&session_id).is_ok());
115 }
116
117 #[test]
118 fn test_prompt_sends_command() {
119 let (tx, mut rx) = mpsc::unbounded_channel();
120 let handle = AcpPromptHandle { cmd_tx: tx };
121 let session_id = SessionId::new("sess-1");
122
123 handle.prompt(&session_id, "hello", None).unwrap();
124
125 let cmd = rx.try_recv().unwrap();
126 match cmd {
127 PromptCommand::Prompt { session_id, text, .. } => {
128 assert_eq!(session_id.0.as_ref(), "sess-1");
129 assert_eq!(text, "hello");
130 }
131 _ => panic!("Expected Prompt command"),
132 }
133 }
134
135 #[test]
136 fn test_cancel_sends_command() {
137 let (tx, mut rx) = mpsc::unbounded_channel();
138 let handle = AcpPromptHandle { cmd_tx: tx };
139 let session_id = SessionId::new("sess-1");
140
141 handle.cancel(&session_id).unwrap();
142
143 let cmd = rx.try_recv().unwrap();
144 assert!(matches!(cmd, PromptCommand::Cancel { .. }));
145 }
146
147 #[test]
148 fn test_set_config_option_sends_command() {
149 let (tx, mut rx) = mpsc::unbounded_channel();
150 let handle = AcpPromptHandle { cmd_tx: tx };
151 let session_id = SessionId::new("sess-1");
152
153 handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
154
155 let cmd = rx.try_recv().unwrap();
156 match cmd {
157 PromptCommand::SetConfigOption { session_id, config_id, value } => {
158 assert_eq!(session_id.0.as_ref(), "sess-1");
159 assert_eq!(config_id, "model");
160 assert_eq!(value, "gpt-4o");
161 }
162 _ => panic!("Expected SetConfigOption command"),
163 }
164 }
165
166 #[test]
167 fn test_prompt_with_content_sends_blocks() {
168 let (tx, mut rx) = mpsc::unbounded_channel();
169 let handle = AcpPromptHandle { cmd_tx: tx };
170 let session_id = SessionId::new("sess-1");
171 let content = vec![ContentBlock::Text(TextContent::new("attached"))];
172
173 handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
174
175 let cmd = rx.try_recv().unwrap();
176 match cmd {
177 PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
178 assert_eq!(session_id.0.as_ref(), "sess-1");
179 assert_eq!(text, "hello");
180 assert_eq!(extra, content);
181 }
182 _ => panic!("Expected Prompt command with content"),
183 }
184 }
185
186 #[test]
187 fn test_list_sessions_sends_command() {
188 let (tx, mut rx) = mpsc::unbounded_channel();
189 let handle = AcpPromptHandle { cmd_tx: tx };
190
191 handle.list_sessions().unwrap();
192
193 let cmd = rx.try_recv().unwrap();
194 assert!(matches!(cmd, PromptCommand::ListSessions));
195 }
196
197 #[test]
198 fn test_load_session_sends_command() {
199 let (tx, mut rx) = mpsc::unbounded_channel();
200 let handle = AcpPromptHandle { cmd_tx: tx };
201 let session_id = SessionId::new("sess-restore");
202 let cwd = Path::new("/tmp/project");
203
204 handle.load_session(&session_id, cwd).unwrap();
205
206 let cmd = rx.try_recv().unwrap();
207 match cmd {
208 PromptCommand::LoadSession { session_id, cwd } => {
209 assert_eq!(session_id.0.as_ref(), "sess-restore");
210 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
211 }
212 _ => panic!("Expected LoadSession command"),
213 }
214 }
215
216 #[test]
217 fn test_new_session_sends_command() {
218 let (tx, mut rx) = mpsc::unbounded_channel();
219 let handle = AcpPromptHandle { cmd_tx: tx };
220 let cwd = std::path::Path::new("/tmp/project");
221
222 handle.new_session(cwd).unwrap();
223
224 let cmd = rx.try_recv().unwrap();
225 match cmd {
226 PromptCommand::NewSession { cwd } => {
227 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
228 }
229 _ => panic!("Expected NewSession command"),
230 }
231 }
232}