Skip to main content

acp_utils/client/
prompt_handle.rs

1use agent_client_protocol::schema::{ContentBlock, SessionId};
2use std::path::{Path, PathBuf};
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6
7/// Commands sent from the main thread to the ACP client task.
8#[derive(Debug)]
9pub enum PromptCommand {
10    Prompt { session_id: SessionId, text: String, content: Option<Vec<ContentBlock>> },
11    Cancel { session_id: SessionId },
12    SetConfigOption { session_id: SessionId, config_id: String, value: String },
13    AuthenticateMcpServer { session_id: SessionId, server_name: String },
14    Authenticate { method_id: String },
15    ListSessions,
16    LoadSession { session_id: SessionId, cwd: PathBuf },
17    NewSession { cwd: std::path::PathBuf },
18}
19
20/// Send-safe handle for issuing prompt commands to the ACP client task.
21#[derive(Clone)]
22pub struct AcpPromptHandle {
23    pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
24}
25
26impl AcpPromptHandle {
27    /// Create a handle whose sends always succeed but are never read.
28    /// Useful for tests that don't care about prompt delivery.
29    pub fn noop() -> Self {
30        let (cmd_tx, rx) = mpsc::unbounded_channel();
31        std::mem::forget(rx);
32        Self { cmd_tx }
33    }
34
35    /// Create a handle paired with a receiver for inspecting sent commands.
36    /// Useful for tests that need to verify which commands were dispatched.
37    pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
38        let (cmd_tx, rx) = mpsc::unbounded_channel();
39        (Self { cmd_tx }, rx)
40    }
41
42    pub fn prompt(
43        &self,
44        session_id: &SessionId,
45        text: &str,
46        content: Option<Vec<ContentBlock>>,
47    ) -> Result<(), AcpClientError> {
48        self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
49    }
50
51    pub fn cancel(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
52        self.send(PromptCommand::Cancel { session_id: session_id.clone() })
53    }
54
55    pub fn set_config_option(
56        &self,
57        session_id: &SessionId,
58        config_id: &str,
59        value: &str,
60    ) -> Result<(), AcpClientError> {
61        self.send(PromptCommand::SetConfigOption {
62            session_id: session_id.clone(),
63            config_id: config_id.to_string(),
64            value: value.to_string(),
65        })
66    }
67
68    pub fn authenticate_mcp_server(&self, session_id: &SessionId, server_name: &str) -> Result<(), AcpClientError> {
69        self.send(PromptCommand::AuthenticateMcpServer {
70            session_id: session_id.clone(),
71            server_name: server_name.to_string(),
72        })
73    }
74
75    pub fn authenticate(&self, method_id: &str) -> Result<(), AcpClientError> {
76        self.send(PromptCommand::Authenticate { method_id: method_id.to_string() })
77    }
78
79    pub fn list_sessions(&self) -> Result<(), AcpClientError> {
80        self.send(PromptCommand::ListSessions)
81    }
82
83    pub fn load_session(&self, session_id: &SessionId, cwd: &Path) -> Result<(), AcpClientError> {
84        self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
85    }
86
87    pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
88        self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
89    }
90
91    fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
92        self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use agent_client_protocol::schema::TextContent;
99
100    use super::*;
101
102    #[test]
103    fn test_noop_handle_succeeds_silently() {
104        let handle = AcpPromptHandle::noop();
105        let session_id = SessionId::new("test");
106
107        assert!(handle.prompt(&session_id, "hello", None).is_ok());
108        assert!(handle.cancel(&session_id).is_ok());
109    }
110
111    #[test]
112    fn test_prompt_sends_command() {
113        let (tx, mut rx) = mpsc::unbounded_channel();
114        let handle = AcpPromptHandle { cmd_tx: tx };
115        let session_id = SessionId::new("sess-1");
116
117        handle.prompt(&session_id, "hello", None).unwrap();
118
119        let cmd = rx.try_recv().unwrap();
120        match cmd {
121            PromptCommand::Prompt { session_id, text, .. } => {
122                assert_eq!(session_id.0.as_ref(), "sess-1");
123                assert_eq!(text, "hello");
124            }
125            _ => panic!("Expected Prompt command"),
126        }
127    }
128
129    #[test]
130    fn test_cancel_sends_command() {
131        let (tx, mut rx) = mpsc::unbounded_channel();
132        let handle = AcpPromptHandle { cmd_tx: tx };
133        let session_id = SessionId::new("sess-1");
134
135        handle.cancel(&session_id).unwrap();
136
137        let cmd = rx.try_recv().unwrap();
138        assert!(matches!(cmd, PromptCommand::Cancel { .. }));
139    }
140
141    #[test]
142    fn test_set_config_option_sends_command() {
143        let (tx, mut rx) = mpsc::unbounded_channel();
144        let handle = AcpPromptHandle { cmd_tx: tx };
145        let session_id = SessionId::new("sess-1");
146
147        handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
148
149        let cmd = rx.try_recv().unwrap();
150        match cmd {
151            PromptCommand::SetConfigOption { session_id, config_id, value } => {
152                assert_eq!(session_id.0.as_ref(), "sess-1");
153                assert_eq!(config_id, "model");
154                assert_eq!(value, "gpt-4o");
155            }
156            _ => panic!("Expected SetConfigOption command"),
157        }
158    }
159
160    #[test]
161    fn test_prompt_with_content_sends_blocks() {
162        let (tx, mut rx) = mpsc::unbounded_channel();
163        let handle = AcpPromptHandle { cmd_tx: tx };
164        let session_id = SessionId::new("sess-1");
165        let content = vec![ContentBlock::Text(TextContent::new("attached"))];
166
167        handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
168
169        let cmd = rx.try_recv().unwrap();
170        match cmd {
171            PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
172                assert_eq!(session_id.0.as_ref(), "sess-1");
173                assert_eq!(text, "hello");
174                assert_eq!(extra, content);
175            }
176            _ => panic!("Expected Prompt command with content"),
177        }
178    }
179
180    #[test]
181    fn test_list_sessions_sends_command() {
182        let (tx, mut rx) = mpsc::unbounded_channel();
183        let handle = AcpPromptHandle { cmd_tx: tx };
184
185        handle.list_sessions().unwrap();
186
187        let cmd = rx.try_recv().unwrap();
188        assert!(matches!(cmd, PromptCommand::ListSessions));
189    }
190
191    #[test]
192    fn test_load_session_sends_command() {
193        let (tx, mut rx) = mpsc::unbounded_channel();
194        let handle = AcpPromptHandle { cmd_tx: tx };
195        let session_id = SessionId::new("sess-restore");
196        let cwd = Path::new("/tmp/project");
197
198        handle.load_session(&session_id, cwd).unwrap();
199
200        let cmd = rx.try_recv().unwrap();
201        match cmd {
202            PromptCommand::LoadSession { session_id, cwd } => {
203                assert_eq!(session_id.0.as_ref(), "sess-restore");
204                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
205            }
206            _ => panic!("Expected LoadSession command"),
207        }
208    }
209
210    #[test]
211    fn test_new_session_sends_command() {
212        let (tx, mut rx) = mpsc::unbounded_channel();
213        let handle = AcpPromptHandle { cmd_tx: tx };
214        let cwd = std::path::Path::new("/tmp/project");
215
216        handle.new_session(cwd).unwrap();
217
218        let cmd = rx.try_recv().unwrap();
219        match cmd {
220            PromptCommand::NewSession { cwd } => {
221                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
222            }
223            _ => panic!("Expected NewSession command"),
224        }
225    }
226}