Skip to main content

agent_sdk/
user_interaction.rs

1//! User interaction types and tools.
2//!
3//! This module provides types and tools for agent-user interaction:
4//!
5//! - [`ConfirmationRequest`] / [`ConfirmationResponse`] - For tool confirmations
6//! - [`QuestionRequest`] / [`QuestionResponse`] - For agent-initiated questions
7//! - [`AskUserQuestionTool`] - Tool that allows agents to ask questions
8//!
9//! # Confirmation Flow
10//!
11//! When an agent needs to execute a tool that requires confirmation:
12//!
13//! 1. Agent hooks create a [`ConfirmationRequest`]
14//! 2. UI displays the request to the user
15//! 3. User responds with [`ConfirmationResponse`]
16//! 4. Agent proceeds based on the response
17//!
18//! # Question Flow
19//!
20//! When an agent needs to ask the user a question:
21//!
22//! 1. Agent calls `ask_user` tool with a [`QuestionRequest`]
23//! 2. UI displays the question to the user
24//! 3. User responds with [`QuestionResponse`]
25//! 4. Agent receives the answer and continues
26//!
27//! # Example
28//!
29//! ```no_run
30//! use agent_sdk::user_interaction::{AskUserQuestionTool, QuestionRequest, QuestionResponse};
31//! use tokio::sync::mpsc;
32//!
33//! // Create channels for communication
34//! let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
35//!
36//! // The tool can be registered with the agent's tool registry
37//! // The UI handles requests and responses through the channels
38//! ```
39
40use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use serde::{Deserialize, Serialize};
43use serde_json::{Value, json};
44use std::sync::atomic::{AtomicU64, Ordering};
45use tokio::sync::mpsc;
46use tokio_util::sync::CancellationToken;
47
48/// Process-wide monotonic counter for generating question correlation ids.
49static QUESTION_SEQ: AtomicU64 = AtomicU64::new(0);
50
51/// Generate a unique correlation id for an outgoing question.
52fn next_request_id() -> String {
53    let seq = QUESTION_SEQ.fetch_add(1, Ordering::Relaxed);
54    format!("ask-{seq}")
55}
56
57/// Receive the response matching `request_id`, discarding stale answers (e.g.
58/// late replies to a previously cancelled question). Returns `None` if the run
59/// is cancelled before a matching answer arrives.
60async fn await_matching_response(
61    rx: &mut mpsc::Receiver<QuestionResponse>,
62    request_id: &str,
63    cancel_token: &CancellationToken,
64) -> Result<Option<QuestionResponse>> {
65    loop {
66        tokio::select! {
67            biased;
68            () = cancel_token.cancelled() => return Ok(None),
69            received = rx.recv() => {
70                let response = received
71                    .context("Failed to receive answer from UI - channel closed")?;
72                // Accept the matching answer. An empty id comes from a legacy
73                // UI that does not echo correlation ids, so accept it rather
74                // than hang. Any other non-matching id is a stale answer to a
75                // different (cancelled) question — discard it and keep waiting.
76                if response.request_id.is_empty() || response.request_id == request_id {
77                    return Ok(Some(response));
78                }
79            }
80        }
81    }
82}
83
84/// Request for user confirmation of a tool execution.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ConfirmationRequest {
87    /// Name of the tool requiring confirmation.
88    pub tool_name: String,
89
90    /// Human-readable description of the action.
91    pub description: String,
92
93    /// Preview of the tool input (JSON formatted).
94    pub input_preview: String,
95
96    /// The tier of the tool (serialized as string).
97    pub tier: String,
98
99    /// Agent's recent reasoning text that led to this tool call.
100    pub context: Option<String>,
101}
102
103impl ConfirmationRequest {
104    /// Creates a new confirmation request.
105    #[must_use]
106    pub fn new(
107        tool_name: impl Into<String>,
108        description: impl Into<String>,
109        input_preview: impl Into<String>,
110        tier: ToolTier,
111    ) -> Self {
112        Self {
113            tool_name: tool_name.into(),
114            description: description.into(),
115            input_preview: input_preview.into(),
116            tier: format!("{tier:?}"),
117            context: None,
118        }
119    }
120
121    /// Adds context to the request.
122    #[must_use]
123    pub fn with_context(mut self, context: impl Into<String>) -> Self {
124        self.context = Some(context.into());
125        self
126    }
127}
128
129/// Response to a confirmation request.
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
131#[serde(rename_all = "snake_case")]
132pub enum ConfirmationResponse {
133    /// User approved the tool execution.
134    Approved,
135
136    /// User denied the tool execution.
137    Denied,
138
139    /// User wants to approve all future requests for this tool.
140    ApproveAll,
141}
142
143/// Request for user to answer a question from the agent.
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct QuestionRequest {
146    /// Correlation id assigned by the tool when the question is dispatched.
147    ///
148    /// The UI must echo this value back on the matching [`QuestionResponse`] so
149    /// answers can be paired to their question even when multiple questions are
150    /// in flight or an earlier question was cancelled. An empty value means the
151    /// request has not been dispatched yet (or a legacy UI that does not echo).
152    #[serde(default)]
153    pub request_id: String,
154
155    /// The question text to display.
156    pub question: String,
157
158    /// Optional header/category for the question.
159    pub header: Option<String>,
160
161    /// Available options (if multiple choice).
162    /// Empty means free-form text input.
163    pub options: Vec<QuestionOption>,
164
165    /// Whether multiple options can be selected.
166    pub multi_select: bool,
167}
168
169impl QuestionRequest {
170    /// Creates a new free-form question request.
171    #[must_use]
172    pub fn new(question: impl Into<String>) -> Self {
173        Self {
174            request_id: String::new(),
175            question: question.into(),
176            header: None,
177            options: Vec::new(),
178            multi_select: false,
179        }
180    }
181
182    /// Creates a new multiple-choice question request.
183    #[must_use]
184    pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
185        Self {
186            request_id: String::new(),
187            question: question.into(),
188            header: None,
189            options,
190            multi_select: false,
191        }
192    }
193
194    /// Adds a header to the question.
195    #[must_use]
196    pub fn with_header(mut self, header: impl Into<String>) -> Self {
197        self.header = Some(header.into());
198        self
199    }
200
201    /// Enables multi-select mode.
202    #[must_use]
203    pub const fn with_multi_select(mut self) -> Self {
204        self.multi_select = true;
205        self
206    }
207}
208
209/// An option in a multiple-choice question.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct QuestionOption {
212    /// The display label for this option.
213    pub label: String,
214
215    /// Optional description explaining this option.
216    pub description: Option<String>,
217}
218
219impl QuestionOption {
220    /// Creates a new option with just a label.
221    #[must_use]
222    pub fn new(label: impl Into<String>) -> Self {
223        Self {
224            label: label.into(),
225            description: None,
226        }
227    }
228
229    /// Creates a new option with label and description.
230    #[must_use]
231    pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
232        Self {
233            label: label.into(),
234            description: Some(description.into()),
235        }
236    }
237}
238
239/// Response to a question request.
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct QuestionResponse {
242    /// Correlation id copied from the [`QuestionRequest`] this answers.
243    ///
244    /// The tool discards responses whose id does not match the question it is
245    /// currently awaiting (e.g. a late answer to a cancelled question). An
246    /// empty value is accepted for backward compatibility with UIs that do not
247    /// echo the id.
248    #[serde(default)]
249    pub request_id: String,
250
251    /// The user's answer (text or selected option labels).
252    pub answer: String,
253
254    /// Whether the user cancelled/skipped the question.
255    pub cancelled: bool,
256}
257
258impl QuestionResponse {
259    /// Creates a new successful response.
260    #[must_use]
261    pub fn success(answer: impl Into<String>) -> Self {
262        Self {
263            request_id: String::new(),
264            answer: answer.into(),
265            cancelled: false,
266        }
267    }
268
269    /// Creates a cancelled response.
270    #[must_use]
271    pub const fn cancelled() -> Self {
272        Self {
273            request_id: String::new(),
274            answer: String::new(),
275            cancelled: true,
276        }
277    }
278
279    /// Sets the correlation id that pairs this response to its question.
280    #[must_use]
281    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
282        self.request_id = request_id.into();
283        self
284    }
285}
286
287/// Tool that allows the agent to ask questions to the user.
288///
289/// This is essential for:
290/// - Clarifying ambiguous requirements
291/// - Offering choices between approaches
292/// - Getting user preferences
293/// - Confirming before major changes
294pub struct AskUserQuestionTool {
295    /// Channel to send questions to the UI.
296    question_tx: mpsc::Sender<QuestionRequest>,
297
298    /// Channel to receive answers from the UI.
299    question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
300}
301
302impl AskUserQuestionTool {
303    /// Creates a new tool with the given channels.
304    #[must_use]
305    pub fn new(
306        question_tx: mpsc::Sender<QuestionRequest>,
307        question_rx: mpsc::Receiver<QuestionResponse>,
308    ) -> Self {
309        Self {
310            question_tx,
311            question_rx: tokio::sync::Mutex::new(question_rx),
312        }
313    }
314
315    /// Creates a new tool with fresh channels.
316    ///
317    /// Returns `(tool, request_receiver, response_sender)` where:
318    /// - `tool` is the `AskUserQuestionTool` instance
319    /// - `request_receiver` receives questions from the agent
320    /// - `response_sender` sends user answers back to the agent
321    #[must_use]
322    pub fn with_channels(
323        buffer_size: usize,
324    ) -> (
325        Self,
326        mpsc::Receiver<QuestionRequest>,
327        mpsc::Sender<QuestionResponse>,
328    ) {
329        let (request_tx, request_rx) = mpsc::channel(buffer_size);
330        let (response_tx, response_rx) = mpsc::channel(buffer_size);
331
332        let tool = Self::new(request_tx, response_rx);
333        (tool, request_rx, response_tx)
334    }
335}
336
337/// Input schema for the `AskUserQuestion` tool.
338#[derive(Debug, Deserialize, Serialize)]
339struct AskUserInput {
340    /// The question to ask the user.
341    question: String,
342
343    /// Optional header/category for the question.
344    #[serde(default)]
345    header: Option<String>,
346
347    /// Optional list of choices for multiple-choice questions.
348    #[serde(default)]
349    options: Vec<OptionInput>,
350
351    /// Whether multiple options can be selected.
352    #[serde(default)]
353    multi_select: bool,
354}
355
356/// Input for a single option.
357#[derive(Debug, Deserialize, Serialize)]
358struct OptionInput {
359    /// The display label.
360    label: String,
361
362    /// Optional description.
363    #[serde(default)]
364    description: Option<String>,
365}
366
367impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
368    type Name = PrimitiveToolName;
369
370    fn name(&self) -> PrimitiveToolName {
371        PrimitiveToolName::AskUser
372    }
373
374    fn display_name(&self) -> &'static str {
375        "Ask User"
376    }
377
378    fn description(&self) -> &'static str {
379        "Ask the user a question to get clarification, preferences, or choices. \
380         Use this when you need user input before proceeding. For yes/no confirmations \
381         of dangerous operations, tool confirmation will be shown automatically - \
382         use this tool for open-ended questions or when offering choices."
383    }
384
385    fn input_schema(&self) -> Value {
386        json!({
387            "type": "object",
388            "required": ["question"],
389            "properties": {
390                "question": {
391                    "type": "string",
392                    "description": "The question to ask the user. Be clear and specific."
393                },
394                "header": {
395                    "type": "string",
396                    "description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
397                },
398                "options": {
399                    "type": "array",
400                    "description": "Optional list of choices for multiple-choice questions",
401                    "items": {
402                        "type": "object",
403                        "required": ["label"],
404                        "properties": {
405                            "label": {
406                                "type": "string",
407                                "description": "The option text to display"
408                            },
409                            "description": {
410                                "type": "string",
411                                "description": "Optional explanation of this option"
412                            }
413                        }
414                    }
415                },
416                "multi_select": {
417                    "type": "boolean",
418                    "description": "Whether multiple options can be selected (default: false)"
419                }
420            }
421        })
422    }
423
424    fn tier(&self) -> ToolTier {
425        // Questions don't modify anything, but they do require user interaction
426        ToolTier::Observe
427    }
428
429    async fn execute(&self, ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
430        // Parse input
431        let input: AskUserInput =
432            serde_json::from_value(input).context("Invalid input for ask_user tool")?;
433
434        // Build request with a fresh correlation id so its answer can be paired
435        // back even if other questions are in flight or this one is cancelled.
436        let request_id = next_request_id();
437        let request = QuestionRequest {
438            request_id: request_id.clone(),
439            question: input.question.clone(),
440            header: input.header,
441            options: input
442                .options
443                .into_iter()
444                .map(|o| QuestionOption {
445                    label: o.label,
446                    description: o.description,
447                })
448                .collect(),
449            multi_select: input.multi_select,
450        };
451
452        // A fresh token never fires, so questions without a configured cancel
453        // token simply wait for the answer.
454        let cancel_token = ctx.cancel_token().unwrap_or_default();
455
456        // Hold the receiver lock across send+recv so concurrent `ask_user`
457        // calls are serialized: only one question is outstanding at a time and
458        // its answer cannot be consumed by a sibling call.
459        let response = {
460            let mut rx = self.question_rx.lock().await;
461
462            self.question_tx
463                .send(request)
464                .await
465                .context("Failed to send question to UI - channel closed")?;
466
467            await_matching_response(&mut rx, &request_id, &cancel_token).await?
468        };
469
470        match response {
471            Some(response) if response.cancelled => Ok(ToolResult::error(
472                "User cancelled the question without providing an answer.",
473            )),
474            Some(response) => Ok(ToolResult::success(format!(
475                "User answered: {}",
476                response.answer
477            ))),
478            None => Ok(ToolResult::error(
479                "Question cancelled before the user answered.",
480            )),
481        }
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::Tool;
489
490    #[test]
491    fn test_confirmation_request_new() {
492        let req =
493            ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
494        assert_eq!(req.tool_name, "write");
495        assert!(req.context.is_none());
496    }
497
498    #[test]
499    fn test_confirmation_request_with_context() {
500        let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
501            .with_context("Agent was fixing a bug");
502        assert!(req.context.is_some());
503        assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
504    }
505
506    #[test]
507    fn test_confirmation_response_serialization() {
508        assert_eq!(
509            serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
510            "\"approved\""
511        );
512        assert_eq!(
513            serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
514            "\"denied\""
515        );
516        assert_eq!(
517            serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
518            "\"approve_all\""
519        );
520    }
521
522    #[test]
523    fn test_question_request_new() {
524        let req = QuestionRequest::new("What color?");
525        assert_eq!(req.question, "What color?");
526        assert!(req.options.is_empty());
527        assert!(!req.multi_select);
528    }
529
530    #[test]
531    fn test_question_request_with_options() {
532        let req = QuestionRequest::with_options(
533            "Which framework?",
534            vec![
535                QuestionOption::new("React"),
536                QuestionOption::with_description("Vue", "Progressive framework"),
537            ],
538        )
539        .with_header("Framework")
540        .with_multi_select();
541
542        assert_eq!(req.options.len(), 2);
543        assert!(req.multi_select);
544        assert_eq!(req.header.unwrap(), "Framework");
545    }
546
547    #[test]
548    fn test_question_response() {
549        let success = QuestionResponse::success("Blue");
550        assert!(!success.cancelled);
551        assert_eq!(success.answer, "Blue");
552
553        let cancelled = QuestionResponse::cancelled();
554        assert!(cancelled.cancelled);
555    }
556
557    #[tokio::test]
558    async fn test_ask_user_tool_creation() {
559        let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
560
561        // Use Tool<()> explicitly to satisfy type inference
562        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::AskUser);
563        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
564    }
565
566    #[tokio::test]
567    async fn test_ask_user_tool_execute() {
568        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
569
570        // Spawn task to handle the question
571        let handle = tokio::spawn(async move {
572            if let Some(request) = request_rx.recv().await {
573                assert_eq!(request.question, "What color?");
574                response_tx
575                    .send(QuestionResponse::success("Blue"))
576                    .await
577                    .unwrap();
578            }
579        });
580
581        let ctx = ToolContext::new(());
582        let result = tool
583            .execute(
584                &ctx,
585                json!({
586                    "question": "What color?"
587                }),
588            )
589            .await
590            .unwrap();
591
592        handle.await.unwrap();
593
594        assert!(result.success);
595        assert!(result.output.contains("Blue"));
596    }
597
598    #[tokio::test]
599    async fn test_ask_user_with_options() {
600        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
601
602        let handle = tokio::spawn(async move {
603            if let Some(request) = request_rx.recv().await {
604                assert_eq!(request.options.len(), 2);
605                assert_eq!(request.options[0].label, "Option A");
606                response_tx
607                    .send(QuestionResponse::success("Option A"))
608                    .await
609                    .unwrap();
610            }
611        });
612
613        let ctx = ToolContext::new(());
614        let result = tool
615            .execute(
616                &ctx,
617                json!({
618                    "question": "Which option?",
619                    "options": [
620                        {"label": "Option A", "description": "First choice"},
621                        {"label": "Option B", "description": "Second choice"}
622                    ]
623                }),
624            )
625            .await
626            .unwrap();
627
628        handle.await.unwrap();
629        assert!(result.success);
630    }
631
632    #[tokio::test]
633    async fn test_ask_user_cancelled() {
634        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
635
636        let handle = tokio::spawn(async move {
637            if request_rx.recv().await.is_some() {
638                response_tx
639                    .send(QuestionResponse::cancelled())
640                    .await
641                    .unwrap();
642            }
643        });
644
645        let ctx = ToolContext::new(());
646        let result = tool
647            .execute(
648                &ctx,
649                json!({
650                    "question": "Continue?"
651                }),
652            )
653            .await
654            .unwrap();
655
656        handle.await.unwrap();
657        assert!(!result.success);
658        assert!(result.output.contains("cancelled"));
659    }
660
661    #[tokio::test]
662    async fn test_ask_user_discards_stale_response() -> Result<()> {
663        let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
664
665        // A late answer to a previously cancelled question is already queued
666        // when the next question is asked.
667        response_tx
668            .send(QuestionResponse::success("STALE").with_request_id("stale-request"))
669            .await
670            .ok()
671            .context("seed stale response")?;
672
673        let responder = response_tx.clone();
674        let handle = tokio::spawn(async move {
675            let request = request_rx.recv().await.context("no question received")?;
676            // Echo the live correlation id so the tool accepts this answer.
677            responder
678                .send(QuestionResponse::success("CORRECT").with_request_id(request.request_id))
679                .await
680                .ok()
681                .context("send live response")?;
682            anyhow::Ok(())
683        });
684
685        let ctx = ToolContext::new(());
686        let result = tool
687            .execute(&ctx, json!({ "question": "Which one?" }))
688            .await?;
689
690        handle.await.context("responder task panicked")??;
691
692        assert!(result.success);
693        assert!(result.output.contains("CORRECT"), "got: {}", result.output);
694        assert!(
695            !result.output.contains("STALE"),
696            "stale answer must be discarded: {}",
697            result.output
698        );
699        Ok(())
700    }
701
702    #[tokio::test]
703    async fn test_ask_user_returns_on_cancel() -> Result<()> {
704        // Keep the channel endpoints alive so the question send succeeds even
705        // though no UI ever answers.
706        let (tool, _request_rx, _response_tx) = AskUserQuestionTool::with_channels(10);
707
708        let token = CancellationToken::new();
709        token.cancel();
710        let ctx = ToolContext::new(()).with_cancel_token(token);
711
712        let result = tool
713            .execute(&ctx, json!({ "question": "Hang forever?" }))
714            .await?;
715
716        assert!(!result.success);
717        assert!(
718            result.output.to_lowercase().contains("cancel"),
719            "got: {}",
720            result.output
721        );
722        Ok(())
723    }
724}