use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
static QUESTION_SEQ: AtomicU64 = AtomicU64::new(0);
fn next_request_id() -> String {
let seq = QUESTION_SEQ.fetch_add(1, Ordering::Relaxed);
format!("ask-{seq}")
}
async fn await_matching_response(
rx: &mut mpsc::Receiver<QuestionResponse>,
request_id: &str,
cancel_token: &CancellationToken,
) -> Result<Option<QuestionResponse>> {
loop {
tokio::select! {
biased;
() = cancel_token.cancelled() => return Ok(None),
received = rx.recv() => {
let response = received
.context("Failed to receive answer from UI - channel closed")?;
if response.request_id.is_empty() || response.request_id == request_id {
return Ok(Some(response));
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfirmationRequest {
pub tool_name: String,
pub description: String,
pub input_preview: String,
pub tier: String,
pub context: Option<String>,
}
impl ConfirmationRequest {
#[must_use]
pub fn new(
tool_name: impl Into<String>,
description: impl Into<String>,
input_preview: impl Into<String>,
tier: ToolTier,
) -> Self {
Self {
tool_name: tool_name.into(),
description: description.into(),
input_preview: input_preview.into(),
tier: format!("{tier:?}"),
context: None,
}
}
#[must_use]
pub fn with_context(mut self, context: impl Into<String>) -> Self {
self.context = Some(context.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConfirmationResponse {
Approved,
Denied,
ApproveAll,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionRequest {
#[serde(default)]
pub request_id: String,
pub question: String,
pub header: Option<String>,
pub options: Vec<QuestionOption>,
pub multi_select: bool,
}
impl QuestionRequest {
#[must_use]
pub fn new(question: impl Into<String>) -> Self {
Self {
request_id: String::new(),
question: question.into(),
header: None,
options: Vec::new(),
multi_select: false,
}
}
#[must_use]
pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
Self {
request_id: String::new(),
question: question.into(),
header: None,
options,
multi_select: false,
}
}
#[must_use]
pub fn with_header(mut self, header: impl Into<String>) -> Self {
self.header = Some(header.into());
self
}
#[must_use]
pub const fn with_multi_select(mut self) -> Self {
self.multi_select = true;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionOption {
pub label: String,
pub description: Option<String>,
}
impl QuestionOption {
#[must_use]
pub fn new(label: impl Into<String>) -> Self {
Self {
label: label.into(),
description: None,
}
}
#[must_use]
pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
Self {
label: label.into(),
description: Some(description.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionResponse {
#[serde(default)]
pub request_id: String,
pub answer: String,
pub cancelled: bool,
}
impl QuestionResponse {
#[must_use]
pub fn success(answer: impl Into<String>) -> Self {
Self {
request_id: String::new(),
answer: answer.into(),
cancelled: false,
}
}
#[must_use]
pub const fn cancelled() -> Self {
Self {
request_id: String::new(),
answer: String::new(),
cancelled: true,
}
}
#[must_use]
pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
self.request_id = request_id.into();
self
}
}
pub struct AskUserQuestionTool {
question_tx: mpsc::Sender<QuestionRequest>,
question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
}
impl AskUserQuestionTool {
#[must_use]
pub fn new(
question_tx: mpsc::Sender<QuestionRequest>,
question_rx: mpsc::Receiver<QuestionResponse>,
) -> Self {
Self {
question_tx,
question_rx: tokio::sync::Mutex::new(question_rx),
}
}
#[must_use]
pub fn with_channels(
buffer_size: usize,
) -> (
Self,
mpsc::Receiver<QuestionRequest>,
mpsc::Sender<QuestionResponse>,
) {
let (request_tx, request_rx) = mpsc::channel(buffer_size);
let (response_tx, response_rx) = mpsc::channel(buffer_size);
let tool = Self::new(request_tx, response_rx);
(tool, request_rx, response_tx)
}
}
#[derive(Debug, Deserialize, Serialize)]
struct AskUserInput {
question: String,
#[serde(default)]
header: Option<String>,
#[serde(default)]
options: Vec<OptionInput>,
#[serde(default)]
multi_select: bool,
}
#[derive(Debug, Deserialize, Serialize)]
struct OptionInput {
label: String,
#[serde(default)]
description: Option<String>,
}
impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
type Name = PrimitiveToolName;
fn name(&self) -> PrimitiveToolName {
PrimitiveToolName::AskUser
}
fn display_name(&self) -> &'static str {
"Ask User"
}
fn description(&self) -> &'static str {
"Ask the user a question to get clarification, preferences, or choices. \
Use this when you need user input before proceeding. For yes/no confirmations \
of dangerous operations, tool confirmation will be shown automatically - \
use this tool for open-ended questions or when offering choices."
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"required": ["question"],
"properties": {
"question": {
"type": "string",
"description": "The question to ask the user. Be clear and specific."
},
"header": {
"type": "string",
"description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
},
"options": {
"type": "array",
"description": "Optional list of choices for multiple-choice questions",
"items": {
"type": "object",
"required": ["label"],
"properties": {
"label": {
"type": "string",
"description": "The option text to display"
},
"description": {
"type": "string",
"description": "Optional explanation of this option"
}
}
}
},
"multi_select": {
"type": "boolean",
"description": "Whether multiple options can be selected (default: false)"
}
}
})
}
fn tier(&self) -> ToolTier {
ToolTier::Observe
}
async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
let input: AskUserInput =
serde_json::from_value(input).context("Invalid input for ask_user tool")?;
let request_id = next_request_id();
let request = QuestionRequest {
request_id: request_id.clone(),
question: input.question.clone(),
header: input.header,
options: input
.options
.into_iter()
.map(|o| QuestionOption {
label: o.label,
description: o.description,
})
.collect(),
multi_select: input.multi_select,
};
let cancel_token = ctx.cancel_token().unwrap_or_default();
let response = {
let mut rx = self.question_rx.lock().await;
self.question_tx
.send(request)
.await
.context("Failed to send question to UI - channel closed")?;
await_matching_response(&mut rx, &request_id, &cancel_token).await?
};
match response {
Some(response) if response.cancelled => Ok(ToolResult::error(
"User cancelled the question without providing an answer.",
)),
Some(response) => Ok(ToolResult::success(format!(
"User answered: {}",
response.answer
))),
None => Ok(ToolResult::error(
"Question cancelled before the user answered.",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Tool;
#[test]
fn test_confirmation_request_new() {
let req =
ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
assert_eq!(req.tool_name, "write");
assert!(req.context.is_none());
}
#[test]
fn test_confirmation_request_with_context() {
let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
.with_context("Agent was fixing a bug");
assert!(req.context.is_some());
assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
}
#[test]
fn test_confirmation_response_serialization() {
assert_eq!(
serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
"\"approved\""
);
assert_eq!(
serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
"\"denied\""
);
assert_eq!(
serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
"\"approve_all\""
);
}
#[test]
fn test_question_request_new() {
let req = QuestionRequest::new("What color?");
assert_eq!(req.question, "What color?");
assert!(req.options.is_empty());
assert!(!req.multi_select);
}
#[test]
fn test_question_request_with_options() {
let req = QuestionRequest::with_options(
"Which framework?",
vec![
QuestionOption::new("React"),
QuestionOption::with_description("Vue", "Progressive framework"),
],
)
.with_header("Framework")
.with_multi_select();
assert_eq!(req.options.len(), 2);
assert!(req.multi_select);
assert_eq!(req.header.unwrap(), "Framework");
}
#[test]
fn test_question_response() {
let success = QuestionResponse::success("Blue");
assert!(!success.cancelled);
assert_eq!(success.answer, "Blue");
let cancelled = QuestionResponse::cancelled();
assert!(cancelled.cancelled);
}
#[tokio::test]
async fn test_ask_user_tool_creation() {
let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::AskUser);
assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
}
#[tokio::test]
async fn test_ask_user_tool_execute() {
let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
let handle = tokio::spawn(async move {
if let Some(request) = request_rx.recv().await {
assert_eq!(request.question, "What color?");
response_tx
.send(QuestionResponse::success("Blue"))
.await
.unwrap();
}
});
let ctx = ToolContext::new(());
let result = tool
.execute(
&ctx,
json!({
"question": "What color?"
}),
)
.await
.unwrap();
handle.await.unwrap();
assert!(result.success);
assert!(result.output.contains("Blue"));
}
#[tokio::test]
async fn test_ask_user_with_options() {
let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
let handle = tokio::spawn(async move {
if let Some(request) = request_rx.recv().await {
assert_eq!(request.options.len(), 2);
assert_eq!(request.options[0].label, "Option A");
response_tx
.send(QuestionResponse::success("Option A"))
.await
.unwrap();
}
});
let ctx = ToolContext::new(());
let result = tool
.execute(
&ctx,
json!({
"question": "Which option?",
"options": [
{"label": "Option A", "description": "First choice"},
{"label": "Option B", "description": "Second choice"}
]
}),
)
.await
.unwrap();
handle.await.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn test_ask_user_cancelled() {
let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
let handle = tokio::spawn(async move {
if request_rx.recv().await.is_some() {
response_tx
.send(QuestionResponse::cancelled())
.await
.unwrap();
}
});
let ctx = ToolContext::new(());
let result = tool
.execute(
&ctx,
json!({
"question": "Continue?"
}),
)
.await
.unwrap();
handle.await.unwrap();
assert!(!result.success);
assert!(result.output.contains("cancelled"));
}
#[tokio::test]
async fn test_ask_user_discards_stale_response() -> Result<()> {
let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
response_tx
.send(QuestionResponse::success("STALE").with_request_id("stale-request"))
.await
.ok()
.context("seed stale response")?;
let responder = response_tx.clone();
let handle = tokio::spawn(async move {
let request = request_rx.recv().await.context("no question received")?;
responder
.send(QuestionResponse::success("CORRECT").with_request_id(request.request_id))
.await
.ok()
.context("send live response")?;
anyhow::Ok(())
});
let ctx = ToolContext::new(());
let result = tool
.execute(&ctx, json!({ "question": "Which one?" }))
.await?;
handle.await.context("responder task panicked")??;
assert!(result.success);
assert!(result.output.contains("CORRECT"), "got: {}", result.output);
assert!(
!result.output.contains("STALE"),
"stale answer must be discarded: {}",
result.output
);
Ok(())
}
#[tokio::test]
async fn test_ask_user_returns_on_cancel() -> Result<()> {
let (tool, _request_rx, _response_tx) = AskUserQuestionTool::with_channels(10);
let token = CancellationToken::new();
token.cancel();
let ctx = ToolContext::new(()).with_cancel_token(token);
let result = tool
.execute(&ctx, json!({ "question": "Hang forever?" }))
.await?;
assert!(!result.success);
assert!(
result.output.to_lowercase().contains("cancel"),
"got: {}",
result.output
);
Ok(())
}
}