Skip to main content

agent_sdk/
user_interaction.rs

1//! User interaction types and tools.
2//!
3//! This module provides types and tools for agent-user interaction:
4//!
5//! - [`ConfirmationRequest`] / [`ConfirmationResponse`] - For tool confirmations
6//! - [`QuestionRequest`] / [`QuestionResponse`] - For agent-initiated questions
7//! - [`AskUserQuestionTool`] - Tool that allows agents to ask questions
8//!
9//! # Confirmation Flow
10//!
11//! When an agent needs to execute a tool that requires confirmation:
12//!
13//! 1. Agent hooks create a [`ConfirmationRequest`]
14//! 2. UI displays the request to the user
15//! 3. User responds with [`ConfirmationResponse`]
16//! 4. Agent proceeds based on the response
17//!
18//! # Question Flow
19//!
20//! When an agent needs to ask the user a question:
21//!
22//! 1. Agent calls `ask_user` tool with a [`QuestionRequest`]
23//! 2. UI displays the question to the user
24//! 3. User responds with [`QuestionResponse`]
25//! 4. Agent receives the answer and continues
26//!
27//! # Example
28//!
29//! ```no_run
30//! use agent_sdk::user_interaction::{AskUserQuestionTool, QuestionRequest, QuestionResponse};
31//! use tokio::sync::mpsc;
32//!
33//! // Create channels for communication
34//! let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
35//!
36//! // The tool can be registered with the agent's tool registry
37//! // The UI handles requests and responses through the channels
38//! ```
39
40use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use serde::{Deserialize, Serialize};
43use serde_json::{Value, json};
44use tokio::sync::mpsc;
45
46/// Request for user confirmation of a tool execution.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ConfirmationRequest {
49    /// Name of the tool requiring confirmation.
50    pub tool_name: String,
51
52    /// Human-readable description of the action.
53    pub description: String,
54
55    /// Preview of the tool input (JSON formatted).
56    pub input_preview: String,
57
58    /// The tier of the tool (serialized as string).
59    pub tier: String,
60
61    /// Agent's recent reasoning text that led to this tool call.
62    pub context: Option<String>,
63}
64
65impl ConfirmationRequest {
66    /// Creates a new confirmation request.
67    #[must_use]
68    pub fn new(
69        tool_name: impl Into<String>,
70        description: impl Into<String>,
71        input_preview: impl Into<String>,
72        tier: ToolTier,
73    ) -> Self {
74        Self {
75            tool_name: tool_name.into(),
76            description: description.into(),
77            input_preview: input_preview.into(),
78            tier: format!("{tier:?}"),
79            context: None,
80        }
81    }
82
83    /// Adds context to the request.
84    #[must_use]
85    pub fn with_context(mut self, context: impl Into<String>) -> Self {
86        self.context = Some(context.into());
87        self
88    }
89}
90
91/// Response to a confirmation request.
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum ConfirmationResponse {
95    /// User approved the tool execution.
96    Approved,
97
98    /// User denied the tool execution.
99    Denied,
100
101    /// User wants to approve all future requests for this tool.
102    ApproveAll,
103}
104
105/// Request for user to answer a question from the agent.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct QuestionRequest {
108    /// The question text to display.
109    pub question: String,
110
111    /// Optional header/category for the question.
112    pub header: Option<String>,
113
114    /// Available options (if multiple choice).
115    /// Empty means free-form text input.
116    pub options: Vec<QuestionOption>,
117
118    /// Whether multiple options can be selected.
119    pub multi_select: bool,
120}
121
122impl QuestionRequest {
123    /// Creates a new free-form question request.
124    #[must_use]
125    pub fn new(question: impl Into<String>) -> Self {
126        Self {
127            question: question.into(),
128            header: None,
129            options: Vec::new(),
130            multi_select: false,
131        }
132    }
133
134    /// Creates a new multiple-choice question request.
135    #[must_use]
136    pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
137        Self {
138            question: question.into(),
139            header: None,
140            options,
141            multi_select: false,
142        }
143    }
144
145    /// Adds a header to the question.
146    #[must_use]
147    pub fn with_header(mut self, header: impl Into<String>) -> Self {
148        self.header = Some(header.into());
149        self
150    }
151
152    /// Enables multi-select mode.
153    #[must_use]
154    pub const fn with_multi_select(mut self) -> Self {
155        self.multi_select = true;
156        self
157    }
158}
159
160/// An option in a multiple-choice question.
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct QuestionOption {
163    /// The display label for this option.
164    pub label: String,
165
166    /// Optional description explaining this option.
167    pub description: Option<String>,
168}
169
170impl QuestionOption {
171    /// Creates a new option with just a label.
172    #[must_use]
173    pub fn new(label: impl Into<String>) -> Self {
174        Self {
175            label: label.into(),
176            description: None,
177        }
178    }
179
180    /// Creates a new option with label and description.
181    #[must_use]
182    pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
183        Self {
184            label: label.into(),
185            description: Some(description.into()),
186        }
187    }
188}
189
190/// Response to a question request.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct QuestionResponse {
193    /// The user's answer (text or selected option labels).
194    pub answer: String,
195
196    /// Whether the user cancelled/skipped the question.
197    pub cancelled: bool,
198}
199
200impl QuestionResponse {
201    /// Creates a new successful response.
202    #[must_use]
203    pub fn success(answer: impl Into<String>) -> Self {
204        Self {
205            answer: answer.into(),
206            cancelled: false,
207        }
208    }
209
210    /// Creates a cancelled response.
211    #[must_use]
212    pub const fn cancelled() -> Self {
213        Self {
214            answer: String::new(),
215            cancelled: true,
216        }
217    }
218}
219
220/// Tool that allows the agent to ask questions to the user.
221///
222/// This is essential for:
223/// - Clarifying ambiguous requirements
224/// - Offering choices between approaches
225/// - Getting user preferences
226/// - Confirming before major changes
227pub struct AskUserQuestionTool {
228    /// Channel to send questions to the UI.
229    question_tx: mpsc::Sender<QuestionRequest>,
230
231    /// Channel to receive answers from the UI.
232    question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
233}
234
235impl AskUserQuestionTool {
236    /// Creates a new tool with the given channels.
237    #[must_use]
238    pub fn new(
239        question_tx: mpsc::Sender<QuestionRequest>,
240        question_rx: mpsc::Receiver<QuestionResponse>,
241    ) -> Self {
242        Self {
243            question_tx,
244            question_rx: tokio::sync::Mutex::new(question_rx),
245        }
246    }
247
248    /// Creates a new tool with fresh channels.
249    ///
250    /// Returns `(tool, request_receiver, response_sender)` where:
251    /// - `tool` is the `AskUserQuestionTool` instance
252    /// - `request_receiver` receives questions from the agent
253    /// - `response_sender` sends user answers back to the agent
254    #[must_use]
255    pub fn with_channels(
256        buffer_size: usize,
257    ) -> (
258        Self,
259        mpsc::Receiver<QuestionRequest>,
260        mpsc::Sender<QuestionResponse>,
261    ) {
262        let (request_tx, request_rx) = mpsc::channel(buffer_size);
263        let (response_tx, response_rx) = mpsc::channel(buffer_size);
264
265        let tool = Self::new(request_tx, response_rx);
266        (tool, request_rx, response_tx)
267    }
268}
269
270/// Input schema for the `AskUserQuestion` tool.
271#[derive(Debug, Deserialize, Serialize)]
272struct AskUserInput {
273    /// The question to ask the user.
274    question: String,
275
276    /// Optional header/category for the question.
277    #[serde(default)]
278    header: Option<String>,
279
280    /// Optional list of choices for multiple-choice questions.
281    #[serde(default)]
282    options: Vec<OptionInput>,
283
284    /// Whether multiple options can be selected.
285    #[serde(default)]
286    multi_select: bool,
287}
288
289/// Input for a single option.
290#[derive(Debug, Deserialize, Serialize)]
291struct OptionInput {
292    /// The display label.
293    label: String,
294
295    /// Optional description.
296    #[serde(default)]
297    description: Option<String>,
298}
299
300impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
301    type Name = PrimitiveToolName;
302
303    fn name(&self) -> PrimitiveToolName {
304        PrimitiveToolName::AskUser
305    }
306
307    fn display_name(&self) -> &'static str {
308        "Ask User"
309    }
310
311    fn description(&self) -> &'static str {
312        "Ask the user a question to get clarification, preferences, or choices. \
313         Use this when you need user input before proceeding. For yes/no confirmations \
314         of dangerous operations, tool confirmation will be shown automatically - \
315         use this tool for open-ended questions or when offering choices."
316    }
317
318    fn input_schema(&self) -> Value {
319        json!({
320            "type": "object",
321            "required": ["question"],
322            "properties": {
323                "question": {
324                    "type": "string",
325                    "description": "The question to ask the user. Be clear and specific."
326                },
327                "header": {
328                    "type": "string",
329                    "description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
330                },
331                "options": {
332                    "type": "array",
333                    "description": "Optional list of choices for multiple-choice questions",
334                    "items": {
335                        "type": "object",
336                        "required": ["label"],
337                        "properties": {
338                            "label": {
339                                "type": "string",
340                                "description": "The option text to display"
341                            },
342                            "description": {
343                                "type": "string",
344                                "description": "Optional explanation of this option"
345                            }
346                        }
347                    }
348                },
349                "multi_select": {
350                    "type": "boolean",
351                    "description": "Whether multiple options can be selected (default: false)"
352                }
353            }
354        })
355    }
356
357    fn tier(&self) -> ToolTier {
358        // Questions don't modify anything, but they do require user interaction
359        ToolTier::Observe
360    }
361
362    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
363        // Parse input
364        let input: AskUserInput =
365            serde_json::from_value(input).context("Invalid input for ask_user tool")?;
366
367        // Build request
368        let request = QuestionRequest {
369            question: input.question.clone(),
370            header: input.header,
371            options: input
372                .options
373                .into_iter()
374                .map(|o| QuestionOption {
375                    label: o.label,
376                    description: o.description,
377                })
378                .collect(),
379            multi_select: input.multi_select,
380        };
381
382        // Send question to UI
383        self.question_tx
384            .send(request)
385            .await
386            .context("Failed to send question to UI - channel closed")?;
387
388        // Wait for response
389        let response = {
390            let mut rx = self.question_rx.lock().await;
391            rx.recv()
392                .await
393                .context("Failed to receive answer from UI - channel closed")?
394        };
395
396        if response.cancelled {
397            Ok(ToolResult::error(
398                "User cancelled the question without providing an answer.",
399            ))
400        } else {
401            Ok(ToolResult::success(format!(
402                "User answered: {}",
403                response.answer
404            )))
405        }
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use crate::Tool;
413
414    #[test]
415    fn test_confirmation_request_new() {
416        let req =
417            ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
418        assert_eq!(req.tool_name, "write");
419        assert!(req.context.is_none());
420    }
421
422    #[test]
423    fn test_confirmation_request_with_context() {
424        let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
425            .with_context("Agent was fixing a bug");
426        assert!(req.context.is_some());
427        assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
428    }
429
430    #[test]
431    fn test_confirmation_response_serialization() {
432        assert_eq!(
433            serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
434            "\"approved\""
435        );
436        assert_eq!(
437            serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
438            "\"denied\""
439        );
440        assert_eq!(
441            serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
442            "\"approve_all\""
443        );
444    }
445
446    #[test]
447    fn test_question_request_new() {
448        let req = QuestionRequest::new("What color?");
449        assert_eq!(req.question, "What color?");
450        assert!(req.options.is_empty());
451        assert!(!req.multi_select);
452    }
453
454    #[test]
455    fn test_question_request_with_options() {
456        let req = QuestionRequest::with_options(
457            "Which framework?",
458            vec![
459                QuestionOption::new("React"),
460                QuestionOption::with_description("Vue", "Progressive framework"),
461            ],
462        )
463        .with_header("Framework")
464        .with_multi_select();
465
466        assert_eq!(req.options.len(), 2);
467        assert!(req.multi_select);
468        assert_eq!(req.header.unwrap(), "Framework");
469    }
470
471    #[test]
472    fn test_question_response() {
473        let success = QuestionResponse::success("Blue");
474        assert!(!success.cancelled);
475        assert_eq!(success.answer, "Blue");
476
477        let cancelled = QuestionResponse::cancelled();
478        assert!(cancelled.cancelled);
479    }
480
481    #[tokio::test]
482    async fn test_ask_user_tool_creation() {
483        let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
484
485        // Use Tool<()> explicitly to satisfy type inference
486        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::AskUser);
487        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
488    }
489
490    #[tokio::test]
491    async fn test_ask_user_tool_execute() {
492        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
493
494        // Spawn task to handle the question
495        let handle = tokio::spawn(async move {
496            if let Some(request) = request_rx.recv().await {
497                assert_eq!(request.question, "What color?");
498                response_tx
499                    .send(QuestionResponse::success("Blue"))
500                    .await
501                    .unwrap();
502            }
503        });
504
505        let ctx = ToolContext::new(());
506        let result = tool
507            .execute(
508                &ctx,
509                json!({
510                    "question": "What color?"
511                }),
512            )
513            .await
514            .unwrap();
515
516        handle.await.unwrap();
517
518        assert!(result.success);
519        assert!(result.output.contains("Blue"));
520    }
521
522    #[tokio::test]
523    async fn test_ask_user_with_options() {
524        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
525
526        let handle = tokio::spawn(async move {
527            if let Some(request) = request_rx.recv().await {
528                assert_eq!(request.options.len(), 2);
529                assert_eq!(request.options[0].label, "Option A");
530                response_tx
531                    .send(QuestionResponse::success("Option A"))
532                    .await
533                    .unwrap();
534            }
535        });
536
537        let ctx = ToolContext::new(());
538        let result = tool
539            .execute(
540                &ctx,
541                json!({
542                    "question": "Which option?",
543                    "options": [
544                        {"label": "Option A", "description": "First choice"},
545                        {"label": "Option B", "description": "Second choice"}
546                    ]
547                }),
548            )
549            .await
550            .unwrap();
551
552        handle.await.unwrap();
553        assert!(result.success);
554    }
555
556    #[tokio::test]
557    async fn test_ask_user_cancelled() {
558        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
559
560        let handle = tokio::spawn(async move {
561            if request_rx.recv().await.is_some() {
562                response_tx
563                    .send(QuestionResponse::cancelled())
564                    .await
565                    .unwrap();
566            }
567        });
568
569        let ctx = ToolContext::new(());
570        let result = tool
571            .execute(
572                &ctx,
573                json!({
574                    "question": "Continue?"
575                }),
576            )
577            .await
578            .unwrap();
579
580        handle.await.unwrap();
581        assert!(!result.success);
582        assert!(result.output.contains("cancelled"));
583    }
584}