Skip to main content

oxi_agent/tools/
questionnaire.rs

1//! Questionnaire tool — ask the user one or more questions via TUI overlay.
2//!
3//! Architecture:
4//! - `QuestionnaireBridge` is created in `oxi-cli` and shared (via `Arc`) between
5//!   `QuestionnaireTool` (agent thread) and `AppState` (TUI main thread).
6//! - When the tool executes, it creates a oneshot channel and stores (questions, sender)
7//!   in the bridge.
8//! - The TUI main loop polls the bridge, and if a pending questionnaire is found,
9//!   creates a `QuestionnaireOverlay` to display it.
10//! - User interaction drives the overlay to send a `QuestionnaireResponse` via the
11//!   oneshot `Sender`. The tool's `execute()` receives it via `rx.await`.
12//! - Abort (Ctrl+C) is handled via `tokio::select!` with the abort signal.
13
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tokio::sync::oneshot;
18
19use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
20
21/// Shared bridge between the questionnaire tool (agent thread) and the TUI
22/// overlay (main thread). Created in `oxi-cli`, injected into both the tool
23/// and `AppState`.
24#[derive(Clone)]
25pub struct QuestionnaireBridge {
26    inner: Arc<parking_lot::Mutex<Option<PendingQuestionnaire>>>,
27}
28
29impl QuestionnaireBridge {
30    /// Create a new empty bridge.
31    pub fn new() -> Self {
32        Self {
33            inner: Arc::new(parking_lot::Mutex::new(None)),
34        }
35    }
36
37    /// Store a pending questionnaire. Called by `QuestionnaireTool::execute`.
38    /// Returns `false` if another questionnaire is already pending (should not
39    /// happen in sequential tool execution, but guards against races).
40    pub fn set(&self, pending: PendingQuestionnaire) -> bool {
41        let mut lock = self.inner.lock();
42        if lock.is_some() {
43            return false;
44        }
45        *lock = Some(pending);
46        true
47    }
48
49    /// Try to take the pending questionnaire. Called by the TUI main loop polling.
50    /// Returns `None` if nothing is pending or already taken.
51    pub fn try_take(&self) -> Option<PendingQuestionnaire> {
52        self.inner.lock().take()
53    }
54
55    /// Returns `true` if a questionnaire is currently pending.
56    pub fn has_pending(&self) -> bool {
57        self.inner.lock().is_some()
58    }
59}
60
61impl Default for QuestionnaireBridge {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67/// A pending questionnaire waiting for user interaction.
68/// The `responder` is a oneshot `Sender` — the overlay calls `send()` when
69/// the user submits or cancels, and the tool's `rx.await` receives it.
70pub struct PendingQuestionnaire {
71    /// Questions to display to the user.
72    pub questions: Vec<Question>,
73    /// Sender end of the response channel. Dropping this (without sending) is
74    /// equivalent to user dismiss.
75    pub responder: oneshot::Sender<QuestionnaireResponse>,
76}
77
78/// A single question to ask the user.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Question {
81    /// Unique identifier for this question.
82    pub id: String,
83    /// Short contextual label for the tab bar. Defaults to "Q1", "Q2", etc.
84    #[serde(default)]
85    pub label: String,
86    /// The full question text to display.
87    pub prompt: String,
88    /// Available options. Can be empty when `allow_other` is `true`.
89    #[serde(default)]
90    pub options: Vec<QuestionOption>,
91    /// Whether to show "Type something..." option. Defaults to `true`.
92    #[serde(default = "default_true")]
93    pub allow_other: bool,
94    /// Whether multiple options can be selected. Defaults to `false`.
95    #[serde(default)]
96    pub multi_select: bool,
97}
98
99fn default_true() -> bool {
100    true
101}
102
103/// An option within a question.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct QuestionOption {
106    /// Value returned when this option is selected.
107    pub value: String,
108    /// Display label for the option.
109    pub label: String,
110    /// Optional description shown below the label.
111    pub description: Option<String>,
112}
113
114/// Response from user interaction.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct QuestionnaireResponse {
117    /// All answers collected.
118    pub answers: Vec<Answer>,
119    /// `true` if the user cancelled (Esc).
120    pub cancelled: bool,
121}
122
123/// A single answer to a question.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct Answer {
126    /// Question ID this answer belongs to.
127    pub id: String,
128    /// The value selected or entered.
129    pub value: String,
130    /// Display label.
131    pub label: String,
132    /// `true` if the user typed custom text (allowOther).
133    pub was_custom: bool,
134    /// 1-based index of the selected option. `None` for custom input.
135    pub index: Option<usize>,
136}
137
138// ── Tool ───────────────────────────────────────────────────────────────────
139
140/// The questionnaire tool — asks the user one or more questions via TUI overlay.
141pub struct QuestionnaireTool {
142    bridge: Arc<QuestionnaireBridge>,
143}
144
145impl QuestionnaireTool {
146    /// Create a new `QuestionnaireTool` that communicates via the given bridge.
147    pub fn new(bridge: Arc<QuestionnaireBridge>) -> Self {
148        Self { bridge }
149    }
150}
151
152// `Clone` is needed because ToolRegistry stores `Arc<dyn AgentTool>`.
153// `QuestionnaireTool` is cheap to clone (only copies the Arc).
154impl Clone for QuestionnaireTool {
155    fn clone(&self) -> Self {
156        Self {
157            bridge: self.bridge.clone(),
158        }
159    }
160}
161
162#[async_trait]
163impl AgentTool for QuestionnaireTool {
164    fn name(&self) -> &str {
165        "questionnaire"
166    }
167
168    fn label(&self) -> &str {
169        "Questionnaire"
170    }
171
172    fn description(&self) -> &str {
173        "Ask the user one or more questions. Use for clarifying requirements, \
174         getting preferences, or confirming decisions. For single questions, \
175         shows a simple option list. For multiple questions, shows a tab-based \
176         interface."
177    }
178
179    fn parameters_schema(&self) -> serde_json::Value {
180        serde_json::json!({
181            "type": "object",
182            "properties": {
183                "questions": {
184                    "type": "array",
185                    "description": "Questions to ask the user",
186                    "items": {
187                        "type": "object",
188                        "properties": {
189                            "id": {
190                                "type": "string",
191                                "description": "Unique identifier for this question"
192                            },
193                            "label": {
194                                "type": "string",
195                                "description": "Short contextual label for tab bar (defaults to Q1, Q2)"
196                            },
197                            "prompt": {
198                                "type": "string",
199                                "description": "The full question text to display"
200                            },
201                            "options": {
202                                "type": "array",
203                                "description": "Available options to choose from. Can be empty for free-text questions.",
204                                "default": [],
205                                "items": {
206                                    "type": "object",
207                                    "properties": {
208                                        "value": {
209                                            "type": "string",
210                                            "description": "The value returned when selected"
211                                        },
212                                        "label": {
213                                            "type": "string",
214                                            "description": "Display label for the option"
215                                        },
216                                        "description": {
217                                            "type": "string",
218                                            "description": "Optional description shown below label"
219                                        }
220                                    },
221                                    "required": ["value", "label"]
222                                }
223                            },
224                            "allowOther": {
225                                "type": "boolean",
226                                "description": "Allow 'Type something' option (default: true)",
227                                "default": true
228                            },
229                            "multiSelect": {
230                                "type": "boolean",
231                                "description": "Allow multiple selections (default: false)",
232                                "default": false
233                            }
234                        },
235                        "required": ["id", "prompt"]
236                    }
237                }
238            },
239            "required": ["questions"]
240        })
241    }
242
243    async fn execute(
244        &self,
245        _tool_call_id: &str,
246        params: serde_json::Value,
247        signal: Option<oneshot::Receiver<()>>,
248        _ctx: &ToolContext,
249    ) -> Result<AgentToolResult, ToolError> {
250        // 1. Parse and validate
251        let questions = parse_questions(&params)?;
252
253        // 2. Create oneshot channel
254        let (tx, rx) = oneshot::channel();
255
256        // 3. Store in bridge — TUI polls it on the main thread
257        if !self.bridge.set(PendingQuestionnaire {
258            questions,
259            responder: tx,
260        }) {
261            return Ok(AgentToolResult::error(
262                "Another questionnaire is already pending",
263            ));
264        }
265
266        // 4. Wait for user response — handle abort via tokio::select!
267        let result = select_with_abort(rx, signal, &self.bridge).await;
268
269        // 5. Format result
270        result
271    }
272}
273
274/// Wait for either the questionnaire response or the abort signal.
275async fn select_with_abort(
276    rx: oneshot::Receiver<QuestionnaireResponse>,
277    signal: Option<oneshot::Receiver<()>>,
278    bridge: &QuestionnaireBridge,
279) -> Result<AgentToolResult, ToolError> {
280    // If no abort signal, use a future that never resolves
281    let abort = async {
282        if let Some(sig) = signal {
283            let _ = sig.await;
284        } else {
285            std::future::pending::<()>().await;
286        }
287    };
288
289    tokio::select! {
290        response = rx => {
291            match response {
292                Ok(resp) => {
293                    if resp.cancelled {
294                        Ok(AgentToolResult::success("User cancelled the questionnaire"))
295                    } else {
296                        Ok(AgentToolResult::success(format_answers(&resp.answers)))
297                    }
298                }
299                Err(_) => {
300                    // Sender was dropped without sending — overlay was closed without result
301                    Ok(AgentToolResult::success("Questionnaire dismissed"))
302                }
303            }
304        }
305        () = abort => {
306            // Abort signal received (Ctrl+C) — clean up bridge
307            bridge.try_take();
308            Ok(AgentToolResult::success("Questionnaire cancelled by user interrupt"))
309        }
310    }
311}
312
313/// Parse and validate the questionnaire parameters from JSON.
314fn parse_questions(params: &serde_json::Value) -> Result<Vec<Question>, ToolError> {
315    let questions = params
316        .get("questions")
317        .and_then(|v| v.as_array())
318        .cloned()
319        .ok_or_else(|| "Missing or invalid 'questions' field".to_string())?;
320
321    let questions: Vec<Question> = questions
322        .into_iter()
323        .map(|v| serde_json::from_value(v).map_err(|e| e.to_string()))
324        .collect::<Result<Vec<_>, _>>()
325        .map_err(|e| format!("Invalid question: {}", e))?;
326
327    if questions.is_empty() {
328        return Err("At least one question is required".to_string());
329    }
330
331    // Assign default labels if not provided
332    let questions: Vec<Question> = questions
333        .into_iter()
334        .enumerate()
335        .map(|(i, mut q)| {
336            if q.label.is_empty() {
337                q.label = format!("Q{}", i + 1);
338            }
339            q
340        })
341        .collect();
342
343    // Validate question IDs are unique
344    let mut ids = std::collections::HashSet::new();
345    for q in &questions {
346        if !ids.insert(&q.id) {
347            return Err(format!("Duplicate question id: {}", q.id));
348        }
349    }
350
351    Ok(questions)
352}
353
354/// Format answers into a human-readable text for the tool result.
355fn format_answers(answers: &[Answer]) -> String {
356    answers
357        .iter()
358        .map(|a| {
359            if a.was_custom {
360                format!("{}: user wrote: {}", a.id, a.label)
361            } else if let Some(idx) = a.index {
362                format!("{}: user selected: {}. {}", a.id, idx, a.label)
363            } else {
364                format!("{}: user selected: {}", a.id, a.label)
365            }
366        })
367        .collect::<Vec<_>>()
368        .join("\n")
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_parse_questions_valid() {
377        let json = serde_json::json!({
378            "questions": [
379                {
380                    "id": "lang",
381                    "prompt": "Pick a language",
382                    "options": [
383                        { "value": "rust", "label": "Rust" },
384                        { "value": "ts", "label": "TypeScript" }
385                    ]
386                }
387            ]
388        });
389        let questions = parse_questions(&json).unwrap();
390        assert_eq!(questions.len(), 1);
391        assert_eq!(questions[0].id, "lang");
392        assert_eq!(questions[0].label, "Q1"); // default label (not id)
393        assert_eq!(questions[0].options.len(), 2);
394        assert!(questions[0].allow_other); // default
395        assert!(!questions[0].multi_select); // default
396    }
397
398    #[test]
399    fn test_parse_questions_with_label() {
400        let json = serde_json::json!({
401            "questions": [
402                {
403                    "id": "lang",
404                    "label": "Language",
405                    "prompt": "Pick a language"
406                }
407            ]
408        });
409        let questions = parse_questions(&json).unwrap();
410        assert_eq!(questions[0].label, "Language");
411    }
412
413    #[test]
414    fn test_parse_questions_empty_options() {
415        // allowOther=true + empty options = free text question
416        let json = serde_json::json!({
417            "questions": [
418                {
419                    "id": "name",
420                    "prompt": "What's your project name?",
421                    "allowOther": true
422                }
423            ]
424        });
425        let questions = parse_questions(&json).unwrap();
426        assert_eq!(questions[0].options.len(), 0);
427        assert!(questions[0].allow_other);
428    }
429
430    #[test]
431    fn test_parse_questions_missing_questions() {
432        let json = serde_json::json!({});
433        let err = parse_questions(&json).unwrap_err();
434        assert!(err.contains("questions"));
435    }
436
437    #[test]
438    fn test_parse_questions_empty_array() {
439        let json = serde_json::json!({ "questions": [] });
440        let err = parse_questions(&json).unwrap_err();
441        assert!(err.contains("one question"));
442    }
443
444    #[test]
445    fn test_parse_questions_duplicate_ids() {
446        let json = serde_json::json!({
447            "questions": [
448                { "id": "a", "prompt": "Q1" },
449                { "id": "a", "prompt": "Q2" }
450            ]
451        });
452        let err = parse_questions(&json).unwrap_err();
453        assert!(err.contains("Duplicate"));
454    }
455
456    #[test]
457    fn test_format_answers_selected() {
458        let answers = vec![Answer {
459            id: "lang".into(),
460            value: "rust".into(),
461            label: "Rust".into(),
462            was_custom: false,
463            index: Some(1),
464        }];
465        let text = format_answers(&answers);
466        assert_eq!(text, "lang: user selected: 1. Rust");
467    }
468
469    #[test]
470    fn test_format_answers_custom() {
471        let answers = vec![Answer {
472            id: "name".into(),
473            value: "myproj".into(),
474            label: "myproj".into(),
475            was_custom: true,
476            index: None,
477        }];
478        let text = format_answers(&answers);
479        assert_eq!(text, "name: user wrote: myproj");
480    }
481
482    #[test]
483    fn test_format_answers_multi() {
484        let answers = vec![
485            Answer {
486                id: "lang".into(),
487                value: "rust".into(),
488                label: "Rust".into(),
489                was_custom: false,
490                index: Some(1),
491            },
492            Answer {
493                id: "db".into(),
494                value: "pg".into(),
495                label: "PostgreSQL".into(),
496                was_custom: false,
497                index: Some(2),
498            },
499            Answer {
500                id: "auth".into(),
501                value: "jwt".into(),
502                label: "jwt".into(),
503                was_custom: true,
504                index: None,
505            },
506        ];
507        let text = format_answers(&answers);
508        assert_eq!(
509            text,
510            "lang: user selected: 1. Rust\ndb: user selected: 2. PostgreSQL\nauth: user wrote: jwt"
511        );
512    }
513
514    #[test]
515    fn test_bridge_set_take() {
516        let bridge = QuestionnaireBridge::new();
517        assert!(!bridge.has_pending());
518
519        let (tx, _rx) = oneshot::channel();
520        let pending = PendingQuestionnaire {
521            questions: vec![],
522            responder: tx,
523        };
524        assert!(bridge.set(pending));
525        assert!(bridge.has_pending());
526
527        let taken = bridge.try_take();
528        assert!(taken.is_some());
529        assert!(!bridge.has_pending());
530
531        // Second take returns None
532        assert!(bridge.try_take().is_none());
533    }
534
535    #[test]
536    fn test_bridge_set_idempotent() {
537        let bridge = QuestionnaireBridge::new();
538        let (tx1, _rx1) = oneshot::channel();
539        let (tx2, _rx2) = oneshot::channel();
540
541        bridge.set(PendingQuestionnaire {
542            questions: vec![],
543            responder: tx1,
544        });
545        assert!(!bridge.set(PendingQuestionnaire {
546            questions: vec![],
547            responder: tx2
548        }));
549    }
550}