Skip to main content

imp_core/tools/
ask.rs

1use async_trait::async_trait;
2use serde::Deserialize;
3use serde_json::json;
4
5use super::{Tool, ToolContext, ToolOutput};
6use crate::error::Result;
7use crate::ui::SelectOption;
8
9pub struct AskTool;
10
11#[derive(Debug, Deserialize)]
12#[serde(untagged)]
13enum OptionItem {
14    Label(String),
15    Rich {
16        label: String,
17        description: Option<String>,
18    },
19}
20
21impl OptionItem {
22    #[allow(dead_code)]
23    fn into_select_option(self) -> SelectOption {
24        match self {
25            OptionItem::Label(label) => SelectOption {
26                label,
27                description: None,
28            },
29            OptionItem::Rich { label, description } => SelectOption { label, description },
30        }
31    }
32
33    fn label(&self) -> &str {
34        match self {
35            OptionItem::Label(l) => l,
36            OptionItem::Rich { label, .. } => label,
37        }
38    }
39}
40
41#[async_trait]
42impl Tool for AskTool {
43    fn name(&self) -> &str {
44        "ask"
45    }
46    fn label(&self) -> &str {
47        "Ask User"
48    }
49    fn description(&self) -> &str {
50        "Ask the user a question. Use options for multiple choice."
51    }
52    fn parameters(&self) -> serde_json::Value {
53        json!({
54            "type": "object",
55            "properties": {
56                "question": { "type": "string" },
57                "context": { "type": "string" },
58                "options": { "type": "array", "items": {} },
59                "multiSelect": { "type": "boolean" },
60                "allowOther": { "type": "boolean" },
61                "default": {},
62                "placeholder": { "type": "string" }
63            },
64            "required": ["question"]
65        })
66    }
67    fn is_readonly(&self) -> bool {
68        true
69    }
70
71    async fn execute(
72        &self,
73        _call_id: &str,
74        params: serde_json::Value,
75        ctx: ToolContext,
76    ) -> Result<ToolOutput> {
77        if !ctx.ui.has_ui() {
78            return Ok(ToolOutput::error("Cannot ask user in this mode"));
79        }
80
81        let question = match params["question"].as_str() {
82            Some(q) => q,
83            None => return Ok(ToolOutput::error("Missing required parameter: question")),
84        };
85
86        let context = params["context"].as_str().unwrap_or("");
87        let allow_other = params["allowOther"].as_bool().unwrap_or(false);
88        let _multi_select = params["multiSelect"].as_bool().unwrap_or(false);
89        let placeholder = params["placeholder"].as_str().unwrap_or("");
90
91        let title = question.to_string();
92
93        // If options are provided, use select; otherwise use text input
94        let raw_options: Option<Vec<OptionItem>> = params
95            .get("options")
96            .and_then(|v| serde_json::from_value(v.clone()).ok());
97
98        match raw_options {
99            Some(items) if !items.is_empty() => {
100                let mut options: Vec<SelectOption> = items
101                    .iter()
102                    .map(|item| SelectOption {
103                        label: item.label().to_string(),
104                        description: match item {
105                            OptionItem::Rich { description, .. } => description.clone(),
106                            OptionItem::Label(_) => None,
107                        },
108                    })
109                    .collect();
110
111                if allow_other {
112                    options.push(SelectOption {
113                        label: "Other...".to_string(),
114                        description: None,
115                    });
116                }
117
118                match ctx.ui.select_with_context(&title, context, &options).await {
119                    Some(idx) => {
120                        // If "Other..." was selected and allow_other is on
121                        if allow_other && idx == options.len() - 1 {
122                            match ctx.ui.input("Enter your answer:", placeholder).await {
123                                Some(text) => Ok(ToolOutput::text(text)),
124                                None => Ok(ToolOutput::text("User skipped")),
125                            }
126                        } else {
127                            Ok(ToolOutput::text(&options[idx].label))
128                        }
129                    }
130                    None => Ok(ToolOutput::text("User skipped")),
131                }
132            }
133            _ => {
134                // Free text input
135                match ctx
136                    .ui
137                    .input_with_context(&title, context, placeholder)
138                    .await
139                {
140                    Some(text) => Ok(ToolOutput::text(text)),
141                    None => Ok(ToolOutput::text("User skipped")),
142                }
143            }
144        }
145    }
146}
147
148/// Format options into a display string (for logging/debugging).
149pub fn format_options(options: &[SelectOption]) -> String {
150    options
151        .iter()
152        .enumerate()
153        .map(|(i, opt)| match &opt.description {
154            Some(desc) => format!("  {}. {} — {}", i + 1, opt.label, desc),
155            None => format!("  {}. {}", i + 1, opt.label),
156        })
157        .collect::<Vec<_>>()
158        .join("\n")
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::tools::ToolContext;
165    use crate::ui::NullInterface;
166    use std::sync::Arc;
167
168    fn test_ctx() -> ToolContext {
169        let (tx, _rx) = tokio::sync::mpsc::channel(16);
170        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
171        ToolContext {
172            cwd: std::path::PathBuf::from("/tmp"),
173            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
174            update_tx: tx,
175            command_tx: cmd_tx,
176            ui: Arc::new(NullInterface),
177            file_cache: Arc::new(crate::tools::FileCache::new()),
178            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
179            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
180            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
181            lua_tool_loader: None,
182            mode: crate::config::AgentMode::Full,
183            read_max_lines: 500,
184            turn_mana_review: Arc::new(std::sync::Mutex::new(
185                crate::mana_review::TurnManaReviewAccumulator::default(),
186            )),
187            config: Arc::new(crate::config::Config::default()),
188        }
189    }
190
191    #[tokio::test]
192    async fn ask_null_interface_returns_error() {
193        let tool = AskTool;
194        let result = tool
195            .execute("c1", json!({"question": "What color?"}), test_ctx())
196            .await
197            .unwrap();
198
199        assert!(result.is_error);
200        let text = extract_text(&result);
201        assert!(text.contains("Cannot ask user in this mode"));
202    }
203
204    #[tokio::test]
205    async fn ask_null_interface_with_options_returns_error() {
206        let tool = AskTool;
207        let result = tool
208            .execute(
209                "c2",
210                json!({
211                    "question": "Pick a color",
212                    "options": ["red", "blue", "green"]
213                }),
214                test_ctx(),
215            )
216            .await
217            .unwrap();
218
219        assert!(result.is_error);
220        let text = extract_text(&result);
221        assert!(text.contains("Cannot ask user in this mode"));
222    }
223
224    #[tokio::test]
225    async fn ask_missing_question_returns_error() {
226        // Use a mock UI that has_ui=true to bypass the first check
227        let (tx, _rx) = tokio::sync::mpsc::channel(16);
228        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
229        let ctx = ToolContext {
230            cwd: std::path::PathBuf::from("/tmp"),
231            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
232            update_tx: tx,
233            command_tx: cmd_tx,
234            ui: Arc::new(MockUi),
235            file_cache: Arc::new(crate::tools::FileCache::new()),
236            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
237            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
238            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
239            lua_tool_loader: None,
240            mode: crate::config::AgentMode::Full,
241            read_max_lines: 500,
242            turn_mana_review: Arc::new(std::sync::Mutex::new(
243                crate::mana_review::TurnManaReviewAccumulator::default(),
244            )),
245            config: Arc::new(crate::config::Config::default()),
246        };
247
248        let tool = AskTool;
249        let result = tool.execute("c3", json!({}), ctx).await.unwrap();
250
251        assert!(result.is_error);
252        let text = extract_text(&result);
253        assert!(text.contains("Missing required parameter: question"));
254    }
255
256    #[test]
257    fn format_options_plain() {
258        let options = vec![
259            SelectOption {
260                label: "Red".into(),
261                description: None,
262            },
263            SelectOption {
264                label: "Blue".into(),
265                description: None,
266            },
267        ];
268        let formatted = format_options(&options);
269        assert!(formatted.contains("1. Red"));
270        assert!(formatted.contains("2. Blue"));
271    }
272
273    #[test]
274    fn format_options_with_descriptions() {
275        let options = vec![
276            SelectOption {
277                label: "Rust".into(),
278                description: Some("Systems language".into()),
279            },
280            SelectOption {
281                label: "Python".into(),
282                description: Some("Scripting language".into()),
283            },
284        ];
285        let formatted = format_options(&options);
286        assert!(formatted.contains("Rust — Systems language"));
287        assert!(formatted.contains("Python — Scripting language"));
288    }
289
290    #[test]
291    fn option_item_parsing() {
292        // String options
293        let items: Vec<OptionItem> = serde_json::from_value(json!(["a", "b"])).unwrap();
294        assert_eq!(items[0].label(), "a");
295        assert_eq!(items[1].label(), "b");
296
297        // Rich options
298        let items: Vec<OptionItem> = serde_json::from_value(json!([
299            {"label": "Rust", "description": "Fast"},
300            {"label": "Go"}
301        ]))
302        .unwrap();
303        assert_eq!(items[0].label(), "Rust");
304        assert_eq!(items[1].label(), "Go");
305    }
306
307    // Simple mock UI that has_ui returns true but all interactions return None
308    struct MockUi;
309
310    #[async_trait]
311    impl crate::ui::UserInterface for MockUi {
312        fn has_ui(&self) -> bool {
313            true
314        }
315        async fn notify(&self, _: &str, _: crate::ui::NotifyLevel) {}
316        async fn confirm(&self, _: &str, _: &str) -> Option<bool> {
317            None
318        }
319        async fn select_with_context(&self, _: &str, _: &str, _: &[SelectOption]) -> Option<usize> {
320            None
321        }
322        async fn input_with_context(&self, _: &str, _: &str, _: &str) -> Option<String> {
323            None
324        }
325        async fn set_status(&self, _: &str, _: Option<&str>) {}
326        async fn set_widget(&self, _: &str, _: Option<crate::ui::WidgetContent>) {}
327        async fn custom(&self, _: crate::ui::ComponentSpec) -> Option<serde_json::Value> {
328            None
329        }
330    }
331
332    fn extract_text(output: &ToolOutput) -> String {
333        output
334            .content
335            .iter()
336            .filter_map(|b| match b {
337                imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
338                _ => None,
339            })
340            .collect::<Vec<_>>()
341            .join("\n")
342    }
343}