1use super::traits::{Tool, ToolResult};
10use crate::channels::traits::{Channel, ChannelMessage, SendMessage};
11use crate::security::SecurityPolicy;
12use crate::security::policy::ToolOperation;
13use async_trait::async_trait;
14use parking_lot::RwLock;
15use serde_json::json;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
21
22const DEFAULT_TIMEOUT_SECS: u64 = 300;
24
25pub struct AskUserTool {
27 security: Arc<SecurityPolicy>,
28 channels: ChannelMapHandle,
29}
30
31impl AskUserTool {
32 pub fn new(security: Arc<SecurityPolicy>) -> Self {
36 Self {
37 security,
38 channels: Arc::new(RwLock::new(HashMap::new())),
39 }
40 }
41
42 pub fn channel_map_handle(&self) -> ChannelMapHandle {
44 Arc::clone(&self.channels)
45 }
46
47 pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
49 *self.channels.write() = map;
50 }
51}
52
53fn format_question(question: &str, choices: Option<&[String]>) -> String {
55 let mut lines = Vec::new();
56 lines.push(format!("**{question}**"));
57
58 if let Some(choices) = choices {
59 lines.push(String::new());
60 for (i, choice) in choices.iter().enumerate() {
61 lines.push(format!("{}. {choice}", i + 1));
62 }
63 lines.push(String::new());
64 lines.push("_Reply with a number or type your answer._".to_string());
65 }
66
67 lines.join("\n")
68}
69
70#[async_trait]
71impl Tool for AskUserTool {
72 fn name(&self) -> &str {
73 "ask_user"
74 }
75
76 fn description(&self) -> &str {
77 "Ask the user a question and wait for their response. \
78 Sends the question to a messaging channel and blocks until the user replies \
79 or the timeout expires. Optionally provide choices for structured responses."
80 }
81
82 fn parameters_schema(&self) -> serde_json::Value {
83 json!({
84 "type": "object",
85 "properties": {
86 "question": {
87 "type": "string",
88 "description": "The question to ask the user"
89 },
90 "choices": {
91 "type": "array",
92 "items": { "type": "string" },
93 "description": "Optional list of choices (renders as buttons on Telegram, numbered list on CLI)"
94 },
95 "timeout_secs": {
96 "type": "integer",
97 "description": "Seconds to wait for a response (default: 300)"
98 },
99 "channel": {
100 "type": "string",
101 "description": "Target channel name. Defaults to the first available channel if omitted."
102 }
103 },
104 "required": ["question"]
105 })
106 }
107
108 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
109 if let Err(e) = self
111 .security
112 .enforce_tool_operation(ToolOperation::Act, "ask_user")
113 {
114 return Ok(ToolResult {
115 success: false,
116 output: String::new(),
117 error: Some(format!("Action blocked: {e}")),
118 });
119 }
120
121 let question = args
123 .get("question")
124 .and_then(|v| v.as_str())
125 .map(|s| s.trim())
126 .filter(|s| !s.is_empty())
127 .ok_or_else(|| anyhow::anyhow!("Missing 'question' parameter"))?
128 .to_string();
129
130 let choices: Option<Vec<String>> = args.get("choices").and_then(|v| {
131 v.as_array().map(|arr| {
132 arr.iter()
133 .filter_map(|item| item.as_str().map(|s| s.trim().to_string()))
134 .filter(|s| !s.is_empty())
135 .collect()
136 })
137 });
138
139 let timeout_secs = args
140 .get("timeout_secs")
141 .and_then(|v| v.as_u64())
142 .unwrap_or(DEFAULT_TIMEOUT_SECS);
143
144 let requested_channel = args
145 .get("channel")
146 .and_then(|v| v.as_str())
147 .map(|s| s.trim().to_string());
148
149 let (channel_name, channel): (String, Arc<dyn Channel>) = {
152 let channels = self.channels.read();
153 if channels.is_empty() {
154 return Ok(ToolResult {
155 success: false,
156 output: String::new(),
157 error: Some("No channels available yet (channels not initialized)".to_string()),
158 });
159 }
160 if let Some(ref name) = requested_channel {
161 let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
162 let available: Vec<String> = channels.keys().cloned().collect();
163 anyhow::anyhow!(
164 "Channel '{}' not found. Available: {}",
165 name,
166 available.join(", ")
167 )
168 })?;
169 (name.clone(), ch)
170 } else {
171 let (name, ch) = channels.iter().next().ok_or_else(|| {
172 anyhow::anyhow!("No channels available. Configure at least one channel.")
173 })?;
174 (name.clone(), ch.clone())
175 }
176 };
177
178 let text = format_question(&question, choices.as_deref());
180 let msg = SendMessage::new(&text, "");
181 if let Err(e) = channel.send(&msg).await {
182 return Ok(ToolResult {
183 success: false,
184 output: String::new(),
185 error: Some(format!(
186 "Failed to send question to channel '{channel_name}': {e}"
187 )),
188 });
189 }
190
191 let (tx, mut rx) = tokio::sync::mpsc::channel::<ChannelMessage>(1);
193 let timeout = std::time::Duration::from_secs(timeout_secs);
194
195 let listen_channel = Arc::clone(&channel);
197 let listen_handle = tokio::spawn(async move { listen_channel.listen(tx).await });
198
199 let response = tokio::time::timeout(timeout, rx.recv()).await;
200
201 listen_handle.abort();
203
204 match response {
205 Ok(Some(msg)) => Ok(ToolResult {
206 success: true,
207 output: msg.content,
208 error: None,
209 }),
210 Ok(None) => Ok(ToolResult {
211 success: false,
212 output: "TIMEOUT".to_string(),
213 error: Some("Channel closed before receiving a response".to_string()),
214 }),
215 Err(_) => Ok(ToolResult {
216 success: false,
217 output: "TIMEOUT".to_string(),
218 error: Some(format!(
219 "No response received within {timeout_secs} seconds"
220 )),
221 }),
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 struct SilentChannel {
232 channel_name: String,
233 sent: Arc<RwLock<Vec<String>>>,
234 }
235
236 impl SilentChannel {
237 fn new(name: &str) -> Self {
238 Self {
239 channel_name: name.to_string(),
240 sent: Arc::new(RwLock::new(Vec::new())),
241 }
242 }
243 }
244
245 #[async_trait]
246 impl Channel for SilentChannel {
247 fn name(&self) -> &str {
248 &self.channel_name
249 }
250
251 async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
252 self.sent.write().push(message.content.clone());
253 Ok(())
254 }
255
256 async fn listen(
257 &self,
258 _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
259 ) -> anyhow::Result<()> {
260 tokio::time::sleep(std::time::Duration::from_secs(600)).await;
262 Ok(())
263 }
264 }
265
266 struct RespondingChannel {
268 channel_name: String,
269 response: String,
270 sent: Arc<RwLock<Vec<String>>>,
271 }
272
273 impl RespondingChannel {
274 fn new(name: &str, response: &str) -> Self {
275 Self {
276 channel_name: name.to_string(),
277 response: response.to_string(),
278 sent: Arc::new(RwLock::new(Vec::new())),
279 }
280 }
281 }
282
283 #[async_trait]
284 impl Channel for RespondingChannel {
285 fn name(&self) -> &str {
286 &self.channel_name
287 }
288
289 async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
290 self.sent.write().push(message.content.clone());
291 Ok(())
292 }
293
294 async fn listen(
295 &self,
296 tx: tokio::sync::mpsc::Sender<ChannelMessage>,
297 ) -> anyhow::Result<()> {
298 let msg = ChannelMessage {
299 id: "resp_1".to_string(),
300 sender: "user".to_string(),
301 reply_target: "user".to_string(),
302 content: self.response.clone(),
303 channel: self.channel_name.clone(),
304 timestamp: 1000,
305 thread_ts: None,
306 interruption_scope_id: None,
307 attachments: vec![],
308 };
309 let _ = tx.send(msg).await;
310 Ok(())
311 }
312 }
313
314 fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> AskUserTool {
315 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
316 let map: HashMap<String, Arc<dyn Channel>> = channels
317 .into_iter()
318 .map(|(name, ch)| (name.to_string(), ch))
319 .collect();
320 tool.populate(map);
321 tool
322 }
323
324 #[test]
327 fn tool_name_and_description() {
328 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
329 assert_eq!(tool.name(), "ask_user");
330 assert!(!tool.description().is_empty());
331 assert!(tool.description().contains("question"));
332 }
333
334 #[test]
335 fn parameter_schema_validation() {
336 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
337 let schema = tool.parameters_schema();
338 assert_eq!(schema["type"], "object");
339 assert!(schema["properties"]["question"].is_object());
340 assert!(schema["properties"]["choices"].is_object());
341 assert!(schema["properties"]["timeout_secs"].is_object());
342 assert!(schema["properties"]["channel"].is_object());
343 let required = schema["required"].as_array().unwrap();
344 assert!(required.iter().any(|v| v == "question"));
345 assert!(!required.iter().any(|v| v == "choices"));
347 assert!(!required.iter().any(|v| v == "timeout_secs"));
348 assert!(!required.iter().any(|v| v == "channel"));
349 }
350
351 #[test]
352 fn spec_matches_metadata() {
353 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
354 let spec = tool.spec();
355 assert_eq!(spec.name, "ask_user");
356 assert_eq!(spec.description, tool.description());
357 assert!(spec.parameters["required"].is_array());
358 }
359
360 #[test]
363 fn format_question_without_choices() {
364 let text = format_question("Are you sure?", None);
365 assert!(text.contains("Are you sure?"));
366 assert!(!text.contains("1."));
367 }
368
369 #[test]
370 fn format_question_with_choices() {
371 let choices = vec!["Yes".to_string(), "No".to_string(), "Maybe".to_string()];
372 let text = format_question("Continue?", Some(&choices));
373 assert!(text.contains("Continue?"));
374 assert!(text.contains("1. Yes"));
375 assert!(text.contains("2. No"));
376 assert!(text.contains("3. Maybe"));
377 assert!(text.contains("Reply with a number"));
378 }
379
380 #[tokio::test]
383 async fn execute_rejects_missing_question() {
384 let tool = make_tool_with_channels(vec![(
385 "test",
386 Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
387 )]);
388 let result = tool.execute(json!({})).await;
389 assert!(result.is_err());
390 }
391
392 #[tokio::test]
393 async fn execute_rejects_empty_question() {
394 let tool = make_tool_with_channels(vec![(
395 "test",
396 Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
397 )]);
398 let result = tool.execute(json!({ "question": " " })).await;
399 assert!(result.is_err());
400 }
401
402 #[tokio::test]
403 async fn empty_channels_returns_not_initialized() {
404 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
405 let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
406 assert!(!result.success);
407 assert!(result.error.as_deref().unwrap().contains("not initialized"));
408 }
409
410 #[tokio::test]
411 async fn unknown_channel_returns_error() {
412 let tool = make_tool_with_channels(vec![(
413 "slack",
414 Arc::new(SilentChannel::new("slack")) as Arc<dyn Channel>,
415 )]);
416 let result = tool
417 .execute(json!({ "question": "Hello?", "channel": "nonexistent" }))
418 .await;
419 assert!(result.is_err());
420 }
421
422 #[tokio::test]
423 async fn timeout_returns_timeout_output() {
424 let tool = make_tool_with_channels(vec![(
425 "test",
426 Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
427 )]);
428 let result = tool
429 .execute(json!({
430 "question": "Confirm?",
431 "timeout_secs": 1
432 }))
433 .await
434 .unwrap();
435 assert!(!result.success);
436 assert_eq!(result.output, "TIMEOUT");
437 assert!(result.error.as_deref().unwrap().contains("1 seconds"));
438 }
439
440 #[tokio::test]
441 async fn successful_response_flow() {
442 let tool = make_tool_with_channels(vec![(
443 "test",
444 Arc::new(RespondingChannel::new("test", "Yes, proceed!")) as Arc<dyn Channel>,
445 )]);
446 let result = tool
447 .execute(json!({
448 "question": "Should we deploy?",
449 "timeout_secs": 5
450 }))
451 .await
452 .unwrap();
453 assert!(result.success, "error: {:?}", result.error);
454 assert_eq!(result.output, "Yes, proceed!");
455 assert!(result.error.is_none());
456 }
457
458 #[tokio::test]
459 async fn successful_response_with_choices() {
460 let tool = make_tool_with_channels(vec![(
461 "telegram",
462 Arc::new(RespondingChannel::new("telegram", "2")) as Arc<dyn Channel>,
463 )]);
464 let result = tool
465 .execute(json!({
466 "question": "Pick an option",
467 "choices": ["Option A", "Option B"],
468 "channel": "telegram",
469 "timeout_secs": 5
470 }))
471 .await
472 .unwrap();
473 assert!(result.success, "error: {:?}", result.error);
474 assert_eq!(result.output, "2");
475 }
476
477 #[tokio::test]
478 async fn channel_map_handle_allows_late_binding() {
479 let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
480 let handle = tool.channel_map_handle();
481
482 let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
484 assert!(!result.success);
485
486 {
488 let mut map = handle.write();
489 map.insert(
490 "cli".to_string(),
491 Arc::new(RespondingChannel::new("cli", "ok")) as Arc<dyn Channel>,
492 );
493 }
494
495 let result = tool
497 .execute(json!({ "question": "Hello?", "timeout_secs": 5 }))
498 .await
499 .unwrap();
500 assert!(result.success);
501 assert_eq!(result.output, "ok");
502 }
503}