nika_engine/runtime/builtin/
prompt.rs1use super::BuiltinTool;
28use crate::error::NikaError;
29use crate::runtime::hitl::{HitlHandler, HitlRequest};
30use serde::{Deserialize, Serialize};
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34
35#[derive(Debug, Clone, Deserialize)]
37pub struct PromptParams {
38 pub message: String,
40 #[serde(default)]
42 pub default: Option<String>,
43}
44
45#[derive(Debug, Clone, Serialize)]
47pub struct PromptResponse {
48 pub response: String,
50 pub default_used: bool,
52}
53
54pub struct PromptTool {
62 headless: bool,
64 handler: Option<Arc<dyn HitlHandler>>,
66}
67
68impl PromptTool {
69 pub fn new_headless() -> Self {
72 Self {
73 headless: true,
74 handler: None,
75 }
76 }
77
78 pub fn new_interactive() -> Self {
81 Self {
82 headless: false,
83 handler: None,
84 }
85 }
86
87 pub fn with_handler(handler: Arc<dyn HitlHandler>) -> Self {
90 Self {
91 headless: false,
92 handler: Some(handler),
93 }
94 }
95}
96
97impl Default for PromptTool {
98 fn default() -> Self {
99 Self::new_headless()
100 }
101}
102
103impl BuiltinTool for PromptTool {
104 fn name(&self) -> &'static str {
105 "prompt"
106 }
107
108 fn description(&self) -> &'static str {
109 "Request user input during workflow execution (HITL)"
110 }
111
112 fn parameters_schema(&self) -> serde_json::Value {
113 serde_json::json!({
115 "type": "object",
116 "properties": {
117 "message": {
118 "type": "string",
119 "description": "Prompt message to display to the user"
120 },
121 "default": {
122 "type": "string",
123 "description": "Default value if no input provided"
124 }
125 },
126 "required": ["message", "default"],
127 "additionalProperties": false
128 })
129 }
130
131 fn call<'a>(
132 &'a self,
133 args: String,
134 ) -> Pin<Box<dyn Future<Output = Result<String, NikaError>> + Send + 'a>> {
135 Box::pin(async move {
136 let params: PromptParams =
138 serde_json::from_str(&args).map_err(|e| NikaError::BuiltinInvalidParams {
139 tool: "nika:prompt".into(),
140 reason: format!("Invalid JSON parameters: {}", e),
141 })?;
142
143 if params.message.is_empty() {
145 return Err(NikaError::BuiltinInvalidParams {
146 tool: "nika:prompt".into(),
147 reason: "Prompt message cannot be empty".into(),
148 });
149 }
150
151 if self.headless {
153 match params.default {
154 Some(default) => {
155 tracing::info!(
156 target: "nika:prompt",
157 message = %params.message,
158 default = %default,
159 "Using default value in headless mode"
160 );
161 let response = PromptResponse {
162 response: default,
163 default_used: true,
164 };
165 return serde_json::to_string(&response).map_err(|e| {
166 NikaError::BuiltinToolError {
167 tool: "nika:prompt".into(),
168 reason: format!("Failed to serialize response: {}", e),
169 }
170 });
171 }
172 None => {
173 return Err(NikaError::BuiltinToolError {
174 tool: "nika:prompt".into(),
175 reason: format!(
176 "HITL required but running in headless mode. Prompt: '{}'",
177 params.message
178 ),
179 });
180 }
181 }
182 }
183
184 if let Some(handler) = &self.handler {
186 let request = HitlRequest::new(¶ms.message);
187 let request = if let Some(default) = params.default.clone() {
188 request.with_default(default)
189 } else {
190 request
191 };
192
193 let hitl_response =
194 handler
195 .prompt(request)
196 .await
197 .map_err(|e| NikaError::BuiltinToolError {
198 tool: "nika:prompt".into(),
199 reason: format!("HITL handler error: {}", e),
200 })?;
201
202 let response = PromptResponse {
203 response: hitl_response.response,
204 default_used: hitl_response.default_used,
205 };
206
207 return serde_json::to_string(&response).map_err(|e| NikaError::BuiltinToolError {
208 tool: "nika:prompt".into(),
209 reason: format!("Failed to serialize response: {}", e),
210 });
211 }
212
213 match params.default {
215 Some(default) => {
216 tracing::warn!(
217 target: "nika:prompt",
218 message = %params.message,
219 default = %default,
220 "HITL handler not configured, using default value"
221 );
222 let response = PromptResponse {
223 response: default,
224 default_used: true,
225 };
226 serde_json::to_string(&response).map_err(|e| NikaError::BuiltinToolError {
227 tool: "nika:prompt".into(),
228 reason: format!("Failed to serialize response: {}", e),
229 })
230 }
231 None => Err(NikaError::BuiltinToolError {
232 tool: "nika:prompt".into(),
233 reason: format!(
234 "HITL handler not configured and no default provided. Prompt: '{}'",
235 params.message
236 ),
237 }),
238 }
239 })
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_prompt_tool_name() {
249 let tool = PromptTool::default();
250 assert_eq!(tool.name(), "prompt");
251 }
252
253 #[test]
254 fn test_prompt_tool_description() {
255 let tool = PromptTool::default();
256 assert!(tool.description().contains("HITL"));
257 }
258
259 #[test]
260 fn test_prompt_tool_schema() {
261 let tool = PromptTool::default();
262 let schema = tool.parameters_schema();
263 assert_eq!(schema["type"], "object");
264 assert!(schema["properties"]["message"].is_object());
265 assert!(schema["properties"]["default"].is_object());
266 assert!(schema["required"]
267 .as_array()
268 .unwrap()
269 .contains(&serde_json::json!("message")));
270 }
271
272 #[tokio::test]
273 async fn test_prompt_headless_with_default() {
274 let tool = PromptTool::new_headless();
275 let result = tool
276 .call(r#"{"message": "Approve?", "default": "yes"}"#.to_string())
277 .await;
278
279 assert!(result.is_ok());
280 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
281 assert_eq!(response["response"], "yes");
282 assert_eq!(response["default_used"], true);
283 }
284
285 #[tokio::test]
286 async fn test_prompt_headless_without_default_errors() {
287 let tool = PromptTool::new_headless();
288 let result = tool.call(r#"{"message": "Approve?"}"#.to_string()).await;
289
290 assert!(result.is_err());
291 let err = result.unwrap_err();
292 assert!(err.to_string().contains("headless mode"));
293 }
294
295 #[tokio::test]
296 async fn test_prompt_interactive_with_default() {
297 let tool = PromptTool::new_interactive();
298 let result = tool
299 .call(r#"{"message": "Confirm?", "default": "confirmed"}"#.to_string())
300 .await;
301
302 assert!(result.is_ok());
304 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
305 assert_eq!(response["response"], "confirmed");
306 assert_eq!(response["default_used"], true);
307 }
308
309 #[tokio::test]
310 async fn test_prompt_interactive_without_default_errors() {
311 let tool = PromptTool::new_interactive();
312 let result = tool
313 .call(r#"{"message": "User input needed"}"#.to_string())
314 .await;
315
316 assert!(result.is_err());
318 let err = result.unwrap_err();
319 assert!(err.to_string().contains("HITL handler not configured"));
320 }
321
322 #[tokio::test]
323 async fn test_prompt_empty_message_errors() {
324 let tool = PromptTool::default();
325 let result = tool.call(r#"{"message": ""}"#.to_string()).await;
326
327 assert!(result.is_err());
328 let err = result.unwrap_err();
329 assert!(err.to_string().contains("cannot be empty"));
330 }
331
332 #[tokio::test]
333 async fn test_prompt_invalid_json() {
334 let tool = PromptTool::default();
335 let result = tool.call("not json".to_string()).await;
336
337 assert!(result.is_err());
338 let err = result.unwrap_err();
339 assert!(err.to_string().contains("Invalid JSON parameters"));
340 }
341
342 #[tokio::test]
343 async fn test_prompt_missing_message() {
344 let tool = PromptTool::default();
345 let result = tool.call(r#"{"default": "test"}"#.to_string()).await;
346
347 assert!(result.is_err());
348 let err = result.unwrap_err();
349 assert!(err.to_string().contains("Invalid JSON parameters"));
350 }
351
352 #[tokio::test]
353 async fn test_prompt_params_deserialization() {
354 let json = r#"{"message": "Test prompt", "default": "default_value"}"#;
355 let params: PromptParams = serde_json::from_str(json).unwrap();
356
357 assert_eq!(params.message, "Test prompt");
358 assert_eq!(params.default, Some("default_value".to_string()));
359 }
360
361 #[tokio::test]
362 async fn test_prompt_params_without_default() {
363 let json = r#"{"message": "Test prompt"}"#;
364 let params: PromptParams = serde_json::from_str(json).unwrap();
365
366 assert_eq!(params.message, "Test prompt");
367 assert_eq!(params.default, None);
368 }
369
370 #[tokio::test]
372 async fn test_prompt_with_hitl_handler_calls_handler() {
373 use crate::runtime::hitl::{HitlError, HitlResponse};
374 use async_trait::async_trait;
375
376 struct MockHandler {
377 response: String,
378 }
379
380 #[async_trait]
381 impl HitlHandler for MockHandler {
382 async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
383 Ok(HitlResponse::new(&self.response))
384 }
385 }
386
387 let handler = Arc::new(MockHandler {
388 response: "user_input".to_string(),
389 });
390 let tool = PromptTool::with_handler(handler);
391 let result = tool
392 .call(r#"{"message": "Enter something"}"#.to_string())
393 .await;
394
395 assert!(result.is_ok());
396 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
397 assert_eq!(response["response"], "user_input");
398 assert_eq!(response["default_used"], false);
399 }
400
401 #[tokio::test]
402 async fn test_prompt_with_hitl_handler_ignores_default() {
403 use crate::runtime::hitl::{HitlError, HitlResponse};
404 use async_trait::async_trait;
405
406 struct MockHandler;
407
408 #[async_trait]
409 impl HitlHandler for MockHandler {
410 async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
411 Ok(HitlResponse::new("handler_response"))
412 }
413 }
414
415 let tool = PromptTool::with_handler(Arc::new(MockHandler));
416 let result = tool
417 .call(r#"{"message": "Confirm?", "default": "ignored_default"}"#.to_string())
418 .await;
419
420 assert!(result.is_ok());
421 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
422 assert_eq!(response["response"], "handler_response");
424 assert_eq!(response["default_used"], false);
425 }
426
427 #[tokio::test]
428 async fn test_prompt_with_hitl_handler_error_propagates() {
429 use crate::runtime::hitl::{HitlError, HitlResponse};
430 use async_trait::async_trait;
431
432 struct ErrorHandler;
433
434 #[async_trait]
435 impl HitlHandler for ErrorHandler {
436 async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
437 Err(HitlError::Cancelled)
438 }
439 }
440
441 let tool = PromptTool::with_handler(Arc::new(ErrorHandler));
442 let result = tool.call(r#"{"message": "Confirm?"}"#.to_string()).await;
443
444 assert!(result.is_err());
445 let err = result.unwrap_err();
446 assert!(err.to_string().contains("HITL handler error"));
447 }
448}