1use std::collections::HashSet;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use super::{Tool, ToolContext, ToolOutput};
7use crate::error::Result;
8use crate::ui::{SelectOption, UserInterface};
9
10pub struct AskTool;
11
12#[async_trait]
13impl Tool for AskTool {
14 fn name(&self) -> &str {
15 "ask_user"
16 }
17 fn label(&self) -> &str {
18 "Ask User"
19 }
20 fn description(&self) -> &str {
21 "Ask the user a question. Use choices for single- or multi-select questions."
22 }
23 fn parameters(&self) -> serde_json::Value {
24 json!({
25 "type": "object",
26 "properties": {
27 "question": { "type": "string" },
28 "choices": {
29 "type": "array",
30 "description": "Choices the user can select from. Use [\"Yes\", \"No\"] for yes/no questions.",
31 "items": { "type": "string" }
32 },
33 "multi_select": { "type": "boolean" },
34 "allow_other": { "type": "boolean" },
35 "placeholder": { "type": "string" }
36 },
37 "required": ["question"]
38 })
39 }
40 fn is_readonly(&self) -> bool {
41 true
42 }
43
44 async fn execute(
45 &self,
46 _call_id: &str,
47 params: serde_json::Value,
48 ctx: ToolContext,
49 ) -> Result<ToolOutput> {
50 if !ctx.ui.has_ui() {
51 return Ok(ToolOutput::error(
52 "Cannot access ask_user tool in this mode. Proceed with an explicit assumption if low-risk, or record a blocker/decision if consequential.",
53 ));
54 }
55
56 let Some(question) = params["question"]
57 .as_str()
58 .map(str::trim)
59 .filter(|q| !q.is_empty())
60 else {
61 return Ok(ToolOutput::error("Missing required parameter: question"));
62 };
63
64 let choices = match parse_choices(¶ms) {
65 Ok(choices) => choices,
66 Err(message) => return Ok(ToolOutput::error(message)),
67 };
68 let multi_select = params["multi_select"].as_bool().unwrap_or(false);
69 let allow_other = params["allow_other"].as_bool().unwrap_or(false);
70 let placeholder = params["placeholder"].as_str().unwrap_or("");
71
72 match choices {
73 Some(mut choices) => {
74 if allow_other {
75 choices.push(SelectOption {
76 label: "Other...".to_string(),
77 description: Some("Type a custom answer".to_string()),
78 });
79 }
80
81 if multi_select {
82 execute_multi_select(&*ctx.ui, question, placeholder, &choices, allow_other)
83 .await
84 } else {
85 execute_single_select(&*ctx.ui, question, placeholder, &choices, allow_other)
86 .await
87 }
88 }
89 None => match ctx.ui.input(question, placeholder).await {
90 Some(answer) => Ok(answer_output(answer, true)),
91 None => Ok(skipped_output(false)),
92 },
93 }
94 }
95}
96
97fn parse_choices(
98 params: &serde_json::Value,
99) -> std::result::Result<Option<Vec<SelectOption>>, String> {
100 let Some(value) = params.get("choices") else {
101 return Ok(None);
102 };
103 let Some(values) = value.as_array() else {
104 return Err("choices must be an array of strings".to_string());
105 };
106 if values.is_empty() {
107 return Err("choices must not be empty".to_string());
108 }
109 if values.len() > 50 {
110 return Err("choices must contain at most 50 items".to_string());
111 }
112
113 let mut seen = HashSet::new();
114 let mut choices = Vec::with_capacity(values.len());
115 for (index, value) in values.iter().enumerate() {
116 let Some(label) = value.as_str().map(str::trim).filter(|s| !s.is_empty()) else {
117 return Err(format!("choices[{index}] must be a non-empty string"));
118 };
119 if label.len() > 200 {
120 return Err(format!("choices[{index}] is too long"));
121 }
122 if !seen.insert(label.to_string()) {
123 return Err(format!("duplicate choice: {label}"));
124 }
125 choices.push(SelectOption {
126 label: label.to_string(),
127 description: None,
128 });
129 }
130
131 Ok(Some(choices))
132}
133
134async fn execute_single_select(
135 ui: &dyn UserInterface,
136 question: &str,
137 placeholder: &str,
138 choices: &[SelectOption],
139 allow_other: bool,
140) -> Result<ToolOutput> {
141 match ui.select(question, choices).await {
142 Some(index) if allow_other && index == choices.len() - 1 => {
143 match ui.input("Enter your answer:", placeholder).await {
144 Some(answer) => Ok(tool_text_with_details(
145 &answer,
146 json!({
147 "answered": true,
148 "skipped": false,
149 "answer": answer,
150 "answers": [answer],
151 "other": true,
152 "multi_select": false
153 }),
154 )),
155 None => Ok(skipped_output(false)),
156 }
157 }
158 Some(index) if index < choices.len() => {
159 let answer = choices[index].label.clone();
160 Ok(tool_text_with_details(
161 &answer,
162 json!({
163 "answered": true,
164 "skipped": false,
165 "answer": answer,
166 "answers": [answer],
167 "choice_index": index,
168 "choice_indices": [index],
169 "other": false,
170 "multi_select": false
171 }),
172 ))
173 }
174 _ => Ok(skipped_output(false)),
175 }
176}
177
178async fn execute_multi_select(
179 ui: &dyn UserInterface,
180 question: &str,
181 placeholder: &str,
182 choices: &[SelectOption],
183 allow_other: bool,
184) -> Result<ToolOutput> {
185 let Some(indices) = ui.multi_select_with_context(question, "", choices).await else {
186 return Ok(skipped_output(true));
187 };
188 if indices.is_empty() {
189 return Ok(skipped_output(true));
190 }
191
192 let other_index = choices.len().saturating_sub(1);
193 let mut answers = Vec::new();
194 let mut choice_indices = Vec::new();
195 let mut other = false;
196 for index in indices {
197 if index >= choices.len() {
198 continue;
199 }
200 if allow_other && index == other_index {
201 other = true;
202 if let Some(answer) = ui.input("Enter your answer:", placeholder).await {
203 answers.push(answer);
204 choice_indices.push(index);
205 }
206 } else {
207 answers.push(choices[index].label.clone());
208 choice_indices.push(index);
209 }
210 }
211
212 if answers.is_empty() {
213 return Ok(skipped_output(true));
214 }
215
216 let text = answers.join(", ");
217 Ok(tool_text_with_details(
218 &text,
219 json!({
220 "answered": true,
221 "skipped": false,
222 "answer": text,
223 "answers": answers,
224 "choice_indices": choice_indices,
225 "other": other,
226 "multi_select": true
227 }),
228 ))
229}
230
231fn tool_text_with_details(text: &str, details: serde_json::Value) -> ToolOutput {
232 let mut output = ToolOutput::text(text);
233 output.details = details;
234 output
235}
236
237fn answer_output(answer: String, free_text: bool) -> ToolOutput {
238 tool_text_with_details(
239 &answer,
240 json!({
241 "answered": true,
242 "skipped": false,
243 "answer": answer,
244 "answers": [answer],
245 "free_text": free_text,
246 "multi_select": false
247 }),
248 )
249}
250
251fn skipped_output(multi_select: bool) -> ToolOutput {
252 tool_text_with_details(
253 "User skipped",
254 json!({
255 "answered": false,
256 "skipped": true,
257 "multi_select": multi_select
258 }),
259 )
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::tools::ToolContext;
266 use crate::ui::NullInterface;
267 use std::sync::{Arc, Mutex};
268
269 fn test_ctx<T: crate::ui::UserInterface + 'static>(ui: Arc<T>) -> ToolContext {
270 let (tx, _rx) = tokio::sync::mpsc::channel(16);
271 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
272 ToolContext {
273 cwd: std::path::PathBuf::from("/tmp"),
274 cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
275 update_tx: tx,
276 command_tx: cmd_tx,
277 ui: ui as Arc<dyn crate::ui::UserInterface>,
278 file_cache: Arc::new(crate::tools::FileCache::new()),
279 checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
280 file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
281 anchor_store: Arc::new(crate::tools::AnchorStore::new()),
282 lua_tool_loader: None,
283 mode: crate::config::AgentMode::Full,
284 read_max_lines: 500,
285 turn_mana_review: Arc::new(std::sync::Mutex::new(
286 crate::mana_review::TurnManaReviewAccumulator::default(),
287 )),
288 config: Arc::new(crate::config::Config::default()),
289 run_policy: Default::default(),
290 supporting_provenance: Vec::new(),
291 }
292 }
293
294 #[tokio::test]
295 async fn ask_null_interface_returns_error() {
296 let tool = AskTool;
297 let result = tool
298 .execute(
299 "c1",
300 json!({"question": "What color?"}),
301 test_ctx(Arc::new(NullInterface)),
302 )
303 .await
304 .unwrap();
305
306 assert!(result.is_error);
307 let text = extract_text(&result);
308 assert!(text.contains("Cannot access ask_user tool in this mode"));
309 }
310
311 #[tokio::test]
312 async fn ask_missing_question_returns_error() {
313 let tool = AskTool;
314 let result = tool
315 .execute("c3", json!({}), test_ctx(Arc::new(MockUi::default())))
316 .await
317 .unwrap();
318
319 assert!(result.is_error);
320 assert!(extract_text(&result).contains("Missing required parameter: question"));
321 }
322
323 #[tokio::test]
324 async fn ask_single_choice_returns_structured_answer() {
325 let tool = AskTool;
326 let ui = MockUi::new().with_select(1);
327 let result = tool
328 .execute(
329 "c4",
330 json!({"question": "Pick", "choices": ["Red", "Blue"]}),
331 test_ctx(ui),
332 )
333 .await
334 .unwrap();
335
336 assert_eq!(extract_text(&result), "Blue");
337 assert_eq!(result.details["answered"], true);
338 assert_eq!(result.details["choice_index"], 1);
339 assert_eq!(result.details["multi_select"], false);
340 }
341
342 #[tokio::test]
343 async fn ask_multi_select_returns_structured_answers() {
344 let tool = AskTool;
345 let ui = MockUi::new().with_multi_select(vec![0, 2]);
346 let result = tool
347 .execute(
348 "c5",
349 json!({"question": "Pick", "choices": ["Red", "Blue", "Green"], "multi_select": true}),
350 test_ctx(ui),
351 )
352 .await
353 .unwrap();
354
355 assert_eq!(extract_text(&result), "Red, Green");
356 assert_eq!(result.details["answers"][0], "Red");
357 assert_eq!(result.details["answers"][1], "Green");
358 assert_eq!(result.details["multi_select"], true);
359 }
360
361 #[tokio::test]
362 async fn ask_free_text_uses_placeholder() {
363 let tool = AskTool;
364 let ui = MockUi::new().with_input("typed");
365 let result = tool
366 .execute(
367 "c6",
368 json!({"question": "Name?", "placeholder": "e.g. Atlas"}),
369 test_ctx(ui.clone()),
370 )
371 .await
372 .unwrap();
373
374 assert_eq!(extract_text(&result), "typed");
375 assert_eq!(
376 ui.last_placeholder.lock().unwrap().as_deref(),
377 Some("e.g. Atlas")
378 );
379 }
380
381 #[tokio::test]
382 async fn ask_rejects_duplicate_choices() {
383 let tool = AskTool;
384 let result = tool
385 .execute(
386 "c7",
387 json!({"question": "Pick", "choices": ["Red", "Red"]}),
388 test_ctx(Arc::new(MockUi::default())),
389 )
390 .await
391 .unwrap();
392
393 assert!(result.is_error);
394 assert!(extract_text(&result).contains("duplicate choice"));
395 }
396
397 #[derive(Default)]
398 struct MockUi {
399 select: Mutex<Option<usize>>,
400 multi_select: Mutex<Option<Vec<usize>>>,
401 input: Mutex<Option<String>>,
402 last_placeholder: Mutex<Option<String>>,
403 }
404
405 impl MockUi {
406 fn new() -> Arc<Self> {
407 Arc::new(Self::default())
408 }
409
410 fn with_select(self: Arc<Self>, value: usize) -> Arc<Self> {
411 *self.select.lock().unwrap() = Some(value);
412 self
413 }
414 fn with_multi_select(self: Arc<Self>, value: Vec<usize>) -> Arc<Self> {
415 *self.multi_select.lock().unwrap() = Some(value);
416 self
417 }
418 fn with_input(self: Arc<Self>, value: &str) -> Arc<Self> {
419 *self.input.lock().unwrap() = Some(value.to_string());
420 self
421 }
422 }
423
424 #[async_trait]
425 impl crate::ui::UserInterface for MockUi {
426 fn has_ui(&self) -> bool {
427 true
428 }
429 async fn notify(&self, _: &str, _: crate::ui::NotifyLevel) {}
430 async fn confirm(&self, _: &str, _: &str) -> Option<bool> {
431 None
432 }
433 async fn select_with_context(&self, _: &str, _: &str, _: &[SelectOption]) -> Option<usize> {
434 *self.select.lock().unwrap()
435 }
436 async fn multi_select_with_context(
437 &self,
438 _: &str,
439 _: &str,
440 _: &[SelectOption],
441 ) -> Option<Vec<usize>> {
442 self.multi_select.lock().unwrap().clone()
443 }
444 async fn input_with_context(&self, _: &str, _: &str, placeholder: &str) -> Option<String> {
445 *self.last_placeholder.lock().unwrap() = Some(placeholder.to_string());
446 self.input.lock().unwrap().clone()
447 }
448 async fn set_status(&self, _: &str, _: Option<&str>) {}
449 async fn set_widget(&self, _: &str, _: Option<crate::ui::WidgetContent>) {}
450 async fn custom(&self, _: crate::ui::ComponentSpec) -> Option<serde_json::Value> {
451 None
452 }
453 }
454
455 fn extract_text(output: &ToolOutput) -> String {
456 output
457 .content
458 .iter()
459 .filter_map(|b| match b {
460 imp_llm::ContentBlock::Text { text } => Some(text.as_str()),
461 _ => None,
462 })
463 .collect::<Vec<_>>()
464 .join("\n")
465 }
466}