1use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use serde::{Deserialize, Serialize};
43use serde_json::{Value, json};
44use std::sync::atomic::{AtomicU64, Ordering};
45use tokio::sync::mpsc;
46use tokio_util::sync::CancellationToken;
47
48static QUESTION_SEQ: AtomicU64 = AtomicU64::new(0);
50
51fn next_request_id() -> String {
53 let seq = QUESTION_SEQ.fetch_add(1, Ordering::Relaxed);
54 format!("ask-{seq}")
55}
56
57async fn await_matching_response(
61 rx: &mut mpsc::Receiver<QuestionResponse>,
62 request_id: &str,
63 cancel_token: &CancellationToken,
64) -> Result<Option<QuestionResponse>> {
65 loop {
66 tokio::select! {
67 biased;
68 () = cancel_token.cancelled() => return Ok(None),
69 received = rx.recv() => {
70 let response = received
71 .context("Failed to receive answer from UI - channel closed")?;
72 if response.request_id.is_empty() || response.request_id == request_id {
77 return Ok(Some(response));
78 }
79 }
80 }
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConfirmationRequest {
87 pub tool_name: String,
89
90 pub description: String,
92
93 pub input_preview: String,
95
96 pub tier: String,
98
99 pub context: Option<String>,
101}
102
103impl ConfirmationRequest {
104 #[must_use]
106 pub fn new(
107 tool_name: impl Into<String>,
108 description: impl Into<String>,
109 input_preview: impl Into<String>,
110 tier: ToolTier,
111 ) -> Self {
112 Self {
113 tool_name: tool_name.into(),
114 description: description.into(),
115 input_preview: input_preview.into(),
116 tier: format!("{tier:?}"),
117 context: None,
118 }
119 }
120
121 #[must_use]
123 pub fn with_context(mut self, context: impl Into<String>) -> Self {
124 self.context = Some(context.into());
125 self
126 }
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
131#[serde(rename_all = "snake_case")]
132pub enum ConfirmationResponse {
133 Approved,
135
136 Denied,
138
139 ApproveAll,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct QuestionRequest {
146 #[serde(default)]
153 pub request_id: String,
154
155 pub question: String,
157
158 pub header: Option<String>,
160
161 pub options: Vec<QuestionOption>,
164
165 pub multi_select: bool,
167}
168
169impl QuestionRequest {
170 #[must_use]
172 pub fn new(question: impl Into<String>) -> Self {
173 Self {
174 request_id: String::new(),
175 question: question.into(),
176 header: None,
177 options: Vec::new(),
178 multi_select: false,
179 }
180 }
181
182 #[must_use]
184 pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
185 Self {
186 request_id: String::new(),
187 question: question.into(),
188 header: None,
189 options,
190 multi_select: false,
191 }
192 }
193
194 #[must_use]
196 pub fn with_header(mut self, header: impl Into<String>) -> Self {
197 self.header = Some(header.into());
198 self
199 }
200
201 #[must_use]
203 pub const fn with_multi_select(mut self) -> Self {
204 self.multi_select = true;
205 self
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct QuestionOption {
212 pub label: String,
214
215 pub description: Option<String>,
217}
218
219impl QuestionOption {
220 #[must_use]
222 pub fn new(label: impl Into<String>) -> Self {
223 Self {
224 label: label.into(),
225 description: None,
226 }
227 }
228
229 #[must_use]
231 pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
232 Self {
233 label: label.into(),
234 description: Some(description.into()),
235 }
236 }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct QuestionResponse {
242 #[serde(default)]
249 pub request_id: String,
250
251 pub answer: String,
253
254 pub cancelled: bool,
256}
257
258impl QuestionResponse {
259 #[must_use]
261 pub fn success(answer: impl Into<String>) -> Self {
262 Self {
263 request_id: String::new(),
264 answer: answer.into(),
265 cancelled: false,
266 }
267 }
268
269 #[must_use]
271 pub const fn cancelled() -> Self {
272 Self {
273 request_id: String::new(),
274 answer: String::new(),
275 cancelled: true,
276 }
277 }
278
279 #[must_use]
281 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
282 self.request_id = request_id.into();
283 self
284 }
285}
286
287pub struct AskUserQuestionTool {
295 question_tx: mpsc::Sender<QuestionRequest>,
297
298 question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
300}
301
302impl AskUserQuestionTool {
303 #[must_use]
305 pub fn new(
306 question_tx: mpsc::Sender<QuestionRequest>,
307 question_rx: mpsc::Receiver<QuestionResponse>,
308 ) -> Self {
309 Self {
310 question_tx,
311 question_rx: tokio::sync::Mutex::new(question_rx),
312 }
313 }
314
315 #[must_use]
322 pub fn with_channels(
323 buffer_size: usize,
324 ) -> (
325 Self,
326 mpsc::Receiver<QuestionRequest>,
327 mpsc::Sender<QuestionResponse>,
328 ) {
329 let (request_tx, request_rx) = mpsc::channel(buffer_size);
330 let (response_tx, response_rx) = mpsc::channel(buffer_size);
331
332 let tool = Self::new(request_tx, response_rx);
333 (tool, request_rx, response_tx)
334 }
335}
336
337#[derive(Debug, Deserialize, Serialize)]
339struct AskUserInput {
340 question: String,
342
343 #[serde(default)]
345 header: Option<String>,
346
347 #[serde(default)]
349 options: Vec<OptionInput>,
350
351 #[serde(default)]
353 multi_select: bool,
354}
355
356#[derive(Debug, Deserialize, Serialize)]
358struct OptionInput {
359 label: String,
361
362 #[serde(default)]
364 description: Option<String>,
365}
366
367impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
368 type Name = PrimitiveToolName;
369
370 fn name(&self) -> PrimitiveToolName {
371 PrimitiveToolName::AskUser
372 }
373
374 fn display_name(&self) -> &'static str {
375 "Ask User"
376 }
377
378 fn description(&self) -> &'static str {
379 "Ask the user a question to get clarification, preferences, or choices. \
380 Use this when you need user input before proceeding. For yes/no confirmations \
381 of dangerous operations, tool confirmation will be shown automatically - \
382 use this tool for open-ended questions or when offering choices."
383 }
384
385 fn input_schema(&self) -> Value {
386 json!({
387 "type": "object",
388 "required": ["question"],
389 "properties": {
390 "question": {
391 "type": "string",
392 "description": "The question to ask the user. Be clear and specific."
393 },
394 "header": {
395 "type": "string",
396 "description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
397 },
398 "options": {
399 "type": "array",
400 "description": "Optional list of choices for multiple-choice questions",
401 "items": {
402 "type": "object",
403 "required": ["label"],
404 "properties": {
405 "label": {
406 "type": "string",
407 "description": "The option text to display"
408 },
409 "description": {
410 "type": "string",
411 "description": "Optional explanation of this option"
412 }
413 }
414 }
415 },
416 "multi_select": {
417 "type": "boolean",
418 "description": "Whether multiple options can be selected (default: false)"
419 }
420 }
421 })
422 }
423
424 fn tier(&self) -> ToolTier {
425 ToolTier::Observe
427 }
428
429 async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
430 let input: AskUserInput =
432 serde_json::from_value(input).context("Invalid input for ask_user tool")?;
433
434 let request_id = next_request_id();
437 let request = QuestionRequest {
438 request_id: request_id.clone(),
439 question: input.question.clone(),
440 header: input.header,
441 options: input
442 .options
443 .into_iter()
444 .map(|o| QuestionOption {
445 label: o.label,
446 description: o.description,
447 })
448 .collect(),
449 multi_select: input.multi_select,
450 };
451
452 let cancel_token = ctx.cancel_token().unwrap_or_default();
455
456 let response = {
460 let mut rx = self.question_rx.lock().await;
461
462 self.question_tx
463 .send(request)
464 .await
465 .context("Failed to send question to UI - channel closed")?;
466
467 await_matching_response(&mut rx, &request_id, &cancel_token).await?
468 };
469
470 match response {
471 Some(response) if response.cancelled => Ok(ToolResult::error(
472 "User cancelled the question without providing an answer.",
473 )),
474 Some(response) => Ok(ToolResult::success(format!(
475 "User answered: {}",
476 response.answer
477 ))),
478 None => Ok(ToolResult::error(
479 "Question cancelled before the user answered.",
480 )),
481 }
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use crate::Tool;
489
490 #[test]
491 fn test_confirmation_request_new() {
492 let req =
493 ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
494 assert_eq!(req.tool_name, "write");
495 assert!(req.context.is_none());
496 }
497
498 #[test]
499 fn test_confirmation_request_with_context() {
500 let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
501 .with_context("Agent was fixing a bug");
502 assert!(req.context.is_some());
503 assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
504 }
505
506 #[test]
507 fn test_confirmation_response_serialization() {
508 assert_eq!(
509 serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
510 "\"approved\""
511 );
512 assert_eq!(
513 serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
514 "\"denied\""
515 );
516 assert_eq!(
517 serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
518 "\"approve_all\""
519 );
520 }
521
522 #[test]
523 fn test_question_request_new() {
524 let req = QuestionRequest::new("What color?");
525 assert_eq!(req.question, "What color?");
526 assert!(req.options.is_empty());
527 assert!(!req.multi_select);
528 }
529
530 #[test]
531 fn test_question_request_with_options() {
532 let req = QuestionRequest::with_options(
533 "Which framework?",
534 vec![
535 QuestionOption::new("React"),
536 QuestionOption::with_description("Vue", "Progressive framework"),
537 ],
538 )
539 .with_header("Framework")
540 .with_multi_select();
541
542 assert_eq!(req.options.len(), 2);
543 assert!(req.multi_select);
544 assert_eq!(req.header.unwrap(), "Framework");
545 }
546
547 #[test]
548 fn test_question_response() {
549 let success = QuestionResponse::success("Blue");
550 assert!(!success.cancelled);
551 assert_eq!(success.answer, "Blue");
552
553 let cancelled = QuestionResponse::cancelled();
554 assert!(cancelled.cancelled);
555 }
556
557 #[tokio::test]
558 async fn test_ask_user_tool_creation() {
559 let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
560
561 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::AskUser);
563 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
564 }
565
566 #[tokio::test]
567 async fn test_ask_user_tool_execute() {
568 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
569
570 let handle = tokio::spawn(async move {
572 if let Some(request) = request_rx.recv().await {
573 assert_eq!(request.question, "What color?");
574 response_tx
575 .send(QuestionResponse::success("Blue"))
576 .await
577 .unwrap();
578 }
579 });
580
581 let ctx = ToolContext::new(());
582 let result = tool
583 .execute(
584 &ctx,
585 json!({
586 "question": "What color?"
587 }),
588 )
589 .await
590 .unwrap();
591
592 handle.await.unwrap();
593
594 assert!(result.success);
595 assert!(result.output.contains("Blue"));
596 }
597
598 #[tokio::test]
599 async fn test_ask_user_with_options() {
600 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
601
602 let handle = tokio::spawn(async move {
603 if let Some(request) = request_rx.recv().await {
604 assert_eq!(request.options.len(), 2);
605 assert_eq!(request.options[0].label, "Option A");
606 response_tx
607 .send(QuestionResponse::success("Option A"))
608 .await
609 .unwrap();
610 }
611 });
612
613 let ctx = ToolContext::new(());
614 let result = tool
615 .execute(
616 &ctx,
617 json!({
618 "question": "Which option?",
619 "options": [
620 {"label": "Option A", "description": "First choice"},
621 {"label": "Option B", "description": "Second choice"}
622 ]
623 }),
624 )
625 .await
626 .unwrap();
627
628 handle.await.unwrap();
629 assert!(result.success);
630 }
631
632 #[tokio::test]
633 async fn test_ask_user_cancelled() {
634 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
635
636 let handle = tokio::spawn(async move {
637 if request_rx.recv().await.is_some() {
638 response_tx
639 .send(QuestionResponse::cancelled())
640 .await
641 .unwrap();
642 }
643 });
644
645 let ctx = ToolContext::new(());
646 let result = tool
647 .execute(
648 &ctx,
649 json!({
650 "question": "Continue?"
651 }),
652 )
653 .await
654 .unwrap();
655
656 handle.await.unwrap();
657 assert!(!result.success);
658 assert!(result.output.contains("cancelled"));
659 }
660
661 #[tokio::test]
662 async fn test_ask_user_discards_stale_response() -> Result<()> {
663 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
664
665 response_tx
668 .send(QuestionResponse::success("STALE").with_request_id("stale-request"))
669 .await
670 .ok()
671 .context("seed stale response")?;
672
673 let responder = response_tx.clone();
674 let handle = tokio::spawn(async move {
675 let request = request_rx.recv().await.context("no question received")?;
676 responder
678 .send(QuestionResponse::success("CORRECT").with_request_id(request.request_id))
679 .await
680 .ok()
681 .context("send live response")?;
682 anyhow::Ok(())
683 });
684
685 let ctx = ToolContext::new(());
686 let result = tool
687 .execute(&ctx, json!({ "question": "Which one?" }))
688 .await?;
689
690 handle.await.context("responder task panicked")??;
691
692 assert!(result.success);
693 assert!(result.output.contains("CORRECT"), "got: {}", result.output);
694 assert!(
695 !result.output.contains("STALE"),
696 "stale answer must be discarded: {}",
697 result.output
698 );
699 Ok(())
700 }
701
702 #[tokio::test]
703 async fn test_ask_user_returns_on_cancel() -> Result<()> {
704 let (tool, _request_rx, _response_tx) = AskUserQuestionTool::with_channels(10);
707
708 let token = CancellationToken::new();
709 token.cancel();
710 let ctx = ToolContext::new(()).with_cancel_token(token);
711
712 let result = tool
713 .execute(&ctx, json!({ "question": "Hang forever?" }))
714 .await?;
715
716 assert!(!result.success);
717 assert!(
718 result.output.to_lowercase().contains("cancel"),
719 "got: {}",
720 result.output
721 );
722 Ok(())
723 }
724}