1use crate::{PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use serde::{Deserialize, Serialize};
43use serde_json::{Value, json};
44use tokio::sync::mpsc;
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ConfirmationRequest {
49 pub tool_name: String,
51
52 pub description: String,
54
55 pub input_preview: String,
57
58 pub tier: String,
60
61 pub context: Option<String>,
63}
64
65impl ConfirmationRequest {
66 #[must_use]
68 pub fn new(
69 tool_name: impl Into<String>,
70 description: impl Into<String>,
71 input_preview: impl Into<String>,
72 tier: ToolTier,
73 ) -> Self {
74 Self {
75 tool_name: tool_name.into(),
76 description: description.into(),
77 input_preview: input_preview.into(),
78 tier: format!("{tier:?}"),
79 context: None,
80 }
81 }
82
83 #[must_use]
85 pub fn with_context(mut self, context: impl Into<String>) -> Self {
86 self.context = Some(context.into());
87 self
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum ConfirmationResponse {
95 Approved,
97
98 Denied,
100
101 ApproveAll,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct QuestionRequest {
108 pub question: String,
110
111 pub header: Option<String>,
113
114 pub options: Vec<QuestionOption>,
117
118 pub multi_select: bool,
120}
121
122impl QuestionRequest {
123 #[must_use]
125 pub fn new(question: impl Into<String>) -> Self {
126 Self {
127 question: question.into(),
128 header: None,
129 options: Vec::new(),
130 multi_select: false,
131 }
132 }
133
134 #[must_use]
136 pub fn with_options(question: impl Into<String>, options: Vec<QuestionOption>) -> Self {
137 Self {
138 question: question.into(),
139 header: None,
140 options,
141 multi_select: false,
142 }
143 }
144
145 #[must_use]
147 pub fn with_header(mut self, header: impl Into<String>) -> Self {
148 self.header = Some(header.into());
149 self
150 }
151
152 #[must_use]
154 pub const fn with_multi_select(mut self) -> Self {
155 self.multi_select = true;
156 self
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct QuestionOption {
163 pub label: String,
165
166 pub description: Option<String>,
168}
169
170impl QuestionOption {
171 #[must_use]
173 pub fn new(label: impl Into<String>) -> Self {
174 Self {
175 label: label.into(),
176 description: None,
177 }
178 }
179
180 #[must_use]
182 pub fn with_description(label: impl Into<String>, description: impl Into<String>) -> Self {
183 Self {
184 label: label.into(),
185 description: Some(description.into()),
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct QuestionResponse {
193 pub answer: String,
195
196 pub cancelled: bool,
198}
199
200impl QuestionResponse {
201 #[must_use]
203 pub fn success(answer: impl Into<String>) -> Self {
204 Self {
205 answer: answer.into(),
206 cancelled: false,
207 }
208 }
209
210 #[must_use]
212 pub const fn cancelled() -> Self {
213 Self {
214 answer: String::new(),
215 cancelled: true,
216 }
217 }
218}
219
220pub struct AskUserQuestionTool {
228 question_tx: mpsc::Sender<QuestionRequest>,
230
231 question_rx: tokio::sync::Mutex<mpsc::Receiver<QuestionResponse>>,
233}
234
235impl AskUserQuestionTool {
236 #[must_use]
238 pub fn new(
239 question_tx: mpsc::Sender<QuestionRequest>,
240 question_rx: mpsc::Receiver<QuestionResponse>,
241 ) -> Self {
242 Self {
243 question_tx,
244 question_rx: tokio::sync::Mutex::new(question_rx),
245 }
246 }
247
248 #[must_use]
255 pub fn with_channels(
256 buffer_size: usize,
257 ) -> (
258 Self,
259 mpsc::Receiver<QuestionRequest>,
260 mpsc::Sender<QuestionResponse>,
261 ) {
262 let (request_tx, request_rx) = mpsc::channel(buffer_size);
263 let (response_tx, response_rx) = mpsc::channel(buffer_size);
264
265 let tool = Self::new(request_tx, response_rx);
266 (tool, request_rx, response_tx)
267 }
268}
269
270#[derive(Debug, Deserialize, Serialize)]
272struct AskUserInput {
273 question: String,
275
276 #[serde(default)]
278 header: Option<String>,
279
280 #[serde(default)]
282 options: Vec<OptionInput>,
283
284 #[serde(default)]
286 multi_select: bool,
287}
288
289#[derive(Debug, Deserialize, Serialize)]
291struct OptionInput {
292 label: String,
294
295 #[serde(default)]
297 description: Option<String>,
298}
299
300impl<Ctx: Send + Sync + 'static> Tool<Ctx> for AskUserQuestionTool {
301 type Name = PrimitiveToolName;
302
303 fn name(&self) -> PrimitiveToolName {
304 PrimitiveToolName::AskUser
305 }
306
307 fn display_name(&self) -> &'static str {
308 "Ask User"
309 }
310
311 fn description(&self) -> &'static str {
312 "Ask the user a question to get clarification, preferences, or choices. \
313 Use this when you need user input before proceeding. For yes/no confirmations \
314 of dangerous operations, tool confirmation will be shown automatically - \
315 use this tool for open-ended questions or when offering choices."
316 }
317
318 fn input_schema(&self) -> Value {
319 json!({
320 "type": "object",
321 "required": ["question"],
322 "properties": {
323 "question": {
324 "type": "string",
325 "description": "The question to ask the user. Be clear and specific."
326 },
327 "header": {
328 "type": "string",
329 "description": "Optional short header/category (e.g., 'Auth method', 'Library choice')"
330 },
331 "options": {
332 "type": "array",
333 "description": "Optional list of choices for multiple-choice questions",
334 "items": {
335 "type": "object",
336 "required": ["label"],
337 "properties": {
338 "label": {
339 "type": "string",
340 "description": "The option text to display"
341 },
342 "description": {
343 "type": "string",
344 "description": "Optional explanation of this option"
345 }
346 }
347 }
348 },
349 "multi_select": {
350 "type": "boolean",
351 "description": "Whether multiple options can be selected (default: false)"
352 }
353 }
354 })
355 }
356
357 fn tier(&self) -> ToolTier {
358 ToolTier::Observe
360 }
361
362 async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
363 let input: AskUserInput =
365 serde_json::from_value(input).context("Invalid input for ask_user tool")?;
366
367 let request = QuestionRequest {
369 question: input.question.clone(),
370 header: input.header,
371 options: input
372 .options
373 .into_iter()
374 .map(|o| QuestionOption {
375 label: o.label,
376 description: o.description,
377 })
378 .collect(),
379 multi_select: input.multi_select,
380 };
381
382 self.question_tx
384 .send(request)
385 .await
386 .context("Failed to send question to UI - channel closed")?;
387
388 let response = {
390 let mut rx = self.question_rx.lock().await;
391 rx.recv()
392 .await
393 .context("Failed to receive answer from UI - channel closed")?
394 };
395
396 if response.cancelled {
397 Ok(ToolResult::error(
398 "User cancelled the question without providing an answer.",
399 ))
400 } else {
401 Ok(ToolResult::success(format!(
402 "User answered: {}",
403 response.answer
404 )))
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use crate::Tool;
413
414 #[test]
415 fn test_confirmation_request_new() {
416 let req =
417 ConfirmationRequest::new("write", "Write to file: foo.txt", "{}", ToolTier::Confirm);
418 assert_eq!(req.tool_name, "write");
419 assert!(req.context.is_none());
420 }
421
422 #[test]
423 fn test_confirmation_request_with_context() {
424 let req = ConfirmationRequest::new("write", "Write to file", "{}", ToolTier::Confirm)
425 .with_context("Agent was fixing a bug");
426 assert!(req.context.is_some());
427 assert_eq!(req.context.unwrap(), "Agent was fixing a bug");
428 }
429
430 #[test]
431 fn test_confirmation_response_serialization() {
432 assert_eq!(
433 serde_json::to_string(&ConfirmationResponse::Approved).unwrap(),
434 "\"approved\""
435 );
436 assert_eq!(
437 serde_json::to_string(&ConfirmationResponse::Denied).unwrap(),
438 "\"denied\""
439 );
440 assert_eq!(
441 serde_json::to_string(&ConfirmationResponse::ApproveAll).unwrap(),
442 "\"approve_all\""
443 );
444 }
445
446 #[test]
447 fn test_question_request_new() {
448 let req = QuestionRequest::new("What color?");
449 assert_eq!(req.question, "What color?");
450 assert!(req.options.is_empty());
451 assert!(!req.multi_select);
452 }
453
454 #[test]
455 fn test_question_request_with_options() {
456 let req = QuestionRequest::with_options(
457 "Which framework?",
458 vec![
459 QuestionOption::new("React"),
460 QuestionOption::with_description("Vue", "Progressive framework"),
461 ],
462 )
463 .with_header("Framework")
464 .with_multi_select();
465
466 assert_eq!(req.options.len(), 2);
467 assert!(req.multi_select);
468 assert_eq!(req.header.unwrap(), "Framework");
469 }
470
471 #[test]
472 fn test_question_response() {
473 let success = QuestionResponse::success("Blue");
474 assert!(!success.cancelled);
475 assert_eq!(success.answer, "Blue");
476
477 let cancelled = QuestionResponse::cancelled();
478 assert!(cancelled.cancelled);
479 }
480
481 #[tokio::test]
482 async fn test_ask_user_tool_creation() {
483 let (tool, _rx, _tx) = AskUserQuestionTool::with_channels(10);
484
485 assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::AskUser);
487 assert_eq!(Tool::<()>::tier(&tool), ToolTier::Observe);
488 }
489
490 #[tokio::test]
491 async fn test_ask_user_tool_execute() {
492 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
493
494 let handle = tokio::spawn(async move {
496 if let Some(request) = request_rx.recv().await {
497 assert_eq!(request.question, "What color?");
498 response_tx
499 .send(QuestionResponse::success("Blue"))
500 .await
501 .unwrap();
502 }
503 });
504
505 let ctx = ToolContext::new(());
506 let result = tool
507 .execute(
508 &ctx,
509 json!({
510 "question": "What color?"
511 }),
512 )
513 .await
514 .unwrap();
515
516 handle.await.unwrap();
517
518 assert!(result.success);
519 assert!(result.output.contains("Blue"));
520 }
521
522 #[tokio::test]
523 async fn test_ask_user_with_options() {
524 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
525
526 let handle = tokio::spawn(async move {
527 if let Some(request) = request_rx.recv().await {
528 assert_eq!(request.options.len(), 2);
529 assert_eq!(request.options[0].label, "Option A");
530 response_tx
531 .send(QuestionResponse::success("Option A"))
532 .await
533 .unwrap();
534 }
535 });
536
537 let ctx = ToolContext::new(());
538 let result = tool
539 .execute(
540 &ctx,
541 json!({
542 "question": "Which option?",
543 "options": [
544 {"label": "Option A", "description": "First choice"},
545 {"label": "Option B", "description": "Second choice"}
546 ]
547 }),
548 )
549 .await
550 .unwrap();
551
552 handle.await.unwrap();
553 assert!(result.success);
554 }
555
556 #[tokio::test]
557 async fn test_ask_user_cancelled() {
558 let (tool, mut request_rx, response_tx) = AskUserQuestionTool::with_channels(10);
559
560 let handle = tokio::spawn(async move {
561 if request_rx.recv().await.is_some() {
562 response_tx
563 .send(QuestionResponse::cancelled())
564 .await
565 .unwrap();
566 }
567 });
568
569 let ctx = ToolContext::new(());
570 let result = tool
571 .execute(
572 &ctx,
573 json!({
574 "question": "Continue?"
575 }),
576 )
577 .await
578 .unwrap();
579
580 handle.await.unwrap();
581 assert!(!result.success);
582 assert!(result.output.contains("cancelled"));
583 }
584}