1use agent_client_protocol::schema::{ContentBlock, SessionId};
2use std::path::{Path, PathBuf};
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6use crate::notifications::{
7 PromptSearchParams, SessionPreviewParams, WorkspaceListParams, WorkspaceMoveParams, WorkspaceMoveTarget,
8};
9
10#[derive(Debug)]
12pub enum PromptCommand {
13 Prompt { session_id: SessionId, text: String, content: Option<Vec<ContentBlock>> },
14 Cancel { session_id: SessionId },
15 SetConfigOption { session_id: SessionId, config_id: String, value: String },
16 AuthenticateMcpServer { session_id: SessionId, server_name: String },
17 Authenticate { method_id: String },
18 ListSessions,
19 LoadSession { session_id: SessionId, cwd: PathBuf },
20 NewSession { cwd: std::path::PathBuf },
21 SearchPrompts(PromptSearchParams),
22 SessionPreview(SessionPreviewParams),
23 ListWorkspaces(WorkspaceListParams),
24 MoveWorkspace(WorkspaceMoveParams),
25}
26
27#[derive(Clone)]
29pub struct AcpPromptHandle {
30 pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
31}
32
33impl AcpPromptHandle {
34 pub fn noop() -> Self {
37 let (cmd_tx, rx) = mpsc::unbounded_channel();
38 std::mem::forget(rx);
39 Self { cmd_tx }
40 }
41
42 pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
45 let (cmd_tx, rx) = mpsc::unbounded_channel();
46 (Self { cmd_tx }, rx)
47 }
48
49 pub fn prompt(
50 &self,
51 session_id: &SessionId,
52 text: &str,
53 content: Option<Vec<ContentBlock>>,
54 ) -> Result<(), AcpClientError> {
55 self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
56 }
57
58 pub fn cancel(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
59 self.send(PromptCommand::Cancel { session_id: session_id.clone() })
60 }
61
62 pub fn set_config_option(
63 &self,
64 session_id: &SessionId,
65 config_id: &str,
66 value: &str,
67 ) -> Result<(), AcpClientError> {
68 self.send(PromptCommand::SetConfigOption {
69 session_id: session_id.clone(),
70 config_id: config_id.to_string(),
71 value: value.to_string(),
72 })
73 }
74
75 pub fn authenticate_mcp_server(&self, session_id: &SessionId, server_name: &str) -> Result<(), AcpClientError> {
76 self.send(PromptCommand::AuthenticateMcpServer {
77 session_id: session_id.clone(),
78 server_name: server_name.to_string(),
79 })
80 }
81
82 pub fn authenticate(&self, method_id: &str) -> Result<(), AcpClientError> {
83 self.send(PromptCommand::Authenticate { method_id: method_id.to_string() })
84 }
85
86 pub fn list_sessions(&self) -> Result<(), AcpClientError> {
87 self.send(PromptCommand::ListSessions)
88 }
89
90 pub fn load_session(&self, session_id: &SessionId, cwd: &Path) -> Result<(), AcpClientError> {
91 self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
92 }
93
94 pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
95 self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
96 }
97
98 pub fn search_prompts(&self, params: PromptSearchParams) -> Result<(), AcpClientError> {
99 self.send(PromptCommand::SearchPrompts(params))
100 }
101
102 pub fn session_preview(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
103 self.send(PromptCommand::SessionPreview(SessionPreviewParams { session_id: session_id.0.to_string() }))
104 }
105
106 pub fn list_workspaces(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
107 self.send(PromptCommand::ListWorkspaces(WorkspaceListParams { session_id: session_id.0.to_string() }))
108 }
109
110 pub fn move_workspace(&self, session_id: &SessionId, target: WorkspaceMoveTarget) -> Result<(), AcpClientError> {
111 self.send(PromptCommand::MoveWorkspace(WorkspaceMoveParams { session_id: session_id.0.to_string(), target }))
112 }
113
114 fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
115 self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use agent_client_protocol::schema::TextContent;
122
123 use super::*;
124
125 #[test]
126 fn test_noop_handle_succeeds_silently() {
127 let handle = AcpPromptHandle::noop();
128 let session_id = SessionId::new("test");
129
130 assert!(handle.prompt(&session_id, "hello", None).is_ok());
131 assert!(handle.cancel(&session_id).is_ok());
132 }
133
134 #[test]
135 fn test_prompt_sends_command() {
136 let (tx, mut rx) = mpsc::unbounded_channel();
137 let handle = AcpPromptHandle { cmd_tx: tx };
138 let session_id = SessionId::new("sess-1");
139
140 handle.prompt(&session_id, "hello", None).unwrap();
141
142 let cmd = rx.try_recv().unwrap();
143 match cmd {
144 PromptCommand::Prompt { session_id, text, .. } => {
145 assert_eq!(session_id.0.as_ref(), "sess-1");
146 assert_eq!(text, "hello");
147 }
148 _ => panic!("Expected Prompt command"),
149 }
150 }
151
152 #[test]
153 fn test_cancel_sends_command() {
154 let (tx, mut rx) = mpsc::unbounded_channel();
155 let handle = AcpPromptHandle { cmd_tx: tx };
156 let session_id = SessionId::new("sess-1");
157
158 handle.cancel(&session_id).unwrap();
159
160 let cmd = rx.try_recv().unwrap();
161 assert!(matches!(cmd, PromptCommand::Cancel { .. }));
162 }
163
164 #[test]
165 fn test_set_config_option_sends_command() {
166 let (tx, mut rx) = mpsc::unbounded_channel();
167 let handle = AcpPromptHandle { cmd_tx: tx };
168 let session_id = SessionId::new("sess-1");
169
170 handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
171
172 let cmd = rx.try_recv().unwrap();
173 match cmd {
174 PromptCommand::SetConfigOption { session_id, config_id, value } => {
175 assert_eq!(session_id.0.as_ref(), "sess-1");
176 assert_eq!(config_id, "model");
177 assert_eq!(value, "gpt-4o");
178 }
179 _ => panic!("Expected SetConfigOption command"),
180 }
181 }
182
183 #[test]
184 fn test_prompt_with_content_sends_blocks() {
185 let (tx, mut rx) = mpsc::unbounded_channel();
186 let handle = AcpPromptHandle { cmd_tx: tx };
187 let session_id = SessionId::new("sess-1");
188 let content = vec![ContentBlock::Text(TextContent::new("attached"))];
189
190 handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
191
192 let cmd = rx.try_recv().unwrap();
193 match cmd {
194 PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
195 assert_eq!(session_id.0.as_ref(), "sess-1");
196 assert_eq!(text, "hello");
197 assert_eq!(extra, content);
198 }
199 _ => panic!("Expected Prompt command with content"),
200 }
201 }
202
203 #[test]
204 fn test_list_sessions_sends_command() {
205 let (tx, mut rx) = mpsc::unbounded_channel();
206 let handle = AcpPromptHandle { cmd_tx: tx };
207
208 handle.list_sessions().unwrap();
209
210 let cmd = rx.try_recv().unwrap();
211 assert!(matches!(cmd, PromptCommand::ListSessions));
212 }
213
214 #[test]
215 fn test_load_session_sends_command() {
216 let (tx, mut rx) = mpsc::unbounded_channel();
217 let handle = AcpPromptHandle { cmd_tx: tx };
218 let session_id = SessionId::new("sess-restore");
219 let cwd = Path::new("/tmp/project");
220
221 handle.load_session(&session_id, cwd).unwrap();
222
223 let cmd = rx.try_recv().unwrap();
224 match cmd {
225 PromptCommand::LoadSession { session_id, cwd } => {
226 assert_eq!(session_id.0.as_ref(), "sess-restore");
227 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
228 }
229 _ => panic!("Expected LoadSession command"),
230 }
231 }
232
233 #[test]
234 fn test_list_workspaces_sends_command() {
235 let (tx, mut rx) = mpsc::unbounded_channel();
236 let handle = AcpPromptHandle { cmd_tx: tx };
237 handle.list_workspaces(&SessionId::new("sess-1")).unwrap();
238
239 let cmd = rx.try_recv().unwrap();
240 match cmd {
241 PromptCommand::ListWorkspaces(params) => assert_eq!(params.session_id, "sess-1"),
242 _ => panic!("Expected ListWorkspaces command"),
243 }
244 }
245
246 #[test]
247 fn test_move_workspace_sends_command() {
248 let (tx, mut rx) = mpsc::unbounded_channel();
249 let handle = AcpPromptHandle { cmd_tx: tx };
250 handle.move_workspace(&SessionId::new("sess-1"), WorkspaceMoveTarget::New { name: "ws".into() }).unwrap();
251
252 let cmd = rx.try_recv().unwrap();
253 match cmd {
254 PromptCommand::MoveWorkspace(params) => {
255 assert_eq!(params.session_id, "sess-1");
256 assert_eq!(params.target, WorkspaceMoveTarget::New { name: "ws".into() });
257 }
258 _ => panic!("Expected MoveWorkspace command"),
259 }
260 }
261
262 #[test]
263 fn test_new_session_sends_command() {
264 let (tx, mut rx) = mpsc::unbounded_channel();
265 let handle = AcpPromptHandle { cmd_tx: tx };
266 let cwd = std::path::Path::new("/tmp/project");
267
268 handle.new_session(cwd).unwrap();
269
270 let cmd = rx.try_recv().unwrap();
271 match cmd {
272 PromptCommand::NewSession { cwd } => {
273 assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
274 }
275 _ => panic!("Expected NewSession command"),
276 }
277 }
278}