use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::oneshot;
use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
#[derive(Clone)]
pub struct QuestionnaireBridge {
inner: Arc<parking_lot::Mutex<Option<PendingQuestionnaire>>>,
}
impl QuestionnaireBridge {
pub fn new() -> Self {
Self { inner: Arc::new(parking_lot::Mutex::new(None)) }
}
pub fn set(&self, pending: PendingQuestionnaire) -> bool {
let mut lock = self.inner.lock();
if lock.is_some() {
return false;
}
*lock = Some(pending);
true
}
pub fn try_take(&self) -> Option<PendingQuestionnaire> {
self.inner.lock().take()
}
pub fn has_pending(&self) -> bool {
self.inner.lock().is_some()
}
}
impl Default for QuestionnaireBridge {
fn default() -> Self {
Self::new()
}
}
pub struct PendingQuestionnaire {
pub questions: Vec<Question>,
pub responder: oneshot::Sender<QuestionnaireResponse>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Question {
pub id: String,
#[serde(default)]
pub label: String,
pub prompt: String,
#[serde(default)]
pub options: Vec<QuestionOption>,
#[serde(default = "default_true")]
pub allow_other: bool,
#[serde(default)]
pub multi_select: bool,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionOption {
pub value: String,
pub label: String,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionnaireResponse {
pub answers: Vec<Answer>,
pub cancelled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Answer {
pub id: String,
pub value: String,
pub label: String,
pub was_custom: bool,
pub index: Option<usize>,
}
pub struct QuestionnaireTool {
bridge: Arc<QuestionnaireBridge>,
}
impl QuestionnaireTool {
pub fn new(bridge: Arc<QuestionnaireBridge>) -> Self {
Self { bridge }
}
}
impl Clone for QuestionnaireTool {
fn clone(&self) -> Self {
Self { bridge: self.bridge.clone() }
}
}
#[async_trait]
impl AgentTool for QuestionnaireTool {
fn name(&self) -> &str {
"questionnaire"
}
fn label(&self) -> &str {
"Questionnaire"
}
fn description(&self) -> &str {
"Ask the user one or more questions. Use for clarifying requirements, \
getting preferences, or confirming decisions. For single questions, \
shows a simple option list. For multiple questions, shows a tab-based \
interface."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"questions": {
"type": "array",
"description": "Questions to ask the user",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for this question"
},
"label": {
"type": "string",
"description": "Short contextual label for tab bar (defaults to Q1, Q2)"
},
"prompt": {
"type": "string",
"description": "The full question text to display"
},
"options": {
"type": "array",
"description": "Available options to choose from. Can be empty for free-text questions.",
"default": [],
"items": {
"type": "object",
"properties": {
"value": {
"type": "string",
"description": "The value returned when selected"
},
"label": {
"type": "string",
"description": "Display label for the option"
},
"description": {
"type": "string",
"description": "Optional description shown below label"
}
},
"required": ["value", "label"]
}
},
"allowOther": {
"type": "boolean",
"description": "Allow 'Type something' option (default: true)",
"default": true
},
"multiSelect": {
"type": "boolean",
"description": "Allow multiple selections (default: false)",
"default": false
}
},
"required": ["id", "prompt"]
}
}
},
"required": ["questions"]
})
}
async fn execute(
&self,
_tool_call_id: &str,
params: serde_json::Value,
signal: Option<oneshot::Receiver<()>>,
_ctx: &ToolContext,
) -> Result<AgentToolResult, ToolError> {
let questions = parse_questions(¶ms)?;
let (tx, rx) = oneshot::channel();
if !self.bridge.set(PendingQuestionnaire {
questions,
responder: tx,
}) {
return Ok(AgentToolResult::error(
"Another questionnaire is already pending",
));
}
let result = select_with_abort(rx, signal, &self.bridge).await;
result
}
}
async fn select_with_abort(
rx: oneshot::Receiver<QuestionnaireResponse>,
signal: Option<oneshot::Receiver<()>>,
bridge: &QuestionnaireBridge,
) -> Result<AgentToolResult, ToolError> {
let abort = async {
if let Some(sig) = signal {
let _ = sig.await;
} else {
std::future::pending::<()>().await;
}
};
tokio::select! {
response = rx => {
match response {
Ok(resp) => {
if resp.cancelled {
Ok(AgentToolResult::success("User cancelled the questionnaire"))
} else {
Ok(AgentToolResult::success(format_answers(&resp.answers)))
}
}
Err(_) => {
Ok(AgentToolResult::success("Questionnaire dismissed"))
}
}
}
() = abort => {
bridge.try_take();
Ok(AgentToolResult::success("Questionnaire cancelled by user interrupt"))
}
}
}
fn parse_questions(params: &serde_json::Value) -> Result<Vec<Question>, ToolError> {
let questions = params
.get("questions")
.and_then(|v| v.as_array())
.map(|arr| arr.clone())
.ok_or_else(|| "Missing or invalid 'questions' field".to_string())?;
let questions: Vec<Question> = questions
.into_iter()
.map(|v| serde_json::from_value(v).map_err(|e| e.to_string()))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("Invalid question: {}", e))?;
if questions.is_empty() {
return Err("At least one question is required".to_string());
}
let questions: Vec<Question> = questions
.into_iter()
.enumerate()
.map(|(i, mut q)| {
if q.label.is_empty() {
q.label = format!("Q{}", i + 1);
}
q
})
.collect();
let mut ids = std::collections::HashSet::new();
for q in &questions {
if !ids.insert(&q.id) {
return Err(format!("Duplicate question id: {}", q.id));
}
}
Ok(questions)
}
fn format_answers(answers: &[Answer]) -> String {
answers
.iter()
.map(|a| {
if a.was_custom {
format!("{}: user wrote: {}", a.id, a.label)
} else if let Some(idx) = a.index {
format!("{}: user selected: {}. {}", a.id, idx, a.label)
} else {
format!("{}: user selected: {}", a.id, a.label)
}
})
.collect::<Vec<_>>()
.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_questions_valid() {
let json = serde_json::json!({
"questions": [
{
"id": "lang",
"prompt": "Pick a language",
"options": [
{ "value": "rust", "label": "Rust" },
{ "value": "ts", "label": "TypeScript" }
]
}
]
});
let questions = parse_questions(&json).unwrap();
assert_eq!(questions.len(), 1);
assert_eq!(questions[0].id, "lang");
assert_eq!(questions[0].label, "Q1"); assert_eq!(questions[0].options.len(), 2);
assert!(questions[0].allow_other); assert!(!questions[0].multi_select); }
#[test]
fn test_parse_questions_with_label() {
let json = serde_json::json!({
"questions": [
{
"id": "lang",
"label": "Language",
"prompt": "Pick a language"
}
]
});
let questions = parse_questions(&json).unwrap();
assert_eq!(questions[0].label, "Language");
}
#[test]
fn test_parse_questions_empty_options() {
let json = serde_json::json!({
"questions": [
{
"id": "name",
"prompt": "What's your project name?",
"allowOther": true
}
]
});
let questions = parse_questions(&json).unwrap();
assert_eq!(questions[0].options.len(), 0);
assert!(questions[0].allow_other);
}
#[test]
fn test_parse_questions_missing_questions() {
let json = serde_json::json!({});
let err = parse_questions(&json).unwrap_err();
assert!(err.contains("questions"));
}
#[test]
fn test_parse_questions_empty_array() {
let json = serde_json::json!({ "questions": [] });
let err = parse_questions(&json).unwrap_err();
assert!(err.contains("one question"));
}
#[test]
fn test_parse_questions_duplicate_ids() {
let json = serde_json::json!({
"questions": [
{ "id": "a", "prompt": "Q1" },
{ "id": "a", "prompt": "Q2" }
]
});
let err = parse_questions(&json).unwrap_err();
assert!(err.contains("Duplicate"));
}
#[test]
fn test_format_answers_selected() {
let answers = vec![
Answer {
id: "lang".into(),
value: "rust".into(),
label: "Rust".into(),
was_custom: false,
index: Some(1),
},
];
let text = format_answers(&answers);
assert_eq!(text, "lang: user selected: 1. Rust");
}
#[test]
fn test_format_answers_custom() {
let answers = vec![
Answer {
id: "name".into(),
value: "myproj".into(),
label: "myproj".into(),
was_custom: true,
index: None,
},
];
let text = format_answers(&answers);
assert_eq!(text, "name: user wrote: myproj");
}
#[test]
fn test_format_answers_multi() {
let answers = vec![
Answer {
id: "lang".into(),
value: "rust".into(),
label: "Rust".into(),
was_custom: false,
index: Some(1),
},
Answer {
id: "db".into(),
value: "pg".into(),
label: "PostgreSQL".into(),
was_custom: false,
index: Some(2),
},
Answer {
id: "auth".into(),
value: "jwt".into(),
label: "jwt".into(),
was_custom: true,
index: None,
},
];
let text = format_answers(&answers);
assert_eq!(text, "lang: user selected: 1. Rust\ndb: user selected: 2. PostgreSQL\nauth: user wrote: jwt");
}
#[test]
fn test_bridge_set_take() {
let bridge = QuestionnaireBridge::new();
assert!(!bridge.has_pending());
let (tx, _rx) = oneshot::channel();
let pending = PendingQuestionnaire {
questions: vec![],
responder: tx,
};
assert!(bridge.set(pending));
assert!(bridge.has_pending());
let taken = bridge.try_take();
assert!(taken.is_some());
assert!(!bridge.has_pending());
assert!(bridge.try_take().is_none());
}
#[test]
fn test_bridge_set_idempotent() {
let bridge = QuestionnaireBridge::new();
let (tx1, _rx1) = oneshot::channel();
let (tx2, _rx2) = oneshot::channel();
bridge.set(PendingQuestionnaire { questions: vec![], responder: tx1 });
assert!(!bridge.set(PendingQuestionnaire { questions: vec![], responder: tx2 }));
}
}