heartbit_core/tool/builtins/
question.rs1#![allow(missing_docs)]
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8
9use crate::error::Error;
10use crate::llm::types::ToolDefinition;
11use crate::tool::{Tool, ToolOutput};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct QuestionRequest {
17 pub questions: Vec<Question>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Question {
22 pub question: String,
23 pub header: String,
24 pub options: Vec<QuestionOption>,
25 pub multiple: bool,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct QuestionOption {
30 pub label: String,
31 pub description: String,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct QuestionResponse {
36 pub answers: Vec<Vec<String>>,
38}
39
40pub type OnQuestion = dyn Fn(QuestionRequest) -> Pin<Box<dyn Future<Output = Result<QuestionResponse, Error>> + Send>>
42 + Send
43 + Sync;
44
45pub struct QuestionTool {
56 on_question: Arc<OnQuestion>,
57}
58
59impl QuestionTool {
60 pub fn new(on_question: Arc<OnQuestion>) -> Self {
61 Self { on_question }
62 }
63}
64
65impl Tool for QuestionTool {
66 fn definition(&self) -> ToolDefinition {
67 ToolDefinition {
68 name: "question".into(),
69 description: "Ask the user structured questions with predefined options. \
70 Use this when you need clarification or a decision from the user."
71 .into(),
72 input_schema: json!({
73 "type": "object",
74 "properties": {
75 "questions": {
76 "type": "array",
77 "items": {
78 "type": "object",
79 "properties": {
80 "question": {
81 "type": "string",
82 "description": "The question to ask"
83 },
84 "header": {
85 "type": "string",
86 "description": "Short label (max 12 chars)"
87 },
88 "options": {
89 "type": "array",
90 "minItems": 2,
91 "items": {
92 "type": "object",
93 "properties": {
94 "label": {"type": "string"},
95 "description": {"type": "string"}
96 },
97 "required": ["label", "description"]
98 }
99 },
100 "multiple": {
101 "type": "boolean",
102 "description": "Allow multiple selections"
103 }
104 },
105 "required": ["question", "header", "options", "multiple"]
106 }
107 }
108 },
109 "required": ["questions"]
110 }),
111 }
112 }
113
114 fn execute(
115 &self,
116 _ctx: &crate::ExecutionContext,
117 input: serde_json::Value,
118 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
119 Box::pin(async move {
120 let questions_value = input
121 .get("questions")
122 .ok_or_else(|| Error::Agent("questions is required".into()))?;
123
124 let questions: Vec<Question> = serde_json::from_value(questions_value.clone())
125 .map_err(|e| Error::Agent(format!("Invalid questions format: {e}")))?;
126
127 if questions.is_empty() {
128 return Ok(ToolOutput::error("At least one question is required."));
129 }
130 for q in &questions {
131 if q.options.len() < 2 {
132 return Ok(ToolOutput::error(format!(
133 "Question '{}' must have at least 2 options.",
134 q.header
135 )));
136 }
137 }
138
139 let request = QuestionRequest {
140 questions: questions.clone(),
141 };
142 let response = match (self.on_question)(request).await {
143 Ok(r) => r,
144 Err(e) => return Ok(ToolOutput::error(format!("Question failed: {e}"))),
145 };
146
147 if response.answers.len() != questions.len() {
148 return Ok(ToolOutput::error(format!(
149 "Expected {} answers but got {}",
150 questions.len(),
151 response.answers.len()
152 )));
153 }
154
155 let mut output = String::new();
157 for (i, q) in questions.iter().enumerate() {
158 let answers = &response.answers[i];
159 output.push_str(&format!("{}: {}\n", q.question, answers.join(", ")));
160 }
161
162 Ok(ToolOutput::success(output))
163 })
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn definition_has_correct_name() {
173 let callback: Arc<OnQuestion> = Arc::new(|_| {
174 Box::pin(async {
175 Ok(QuestionResponse {
176 answers: vec![vec!["A".into()]],
177 })
178 })
179 });
180 let tool = QuestionTool::new(callback);
181 assert_eq!(tool.definition().name, "question");
182 }
183
184 #[tokio::test]
185 async fn question_tool_asks_and_returns() {
186 let callback: Arc<OnQuestion> = Arc::new(|req| {
187 Box::pin(async move {
188 let mut answers = Vec::new();
189 for q in &req.questions {
190 answers.push(vec![q.options[0].label.clone()]);
191 }
192 Ok(QuestionResponse { answers })
193 })
194 });
195
196 let tool = QuestionTool::new(callback);
197 let result = tool
198 .execute(
199 &crate::ExecutionContext::default(),
200 json!({
201 "questions": [{
202 "question": "Which color?",
203 "header": "Color",
204 "options": [
205 {"label": "Red", "description": "A warm color"},
206 {"label": "Blue", "description": "A cool color"}
207 ],
208 "multiple": false
209 }]
210 }),
211 )
212 .await
213 .unwrap();
214 assert!(!result.is_error);
215 assert!(result.content.contains("Red"));
216 }
217
218 #[tokio::test]
219 async fn question_tool_empty_questions() {
220 let callback: Arc<OnQuestion> =
221 Arc::new(|_| Box::pin(async { Ok(QuestionResponse { answers: vec![] }) }));
222
223 let tool = QuestionTool::new(callback);
224 let result = tool
225 .execute(
226 &crate::ExecutionContext::default(),
227 json!({"questions": []}),
228 )
229 .await
230 .unwrap();
231 assert!(result.is_error);
232 assert!(result.content.contains("At least one question"));
233 }
234
235 #[tokio::test]
236 async fn question_with_too_few_options_rejected() {
237 let callback: Arc<OnQuestion> =
238 Arc::new(|_| Box::pin(async { Ok(QuestionResponse { answers: vec![] }) }));
239
240 let tool = QuestionTool::new(callback);
241
242 let result = tool
244 .execute(
245 &crate::ExecutionContext::default(),
246 json!({
247 "questions": [{
248 "question": "Pick one",
249 "header": "Choice",
250 "options": [],
251 "multiple": false
252 }]
253 }),
254 )
255 .await
256 .unwrap();
257 assert!(result.is_error);
258 assert!(result.content.contains("at least 2 options"));
259
260 let result = tool
262 .execute(
263 &crate::ExecutionContext::default(),
264 json!({
265 "questions": [{
266 "question": "Pick one",
267 "header": "Choice",
268 "options": [{"label": "Only", "description": "Single option"}],
269 "multiple": false
270 }]
271 }),
272 )
273 .await
274 .unwrap();
275 assert!(result.is_error);
276 assert!(result.content.contains("at least 2 options"));
277 }
278
279 #[tokio::test]
280 async fn question_tool_rejects_mismatched_answer_count() {
281 let callback: Arc<OnQuestion> = Arc::new(|_| {
283 Box::pin(async {
284 Ok(QuestionResponse {
285 answers: vec![vec!["A".into()], vec!["B".into()]],
286 })
287 })
288 });
289
290 let tool = QuestionTool::new(callback);
291 let result = tool
292 .execute(
293 &crate::ExecutionContext::default(),
294 json!({
295 "questions": [{
296 "question": "Pick one",
297 "header": "Choice",
298 "options": [
299 {"label": "A", "description": "Option A"},
300 {"label": "B", "description": "Option B"}
301 ],
302 "multiple": false
303 }]
304 }),
305 )
306 .await
307 .unwrap();
308 assert!(result.is_error);
309 assert!(
310 result.content.contains("Expected 1 answers but got 2"),
311 "got: {}",
312 result.content
313 );
314 }
315
316 #[tokio::test]
317 async fn question_tool_callback_error_returns_tool_error() {
318 let callback: Arc<OnQuestion> =
319 Arc::new(|_| Box::pin(async { Err(Error::Agent("User cancelled".into())) }));
320
321 let tool = QuestionTool::new(callback);
322 let result = tool
323 .execute(
324 &crate::ExecutionContext::default(),
325 json!({
326 "questions": [{
327 "question": "Pick one",
328 "header": "Choice",
329 "options": [
330 {"label": "A", "description": "Option A"},
331 {"label": "B", "description": "Option B"}
332 ],
333 "multiple": false
334 }]
335 }),
336 )
337 .await
338 .unwrap(); assert!(result.is_error);
340 assert!(result.content.contains("User cancelled"));
341 }
342}