agent_sdk/
user_interaction.rs

1//! User interaction types and tools.
2//!
3//! This module provides types and tools for agent-user interaction:
4//!
5//! - [`ConfirmationRequest`] / [`ConfirmationResponse`] - For tool confirmations
6//! - [`QuestionRequest`] / [`QuestionResponse`] - For agent-initiated questions
7//! - [`AskUserQuestionTool`] - Tool that allows agents to ask questions
8//!
9//! # Confirmation Flow
10//!
11//! When an agent needs to execute a tool that requires confirmation:
12//!
13//! 1. Agent hooks create a [`ConfirmationRequest`]
14//! 2. UI displays the request to the user
15//! 3. User responds with [`ConfirmationResponse`]
16//! 4. Agent proceeds based on the response
17//!
18//! # Question Flow
19//!
20//! When an agent needs to ask the user a question:
21//!
22//! 1. Agent calls `ask_user` tool with a [`QuestionRequest`]
23//! 2. UI displays the question to the user
24//! 3. User responds with [`QuestionResponse`]
25//! 4. Agent receives the answer and continues
26//!
27//! # Example
28//!
29//! ```no_run
30//! use agent_sdk::user_interaction::{AskUserQuestionTool, QuestionRequest, QuestionResponse};
31//! use tokio::sync::mpsc;
32//!
33//! // Create channels for communication
34//! let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
35//!
36//! // The tool can be registered with the agent's tool registry
37//! // The UI handles requests and responses through the channels
38//! ```
39
40use crate::{Tool, ToolContext, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use async_trait::async_trait;
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, json};
45use tokio::sync::mpsc;
46
47/// Request for user confirmation of a tool execution.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ConfirmationRequest {
50    /// Name of the tool requiring confirmation.
51    pub tool_name: String,
52
53    /// Human-readable description of the action.
54    pub description: String,
55
56    /// Preview of the tool input (JSON formatted).
57    pub input_preview: String,
58
59    /// The tier of the tool (serialized as string).
60    pub tier: String,
61
62    /// Agent's recent reasoning text that led to this tool call.
63    pub context: Option<String>,
64}
65
66impl ConfirmationRequest {
67    /// Creates a new confirmation request.
68    #[must_use]
69    pub fn new(
70        tool_name: impl Into<String>,
71        description: impl Into<String>,
72        input_preview: impl Into<String>,
73        tier: ToolTier,
74    ) -> Self {
75        Self {
76            tool_name: tool_name.into(),
77            description: description.into(),
78            input_preview: input_preview.into(),
79            tier: format!("{tier:?}"),
80            context: None,
81        }
82    }
83
84    /// Adds context to the request.
85    #[must_use]
86    pub fn with_context(mut self, context: impl Into<String>) -> Self {
87        self.context = Some(context.into());
88        self
89    }
90}
91
92/// Response to a confirmation request.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum ConfirmationResponse {
96    /// User approved the tool execution.
97    Approved,
98
99    /// User denied the tool execution.
100    Denied,
101
102    /// User wants to approve all future requests for this tool.
103    ApproveAll,
104}
105
106/// Request for user to answer a question from the agent.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct QuestionRequest {
109    /// The question text to display.
110    pub question: String,
111
112    /// Optional header/category for the question.
113    pub header: Option<String>,
114
115    /// Available options (if multiple choice).
116    /// Empty means free-form text input.
117    pub options: Vec<QuestionOption>,
118
119    /// Whether multiple options can be selected.
120    pub multi_select: bool,
121}
122
123impl QuestionRequest {
124    /// Creates a new free-form question request.
125    #[must_use]
126    pub fn new(question: impl Into<String>) -> Self {
127        Self {
128            question: question.into(),
129            header: None,
130            options: Vec::new(),
131            multi_select: false,
132        }
133    }
134
135    /// Creates a new multiple-choice question request.
136    #[must_use]
137    pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
138        Self {
139            question: question.into(),
140            header: None,
141            options,
142            multi_select: false,
143        }
144    }
145
146    /// Adds a header to the question.
147    #[must_use]
148    pub fn with_header(mut self, header: impl Into<String>) -> Self {
149        self.header = Some(header.into());
150        self
151    }
152
153    /// Enables multi-select mode.
154    #[must_use]
155    pub const fn with_multi_select(mut self) -> Self {
156        self.multi_select = true;
157        self
158    }
159}
160
161/// An option in a multiple-choice question.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct QuestionOption {
164    /// The display label for this option.
165    pub label: String,
166
167    /// Optional description explaining this option.
168    pub description: Option<String>,
169}
170
171impl QuestionOption {
172    /// Creates a new option with just a label.
173    #[must_use]
174    pub fn new(label: impl Into<String>) -> Self {
175        Self {
176            label: label.into(),
177            description: None,
178        }
179    }
180
181    /// Creates a new option with label and description.
182    #[must_use]
183    pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
184        Self {
185            label: label.into(),
186            description: Some(description.into()),
187        }
188    }
189}
190
191/// Response to a question request.
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct QuestionResponse {
194    /// The user's answer (text or selected option labels).
195    pub answer: String,
196
197    /// Whether the user cancelled/skipped the question.
198    pub cancelled: bool,
199}
200
201impl QuestionResponse {
202    /// Creates a new successful response.
203    #[must_use]
204    pub fn success(answer: impl Into<String>) -> Self {
205        Self {
206            answer: answer.into(),
207            cancelled: false,
208        }
209    }
210
211    /// Creates a cancelled response.
212    #[must_use]
213    pub const fn cancelled() -> Self {
214        Self {
215            answer: String::new(),
216            cancelled: true,
217        }
218    }
219}
220
221/// Tool that allows the agent to ask questions to the user.
222///
223/// This is essential for:
224/// - Clarifying ambiguous requirements
225/// - Offering choices between approaches
226/// - Getting user preferences
227/// - Confirming before major changes
228pub struct AskUserQuestionTool {
229    /// Channel to send questions to the UI.
230    question_tx: mpsc::Sender<QuestionRequest>,
231
232    /// Channel to receive answers from the UI.
233    question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
234}
235
236impl AskUserQuestionTool {
237    /// Creates a new tool with the given channels.
238    #[must_use]
239    pub fn new(
240        question_tx: mpsc::Sender<QuestionRequest>,
241        question_rx: mpsc::Receiver<QuestionResponse>,
242    ) -> Self {
243        Self {
244            question_tx,
245            question_rx: tokio::sync::Mutex::new(question_rx),
246        }
247    }
248
249    /// Creates a new tool with fresh channels.
250    ///
251    /// Returns `(tool, request_receiver, response_sender)` where:
252    /// - `tool` is the `AskUserQuestionTool` instance
253    /// - `request_receiver` receives questions from the agent
254    /// - `response_sender` sends user answers back to the agent
255    #[must_use]
256    pub fn with_channels(
257        buffer_size: usize,
258    ) -> (
259        Self,
260        mpsc::Receiver<QuestionRequest>,
261        mpsc::Sender<QuestionResponse>,
262    ) {
263        let (request_tx, request_rx) = mpsc::channel(buffer_size);
264        let (response_tx, response_rx) = mpsc::channel(buffer_size);
265
266        let tool = Self::new(request_tx, response_rx);
267        (tool, request_rx, response_tx)
268    }
269}
270
271/// Input schema for the `AskUserQuestion` tool.
272#[derive(Debug, Deserialize, Serialize)]
273struct AskUserInput {
274    /// The question to ask the user.
275    question: String,
276
277    /// Optional header/category for the question.
278    #[serde(default)]
279    header: Option<String>,
280
281    /// Optional list of choices for multiple-choice questions.
282    #[serde(default)]
283    options: Vec<OptionInput>,
284
285    /// Whether multiple options can be selected.
286    #[serde(default)]
287    multi_select: bool,
288}
289
290/// Input for a single option.
291#[derive(Debug, Deserialize, Serialize)]
292struct OptionInput {
293    /// The display label.
294    label: String,
295
296    /// Optional description.
297    #[serde(default)]
298    description: Option<String>,
299}
300
301#[async_trait]
302impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
303    fn name(&self) -> &'static str {
304        "ask_user"
305    }
306
307    fn description(&self) -> &'static str {
308        "Ask the user a question to get clarification, preferences, or choices. \
309         Use this when you need user input before proceeding. For yes/no confirmations \
310         of dangerous operations, tool confirmation will be shown automatically - \
311         use this tool for open-ended questions or when offering choices."
312    }
313
314    fn input_schema(&self) -> Value {
315        json!({
316            "type": "object",
317            "required": ["question"],
318            "properties": {
319                "question": {
320                    "type": "string",
321                    "description": "The question to ask the user. Be clear and specific."
322                },
323                "header": {
324                    "type": "string",
325                    "description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
326                },
327                "options": {
328                    "type": "array",
329                    "description": "Optional list of choices for multiple-choice questions",
330                    "items": {
331                        "type": "object",
332                        "required": ["label"],
333                        "properties": {
334                            "label": {
335                                "type": "string",
336                                "description": "The option text to display"
337                            },
338                            "description": {
339                                "type": "string",
340                                "description": "Optional explanation of this option"
341                            }
342                        }
343                    }
344                },
345                "multi_select": {
346                    "type": "boolean",
347                    "description": "Whether multiple options can be selected (default: false)"
348                }
349            }
350        })
351    }
352
353    fn tier(&self) -> ToolTier {
354        // Questions don't modify anything, but they do require user interaction
355        ToolTier::Observe
356    }
357
358    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
359        // Parse input
360        let input: AskUserInput =
361            serde_json::from_value(input).context("Invalid input for ask_user tool")?;
362
363        // Build request
364        let request = QuestionRequest {
365            question: input.question.clone(),
366            header: input.header,
367            options: input
368                .options
369                .into_iter()
370                .map(|o| QuestionOption {
371                    label: o.label,
372                    description: o.description,
373                })
374                .collect(),
375            multi_select: input.multi_select,
376        };
377
378        // Send question to UI
379        self.question_tx
380            .send(request)
381            .await
382            .context("Failed to send question to UI - channel closed")?;
383
384        // Wait for response
385        let response = {
386            let mut rx = self.question_rx.lock().await;
387            rx.recv()
388                .await
389                .context("Failed to receive answer from UI - channel closed")?
390        };
391
392        if response.cancelled {
393            Ok(ToolResult::error(
394                "User cancelled the question without providing an answer.",
395            ))
396        } else {
397            Ok(ToolResult::success(format!(
398                "User answered: {}",
399                response.answer
400            )))
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::Tool;
409
410    #[test]
411    fn test_confirmation_request_new() {
412        let req =
413            ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
414        assert_eq!(req.tool_name, "write");
415        assert!(req.context.is_none());
416    }
417
418    #[test]
419    fn test_confirmation_request_with_context() {
420        let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
421            .with_context("Agent was fixing a bug");
422        assert!(req.context.is_some());
423        assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
424    }
425
426    #[test]
427    fn test_confirmation_response_serialization() {
428        assert_eq!(
429            serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
430            "\"approved\""
431        );
432        assert_eq!(
433            serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
434            "\"denied\""
435        );
436        assert_eq!(
437            serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
438            "\"approve_all\""
439        );
440    }
441
442    #[test]
443    fn test_question_request_new() {
444        let req = QuestionRequest::new("What color?");
445        assert_eq!(req.question, "What color?");
446        assert!(req.options.is_empty());
447        assert!(!req.multi_select);
448    }
449
450    #[test]
451    fn test_question_request_with_options() {
452        let req = QuestionRequest::with_options(
453            "Which framework?",
454            vec![
455                QuestionOption::new("React"),
456                QuestionOption::with_description("Vue", "Progressive framework"),
457            ],
458        )
459        .with_header("Framework")
460        .with_multi_select();
461
462        assert_eq!(req.options.len(), 2);
463        assert!(req.multi_select);
464        assert_eq!(req.header.unwrap(), "Framework");
465    }
466
467    #[test]
468    fn test_question_response() {
469        let success = QuestionResponse::success("Blue");
470        assert!(!success.cancelled);
471        assert_eq!(success.answer, "Blue");
472
473        let cancelled = QuestionResponse::cancelled();
474        assert!(cancelled.cancelled);
475    }
476
477    #[tokio::test]
478    async fn test_ask_user_tool_creation() {
479        let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
480
481        // Use Tool<()> explicitly to satisfy type inference
482        assert_eq!(Tool::<()>::name(&tool), "ask_user");
483        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
484    }
485
486    #[tokio::test]
487    async fn test_ask_user_tool_execute() {
488        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
489
490        // Spawn task to handle the question
491        let handle = tokio::spawn(async move {
492            if let Some(request) = request_rx.recv().await {
493                assert_eq!(request.question, "What color?");
494                response_tx
495                    .send(QuestionResponse::success("Blue"))
496                    .await
497                    .unwrap();
498            }
499        });
500
501        let ctx = ToolContext::new(());
502        let result = tool
503            .execute(
504                &ctx,
505                json!({
506                    "question": "What color?"
507                }),
508            )
509            .await
510            .unwrap();
511
512        handle.await.unwrap();
513
514        assert!(result.success);
515        assert!(result.output.contains("Blue"));
516    }
517
518    #[tokio::test]
519    async fn test_ask_user_with_options() {
520        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
521
522        let handle = tokio::spawn(async move {
523            if let Some(request) = request_rx.recv().await {
524                assert_eq!(request.options.len(), 2);
525                assert_eq!(request.options[0].label, "Option A");
526                response_tx
527                    .send(QuestionResponse::success("Option A"))
528                    .await
529                    .unwrap();
530            }
531        });
532
533        let ctx = ToolContext::new(());
534        let result = tool
535            .execute(
536                &ctx,
537                json!({
538                    "question": "Which option?",
539                    "options": [
540                        {"label": "Option A", "description": "First choice"},
541                        {"label": "Option B", "description": "Second choice"}
542                    ]
543                }),
544            )
545            .await
546            .unwrap();
547
548        handle.await.unwrap();
549        assert!(result.success);
550    }
551
552    #[tokio::test]
553    async fn test_ask_user_cancelled() {
554        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
555
556        let handle = tokio::spawn(async move {
557            if request_rx.recv().await.is_some() {
558                response_tx
559                    .send(QuestionResponse::cancelled())
560                    .await
561                    .unwrap();
562            }
563        });
564
565        let ctx = ToolContext::new(());
566        let result = tool
567            .execute(
568                &ctx,
569                json!({
570                    "question": "Continue?"
571                }),
572            )
573            .await
574            .unwrap();
575
576        handle.await.unwrap();
577        assert!(!result.success);
578        assert!(result.output.contains("cancelled"));
579    }
580}