use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;
use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
use async_trait::async_trait;
#[derive(Clone)]
pub struct AskBridge {
inner: Arc<parking_lot::Mutex<Option<PendingAsk>>>,
ui_attached: Arc<AtomicBool>,
session_id: Arc<parking_lot::Mutex<Option<String>>>,
timeout: Option<Duration>,
}
impl AskBridge {
pub fn new() -> Self {
Self {
inner: Arc::new(parking_lot::Mutex::new(None)),
ui_attached: Arc::new(AtomicBool::new(false)),
session_id: Arc::new(parking_lot::Mutex::new(None)),
timeout: None,
}
}
pub fn with_timeout(timeout: Option<Duration>) -> Self {
Self {
timeout,
..Self::new()
}
}
pub fn attach_with_session(&self, session_id: impl Into<String>) {
let id = session_id.into();
debug_assert!(
!id.is_empty(),
"AskBridge::attach_with_session called with empty session_id"
);
*self.session_id.lock() = Some(id);
self.ui_attached.store(true, Ordering::SeqCst);
}
pub fn is_ui_attached(&self) -> bool {
self.ui_attached.load(Ordering::SeqCst)
}
#[cfg(any(test, debug_assertions))]
pub fn attach(&self) {
self.ui_attached.store(true, Ordering::SeqCst);
}
pub fn session_id(&self) -> Option<String> {
self.session_id.lock().clone()
}
pub fn timeout(&self) -> Option<Duration> {
self.timeout
}
pub fn set(&self, pending: PendingAsk) -> bool {
let mut lock = self.inner.lock();
if lock.is_some() {
return false;
}
*lock = Some(pending);
true
}
pub fn try_take(&self) -> Option<PendingAsk> {
self.inner.lock().take()
}
pub fn has_pending(&self) -> bool {
self.inner.lock().is_some()
}
}
impl Default for AskBridge {
fn default() -> Self {
Self::new()
}
}
pub struct PendingAsk {
pub questions: Vec<Question>,
pub responder: oneshot::Sender<AskResponse>,
pub timeout: Option<Duration>,
pub session_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Question {
pub id: String,
#[serde(default)]
pub label: String,
pub prompt: String,
#[serde(default)]
pub options: Vec<QuestionOption>,
#[serde(default = "default_true")]
pub allow_other: bool,
#[serde(default)]
pub multi_select: bool,
#[serde(default)]
pub recommended: Option<usize>,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionOption {
pub value: String,
pub label: String,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AskResponse {
pub answers: Vec<Answer>,
pub cancelled: bool,
#[serde(default)]
pub timed_out: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Answer {
pub id: String,
pub value: String,
pub label: String,
pub was_custom: bool,
pub index: Option<usize>,
}
pub struct AskTool {
bridge: Arc<AskBridge>,
}
impl AskTool {
pub fn new(bridge: Arc<AskBridge>) -> Self {
Self { bridge }
}
}
impl Clone for AskTool {
fn clone(&self) -> Self {
Self {
bridge: self.bridge.clone(),
}
}
}
#[async_trait]
impl AgentTool for AskTool {
fn name(&self) -> &str {
"ask"
}
fn label(&self) -> &str {
"Ask"
}
fn description(&self) -> &str {
"Ask the user a clarifying question when choices have materially \
different tradeoffs the user must decide. Default to action — pick \
the conservative/standard option and proceed when a reasonable \
default exists; only ask when the user must weigh the tradeoff. Do \
NOT include an 'Other' option — the UI appends 'Other (type your \
own)' automatically. Use 'recommended' (0-indexed) to mark the \
default; a '(Recommended)' suffix is added automatically. Set \
'multiSelect' true to allow multiple selections. Provide 2-5 \
concise options with short labels; put explanatory tradeoffs in \
'description'. Batch related questions in one call via 'questions'."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"questions": {
"type": "array",
"description": "Questions to ask the user",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for this question"
},
"label": {
"type": "string",
"description": "Short contextual label (defaults to the id)"
},
"prompt": {
"type": "string",
"description": "The full question text to display"
},
"options": {
"type": "array",
"description": "Available options (2-5). Do NOT include 'Other' — the UI adds it automatically.",
"default": [],
"items": {
"type": "object",
"properties": {
"value": {
"type": "string",
"description": "The value returned when selected"
},
"label": {
"type": "string",
"description": "Short display label for the option"
},
"description": {
"type": "string",
"description": "Optional explanatory tradeoff shown below the label"
}
},
"required": ["value", "label"]
}
},
"allowOther": {
"type": "boolean",
"description": "Show 'Other (type your own)' (default: true)",
"default": true
},
"multiSelect": {
"type": "boolean",
"description": "Allow multiple selections (default: false)",
"default": false
},
"recommended": {
"type": "number",
"description": "Recommended option index (0-based). Marks the default and is used for timeout auto-selection.",
"minimum": 0
}
},
"required": ["id", "prompt"]
}
},
},
"required": ["questions"]
})
}
async fn execute(
&self,
_tool_call_id: &str,
params: serde_json::Value,
signal: Option<oneshot::Receiver<()>>,
_ctx: &ToolContext,
) -> Result<AgentToolResult, ToolError> {
if !self.bridge.is_ui_attached() {
return Ok(AgentToolResult::error(
"Ask requires interactive TUI mode. \
Not available in --print or RPC mode.",
));
}
let session_id = self.bridge.session_id();
debug_assert!(
session_id.as_deref().is_some_and(|s| !s.is_empty()),
"AskBridge was attached without a non-empty session_id; refusing to run"
);
let questions = parse_questions(¶ms)?;
let timeout = self.bridge.timeout();
let (tx, rx) = oneshot::channel();
if !self.bridge.set(PendingAsk {
questions,
responder: tx,
timeout,
session_id,
}) {
return Ok(AgentToolResult::error("Another ask is already pending"));
}
select_with_abort(rx, signal, &self.bridge).await
}
}
async fn select_with_abort(
rx: oneshot::Receiver<AskResponse>,
signal: Option<oneshot::Receiver<()>>,
bridge: &AskBridge,
) -> Result<AgentToolResult, ToolError> {
let abort = async {
if let Some(sig) = signal {
let _ = sig.await;
} else {
std::future::pending::<()>().await;
}
};
tokio::select! {
response = rx => {
match response {
Ok(resp) => {
if resp.cancelled {
Ok(AgentToolResult::success("User cancelled the question"))
} else {
Ok(AgentToolResult::success(format_answers(
&resp.answers,
resp.timed_out,
)))
}
}
Err(_) => {
Ok(AgentToolResult::success("Question dismissed"))
}
}
}
() = abort => {
bridge.try_take();
Ok(AgentToolResult::success("Question cancelled by user interrupt"))
}
}
}
fn parse_questions(params: &serde_json::Value) -> Result<Vec<Question>, ToolError> {
let questions = params
.get("questions")
.and_then(|v| v.as_array())
.cloned()
.ok_or_else(|| "Missing or invalid 'questions' field".to_string())?;
let questions: Vec<Question> = questions
.into_iter()
.map(|v| serde_json::from_value(v).map_err(|e| e.to_string()))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("Invalid question: {}", e))?;
if questions.is_empty() {
return Err("At least one question is required".to_string());
}
let questions: Vec<Question> = questions
.into_iter()
.map(|mut q| {
if q.label.is_empty() {
q.label = q.id.clone();
}
q
})
.collect();
let mut ids = std::collections::HashSet::new();
for q in &questions {
if !ids.insert(&q.id) {
return Err(format!("Duplicate question id: {}", q.id));
}
}
Ok(questions)
}
pub fn format_answers(answers: &[Answer], timed_out: bool) -> String {
let suffix = if timed_out {
" (auto-selected after timeout)"
} else {
""
};
answers
.iter()
.map(|a| {
let base = if a.was_custom {
format!("{}: \"{}\"", a.id, a.label)
} else if a.value.contains(',') {
let labels: Vec<&str> = a.label.split(", ").collect();
format!("{}: [{}]", a.id, labels.join(", "))
} else {
format!("{}: {}", a.id, a.label)
};
format!("{base}{suffix}")
})
.collect::<Vec<_>>()
.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_questions_valid() {
let json = serde_json::json!({
"questions": [
{
"id": "lang",
"prompt": "Pick a language",
"options": [
{ "value": "rust", "label": "Rust" },
{ "value": "ts", "label": "TypeScript" }
]
}
]
});
let questions = parse_questions(&json).unwrap();
assert_eq!(questions.len(), 1);
assert_eq!(questions[0].id, "lang");
assert_eq!(questions[0].label, "lang"); assert_eq!(questions[0].options.len(), 2);
assert!(questions[0].allow_other); assert!(!questions[0].multi_select); }
#[test]
fn test_parse_questions_with_label() {
let json = serde_json::json!({
"questions": [
{
"id": "lang",
"label": "Language",
"prompt": "Pick a language"
}
]
});
let questions = parse_questions(&json).unwrap();
assert_eq!(questions[0].label, "Language");
}
#[test]
fn test_parse_questions_empty_options() {
let json = serde_json::json!({
"questions": [
{
"id": "name",
"prompt": "What's your project name?",
"allowOther": true
}
]
});
let questions = parse_questions(&json).unwrap();
assert_eq!(questions[0].options.len(), 0);
assert!(questions[0].allow_other);
}
#[test]
fn test_parse_questions_missing_questions() {
let json = serde_json::json!({});
let err = parse_questions(&json).unwrap_err();
assert!(err.contains("questions"));
}
#[test]
fn test_parse_questions_empty_array() {
let json = serde_json::json!({ "questions": [] });
let err = parse_questions(&json).unwrap_err();
assert!(err.contains("one question"));
}
#[test]
fn test_parse_questions_duplicate_ids() {
let json = serde_json::json!({
"questions": [
{ "id": "a", "prompt": "Q1" },
{ "id": "a", "prompt": "Q2" }
]
});
let err = parse_questions(&json).unwrap_err();
assert!(err.contains("Duplicate"));
}
#[test]
fn test_format_answers_single() {
let answers = vec![Answer {
id: "lang".into(),
value: "rust".into(),
label: "Rust".into(),
was_custom: false,
index: Some(1),
}];
let text = format_answers(&answers, false);
assert_eq!(text, "lang: Rust");
}
#[test]
fn test_format_answers_custom() {
let answers = vec![Answer {
id: "name".into(),
value: "myproj".into(),
label: "myproj".into(),
was_custom: true,
index: None,
}];
let text = format_answers(&answers, false);
assert_eq!(text, "name: \"myproj\"");
}
#[test]
fn test_format_answers_multi() {
let answers = vec![Answer {
id: "lang".into(),
value: "rust, go".into(), label: "Rust, Go".into(),
was_custom: false,
index: None,
}];
let text = format_answers(&answers, false);
assert_eq!(text, "lang: [Rust, Go]");
}
#[test]
fn test_format_answers_timed_out() {
let answers = vec![Answer {
id: "auth".into(),
value: "oauth".into(),
label: "OAuth2".into(),
was_custom: false,
index: Some(2),
}];
let text = format_answers(&answers, true);
assert_eq!(text, "auth: OAuth2 (auto-selected after timeout)");
}
#[test]
fn test_bridge_set_take() {
let bridge = AskBridge::new();
assert!(!bridge.has_pending());
let (tx, _rx) = oneshot::channel();
let pending = PendingAsk {
questions: vec![],
responder: tx,
timeout: None,
session_id: None,
};
assert!(bridge.set(pending));
assert!(bridge.has_pending());
let taken = bridge.try_take();
assert!(taken.is_some());
assert!(!bridge.has_pending());
assert!(bridge.try_take().is_none());
}
#[test]
fn test_bridge_set_idempotent() {
let bridge = AskBridge::new();
let (tx1, _rx1) = oneshot::channel();
let (tx2, _rx2) = oneshot::channel();
bridge.set(PendingAsk {
questions: vec![],
responder: tx1,
timeout: None,
session_id: None,
});
assert!(!bridge.set(PendingAsk {
questions: vec![],
responder: tx2,
timeout: None,
session_id: None,
}));
}
#[test]
fn test_ui_attached_flag() {
let bridge = AskBridge::new();
assert!(!bridge.is_ui_attached());
bridge.attach();
assert!(bridge.is_ui_attached());
}
#[test]
fn test_bridge_with_timeout() {
let bridge = AskBridge::with_timeout(Some(Duration::from_secs(30)));
assert_eq!(bridge.timeout(), Some(Duration::from_secs(30)));
assert!(!bridge.is_ui_attached());
let no_timeout = AskBridge::new();
assert_eq!(no_timeout.timeout(), None);
}
#[test]
fn test_question_deserializes_without_recommended() {
let json = serde_json::json!({
"id": "test",
"prompt": "Test question?",
"options": [{"value": "a", "label": "A"}]
});
let q: Question = serde_json::from_value(json).unwrap();
assert_eq!(q.recommended, None);
}
#[test]
fn test_question_deserializes_with_recommended() {
let json = serde_json::json!({
"id": "test",
"prompt": "Test question?",
"options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}],
"recommended": 1
});
let q: Question = serde_json::from_value(json).unwrap();
assert_eq!(q.recommended, Some(1));
}
#[test]
fn test_tool_name_is_ask() {
let bridge = Arc::new(AskBridge::new());
let tool = AskTool::new(bridge);
assert_eq!(tool.name(), "ask");
assert_eq!(tool.label(), "Ask");
}
#[test]
fn test_attach_with_session_stores_id() {
let bridge = AskBridge::new();
assert!(!bridge.is_ui_attached());
assert_eq!(bridge.session_id(), None);
bridge.attach_with_session("tui");
assert!(bridge.is_ui_attached());
assert_eq!(bridge.session_id().as_deref(), Some("tui"));
}
#[test]
fn test_format_answers_multi_with_comma_label() {
let answers = vec![Answer {
id: "tags".into(),
value: "a,b".into(),
label: "A, B".into(),
was_custom: false,
index: None,
}];
let text = format_answers(&answers, false);
assert_eq!(text, "tags: [A, B]");
}
#[test]
fn test_format_answers_cancelled_marker() {
let answers = vec![Answer {
id: "q1".into(),
value: String::new(),
label: String::new(),
was_custom: false,
index: None,
}];
let text = format_answers(&answers, false);
assert_eq!(text, "q1: ");
}
}