1use crate::{Tool, ToolContext, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use async_trait::async_trait;
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, json};
45use tokio::sync::mpsc;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ConfirmationRequest {
50 pub tool_name: String,
52
53 pub description: String,
55
56 pub input_preview: String,
58
59 pub tier: String,
61
62 pub context: Option<String>,
64}
65
66impl ConfirmationRequest {
67 #[must_use]
69 pub fn new(
70 tool_name: impl Into<String>,
71 description: impl Into<String>,
72 input_preview: impl Into<String>,
73 tier: ToolTier,
74 ) -> Self {
75 Self {
76 tool_name: tool_name.into(),
77 description: description.into(),
78 input_preview: input_preview.into(),
79 tier: format!("{tier:?}"),
80 context: None,
81 }
82 }
83
84 #[must_use]
86 pub fn with_context(mut self, context: impl Into<String>) -> Self {
87 self.context = Some(context.into());
88 self
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum ConfirmationResponse {
96 Approved,
98
99 Denied,
101
102 ApproveAll,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct QuestionRequest {
109 pub question: String,
111
112 pub header: Option<String>,
114
115 pub options: Vec<QuestionOption>,
118
119 pub multi_select: bool,
121}
122
123impl QuestionRequest {
124 #[must_use]
126 pub fn new(question: impl Into<String>) -> Self {
127 Self {
128 question: question.into(),
129 header: None,
130 options: Vec::new(),
131 multi_select: false,
132 }
133 }
134
135 #[must_use]
137 pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
138 Self {
139 question: question.into(),
140 header: None,
141 options,
142 multi_select: false,
143 }
144 }
145
146 #[must_use]
148 pub fn with_header(mut self, header: impl Into<String>) -> Self {
149 self.header = Some(header.into());
150 self
151 }
152
153 #[must_use]
155 pub const fn with_multi_select(mut self) -> Self {
156 self.multi_select = true;
157 self
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct QuestionOption {
164 pub label: String,
166
167 pub description: Option<String>,
169}
170
171impl QuestionOption {
172 #[must_use]
174 pub fn new(label: impl Into<String>) -> Self {
175 Self {
176 label: label.into(),
177 description: None,
178 }
179 }
180
181 #[must_use]
183 pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
184 Self {
185 label: label.into(),
186 description: Some(description.into()),
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct QuestionResponse {
194 pub answer: String,
196
197 pub cancelled: bool,
199}
200
201impl QuestionResponse {
202 #[must_use]
204 pub fn success(answer: impl Into<String>) -> Self {
205 Self {
206 answer: answer.into(),
207 cancelled: false,
208 }
209 }
210
211 #[must_use]
213 pub const fn cancelled() -> Self {
214 Self {
215 answer: String::new(),
216 cancelled: true,
217 }
218 }
219}
220
221pub struct AskUserQuestionTool {
229 question_tx: mpsc::Sender<QuestionRequest>,
231
232 question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
234}
235
236impl AskUserQuestionTool {
237 #[must_use]
239 pub fn new(
240 question_tx: mpsc::Sender<QuestionRequest>,
241 question_rx: mpsc::Receiver<QuestionResponse>,
242 ) -> Self {
243 Self {
244 question_tx,
245 question_rx: tokio::sync::Mutex::new(question_rx),
246 }
247 }
248
249 #[must_use]
256 pub fn with_channels(
257 buffer_size: usize,
258 ) -> (
259 Self,
260 mpsc::Receiver<QuestionRequest>,
261 mpsc::Sender<QuestionResponse>,
262 ) {
263 let (request_tx, request_rx) = mpsc::channel(buffer_size);
264 let (response_tx, response_rx) = mpsc::channel(buffer_size);
265
266 let tool = Self::new(request_tx, response_rx);
267 (tool, request_rx, response_tx)
268 }
269}
270
271#[derive(Debug, Deserialize, Serialize)]
273struct AskUserInput {
274 question: String,
276
277 #[serde(default)]
279 header: Option<String>,
280
281 #[serde(default)]
283 options: Vec<OptionInput>,
284
285 #[serde(default)]
287 multi_select: bool,
288}
289
290#[derive(Debug, Deserialize, Serialize)]
292struct OptionInput {
293 label: String,
295
296 #[serde(default)]
298 description: Option<String>,
299}
300
301#[async_trait]
302impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
303 fn name(&self) -> &'static str {
304 "ask_user"
305 }
306
307 fn description(&self) -> &'static str {
308 "Ask the user a question to get clarification, preferences, or choices. \
309 Use this when you need user input before proceeding. For yes/no confirmations \
310 of dangerous operations, tool confirmation will be shown automatically - \
311 use this tool for open-ended questions or when offering choices."
312 }
313
314 fn input_schema(&self) -> Value {
315 json!({
316 "type": "object",
317 "required": ["question"],
318 "properties": {
319 "question": {
320 "type": "string",
321 "description": "The question to ask the user. Be clear and specific."
322 },
323 "header": {
324 "type": "string",
325 "description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
326 },
327 "options": {
328 "type": "array",
329 "description": "Optional list of choices for multiple-choice questions",
330 "items": {
331 "type": "object",
332 "required": ["label"],
333 "properties": {
334 "label": {
335 "type": "string",
336 "description": "The option text to display"
337 },
338 "description": {
339 "type": "string",
340 "description": "Optional explanation of this option"
341 }
342 }
343 }
344 },
345 "multi_select": {
346 "type": "boolean",
347 "description": "Whether multiple options can be selected (default: false)"
348 }
349 }
350 })
351 }
352
353 fn tier(&self) -> ToolTier {
354 ToolTier::Observe
356 }
357
358 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
359 let input: AskUserInput =
361 serde_json::from_value(input).context("Invalid input for ask_user tool")?;
362
363 let request = QuestionRequest {
365 question: input.question.clone(),
366 header: input.header,
367 options: input
368 .options
369 .into_iter()
370 .map(|o| QuestionOption {
371 label: o.label,
372 description: o.description,
373 })
374 .collect(),
375 multi_select: input.multi_select,
376 };
377
378 self.question_tx
380 .send(request)
381 .await
382 .context("Failed to send question to UI - channel closed")?;
383
384 let response = {
386 let mut rx = self.question_rx.lock().await;
387 rx.recv()
388 .await
389 .context("Failed to receive answer from UI - channel closed")?
390 };
391
392 if response.cancelled {
393 Ok(ToolResult::error(
394 "User cancelled the question without providing an answer.",
395 ))
396 } else {
397 Ok(ToolResult::success(format!(
398 "User answered: {}",
399 response.answer
400 )))
401 }
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::Tool;
409
410 #[test]
411 fn test_confirmation_request_new() {
412 let req =
413 ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
414 assert_eq!(req.tool_name, "write");
415 assert!(req.context.is_none());
416 }
417
418 #[test]
419 fn test_confirmation_request_with_context() {
420 let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
421 .with_context("Agent was fixing a bug");
422 assert!(req.context.is_some());
423 assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
424 }
425
426 #[test]
427 fn test_confirmation_response_serialization() {
428 assert_eq!(
429 serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
430 "\"approved\""
431 );
432 assert_eq!(
433 serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
434 "\"denied\""
435 );
436 assert_eq!(
437 serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
438 "\"approve_all\""
439 );
440 }
441
442 #[test]
443 fn test_question_request_new() {
444 let req = QuestionRequest::new("What color?");
445 assert_eq!(req.question, "What color?");
446 assert!(req.options.is_empty());
447 assert!(!req.multi_select);
448 }
449
450 #[test]
451 fn test_question_request_with_options() {
452 let req = QuestionRequest::with_options(
453 "Which framework?",
454 vec![
455 QuestionOption::new("React"),
456 QuestionOption::with_description("Vue", "Progressive framework"),
457 ],
458 )
459 .with_header("Framework")
460 .with_multi_select();
461
462 assert_eq!(req.options.len(), 2);
463 assert!(req.multi_select);
464 assert_eq!(req.header.unwrap(), "Framework");
465 }
466
467 #[test]
468 fn test_question_response() {
469 let success = QuestionResponse::success("Blue");
470 assert!(!success.cancelled);
471 assert_eq!(success.answer, "Blue");
472
473 let cancelled = QuestionResponse::cancelled();
474 assert!(cancelled.cancelled);
475 }
476
477 #[tokio::test]
478 async fn test_ask_user_tool_creation() {
479 let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
480
481 assert_eq!(Tool::<()>::name(&tool), "ask_user");
483 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
484 }
485
486 #[tokio::test]
487 async fn test_ask_user_tool_execute() {
488 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
489
490 let handle = tokio::spawn(async move {
492 if let Some(request) = request_rx.recv().await {
493 assert_eq!(request.question, "What color?");
494 response_tx
495 .send(QuestionResponse::success("Blue"))
496 .await
497 .unwrap();
498 }
499 });
500
501 let ctx = ToolContext::new(());
502 let result = tool
503 .execute(
504 &ctx,
505 json!({
506 "question": "What color?"
507 }),
508 )
509 .await
510 .unwrap();
511
512 handle.await.unwrap();
513
514 assert!(result.success);
515 assert!(result.output.contains("Blue"));
516 }
517
518 #[tokio::test]
519 async fn test_ask_user_with_options() {
520 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
521
522 let handle = tokio::spawn(async move {
523 if let Some(request) = request_rx.recv().await {
524 assert_eq!(request.options.len(), 2);
525 assert_eq!(request.options[0].label, "Option A");
526 response_tx
527 .send(QuestionResponse::success("Option A"))
528 .await
529 .unwrap();
530 }
531 });
532
533 let ctx = ToolContext::new(());
534 let result = tool
535 .execute(
536 &ctx,
537 json!({
538 "question": "Which option?",
539 "options": [
540 {"label": "Option A", "description": "First choice"},
541 {"label": "Option B", "description": "Second choice"}
542 ]
543 }),
544 )
545 .await
546 .unwrap();
547
548 handle.await.unwrap();
549 assert!(result.success);
550 }
551
552 #[tokio::test]
553 async fn test_ask_user_cancelled() {
554 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
555
556 let handle = tokio::spawn(async move {
557 if request_rx.recv().await.is_some() {
558 response_tx
559 .send(QuestionResponse::cancelled())
560 .await
561 .unwrap();
562 }
563 });
564
565 let ctx = ToolContext::new(());
566 let result = tool
567 .execute(
568 &ctx,
569 json!({
570 "question": "Continue?"
571 }),
572 )
573 .await
574 .unwrap();
575
576 handle.await.unwrap();
577 assert!(!result.success);
578 assert!(result.output.contains("cancelled"));
579 }
580}