Skip to main content

heartbit_core/tool/builtins/
question.rs

1#![allow(missing_docs)]
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8
9use crate::error::Error;
10use crate::llm::types::ToolDefinition;
11use crate::tool::{Tool, ToolOutput};
12
13// --- Types ---
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct QuestionRequest {
17    pub questions: Vec<Question>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Question {
22    pub question: String,
23    pub header: String,
24    pub options: Vec<QuestionOption>,
25    pub multiple: bool,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct QuestionOption {
30    pub label: String,
31    pub description: String,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct QuestionResponse {
36    /// Per-question list of selected labels.
37    pub answers: Vec<Vec<String>>,
38}
39
40/// Callback type for agent-to-user structured questions.
41pub type OnQuestion = dyn Fn(QuestionRequest) -> Pin<Box<dyn Future<Output = Result<QuestionResponse, Error>> + Send>>
42    + Send
43    + Sync;
44
45// --- Tool ---
46
47/// Builtin tool that pauses the agent to ask the user structured questions.
48///
49/// When the agent needs clarification before proceeding it can call this tool
50/// with a list of questions, each with a header and a set of labelled options.
51/// Execution is suspended until the registered `OnQuestion` callback returns
52/// the user's answers, enabling interactive human-in-the-loop flows within an
53/// otherwise autonomous run. Each question may allow single or multiple
54/// selections via the `multiple` flag.
55pub struct QuestionTool {
56    on_question: Arc<OnQuestion>,
57}
58
59impl QuestionTool {
60    pub fn new(on_question: Arc<OnQuestion>) -> Self {
61        Self { on_question }
62    }
63}
64
65impl Tool for QuestionTool {
66    fn definition(&self) -> ToolDefinition {
67        ToolDefinition {
68            name: "question".into(),
69            description: "Ask the user structured questions with predefined options. \
70                          Use this when you need clarification or a decision from the user."
71                .into(),
72            input_schema: json!({
73                "type": "object",
74                "properties": {
75                    "questions": {
76                        "type": "array",
77                        "items": {
78                            "type": "object",
79                            "properties": {
80                                "question": {
81                                    "type": "string",
82                                    "description": "The question to ask"
83                                },
84                                "header": {
85                                    "type": "string",
86                                    "description": "Short label (max 12 chars)"
87                                },
88                                "options": {
89                                    "type": "array",
90                                    "minItems": 2,
91                                    "items": {
92                                        "type": "object",
93                                        "properties": {
94                                            "label": {"type": "string"},
95                                            "description": {"type": "string"}
96                                        },
97                                        "required": ["label", "description"]
98                                    }
99                                },
100                                "multiple": {
101                                    "type": "boolean",
102                                    "description": "Allow multiple selections"
103                                }
104                            },
105                            "required": ["question", "header", "options", "multiple"]
106                        }
107                    }
108                },
109                "required": ["questions"]
110            }),
111        }
112    }
113
114    fn execute(
115        &self,
116        _ctx: &crate::ExecutionContext,
117        input: serde_json::Value,
118    ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
119        Box::pin(async move {
120            let questions_value = input
121                .get("questions")
122                .ok_or_else(|| Error::Agent("questions is required".into()))?;
123
124            let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
125                .map_err(|e| Error::Agent(format!("Invalid questions format: {e}")))?;
126
127            if questions.is_empty() {
128                return Ok(ToolOutput::error("At least one question is required."));
129            }
130            for q in &questions {
131                if q.options.len() < 2 {
132                    return Ok(ToolOutput::error(format!(
133                        "Question '{}' must have at least 2 options.",
134                        q.header
135                    )));
136                }
137            }
138
139            let request = QuestionRequest {
140                questions: questions.clone(),
141            };
142            let response = match (self.on_question)(request).await {
143                Ok(r) => r,
144                Err(e) => return Ok(ToolOutput::error(format!("Question failed: {e}"))),
145            };
146
147            if response.answers.len() != questions.len() {
148                return Ok(ToolOutput::error(format!(
149                    "Expected {} answers but got {}",
150                    questions.len(),
151                    response.answers.len()
152                )));
153            }
154
155            // Format answers
156            let mut output = String::new();
157            for (i, q) in questions.iter().enumerate() {
158                let answers = &response.answers[i];
159                output.push_str(&format!("{}: {}\n", q.question, answers.join(", ")));
160            }
161
162            Ok(ToolOutput::success(output))
163        })
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn definition_has_correct_name() {
173        let callback: Arc<OnQuestion> = Arc::new(|_| {
174            Box::pin(async {
175                Ok(QuestionResponse {
176                    answers: vec![vec!["A".into()]],
177                })
178            })
179        });
180        let tool = QuestionTool::new(callback);
181        assert_eq!(tool.definition().name, "question");
182    }
183
184    #[tokio::test]
185    async fn question_tool_asks_and_returns() {
186        let callback: Arc<OnQuestion> = Arc::new(|req| {
187            Box::pin(async move {
188                let mut answers = Vec::new();
189                for q in &req.questions {
190                    answers.push(vec![q.options[0].label.clone()]);
191                }
192                Ok(QuestionResponse { answers })
193            })
194        });
195
196        let tool = QuestionTool::new(callback);
197        let result = tool
198            .execute(
199                &crate::ExecutionContext::default(),
200                json!({
201                    "questions": [{
202                        "question": "Which color?",
203                        "header": "Color",
204                        "options": [
205                            {"label": "Red", "description": "A warm color"},
206                            {"label": "Blue", "description": "A cool color"}
207                        ],
208                        "multiple": false
209                    }]
210                }),
211            )
212            .await
213            .unwrap();
214        assert!(!result.is_error);
215        assert!(result.content.contains("Red"));
216    }
217
218    #[tokio::test]
219    async fn question_tool_empty_questions() {
220        let callback: Arc<OnQuestion> =
221            Arc::new(|_| Box::pin(async { Ok(QuestionResponse { answers: vec![] }) }));
222
223        let tool = QuestionTool::new(callback);
224        let result = tool
225            .execute(
226                &crate::ExecutionContext::default(),
227                json!({"questions": []}),
228            )
229            .await
230            .unwrap();
231        assert!(result.is_error);
232        assert!(result.content.contains("At least one question"));
233    }
234
235    #[tokio::test]
236    async fn question_with_too_few_options_rejected() {
237        let callback: Arc<OnQuestion> =
238            Arc::new(|_| Box::pin(async { Ok(QuestionResponse { answers: vec![] }) }));
239
240        let tool = QuestionTool::new(callback);
241
242        // Zero options
243        let result = tool
244            .execute(
245                &crate::ExecutionContext::default(),
246                json!({
247                    "questions": [{
248                        "question": "Pick one",
249                        "header": "Choice",
250                        "options": [],
251                        "multiple": false
252                    }]
253                }),
254            )
255            .await
256            .unwrap();
257        assert!(result.is_error);
258        assert!(result.content.contains("at least 2 options"));
259
260        // One option (also rejected)
261        let result = tool
262            .execute(
263                &crate::ExecutionContext::default(),
264                json!({
265                    "questions": [{
266                        "question": "Pick one",
267                        "header": "Choice",
268                        "options": [{"label": "Only", "description": "Single option"}],
269                        "multiple": false
270                    }]
271                }),
272            )
273            .await
274            .unwrap();
275        assert!(result.is_error);
276        assert!(result.content.contains("at least 2 options"));
277    }
278
279    #[tokio::test]
280    async fn question_tool_rejects_mismatched_answer_count() {
281        // Callback returns 2 answers but only 1 question asked
282        let callback: Arc<OnQuestion> = Arc::new(|_| {
283            Box::pin(async {
284                Ok(QuestionResponse {
285                    answers: vec![vec!["A".into()], vec!["B".into()]],
286                })
287            })
288        });
289
290        let tool = QuestionTool::new(callback);
291        let result = tool
292            .execute(
293                &crate::ExecutionContext::default(),
294                json!({
295                    "questions": [{
296                        "question": "Pick one",
297                        "header": "Choice",
298                        "options": [
299                            {"label": "A", "description": "Option A"},
300                            {"label": "B", "description": "Option B"}
301                        ],
302                        "multiple": false
303                    }]
304                }),
305            )
306            .await
307            .unwrap();
308        assert!(result.is_error);
309        assert!(
310            result.content.contains("Expected 1 answers but got 2"),
311            "got: {}",
312            result.content
313        );
314    }
315
316    #[tokio::test]
317    async fn question_tool_callback_error_returns_tool_error() {
318        let callback: Arc<OnQuestion> =
319            Arc::new(|_| Box::pin(async { Err(Error::Agent("User cancelled".into())) }));
320
321        let tool = QuestionTool::new(callback);
322        let result = tool
323            .execute(
324                &crate::ExecutionContext::default(),
325                json!({
326                    "questions": [{
327                        "question": "Pick one",
328                        "header": "Choice",
329                        "options": [
330                            {"label": "A", "description": "Option A"},
331                            {"label": "B", "description": "Option B"}
332                        ],
333                        "multiple": false
334                    }]
335                }),
336            )
337            .await
338            .unwrap(); // Should not propagate error
339        assert!(result.is_error);
340        assert!(result.content.contains("User cancelled"));
341    }
342}