Skip to main content

imp_core/tools/
ask.rs

1use std::collections::HashSet;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use super::{Tool, ToolContext, ToolOutput};
7use crate::error::Result;
8use crate::ui::{SelectOption, UserInterface};
9
10pub struct AskTool;
11
12#[async_trait]
13impl Tool for AskTool {
14    fn name(&self) -> &str {
15        "ask_user"
16    }
17    fn label(&self) -> &str {
18        "Ask User"
19    }
20    fn description(&self) -> &str {
21        "Ask the user a question. Use choices for single- or multi-select questions."
22    }
23    fn parameters(&self) -> serde_json::Value {
24        json!({
25            "type": "object",
26            "properties": {
27                "question": { "type": "string" },
28                "choices": {
29                    "type": "array",
30                    "description": "Choices the user can select from. Use [\"Yes\", \"No\"] for yes/no questions.",
31                    "items": { "type": "string" }
32                },
33                "multi_select": { "type": "boolean" },
34                "allow_other": { "type": "boolean" },
35                "placeholder": { "type": "string" }
36            },
37            "required": ["question"]
38        })
39    }
40    fn is_readonly(&self) -> bool {
41        true
42    }
43
44    async fn execute(
45        &self,
46        _call_id: &str,
47        params: serde_json::Value,
48        ctx: ToolContext,
49    ) -> Result<ToolOutput> {
50        if !ctx.ui.has_ui() {
51            return Ok(ToolOutput::error(
52                "Cannot access ask_user tool in this mode. Proceed with an explicit assumption if low-risk, or record a blocker/decision if consequential.",
53            ));
54        }
55
56        let Some(question) = params["question"]
57            .as_str()
58            .map(str::trim)
59            .filter(|q| !q.is_empty())
60        else {
61            return Ok(ToolOutput::error("Missing required parameter: question"));
62        };
63
64        let choices = match parse_choices(&params) {
65            Ok(choices) => choices,
66            Err(message) => return Ok(ToolOutput::error(message)),
67        };
68        let multi_select = params["multi_select"].as_bool().unwrap_or(false);
69        let allow_other = params["allow_other"].as_bool().unwrap_or(false);
70        let placeholder = params["placeholder"].as_str().unwrap_or("");
71
72        match choices {
73            Some(mut choices) => {
74                if allow_other {
75                    choices.push(SelectOption {
76                        label: "Other...".to_string(),
77                        description: Some("Type a custom answer".to_string()),
78                    });
79                }
80
81                if multi_select {
82                    execute_multi_select(&*ctx.ui, question, placeholder, &choices, allow_other)
83                        .await
84                } else {
85                    execute_single_select(&*ctx.ui, question, placeholder, &choices, allow_other)
86                        .await
87                }
88            }
89            None => match ctx.ui.input(question, placeholder).await {
90                Some(answer) => Ok(answer_output(answer, true)),
91                None => Ok(skipped_output(false)),
92            },
93        }
94    }
95}
96
97fn parse_choices(
98    params: &serde_json::Value,
99) -> std::result::Result<Option<Vec<SelectOption>>, String> {
100    let Some(value) = params.get("choices") else {
101        return Ok(None);
102    };
103    let Some(values) = value.as_array() else {
104        return Err("choices must be an array of strings".to_string());
105    };
106    if values.is_empty() {
107        return Err("choices must not be empty".to_string());
108    }
109    if values.len() > 50 {
110        return Err("choices must contain at most 50 items".to_string());
111    }
112
113    let mut seen = HashSet::new();
114    let mut choices = Vec::with_capacity(values.len());
115    for (index, value) in values.iter().enumerate() {
116        let Some(label) = value.as_str().map(str::trim).filter(|s| !s.is_empty()) else {
117            return Err(format!("choices[{index}] must be a non-empty string"));
118        };
119        if label.len() > 200 {
120            return Err(format!("choices[{index}] is too long"));
121        }
122        if !seen.insert(label.to_string()) {
123            return Err(format!("duplicate choice: {label}"));
124        }
125        choices.push(SelectOption {
126            label: label.to_string(),
127            description: None,
128        });
129    }
130
131    Ok(Some(choices))
132}
133
134async fn execute_single_select(
135    ui: &dyn UserInterface,
136    question: &str,
137    placeholder: &str,
138    choices: &[SelectOption],
139    allow_other: bool,
140) -> Result<ToolOutput> {
141    match ui.select(question, choices).await {
142        Some(index) if allow_other && index == choices.len() - 1 => {
143            match ui.input("Enter your answer:", placeholder).await {
144                Some(answer) => Ok(tool_text_with_details(
145                    &answer,
146                    json!({
147                        "answered": true,
148                        "skipped": false,
149                        "answer": answer,
150                        "answers": [answer],
151                        "other": true,
152                        "multi_select": false
153                    }),
154                )),
155                None => Ok(skipped_output(false)),
156            }
157        }
158        Some(index) if index < choices.len() => {
159            let answer = choices[index].label.clone();
160            Ok(tool_text_with_details(
161                &answer,
162                json!({
163                    "answered": true,
164                    "skipped": false,
165                    "answer": answer,
166                    "answers": [answer],
167                    "choice_index": index,
168                    "choice_indices": [index],
169                    "other": false,
170                    "multi_select": false
171                }),
172            ))
173        }
174        _ => Ok(skipped_output(false)),
175    }
176}
177
178async fn execute_multi_select(
179    ui: &dyn UserInterface,
180    question: &str,
181    placeholder: &str,
182    choices: &[SelectOption],
183    allow_other: bool,
184) -> Result<ToolOutput> {
185    let Some(indices) = ui.multi_select_with_context(question, "", choices).await else {
186        return Ok(skipped_output(true));
187    };
188    if indices.is_empty() {
189        return Ok(skipped_output(true));
190    }
191
192    let other_index = choices.len().saturating_sub(1);
193    let mut answers = Vec::new();
194    let mut choice_indices = Vec::new();
195    let mut other = false;
196    for index in indices {
197        if index >= choices.len() {
198            continue;
199        }
200        if allow_other && index == other_index {
201            other = true;
202            if let Some(answer) = ui.input("Enter your answer:", placeholder).await {
203                answers.push(answer);
204                choice_indices.push(index);
205            }
206        } else {
207            answers.push(choices[index].label.clone());
208            choice_indices.push(index);
209        }
210    }
211
212    if answers.is_empty() {
213        return Ok(skipped_output(true));
214    }
215
216    let text = answers.join(", ");
217    Ok(tool_text_with_details(
218        &text,
219        json!({
220            "answered": true,
221            "skipped": false,
222            "answer": text,
223            "answers": answers,
224            "choice_indices": choice_indices,
225            "other": other,
226            "multi_select": true
227        }),
228    ))
229}
230
231fn tool_text_with_details(text: &str, details: serde_json::Value) -> ToolOutput {
232    let mut output = ToolOutput::text(text);
233    output.details = details;
234    output
235}
236
237fn answer_output(answer: String, free_text: bool) -> ToolOutput {
238    tool_text_with_details(
239        &answer,
240        json!({
241            "answered": true,
242            "skipped": false,
243            "answer": answer,
244            "answers": [answer],
245            "free_text": free_text,
246            "multi_select": false
247        }),
248    )
249}
250
251fn skipped_output(multi_select: bool) -> ToolOutput {
252    tool_text_with_details(
253        "User skipped",
254        json!({
255            "answered": false,
256            "skipped": true,
257            "multi_select": multi_select
258        }),
259    )
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::tools::ToolContext;
266    use crate::ui::NullInterface;
267    use std::sync::{Arc, Mutex};
268
269    fn test_ctx<T: crate::ui::UserInterface + 'static>(ui: Arc<T>) -> ToolContext {
270        let (tx, _rx) = tokio::sync::mpsc::channel(16);
271        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
272        ToolContext {
273            cwd: std::path::PathBuf::from("/tmp"),
274            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
275            update_tx: tx,
276            command_tx: cmd_tx,
277            ui: ui as Arc<dyn crate::ui::UserInterface>,
278            file_cache: Arc::new(crate::tools::FileCache::new()),
279            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
280            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
281            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
282            lua_tool_loader: None,
283            mode: crate::config::AgentMode::Full,
284            read_max_lines: 500,
285            turn_mana_review: Arc::new(std::sync::Mutex::new(
286                crate::mana_review::TurnManaReviewAccumulator::default(),
287            )),
288            config: Arc::new(crate::config::Config::default()),
289            run_policy: Default::default(),
290            supporting_provenance: Vec::new(),
291        }
292    }
293
294    #[tokio::test]
295    async fn ask_null_interface_returns_error() {
296        let tool = AskTool;
297        let result = tool
298            .execute(
299                "c1",
300                json!({"question": "What color?"}),
301                test_ctx(Arc::new(NullInterface)),
302            )
303            .await
304            .unwrap();
305
306        assert!(result.is_error);
307        let text = extract_text(&result);
308        assert!(text.contains("Cannot access ask_user tool in this mode"));
309    }
310
311    #[tokio::test]
312    async fn ask_missing_question_returns_error() {
313        let tool = AskTool;
314        let result = tool
315            .execute("c3", json!({}), test_ctx(Arc::new(MockUi::default())))
316            .await
317            .unwrap();
318
319        assert!(result.is_error);
320        assert!(extract_text(&result).contains("Missing required parameter: question"));
321    }
322
323    #[tokio::test]
324    async fn ask_single_choice_returns_structured_answer() {
325        let tool = AskTool;
326        let ui = MockUi::new().with_select(1);
327        let result = tool
328            .execute(
329                "c4",
330                json!({"question": "Pick", "choices": ["Red", "Blue"]}),
331                test_ctx(ui),
332            )
333            .await
334            .unwrap();
335
336        assert_eq!(extract_text(&result), "Blue");
337        assert_eq!(result.details["answered"], true);
338        assert_eq!(result.details["choice_index"], 1);
339        assert_eq!(result.details["multi_select"], false);
340    }
341
342    #[tokio::test]
343    async fn ask_multi_select_returns_structured_answers() {
344        let tool = AskTool;
345        let ui = MockUi::new().with_multi_select(vec![0, 2]);
346        let result = tool
347            .execute(
348                "c5",
349                json!({"question": "Pick", "choices": ["Red", "Blue", "Green"], "multi_select": true}),
350                test_ctx(ui),
351            )
352            .await
353            .unwrap();
354
355        assert_eq!(extract_text(&result), "Red, Green");
356        assert_eq!(result.details["answers"][0], "Red");
357        assert_eq!(result.details["answers"][1], "Green");
358        assert_eq!(result.details["multi_select"], true);
359    }
360
361    #[tokio::test]
362    async fn ask_free_text_uses_placeholder() {
363        let tool = AskTool;
364        let ui = MockUi::new().with_input("typed");
365        let result = tool
366            .execute(
367                "c6",
368                json!({"question": "Name?", "placeholder": "e.g. Atlas"}),
369                test_ctx(ui.clone()),
370            )
371            .await
372            .unwrap();
373
374        assert_eq!(extract_text(&result), "typed");
375        assert_eq!(
376            ui.last_placeholder.lock().unwrap().as_deref(),
377            Some("e.g. Atlas")
378        );
379    }
380
381    #[tokio::test]
382    async fn ask_rejects_duplicate_choices() {
383        let tool = AskTool;
384        let result = tool
385            .execute(
386                "c7",
387                json!({"question": "Pick", "choices": ["Red", "Red"]}),
388                test_ctx(Arc::new(MockUi::default())),
389            )
390            .await
391            .unwrap();
392
393        assert!(result.is_error);
394        assert!(extract_text(&result).contains("duplicate choice"));
395    }
396
397    #[derive(Default)]
398    struct MockUi {
399        select: Mutex<Option<usize>>,
400        multi_select: Mutex<Option<Vec<usize>>>,
401        input: Mutex<Option<String>>,
402        last_placeholder: Mutex<Option<String>>,
403    }
404
405    impl MockUi {
406        fn new() -> Arc<Self> {
407            Arc::new(Self::default())
408        }
409
410        fn with_select(self: Arc<Self>, value: usize) -> Arc<Self> {
411            *self.select.lock().unwrap() = Some(value);
412            self
413        }
414        fn with_multi_select(self: Arc<Self>, value: Vec<usize>) -> Arc<Self> {
415            *self.multi_select.lock().unwrap() = Some(value);
416            self
417        }
418        fn with_input(self: Arc<Self>, value: &str) -> Arc<Self> {
419            *self.input.lock().unwrap() = Some(value.to_string());
420            self
421        }
422    }
423
424    #[async_trait]
425    impl crate::ui::UserInterface for MockUi {
426        fn has_ui(&self) -> bool {
427            true
428        }
429        async fn notify(&self, _: &str, _: crate::ui::NotifyLevel) {}
430        async fn confirm(&self, _: &str, _: &str) -> Option<bool> {
431            None
432        }
433        async fn select_with_context(&self, _: &str, _: &str, _: &[SelectOption]) -> Option<usize> {
434            *self.select.lock().unwrap()
435        }
436        async fn multi_select_with_context(
437            &self,
438            _: &str,
439            _: &str,
440            _: &[SelectOption],
441        ) -> Option<Vec<usize>> {
442            self.multi_select.lock().unwrap().clone()
443        }
444        async fn input_with_context(&self, _: &str, _: &str, placeholder: &str) -> Option<String> {
445            *self.last_placeholder.lock().unwrap() = Some(placeholder.to_string());
446            self.input.lock().unwrap().clone()
447        }
448        async fn set_status(&self, _: &str, _: Option<&str>) {}
449        async fn set_widget(&self, _: &str, _: Option<crate::ui::WidgetContent>) {}
450        async fn custom(&self, _: crate::ui::ComponentSpec) -> Option<serde_json::Value> {
451            None
452        }
453    }
454
455    fn extract_text(output: &ToolOutput) -> String {
456        output
457            .content
458            .iter()
459            .filter_map(|b| match b {
460                imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
461                _ => None,
462            })
463            .collect::<Vec<_>>()
464            .join("\n")
465    }
466}