Skip to main content

construct/tools/
ask_user.rs

1//! Interactive user prompting tool for cross-channel confirmations.
2//!
3//! Exposes `ask_user` as an agent-callable tool that sends a question to a
4//! messaging channel and waits for the user's response. The tool holds a
5//! late-binding channel map handle that is populated once channels are
6//! initialized (after tool construction). This mirrors the pattern used by
7//! [`ReactionTool`](super::reaction::ReactionTool).
8
9use super::traits::{Tool, ToolResult};
10use crate::channels::traits::{Channel, ChannelMessage, SendMessage};
11use crate::security::SecurityPolicy;
12use crate::security::policy::ToolOperation;
13use async_trait::async_trait;
14use parking_lot::RwLock;
15use serde_json::json;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// Shared handle giving tools late-bound access to the live channel map.
20pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
21
22/// Default timeout in seconds when waiting for a user response.
23const DEFAULT_TIMEOUT_SECS: u64 = 300;
24
25/// Agent-callable tool for sending a question to a user and waiting for their response.
26pub struct AskUserTool {
27    security: Arc<SecurityPolicy>,
28    channels: ChannelMapHandle,
29}
30
31impl AskUserTool {
32    /// Create a new ask_user tool with an empty channel map.
33    /// Call [`channel_map_handle`] and write to the returned handle once channels
34    /// are available.
35    pub fn new(security: Arc<SecurityPolicy>) -> Self {
36        Self {
37            security,
38            channels: Arc::new(RwLock::new(HashMap::new())),
39        }
40    }
41
42    /// Return the shared handle so callers can populate it after channel init.
43    pub fn channel_map_handle(&self) -> ChannelMapHandle {
44        Arc::clone(&self.channels)
45    }
46
47    /// Convenience: populate the channel map from a pre-built map.
48    pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
49        *self.channels.write() = map;
50    }
51}
52
53/// Format a question with optional choices for display.
54fn format_question(question: &str, choices: Option<&[String]>) -> String {
55    let mut lines = Vec::new();
56    lines.push(format!("**{question}**"));
57
58    if let Some(choices) = choices {
59        lines.push(String::new());
60        for (i, choice) in choices.iter().enumerate() {
61            lines.push(format!("{}. {choice}", i + 1));
62        }
63        lines.push(String::new());
64        lines.push("_Reply with a number or type your answer._".to_string());
65    }
66
67    lines.join("\n")
68}
69
70#[async_trait]
71impl Tool for AskUserTool {
72    fn name(&self) -> &str {
73        "ask_user"
74    }
75
76    fn description(&self) -> &str {
77        "Ask the user a question and wait for their response. \
78         Sends the question to a messaging channel and blocks until the user replies \
79         or the timeout expires. Optionally provide choices for structured responses."
80    }
81
82    fn parameters_schema(&self) -> serde_json::Value {
83        json!({
84            "type": "object",
85            "properties": {
86                "question": {
87                    "type": "string",
88                    "description": "The question to ask the user"
89                },
90                "choices": {
91                    "type": "array",
92                    "items": { "type": "string" },
93                    "description": "Optional list of choices (renders as buttons on Telegram, numbered list on CLI)"
94                },
95                "timeout_secs": {
96                    "type": "integer",
97                    "description": "Seconds to wait for a response (default: 300)"
98                },
99                "channel": {
100                    "type": "string",
101                    "description": "Target channel name. Defaults to the first available channel if omitted."
102                }
103            },
104            "required": ["question"]
105        })
106    }
107
108    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
109        // Security gate: Act operation
110        if let Err(e) = self
111            .security
112            .enforce_tool_operation(ToolOperation::Act, "ask_user")
113        {
114            return Ok(ToolResult {
115                success: false,
116                output: String::new(),
117                error: Some(format!("Action blocked: {e}")),
118            });
119        }
120
121        // Parse required params
122        let question = args
123            .get("question")
124            .and_then(|v| v.as_str())
125            .map(|s| s.trim())
126            .filter(|s| !s.is_empty())
127            .ok_or_else(|| anyhow::anyhow!("Missing 'question' parameter"))?
128            .to_string();
129
130        let choices: Option<Vec<String>> = args.get("choices").and_then(|v| {
131            v.as_array().map(|arr| {
132                arr.iter()
133                    .filter_map(|item| item.as_str().map(|s| s.trim().to_string()))
134                    .filter(|s| !s.is_empty())
135                    .collect()
136            })
137        });
138
139        let timeout_secs = args
140            .get("timeout_secs")
141            .and_then(|v| v.as_u64())
142            .unwrap_or(DEFAULT_TIMEOUT_SECS);
143
144        let requested_channel = args
145            .get("channel")
146            .and_then(|v| v.as_str())
147            .map(|s| s.trim().to_string());
148
149        // Resolve channel from handle — block-scoped to drop the RwLock guard
150        // before any `.await` (parking_lot guards are !Send).
151        let (channel_name, channel): (String, Arc<dyn Channel>) = {
152            let channels = self.channels.read();
153            if channels.is_empty() {
154                return Ok(ToolResult {
155                    success: false,
156                    output: String::new(),
157                    error: Some("No channels available yet (channels not initialized)".to_string()),
158                });
159            }
160            if let Some(ref name) = requested_channel {
161                let ch = channels.get(name.as_str()).cloned().ok_or_else(|| {
162                    let available: Vec<String> = channels.keys().cloned().collect();
163                    anyhow::anyhow!(
164                        "Channel '{}' not found. Available: {}",
165                        name,
166                        available.join(", ")
167                    )
168                })?;
169                (name.clone(), ch)
170            } else {
171                let (name, ch) = channels.iter().next().ok_or_else(|| {
172                    anyhow::anyhow!("No channels available. Configure at least one channel.")
173                })?;
174                (name.clone(), ch.clone())
175            }
176        };
177
178        // Format and send the question
179        let text = format_question(&question, choices.as_deref());
180        let msg = SendMessage::new(&text, "");
181        if let Err(e) = channel.send(&msg).await {
182            return Ok(ToolResult {
183                success: false,
184                output: String::new(),
185                error: Some(format!(
186                    "Failed to send question to channel '{channel_name}': {e}"
187                )),
188            });
189        }
190
191        // Listen for user response with timeout
192        let (tx, mut rx) = tokio::sync::mpsc::channel::<ChannelMessage>(1);
193        let timeout = std::time::Duration::from_secs(timeout_secs);
194
195        // Spawn a listener task on the channel
196        let listen_channel = Arc::clone(&channel);
197        let listen_handle = tokio::spawn(async move { listen_channel.listen(tx).await });
198
199        let response = tokio::time::timeout(timeout, rx.recv()).await;
200
201        // Abort the listener once we have a response or timeout
202        listen_handle.abort();
203
204        match response {
205            Ok(Some(msg)) => Ok(ToolResult {
206                success: true,
207                output: msg.content,
208                error: None,
209            }),
210            Ok(None) => Ok(ToolResult {
211                success: false,
212                output: "TIMEOUT".to_string(),
213                error: Some("Channel closed before receiving a response".to_string()),
214            }),
215            Err(_) => Ok(ToolResult {
216                success: false,
217                output: "TIMEOUT".to_string(),
218                error: Some(format!(
219                    "No response received within {timeout_secs} seconds"
220                )),
221            }),
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    /// A stub channel that records sent messages but never produces incoming messages.
231    struct SilentChannel {
232        channel_name: String,
233        sent: Arc<RwLock<Vec<String>>>,
234    }
235
236    impl SilentChannel {
237        fn new(name: &str) -> Self {
238            Self {
239                channel_name: name.to_string(),
240                sent: Arc::new(RwLock::new(Vec::new())),
241            }
242        }
243    }
244
245    #[async_trait]
246    impl Channel for SilentChannel {
247        fn name(&self) -> &str {
248            &self.channel_name
249        }
250
251        async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
252            self.sent.write().push(message.content.clone());
253            Ok(())
254        }
255
256        async fn listen(
257            &self,
258            _tx: tokio::sync::mpsc::Sender<ChannelMessage>,
259        ) -> anyhow::Result<()> {
260            // Never sends anything — simulates no user response
261            tokio::time::sleep(std::time::Duration::from_secs(600)).await;
262            Ok(())
263        }
264    }
265
266    /// A stub channel that immediately responds with a canned message.
267    struct RespondingChannel {
268        channel_name: String,
269        response: String,
270        sent: Arc<RwLock<Vec<String>>>,
271    }
272
273    impl RespondingChannel {
274        fn new(name: &str, response: &str) -> Self {
275            Self {
276                channel_name: name.to_string(),
277                response: response.to_string(),
278                sent: Arc::new(RwLock::new(Vec::new())),
279            }
280        }
281    }
282
283    #[async_trait]
284    impl Channel for RespondingChannel {
285        fn name(&self) -> &str {
286            &self.channel_name
287        }
288
289        async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
290            self.sent.write().push(message.content.clone());
291            Ok(())
292        }
293
294        async fn listen(
295            &self,
296            tx: tokio::sync::mpsc::Sender<ChannelMessage>,
297        ) -> anyhow::Result<()> {
298            let msg = ChannelMessage {
299                id: "resp_1".to_string(),
300                sender: "user".to_string(),
301                reply_target: "user".to_string(),
302                content: self.response.clone(),
303                channel: self.channel_name.clone(),
304                timestamp: 1000,
305                thread_ts: None,
306                interruption_scope_id: None,
307                attachments: vec![],
308            };
309            let _ = tx.send(msg).await;
310            Ok(())
311        }
312    }
313
314    fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> AskUserTool {
315        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
316        let map: HashMap<String, Arc<dyn Channel>> = channels
317            .into_iter()
318            .map(|(name, ch)| (name.to_string(), ch))
319            .collect();
320        tool.populate(map);
321        tool
322    }
323
324    // ── Metadata tests ──
325
326    #[test]
327    fn tool_name_and_description() {
328        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
329        assert_eq!(tool.name(), "ask_user");
330        assert!(!tool.description().is_empty());
331        assert!(tool.description().contains("question"));
332    }
333
334    #[test]
335    fn parameter_schema_validation() {
336        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
337        let schema = tool.parameters_schema();
338        assert_eq!(schema["type"], "object");
339        assert!(schema["properties"]["question"].is_object());
340        assert!(schema["properties"]["choices"].is_object());
341        assert!(schema["properties"]["timeout_secs"].is_object());
342        assert!(schema["properties"]["channel"].is_object());
343        let required = schema["required"].as_array().unwrap();
344        assert!(required.iter().any(|v| v == "question"));
345        // choices, timeout_secs, channel are optional
346        assert!(!required.iter().any(|v| v == "choices"));
347        assert!(!required.iter().any(|v| v == "timeout_secs"));
348        assert!(!required.iter().any(|v| v == "channel"));
349    }
350
351    #[test]
352    fn spec_matches_metadata() {
353        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
354        let spec = tool.spec();
355        assert_eq!(spec.name, "ask_user");
356        assert_eq!(spec.description, tool.description());
357        assert!(spec.parameters["required"].is_array());
358    }
359
360    // ── Format question tests ──
361
362    #[test]
363    fn format_question_without_choices() {
364        let text = format_question("Are you sure?", None);
365        assert!(text.contains("Are you sure?"));
366        assert!(!text.contains("1."));
367    }
368
369    #[test]
370    fn format_question_with_choices() {
371        let choices = vec!["Yes".to_string(), "No".to_string(), "Maybe".to_string()];
372        let text = format_question("Continue?", Some(&choices));
373        assert!(text.contains("Continue?"));
374        assert!(text.contains("1. Yes"));
375        assert!(text.contains("2. No"));
376        assert!(text.contains("3. Maybe"));
377        assert!(text.contains("Reply with a number"));
378    }
379
380    // ── Execute tests ──
381
382    #[tokio::test]
383    async fn execute_rejects_missing_question() {
384        let tool = make_tool_with_channels(vec![(
385            "test",
386            Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
387        )]);
388        let result = tool.execute(json!({})).await;
389        assert!(result.is_err());
390    }
391
392    #[tokio::test]
393    async fn execute_rejects_empty_question() {
394        let tool = make_tool_with_channels(vec![(
395            "test",
396            Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
397        )]);
398        let result = tool.execute(json!({ "question": "  " })).await;
399        assert!(result.is_err());
400    }
401
402    #[tokio::test]
403    async fn empty_channels_returns_not_initialized() {
404        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
405        let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
406        assert!(!result.success);
407        assert!(result.error.as_deref().unwrap().contains("not initialized"));
408    }
409
410    #[tokio::test]
411    async fn unknown_channel_returns_error() {
412        let tool = make_tool_with_channels(vec![(
413            "slack",
414            Arc::new(SilentChannel::new("slack")) as Arc<dyn Channel>,
415        )]);
416        let result = tool
417            .execute(json!({ "question": "Hello?", "channel": "nonexistent" }))
418            .await;
419        assert!(result.is_err());
420    }
421
422    #[tokio::test]
423    async fn timeout_returns_timeout_output() {
424        let tool = make_tool_with_channels(vec![(
425            "test",
426            Arc::new(SilentChannel::new("test")) as Arc<dyn Channel>,
427        )]);
428        let result = tool
429            .execute(json!({
430                "question": "Confirm?",
431                "timeout_secs": 1
432            }))
433            .await
434            .unwrap();
435        assert!(!result.success);
436        assert_eq!(result.output, "TIMEOUT");
437        assert!(result.error.as_deref().unwrap().contains("1 seconds"));
438    }
439
440    #[tokio::test]
441    async fn successful_response_flow() {
442        let tool = make_tool_with_channels(vec![(
443            "test",
444            Arc::new(RespondingChannel::new("test", "Yes, proceed!")) as Arc<dyn Channel>,
445        )]);
446        let result = tool
447            .execute(json!({
448                "question": "Should we deploy?",
449                "timeout_secs": 5
450            }))
451            .await
452            .unwrap();
453        assert!(result.success, "error: {:?}", result.error);
454        assert_eq!(result.output, "Yes, proceed!");
455        assert!(result.error.is_none());
456    }
457
458    #[tokio::test]
459    async fn successful_response_with_choices() {
460        let tool = make_tool_with_channels(vec![(
461            "telegram",
462            Arc::new(RespondingChannel::new("telegram", "2")) as Arc<dyn Channel>,
463        )]);
464        let result = tool
465            .execute(json!({
466                "question": "Pick an option",
467                "choices": ["Option A", "Option B"],
468                "channel": "telegram",
469                "timeout_secs": 5
470            }))
471            .await
472            .unwrap();
473        assert!(result.success, "error: {:?}", result.error);
474        assert_eq!(result.output, "2");
475    }
476
477    #[tokio::test]
478    async fn channel_map_handle_allows_late_binding() {
479        let tool = AskUserTool::new(Arc::new(SecurityPolicy::default()));
480        let handle = tool.channel_map_handle();
481
482        // Initially empty — tool reports not initialized
483        let result = tool.execute(json!({ "question": "Hello?" })).await.unwrap();
484        assert!(!result.success);
485
486        // Populate via the handle
487        {
488            let mut map = handle.write();
489            map.insert(
490                "cli".to_string(),
491                Arc::new(RespondingChannel::new("cli", "ok")) as Arc<dyn Channel>,
492            );
493        }
494
495        // Now the tool can route to the channel
496        let result = tool
497            .execute(json!({ "question": "Hello?", "timeout_secs": 5 }))
498            .await
499            .unwrap();
500        assert!(result.success);
501        assert_eq!(result.output, "ok");
502    }
503}