1use serde::{Deserialize, Serialize};
22use std::sync::Arc;
23use std::sync::atomic::{AtomicBool, Ordering};
24use std::time::Duration;
25use tokio::sync::oneshot;
26
27use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
28use async_trait::async_trait;
29
30#[derive(Clone)]
33pub struct AskBridge {
34 inner: Arc<parking_lot::Mutex<Option<PendingAsk>>>,
35 ui_attached: Arc<AtomicBool>,
39 session_id: Arc<parking_lot::Mutex<Option<String>>>,
46 timeout: Option<Duration>,
49}
50
51impl AskBridge {
52 pub fn new() -> Self {
54 Self {
55 inner: Arc::new(parking_lot::Mutex::new(None)),
56 ui_attached: Arc::new(AtomicBool::new(false)),
57 session_id: Arc::new(parking_lot::Mutex::new(None)),
58 timeout: None,
59 }
60 }
61
62 pub fn with_timeout(timeout: Option<Duration>) -> Self {
64 Self {
65 timeout,
66 ..Self::new()
67 }
68 }
69
70 pub fn attach_with_session(&self, session_id: impl Into<String>) {
77 let id = session_id.into();
78 debug_assert!(
79 !id.is_empty(),
80 "AskBridge::attach_with_session called with empty session_id"
81 );
82 *self.session_id.lock() = Some(id);
83 self.ui_attached.store(true, Ordering::SeqCst);
84 }
85
86 pub fn is_ui_attached(&self) -> bool {
88 self.ui_attached.load(Ordering::SeqCst)
89 }
90
91 #[cfg(any(test, debug_assertions))]
95 pub fn attach(&self) {
96 self.ui_attached.store(true, Ordering::SeqCst);
97 }
98
99 pub fn session_id(&self) -> Option<String> {
101 self.session_id.lock().clone()
102 }
103 pub fn timeout(&self) -> Option<Duration> {
105 self.timeout
106 }
107
108 pub fn set(&self, pending: PendingAsk) -> bool {
112 let mut lock = self.inner.lock();
113 if lock.is_some() {
114 return false;
115 }
116 *lock = Some(pending);
117 true
118 }
119
120 pub fn try_take(&self) -> Option<PendingAsk> {
123 self.inner.lock().take()
124 }
125
126 pub fn has_pending(&self) -> bool {
128 self.inner.lock().is_some()
129 }
130}
131
132impl Default for AskBridge {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138pub struct PendingAsk {
142 pub questions: Vec<Question>,
144 pub responder: oneshot::Sender<AskResponse>,
147 pub timeout: Option<Duration>,
149 pub session_id: Option<String>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Question {
158 pub id: String,
160 #[serde(default)]
163 pub label: String,
164 pub prompt: String,
166 #[serde(default)]
168 pub options: Vec<QuestionOption>,
169 #[serde(default = "default_true")]
173 pub allow_other: bool,
174 #[serde(default)]
176 pub multi_select: bool,
177 #[serde(default)]
181 pub recommended: Option<usize>,
182}
183
184fn default_true() -> bool {
185 true
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct QuestionOption {
191 pub value: String,
193 pub label: String,
195 pub description: Option<String>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct AskResponse {
202 pub answers: Vec<Answer>,
204 pub cancelled: bool,
206 #[serde(default)]
208 pub timed_out: bool,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct Answer {
214 pub id: String,
216 pub value: String,
218 pub label: String,
220 pub was_custom: bool,
222 pub index: Option<usize>,
224}
225
226pub struct AskTool {
230 bridge: Arc<AskBridge>,
231}
232
233impl AskTool {
234 pub fn new(bridge: Arc<AskBridge>) -> Self {
236 Self { bridge }
237 }
238}
239
240impl Clone for AskTool {
243 fn clone(&self) -> Self {
244 Self {
245 bridge: self.bridge.clone(),
246 }
247 }
248}
249
250#[async_trait]
251impl AgentTool for AskTool {
252 fn name(&self) -> &str {
253 "ask"
254 }
255
256 fn label(&self) -> &str {
257 "Ask"
258 }
259
260 fn description(&self) -> &str {
261 "Ask the user a clarifying question when choices have materially \
262 different tradeoffs the user must decide. Default to action — pick \
263 the conservative/standard option and proceed when a reasonable \
264 default exists; only ask when the user must weigh the tradeoff. Do \
265 NOT include an 'Other' option — the UI appends 'Other (type your \
266 own)' automatically. Use 'recommended' (0-indexed) to mark the \
267 default; a '(Recommended)' suffix is added automatically. Set \
268 'multiSelect' true to allow multiple selections. Provide 2-5 \
269 concise options with short labels; put explanatory tradeoffs in \
270 'description'. Batch related questions in one call via 'questions'."
271 }
272
273 fn parameters_schema(&self) -> serde_json::Value {
274 serde_json::json!({
275 "type": "object",
276 "properties": {
277 "questions": {
278 "type": "array",
279 "description": "Questions to ask the user",
280 "items": {
281 "type": "object",
282 "properties": {
283 "id": {
284 "type": "string",
285 "description": "Unique identifier for this question"
286 },
287 "label": {
288 "type": "string",
289 "description": "Short contextual label (defaults to the id)"
290 },
291 "prompt": {
292 "type": "string",
293 "description": "The full question text to display"
294 },
295 "options": {
296 "type": "array",
297 "description": "Available options (2-5). Do NOT include 'Other' — the UI adds it automatically.",
298 "default": [],
299 "items": {
300 "type": "object",
301 "properties": {
302 "value": {
303 "type": "string",
304 "description": "The value returned when selected"
305 },
306 "label": {
307 "type": "string",
308 "description": "Short display label for the option"
309 },
310 "description": {
311 "type": "string",
312 "description": "Optional explanatory tradeoff shown below the label"
313 }
314 },
315 "required": ["value", "label"]
316 }
317 },
318 "allowOther": {
319 "type": "boolean",
320 "description": "Show 'Other (type your own)' (default: true)",
321 "default": true
322 },
323 "multiSelect": {
324 "type": "boolean",
325 "description": "Allow multiple selections (default: false)",
326 "default": false
327 },
328 "recommended": {
329 "type": "number",
330 "description": "Recommended option index (0-based). Marks the default and is used for timeout auto-selection.",
331 "minimum": 0
332 }
333 },
334 "required": ["id", "prompt"]
335 }
336 },
337 },
338 "required": ["questions"]
339 })
340 }
341
342 async fn execute(
343 &self,
344 _tool_call_id: &str,
345 params: serde_json::Value,
346 signal: Option<oneshot::Receiver<()>>,
347 _ctx: &ToolContext,
348 ) -> Result<AgentToolResult, ToolError> {
349 if !self.bridge.is_ui_attached() {
351 return Ok(AgentToolResult::error(
352 "Ask requires interactive TUI mode. \
353 Not available in --print or RPC mode.",
354 ));
355 }
356
357 let session_id = self.bridge.session_id();
364 debug_assert!(
365 session_id.as_deref().is_some_and(|s| !s.is_empty()),
366 "AskBridge was attached without a non-empty session_id; refusing to run"
367 );
368
369 let questions = parse_questions(¶ms)?;
371 let timeout = self.bridge.timeout();
372
373 let (tx, rx) = oneshot::channel();
375
376 if !self.bridge.set(PendingAsk {
378 questions,
379 responder: tx,
380 timeout,
381 session_id,
382 }) {
383 return Ok(AgentToolResult::error("Another ask is already pending"));
384 }
385
386 select_with_abort(rx, signal, &self.bridge).await
388 }
389}
390
391async fn select_with_abort(
393 rx: oneshot::Receiver<AskResponse>,
394 signal: Option<oneshot::Receiver<()>>,
395 bridge: &AskBridge,
396) -> Result<AgentToolResult, ToolError> {
397 let abort = async {
399 if let Some(sig) = signal {
400 let _ = sig.await;
401 } else {
402 std::future::pending::<()>().await;
403 }
404 };
405
406 tokio::select! {
407 response = rx => {
408 match response {
409 Ok(resp) => {
410 if resp.cancelled {
411 Ok(AgentToolResult::success("User cancelled the question"))
412 } else {
413 Ok(AgentToolResult::success(format_answers(
414 &resp.answers,
415 resp.timed_out,
416 )))
417 }
418 }
419 Err(_) => {
420 Ok(AgentToolResult::success("Question dismissed"))
422 }
423 }
424 }
425 () = abort => {
426 bridge.try_take();
428 Ok(AgentToolResult::success("Question cancelled by user interrupt"))
429 }
430 }
431}
432
433fn parse_questions(params: &serde_json::Value) -> Result<Vec<Question>, ToolError> {
435 let questions = params
436 .get("questions")
437 .and_then(|v| v.as_array())
438 .cloned()
439 .ok_or_else(|| "Missing or invalid 'questions' field".to_string())?;
440
441 let questions: Vec<Question> = questions
442 .into_iter()
443 .map(|v| serde_json::from_value(v).map_err(|e| e.to_string()))
444 .collect::<Result<Vec<_>, _>>()
445 .map_err(|e| format!("Invalid question: {}", e))?;
446
447 if questions.is_empty() {
448 return Err("At least one question is required".to_string());
449 }
450
451 let questions: Vec<Question> = questions
453 .into_iter()
454 .map(|mut q| {
455 if q.label.is_empty() {
456 q.label = q.id.clone();
457 }
458 q
459 })
460 .collect();
461
462 let mut ids = std::collections::HashSet::new();
464 for q in &questions {
465 if !ids.insert(&q.id) {
466 return Err(format!("Duplicate question id: {}", q.id));
467 }
468 }
469
470 Ok(questions)
471}
472
473pub fn format_answers(answers: &[Answer], timed_out: bool) -> String {
484 let suffix = if timed_out {
485 " (auto-selected after timeout)"
486 } else {
487 ""
488 };
489 answers
490 .iter()
491 .map(|a| {
492 let base = if a.was_custom {
493 format!("{}: \"{}\"", a.id, a.label)
494 } else if a.value.contains(',') {
495 let labels: Vec<&str> = a.label.split(", ").collect();
497 format!("{}: [{}]", a.id, labels.join(", "))
498 } else {
499 format!("{}: {}", a.id, a.label)
500 };
501 format!("{base}{suffix}")
502 })
503 .collect::<Vec<_>>()
504 .join("\n")
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn test_parse_questions_valid() {
513 let json = serde_json::json!({
514 "questions": [
515 {
516 "id": "lang",
517 "prompt": "Pick a language",
518 "options": [
519 { "value": "rust", "label": "Rust" },
520 { "value": "ts", "label": "TypeScript" }
521 ]
522 }
523 ]
524 });
525 let questions = parse_questions(&json).unwrap();
526 assert_eq!(questions.len(), 1);
527 assert_eq!(questions[0].id, "lang");
528 assert_eq!(questions[0].label, "lang"); assert_eq!(questions[0].options.len(), 2);
530 assert!(questions[0].allow_other); assert!(!questions[0].multi_select); }
533
534 #[test]
535 fn test_parse_questions_with_label() {
536 let json = serde_json::json!({
537 "questions": [
538 {
539 "id": "lang",
540 "label": "Language",
541 "prompt": "Pick a language"
542 }
543 ]
544 });
545 let questions = parse_questions(&json).unwrap();
546 assert_eq!(questions[0].label, "Language");
547 }
548
549 #[test]
550 fn test_parse_questions_empty_options() {
551 let json = serde_json::json!({
553 "questions": [
554 {
555 "id": "name",
556 "prompt": "What's your project name?",
557 "allowOther": true
558 }
559 ]
560 });
561 let questions = parse_questions(&json).unwrap();
562 assert_eq!(questions[0].options.len(), 0);
563 assert!(questions[0].allow_other);
564 }
565
566 #[test]
567 fn test_parse_questions_missing_questions() {
568 let json = serde_json::json!({});
569 let err = parse_questions(&json).unwrap_err();
570 assert!(err.contains("questions"));
571 }
572
573 #[test]
574 fn test_parse_questions_empty_array() {
575 let json = serde_json::json!({ "questions": [] });
576 let err = parse_questions(&json).unwrap_err();
577 assert!(err.contains("one question"));
578 }
579
580 #[test]
581 fn test_parse_questions_duplicate_ids() {
582 let json = serde_json::json!({
583 "questions": [
584 { "id": "a", "prompt": "Q1" },
585 { "id": "a", "prompt": "Q2" }
586 ]
587 });
588 let err = parse_questions(&json).unwrap_err();
589 assert!(err.contains("Duplicate"));
590 }
591
592 #[test]
593 fn test_format_answers_single() {
594 let answers = vec![Answer {
595 id: "lang".into(),
596 value: "rust".into(),
597 label: "Rust".into(),
598 was_custom: false,
599 index: Some(1),
600 }];
601 let text = format_answers(&answers, false);
602 assert_eq!(text, "lang: Rust");
603 }
604
605 #[test]
606 fn test_format_answers_custom() {
607 let answers = vec![Answer {
608 id: "name".into(),
609 value: "myproj".into(),
610 label: "myproj".into(),
611 was_custom: true,
612 index: None,
613 }];
614 let text = format_answers(&answers, false);
615 assert_eq!(text, "name: \"myproj\"");
616 }
617
618 #[test]
619 fn test_format_answers_multi() {
620 let answers = vec![Answer {
621 id: "lang".into(),
622 value: "rust, go".into(), label: "Rust, Go".into(),
624 was_custom: false,
625 index: None,
626 }];
627 let text = format_answers(&answers, false);
628 assert_eq!(text, "lang: [Rust, Go]");
629 }
630
631 #[test]
632 fn test_format_answers_timed_out() {
633 let answers = vec![Answer {
634 id: "auth".into(),
635 value: "oauth".into(),
636 label: "OAuth2".into(),
637 was_custom: false,
638 index: Some(2),
639 }];
640 let text = format_answers(&answers, true);
641 assert_eq!(text, "auth: OAuth2 (auto-selected after timeout)");
642 }
643
644 #[test]
645 fn test_bridge_set_take() {
646 let bridge = AskBridge::new();
647 assert!(!bridge.has_pending());
648
649 let (tx, _rx) = oneshot::channel();
650 let pending = PendingAsk {
651 questions: vec![],
652 responder: tx,
653 timeout: None,
654 session_id: None,
655 };
656 assert!(bridge.set(pending));
657 assert!(bridge.has_pending());
658
659 let taken = bridge.try_take();
660 assert!(taken.is_some());
661 assert!(!bridge.has_pending());
662
663 assert!(bridge.try_take().is_none());
665 }
666
667 #[test]
668 fn test_bridge_set_idempotent() {
669 let bridge = AskBridge::new();
670 let (tx1, _rx1) = oneshot::channel();
671 let (tx2, _rx2) = oneshot::channel();
672
673 bridge.set(PendingAsk {
674 questions: vec![],
675 responder: tx1,
676 timeout: None,
677 session_id: None,
678 });
679 assert!(!bridge.set(PendingAsk {
680 questions: vec![],
681 responder: tx2,
682 timeout: None,
683 session_id: None,
684 }));
685 }
686
687 #[test]
688 fn test_ui_attached_flag() {
689 let bridge = AskBridge::new();
690 assert!(!bridge.is_ui_attached());
691 bridge.attach();
692 assert!(bridge.is_ui_attached());
693 }
694
695 #[test]
696 fn test_bridge_with_timeout() {
697 let bridge = AskBridge::with_timeout(Some(Duration::from_secs(30)));
698 assert_eq!(bridge.timeout(), Some(Duration::from_secs(30)));
699 assert!(!bridge.is_ui_attached()); let no_timeout = AskBridge::new();
702 assert_eq!(no_timeout.timeout(), None);
703 }
704
705 #[test]
706 fn test_question_deserializes_without_recommended() {
707 let json = serde_json::json!({
709 "id": "test",
710 "prompt": "Test question?",
711 "options": [{"value": "a", "label": "A"}]
712 });
713 let q: Question = serde_json::from_value(json).unwrap();
714 assert_eq!(q.recommended, None);
715 }
716
717 #[test]
718 fn test_question_deserializes_with_recommended() {
719 let json = serde_json::json!({
720 "id": "test",
721 "prompt": "Test question?",
722 "options": [{"value": "a", "label": "A"}, {"value": "b", "label": "B"}],
723 "recommended": 1
724 });
725 let q: Question = serde_json::from_value(json).unwrap();
726 assert_eq!(q.recommended, Some(1));
727 }
728
729 #[test]
730 fn test_tool_name_is_ask() {
731 let bridge = Arc::new(AskBridge::new());
732 let tool = AskTool::new(bridge);
733 assert_eq!(tool.name(), "ask");
734 assert_eq!(tool.label(), "Ask");
735 }
736
737 #[test]
738 fn test_attach_with_session_stores_id() {
739 let bridge = AskBridge::new();
740 assert!(!bridge.is_ui_attached());
741 assert_eq!(bridge.session_id(), None);
742 bridge.attach_with_session("tui");
743 assert!(bridge.is_ui_attached());
744 assert_eq!(bridge.session_id().as_deref(), Some("tui"));
745 }
746
747 #[test]
748 fn test_format_answers_multi_with_comma_label() {
749 let answers = vec![Answer {
753 id: "tags".into(),
754 value: "a,b".into(),
755 label: "A, B".into(),
756 was_custom: false,
757 index: None,
758 }];
759 let text = format_answers(&answers, false);
760 assert_eq!(text, "tags: [A, B]");
761 }
762
763 #[test]
764 fn test_format_answers_cancelled_marker() {
765 let answers = vec![Answer {
766 id: "q1".into(),
767 value: String::new(),
768 label: String::new(),
769 was_custom: false,
770 index: None,
771 }];
772 let text = format_answers(&answers, false);
776 assert_eq!(text, "q1: ");
777 }
778}