Skip to main content

acp_utils/client/
prompt_handle.rs

1use agent_client_protocol::schema::{ContentBlock, SessionId};
2use std::path::{Path, PathBuf};
3use tokio::sync::mpsc;
4
5use super::error::AcpClientError;
6use crate::notifications::PromptSearchParams;
7
8/// Commands sent from the main thread to the ACP client task.
9#[derive(Debug)]
10pub enum PromptCommand {
11    Prompt { session_id: SessionId, text: String, content: Option<Vec<ContentBlock>> },
12    Cancel { session_id: SessionId },
13    SetConfigOption { session_id: SessionId, config_id: String, value: String },
14    AuthenticateMcpServer { session_id: SessionId, server_name: String },
15    Authenticate { method_id: String },
16    ListSessions,
17    LoadSession { session_id: SessionId, cwd: PathBuf },
18    NewSession { cwd: std::path::PathBuf },
19    SearchPrompts(PromptSearchParams),
20}
21
22/// Send-safe handle for issuing prompt commands to the ACP client task.
23#[derive(Clone)]
24pub struct AcpPromptHandle {
25    pub(crate) cmd_tx: mpsc::UnboundedSender<PromptCommand>,
26}
27
28impl AcpPromptHandle {
29    /// Create a handle whose sends always succeed but are never read.
30    /// Useful for tests that don't care about prompt delivery.
31    pub fn noop() -> Self {
32        let (cmd_tx, rx) = mpsc::unbounded_channel();
33        std::mem::forget(rx);
34        Self { cmd_tx }
35    }
36
37    /// Create a handle paired with a receiver for inspecting sent commands.
38    /// Useful for tests that need to verify which commands were dispatched.
39    pub fn recording() -> (Self, mpsc::UnboundedReceiver<PromptCommand>) {
40        let (cmd_tx, rx) = mpsc::unbounded_channel();
41        (Self { cmd_tx }, rx)
42    }
43
44    pub fn prompt(
45        &self,
46        session_id: &SessionId,
47        text: &str,
48        content: Option<Vec<ContentBlock>>,
49    ) -> Result<(), AcpClientError> {
50        self.send(PromptCommand::Prompt { session_id: session_id.clone(), text: text.to_string(), content })
51    }
52
53    pub fn cancel(&self, session_id: &SessionId) -> Result<(), AcpClientError> {
54        self.send(PromptCommand::Cancel { session_id: session_id.clone() })
55    }
56
57    pub fn set_config_option(
58        &self,
59        session_id: &SessionId,
60        config_id: &str,
61        value: &str,
62    ) -> Result<(), AcpClientError> {
63        self.send(PromptCommand::SetConfigOption {
64            session_id: session_id.clone(),
65            config_id: config_id.to_string(),
66            value: value.to_string(),
67        })
68    }
69
70    pub fn authenticate_mcp_server(&self, session_id: &SessionId, server_name: &str) -> Result<(), AcpClientError> {
71        self.send(PromptCommand::AuthenticateMcpServer {
72            session_id: session_id.clone(),
73            server_name: server_name.to_string(),
74        })
75    }
76
77    pub fn authenticate(&self, method_id: &str) -> Result<(), AcpClientError> {
78        self.send(PromptCommand::Authenticate { method_id: method_id.to_string() })
79    }
80
81    pub fn list_sessions(&self) -> Result<(), AcpClientError> {
82        self.send(PromptCommand::ListSessions)
83    }
84
85    pub fn load_session(&self, session_id: &SessionId, cwd: &Path) -> Result<(), AcpClientError> {
86        self.send(PromptCommand::LoadSession { session_id: session_id.clone(), cwd: cwd.to_path_buf() })
87    }
88
89    pub fn new_session(&self, cwd: &Path) -> Result<(), AcpClientError> {
90        self.send(PromptCommand::NewSession { cwd: cwd.to_path_buf() })
91    }
92
93    pub fn search_prompts(&self, params: PromptSearchParams) -> Result<(), AcpClientError> {
94        self.send(PromptCommand::SearchPrompts(params))
95    }
96
97    fn send(&self, cmd: PromptCommand) -> Result<(), AcpClientError> {
98        self.cmd_tx.send(cmd).map_err(|_| AcpClientError::AgentCrashed("command channel closed".into()))
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use agent_client_protocol::schema::TextContent;
105
106    use super::*;
107
108    #[test]
109    fn test_noop_handle_succeeds_silently() {
110        let handle = AcpPromptHandle::noop();
111        let session_id = SessionId::new("test");
112
113        assert!(handle.prompt(&session_id, "hello", None).is_ok());
114        assert!(handle.cancel(&session_id).is_ok());
115    }
116
117    #[test]
118    fn test_prompt_sends_command() {
119        let (tx, mut rx) = mpsc::unbounded_channel();
120        let handle = AcpPromptHandle { cmd_tx: tx };
121        let session_id = SessionId::new("sess-1");
122
123        handle.prompt(&session_id, "hello", None).unwrap();
124
125        let cmd = rx.try_recv().unwrap();
126        match cmd {
127            PromptCommand::Prompt { session_id, text, .. } => {
128                assert_eq!(session_id.0.as_ref(), "sess-1");
129                assert_eq!(text, "hello");
130            }
131            _ => panic!("Expected Prompt command"),
132        }
133    }
134
135    #[test]
136    fn test_cancel_sends_command() {
137        let (tx, mut rx) = mpsc::unbounded_channel();
138        let handle = AcpPromptHandle { cmd_tx: tx };
139        let session_id = SessionId::new("sess-1");
140
141        handle.cancel(&session_id).unwrap();
142
143        let cmd = rx.try_recv().unwrap();
144        assert!(matches!(cmd, PromptCommand::Cancel { .. }));
145    }
146
147    #[test]
148    fn test_set_config_option_sends_command() {
149        let (tx, mut rx) = mpsc::unbounded_channel();
150        let handle = AcpPromptHandle { cmd_tx: tx };
151        let session_id = SessionId::new("sess-1");
152
153        handle.set_config_option(&session_id, "model", "gpt-4o").unwrap();
154
155        let cmd = rx.try_recv().unwrap();
156        match cmd {
157            PromptCommand::SetConfigOption { session_id, config_id, value } => {
158                assert_eq!(session_id.0.as_ref(), "sess-1");
159                assert_eq!(config_id, "model");
160                assert_eq!(value, "gpt-4o");
161            }
162            _ => panic!("Expected SetConfigOption command"),
163        }
164    }
165
166    #[test]
167    fn test_prompt_with_content_sends_blocks() {
168        let (tx, mut rx) = mpsc::unbounded_channel();
169        let handle = AcpPromptHandle { cmd_tx: tx };
170        let session_id = SessionId::new("sess-1");
171        let content = vec![ContentBlock::Text(TextContent::new("attached"))];
172
173        handle.prompt(&session_id, "hello", Some(content.clone())).unwrap();
174
175        let cmd = rx.try_recv().unwrap();
176        match cmd {
177            PromptCommand::Prompt { session_id, text, content: Some(extra) } => {
178                assert_eq!(session_id.0.as_ref(), "sess-1");
179                assert_eq!(text, "hello");
180                assert_eq!(extra, content);
181            }
182            _ => panic!("Expected Prompt command with content"),
183        }
184    }
185
186    #[test]
187    fn test_list_sessions_sends_command() {
188        let (tx, mut rx) = mpsc::unbounded_channel();
189        let handle = AcpPromptHandle { cmd_tx: tx };
190
191        handle.list_sessions().unwrap();
192
193        let cmd = rx.try_recv().unwrap();
194        assert!(matches!(cmd, PromptCommand::ListSessions));
195    }
196
197    #[test]
198    fn test_load_session_sends_command() {
199        let (tx, mut rx) = mpsc::unbounded_channel();
200        let handle = AcpPromptHandle { cmd_tx: tx };
201        let session_id = SessionId::new("sess-restore");
202        let cwd = Path::new("/tmp/project");
203
204        handle.load_session(&session_id, cwd).unwrap();
205
206        let cmd = rx.try_recv().unwrap();
207        match cmd {
208            PromptCommand::LoadSession { session_id, cwd } => {
209                assert_eq!(session_id.0.as_ref(), "sess-restore");
210                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
211            }
212            _ => panic!("Expected LoadSession command"),
213        }
214    }
215
216    #[test]
217    fn test_new_session_sends_command() {
218        let (tx, mut rx) = mpsc::unbounded_channel();
219        let handle = AcpPromptHandle { cmd_tx: tx };
220        let cwd = std::path::Path::new("/tmp/project");
221
222        handle.new_session(cwd).unwrap();
223
224        let cmd = rx.try_recv().unwrap();
225        match cmd {
226            PromptCommand::NewSession { cwd } => {
227                assert_eq!(cwd, std::path::PathBuf::from("/tmp/project"));
228            }
229            _ => panic!("Expected NewSession command"),
230        }
231    }
232}