1use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tokio::sync::oneshot;
18
19use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
20
21#[derive(Clone)]
25pub struct QuestionnaireBridge {
26 inner: Arc<parking_lot::Mutex<Option<PendingQuestionnaire>>>,
27}
28
29impl QuestionnaireBridge {
30 pub fn new() -> Self {
32 Self {
33 inner: Arc::new(parking_lot::Mutex::new(None)),
34 }
35 }
36
37 pub fn set(&self, pending: PendingQuestionnaire) -> bool {
41 let mut lock = self.inner.lock();
42 if lock.is_some() {
43 return false;
44 }
45 *lock = Some(pending);
46 true
47 }
48
49 pub fn try_take(&self) -> Option<PendingQuestionnaire> {
52 self.inner.lock().take()
53 }
54
55 pub fn has_pending(&self) -> bool {
57 self.inner.lock().is_some()
58 }
59}
60
61impl Default for QuestionnaireBridge {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67pub struct PendingQuestionnaire {
71 pub questions: Vec<Question>,
73 pub responder: oneshot::Sender<QuestionnaireResponse>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Question {
81 pub id: String,
83 #[serde(default)]
85 pub label: String,
86 pub prompt: String,
88 #[serde(default)]
90 pub options: Vec<QuestionOption>,
91 #[serde(default = "default_true")]
93 pub allow_other: bool,
94 #[serde(default)]
96 pub multi_select: bool,
97}
98
99fn default_true() -> bool {
100 true
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct QuestionOption {
106 pub value: String,
108 pub label: String,
110 pub description: Option<String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct QuestionnaireResponse {
117 pub answers: Vec<Answer>,
119 pub cancelled: bool,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct Answer {
126 pub id: String,
128 pub value: String,
130 pub label: String,
132 pub was_custom: bool,
134 pub index: Option<usize>,
136}
137
138pub struct QuestionnaireTool {
142 bridge: Arc<QuestionnaireBridge>,
143}
144
145impl QuestionnaireTool {
146 pub fn new(bridge: Arc<QuestionnaireBridge>) -> Self {
148 Self { bridge }
149 }
150}
151
152impl Clone for QuestionnaireTool {
155 fn clone(&self) -> Self {
156 Self {
157 bridge: self.bridge.clone(),
158 }
159 }
160}
161
162#[async_trait]
163impl AgentTool for QuestionnaireTool {
164 fn name(&self) -> &str {
165 "questionnaire"
166 }
167
168 fn label(&self) -> &str {
169 "Questionnaire"
170 }
171
172 fn description(&self) -> &str {
173 "Ask the user one or more questions. Use for clarifying requirements, \
174 getting preferences, or confirming decisions. For single questions, \
175 shows a simple option list. For multiple questions, shows a tab-based \
176 interface."
177 }
178
179 fn parameters_schema(&self) -> serde_json::Value {
180 serde_json::json!({
181 "type": "object",
182 "properties": {
183 "questions": {
184 "type": "array",
185 "description": "Questions to ask the user",
186 "items": {
187 "type": "object",
188 "properties": {
189 "id": {
190 "type": "string",
191 "description": "Unique identifier for this question"
192 },
193 "label": {
194 "type": "string",
195 "description": "Short contextual label for tab bar (defaults to Q1, Q2)"
196 },
197 "prompt": {
198 "type": "string",
199 "description": "The full question text to display"
200 },
201 "options": {
202 "type": "array",
203 "description": "Available options to choose from. Can be empty for free-text questions.",
204 "default": [],
205 "items": {
206 "type": "object",
207 "properties": {
208 "value": {
209 "type": "string",
210 "description": "The value returned when selected"
211 },
212 "label": {
213 "type": "string",
214 "description": "Display label for the option"
215 },
216 "description": {
217 "type": "string",
218 "description": "Optional description shown below label"
219 }
220 },
221 "required": ["value", "label"]
222 }
223 },
224 "allowOther": {
225 "type": "boolean",
226 "description": "Allow 'Type something' option (default: true)",
227 "default": true
228 },
229 "multiSelect": {
230 "type": "boolean",
231 "description": "Allow multiple selections (default: false)",
232 "default": false
233 }
234 },
235 "required": ["id", "prompt"]
236 }
237 }
238 },
239 "required": ["questions"]
240 })
241 }
242
243 async fn execute(
244 &self,
245 _tool_call_id: &str,
246 params: serde_json::Value,
247 signal: Option<oneshot::Receiver<()>>,
248 _ctx: &ToolContext,
249 ) -> Result<AgentToolResult, ToolError> {
250 let questions = parse_questions(¶ms)?;
252
253 let (tx, rx) = oneshot::channel();
255
256 if !self.bridge.set(PendingQuestionnaire {
258 questions,
259 responder: tx,
260 }) {
261 return Ok(AgentToolResult::error(
262 "Another questionnaire is already pending",
263 ));
264 }
265
266 let result = select_with_abort(rx, signal, &self.bridge).await;
268
269 result
271 }
272}
273
274async fn select_with_abort(
276 rx: oneshot::Receiver<QuestionnaireResponse>,
277 signal: Option<oneshot::Receiver<()>>,
278 bridge: &QuestionnaireBridge,
279) -> Result<AgentToolResult, ToolError> {
280 let abort = async {
282 if let Some(sig) = signal {
283 let _ = sig.await;
284 } else {
285 std::future::pending::<()>().await;
286 }
287 };
288
289 tokio::select! {
290 response = rx => {
291 match response {
292 Ok(resp) => {
293 if resp.cancelled {
294 Ok(AgentToolResult::success("User cancelled the questionnaire"))
295 } else {
296 Ok(AgentToolResult::success(format_answers(&resp.answers)))
297 }
298 }
299 Err(_) => {
300 Ok(AgentToolResult::success("Questionnaire dismissed"))
302 }
303 }
304 }
305 () = abort => {
306 bridge.try_take();
308 Ok(AgentToolResult::success("Questionnaire cancelled by user interrupt"))
309 }
310 }
311}
312
313fn parse_questions(params: &serde_json::Value) -> Result<Vec<Question>, ToolError> {
315 let questions = params
316 .get("questions")
317 .and_then(|v| v.as_array())
318 .cloned()
319 .ok_or_else(|| "Missing or invalid 'questions' field".to_string())?;
320
321 let questions: Vec<Question> = questions
322 .into_iter()
323 .map(|v| serde_json::from_value(v).map_err(|e| e.to_string()))
324 .collect::<Result<Vec<_>, _>>()
325 .map_err(|e| format!("Invalid question: {}", e))?;
326
327 if questions.is_empty() {
328 return Err("At least one question is required".to_string());
329 }
330
331 let questions: Vec<Question> = questions
333 .into_iter()
334 .enumerate()
335 .map(|(i, mut q)| {
336 if q.label.is_empty() {
337 q.label = format!("Q{}", i + 1);
338 }
339 q
340 })
341 .collect();
342
343 let mut ids = std::collections::HashSet::new();
345 for q in &questions {
346 if !ids.insert(&q.id) {
347 return Err(format!("Duplicate question id: {}", q.id));
348 }
349 }
350
351 Ok(questions)
352}
353
354fn format_answers(answers: &[Answer]) -> String {
356 answers
357 .iter()
358 .map(|a| {
359 if a.was_custom {
360 format!("{}: user wrote: {}", a.id, a.label)
361 } else if let Some(idx) = a.index {
362 format!("{}: user selected: {}. {}", a.id, idx, a.label)
363 } else {
364 format!("{}: user selected: {}", a.id, a.label)
365 }
366 })
367 .collect::<Vec<_>>()
368 .join("\n")
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_parse_questions_valid() {
377 let json = serde_json::json!({
378 "questions": [
379 {
380 "id": "lang",
381 "prompt": "Pick a language",
382 "options": [
383 { "value": "rust", "label": "Rust" },
384 { "value": "ts", "label": "TypeScript" }
385 ]
386 }
387 ]
388 });
389 let questions = parse_questions(&json).unwrap();
390 assert_eq!(questions.len(), 1);
391 assert_eq!(questions[0].id, "lang");
392 assert_eq!(questions[0].label, "Q1"); assert_eq!(questions[0].options.len(), 2);
394 assert!(questions[0].allow_other); assert!(!questions[0].multi_select); }
397
398 #[test]
399 fn test_parse_questions_with_label() {
400 let json = serde_json::json!({
401 "questions": [
402 {
403 "id": "lang",
404 "label": "Language",
405 "prompt": "Pick a language"
406 }
407 ]
408 });
409 let questions = parse_questions(&json).unwrap();
410 assert_eq!(questions[0].label, "Language");
411 }
412
413 #[test]
414 fn test_parse_questions_empty_options() {
415 let json = serde_json::json!({
417 "questions": [
418 {
419 "id": "name",
420 "prompt": "What's your project name?",
421 "allowOther": true
422 }
423 ]
424 });
425 let questions = parse_questions(&json).unwrap();
426 assert_eq!(questions[0].options.len(), 0);
427 assert!(questions[0].allow_other);
428 }
429
430 #[test]
431 fn test_parse_questions_missing_questions() {
432 let json = serde_json::json!({});
433 let err = parse_questions(&json).unwrap_err();
434 assert!(err.contains("questions"));
435 }
436
437 #[test]
438 fn test_parse_questions_empty_array() {
439 let json = serde_json::json!({ "questions": [] });
440 let err = parse_questions(&json).unwrap_err();
441 assert!(err.contains("one question"));
442 }
443
444 #[test]
445 fn test_parse_questions_duplicate_ids() {
446 let json = serde_json::json!({
447 "questions": [
448 { "id": "a", "prompt": "Q1" },
449 { "id": "a", "prompt": "Q2" }
450 ]
451 });
452 let err = parse_questions(&json).unwrap_err();
453 assert!(err.contains("Duplicate"));
454 }
455
456 #[test]
457 fn test_format_answers_selected() {
458 let answers = vec![Answer {
459 id: "lang".into(),
460 value: "rust".into(),
461 label: "Rust".into(),
462 was_custom: false,
463 index: Some(1),
464 }];
465 let text = format_answers(&answers);
466 assert_eq!(text, "lang: user selected: 1. Rust");
467 }
468
469 #[test]
470 fn test_format_answers_custom() {
471 let answers = vec![Answer {
472 id: "name".into(),
473 value: "myproj".into(),
474 label: "myproj".into(),
475 was_custom: true,
476 index: None,
477 }];
478 let text = format_answers(&answers);
479 assert_eq!(text, "name: user wrote: myproj");
480 }
481
482 #[test]
483 fn test_format_answers_multi() {
484 let answers = vec![
485 Answer {
486 id: "lang".into(),
487 value: "rust".into(),
488 label: "Rust".into(),
489 was_custom: false,
490 index: Some(1),
491 },
492 Answer {
493 id: "db".into(),
494 value: "pg".into(),
495 label: "PostgreSQL".into(),
496 was_custom: false,
497 index: Some(2),
498 },
499 Answer {
500 id: "auth".into(),
501 value: "jwt".into(),
502 label: "jwt".into(),
503 was_custom: true,
504 index: None,
505 },
506 ];
507 let text = format_answers(&answers);
508 assert_eq!(
509 text,
510 "lang: user selected: 1. Rust\ndb: user selected: 2. PostgreSQL\nauth: user wrote: jwt"
511 );
512 }
513
514 #[test]
515 fn test_bridge_set_take() {
516 let bridge = QuestionnaireBridge::new();
517 assert!(!bridge.has_pending());
518
519 let (tx, _rx) = oneshot::channel();
520 let pending = PendingQuestionnaire {
521 questions: vec![],
522 responder: tx,
523 };
524 assert!(bridge.set(pending));
525 assert!(bridge.has_pending());
526
527 let taken = bridge.try_take();
528 assert!(taken.is_some());
529 assert!(!bridge.has_pending());
530
531 assert!(bridge.try_take().is_none());
533 }
534
535 #[test]
536 fn test_bridge_set_idempotent() {
537 let bridge = QuestionnaireBridge::new();
538 let (tx1, _rx1) = oneshot::channel();
539 let (tx2, _rx2) = oneshot::channel();
540
541 bridge.set(PendingQuestionnaire {
542 questions: vec![],
543 responder: tx1,
544 });
545 assert!(!bridge.set(PendingQuestionnaire {
546 questions: vec![],
547 responder: tx2
548 }));
549 }
550}