Skip to main content

nika_engine/runtime/builtin/
prompt.rs

1//! nika:prompt - Human-In-The-Loop user input request.
2//!
3//! # Parameters
4//!
5//! ```json
6//! {
7//!   "message": "Please review this output",  // Prompt message for user
8//!   "default": "approve"                     // Default value (optional)
9//! }
10//! ```
11//!
12//! # Returns
13//!
14//! ```json
15//! {
16//!   "response": "user input here",
17//!   "default_used": false
18//! }
19//! ```
20//!
21//! # Note
22//!
23//! This tool requires runtime integration with a HITL channel.
24//! When invoked without a HITL channel (e.g., in headless mode),
25//! it returns the default value if provided, or errors.
26
27use super::BuiltinTool;
28use crate::error::NikaError;
29use crate::runtime::hitl::{HitlHandler, HitlRequest};
30use serde::{Deserialize, Serialize};
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34
35/// Parameters for nika:prompt tool.
36#[derive(Debug, Clone, Deserialize)]
37pub struct PromptParams {
38    /// The prompt message to display to the user.
39    pub message: String,
40    /// Default value to use if no input is provided (optional).
41    #[serde(default)]
42    pub default: Option<String>,
43}
44
45/// Response from nika:prompt tool.
46#[derive(Debug, Clone, Serialize)]
47pub struct PromptResponse {
48    /// The user's response (or default value).
49    pub response: String,
50    /// Whether the default value was used.
51    pub default_used: bool,
52}
53
54/// nika:prompt builtin tool.
55///
56/// Requests user input during workflow execution (HITL).
57/// In headless mode, returns the default value or errors.
58///
59/// Full HITL integration requires a channel for user interaction,
60/// which is set up when the Router is integrated with the TUI.
61pub struct PromptTool {
62    /// Whether running in headless mode (no TUI).
63    headless: bool,
64    /// Optional HITL handler for interactive mode.
65    handler: Option<Arc<dyn HitlHandler>>,
66}
67
68impl PromptTool {
69    /// Create a new PromptTool in headless mode.
70    /// In headless mode, the tool uses the default value or errors.
71    pub fn new_headless() -> Self {
72        Self {
73            headless: true,
74            handler: None,
75        }
76    }
77
78    /// Create a new PromptTool in interactive mode.
79    /// Full HITL support requires runtime integration.
80    pub fn new_interactive() -> Self {
81        Self {
82            headless: false,
83            handler: None,
84        }
85    }
86
87    /// Create a new PromptTool with a custom HITL handler.
88    /// The handler will be used for all prompt requests.
89    pub fn with_handler(handler: Arc<dyn HitlHandler>) -> Self {
90        Self {
91            headless: false,
92            handler: Some(handler),
93        }
94    }
95}
96
97impl Default for PromptTool {
98    fn default() -> Self {
99        Self::new_headless()
100    }
101}
102
103impl BuiltinTool for PromptTool {
104    fn name(&self) -> &'static str {
105        "prompt"
106    }
107
108    fn description(&self) -> &'static str {
109        "Request user input during workflow execution (HITL)"
110    }
111
112    fn parameters_schema(&self) -> serde_json::Value {
113        // OpenAI-compatible schema with additionalProperties: false
114        serde_json::json!({
115            "type": "object",
116            "properties": {
117                "message": {
118                    "type": "string",
119                    "description": "Prompt message to display to the user"
120                },
121                "default": {
122                    "type": "string",
123                    "description": "Default value if no input provided"
124                }
125            },
126            "required": ["message", "default"],
127            "additionalProperties": false
128        })
129    }
130
131    fn call<'a>(
132        &'a self,
133        args: String,
134    ) -> Pin<Box<dyn Future<Output = Result<String, NikaError>> + Send + 'a>> {
135        Box::pin(async move {
136            // Parse parameters
137            let params: PromptParams =
138                serde_json::from_str(&args).map_err(|e| NikaError::BuiltinInvalidParams {
139                    tool: "nika:prompt".into(),
140                    reason: format!("Invalid JSON parameters: {}", e),
141                })?;
142
143            // Validate message is not empty
144            if params.message.is_empty() {
145                return Err(NikaError::BuiltinInvalidParams {
146                    tool: "nika:prompt".into(),
147                    reason: "Prompt message cannot be empty".into(),
148                });
149            }
150
151            // In headless mode, use default value or error
152            if self.headless {
153                match params.default {
154                    Some(default) => {
155                        tracing::info!(
156                            target: "nika:prompt",
157                            message = %params.message,
158                            default = %default,
159                            "Using default value in headless mode"
160                        );
161                        let response = PromptResponse {
162                            response: default,
163                            default_used: true,
164                        };
165                        return serde_json::to_string(&response).map_err(|e| {
166                            NikaError::BuiltinToolError {
167                                tool: "nika:prompt".into(),
168                                reason: format!("Failed to serialize response: {}", e),
169                            }
170                        });
171                    }
172                    None => {
173                        return Err(NikaError::BuiltinToolError {
174                            tool: "nika:prompt".into(),
175                            reason: format!(
176                                "HITL required but running in headless mode. Prompt: '{}'",
177                                params.message
178                            ),
179                        });
180                    }
181                }
182            }
183
184            // Use HITL handler if provided
185            if let Some(handler) = &self.handler {
186                let request = HitlRequest::new(&params.message);
187                let request = if let Some(default) = params.default.clone() {
188                    request.with_default(default)
189                } else {
190                    request
191                };
192
193                let hitl_response =
194                    handler
195                        .prompt(request)
196                        .await
197                        .map_err(|e| NikaError::BuiltinToolError {
198                            tool: "nika:prompt".into(),
199                            reason: format!("HITL handler error: {}", e),
200                        })?;
201
202                let response = PromptResponse {
203                    response: hitl_response.response,
204                    default_used: hitl_response.default_used,
205                };
206
207                return serde_json::to_string(&response).map_err(|e| NikaError::BuiltinToolError {
208                    tool: "nika:prompt".into(),
209                    reason: format!("Failed to serialize response: {}", e),
210                });
211            }
212
213            // Fallback: interactive mode without handler uses default or errors
214            match params.default {
215                Some(default) => {
216                    tracing::warn!(
217                        target: "nika:prompt",
218                        message = %params.message,
219                        default = %default,
220                        "HITL handler not configured, using default value"
221                    );
222                    let response = PromptResponse {
223                        response: default,
224                        default_used: true,
225                    };
226                    serde_json::to_string(&response).map_err(|e| NikaError::BuiltinToolError {
227                        tool: "nika:prompt".into(),
228                        reason: format!("Failed to serialize response: {}", e),
229                    })
230                }
231                None => Err(NikaError::BuiltinToolError {
232                    tool: "nika:prompt".into(),
233                    reason: format!(
234                        "HITL handler not configured and no default provided. Prompt: '{}'",
235                        params.message
236                    ),
237                }),
238            }
239        })
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_prompt_tool_name() {
249        let tool = PromptTool::default();
250        assert_eq!(tool.name(), "prompt");
251    }
252
253    #[test]
254    fn test_prompt_tool_description() {
255        let tool = PromptTool::default();
256        assert!(tool.description().contains("HITL"));
257    }
258
259    #[test]
260    fn test_prompt_tool_schema() {
261        let tool = PromptTool::default();
262        let schema = tool.parameters_schema();
263        assert_eq!(schema["type"], "object");
264        assert!(schema["properties"]["message"].is_object());
265        assert!(schema["properties"]["default"].is_object());
266        assert!(schema["required"]
267            .as_array()
268            .unwrap()
269            .contains(&serde_json::json!("message")));
270    }
271
272    #[tokio::test]
273    async fn test_prompt_headless_with_default() {
274        let tool = PromptTool::new_headless();
275        let result = tool
276            .call(r#"{"message": "Approve?", "default": "yes"}"#.to_string())
277            .await;
278
279        assert!(result.is_ok());
280        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
281        assert_eq!(response["response"], "yes");
282        assert_eq!(response["default_used"], true);
283    }
284
285    #[tokio::test]
286    async fn test_prompt_headless_without_default_errors() {
287        let tool = PromptTool::new_headless();
288        let result = tool.call(r#"{"message": "Approve?"}"#.to_string()).await;
289
290        assert!(result.is_err());
291        let err = result.unwrap_err();
292        assert!(err.to_string().contains("headless mode"));
293    }
294
295    #[tokio::test]
296    async fn test_prompt_interactive_with_default() {
297        let tool = PromptTool::new_interactive();
298        let result = tool
299            .call(r#"{"message": "Confirm?", "default": "confirmed"}"#.to_string())
300            .await;
301
302        // Currently falls back to default since HITL channel not configured
303        assert!(result.is_ok());
304        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
305        assert_eq!(response["response"], "confirmed");
306        assert_eq!(response["default_used"], true);
307    }
308
309    #[tokio::test]
310    async fn test_prompt_interactive_without_default_errors() {
311        let tool = PromptTool::new_interactive();
312        let result = tool
313            .call(r#"{"message": "User input needed"}"#.to_string())
314            .await;
315
316        // Currently errors since HITL handler not configured
317        assert!(result.is_err());
318        let err = result.unwrap_err();
319        assert!(err.to_string().contains("HITL handler not configured"));
320    }
321
322    #[tokio::test]
323    async fn test_prompt_empty_message_errors() {
324        let tool = PromptTool::default();
325        let result = tool.call(r#"{"message": ""}"#.to_string()).await;
326
327        assert!(result.is_err());
328        let err = result.unwrap_err();
329        assert!(err.to_string().contains("cannot be empty"));
330    }
331
332    #[tokio::test]
333    async fn test_prompt_invalid_json() {
334        let tool = PromptTool::default();
335        let result = tool.call("not json".to_string()).await;
336
337        assert!(result.is_err());
338        let err = result.unwrap_err();
339        assert!(err.to_string().contains("Invalid JSON parameters"));
340    }
341
342    #[tokio::test]
343    async fn test_prompt_missing_message() {
344        let tool = PromptTool::default();
345        let result = tool.call(r#"{"default": "test"}"#.to_string()).await;
346
347        assert!(result.is_err());
348        let err = result.unwrap_err();
349        assert!(err.to_string().contains("Invalid JSON parameters"));
350    }
351
352    #[tokio::test]
353    async fn test_prompt_params_deserialization() {
354        let json = r#"{"message": "Test prompt", "default": "default_value"}"#;
355        let params: PromptParams = serde_json::from_str(json).unwrap();
356
357        assert_eq!(params.message, "Test prompt");
358        assert_eq!(params.default, Some("default_value".to_string()));
359    }
360
361    #[tokio::test]
362    async fn test_prompt_params_without_default() {
363        let json = r#"{"message": "Test prompt"}"#;
364        let params: PromptParams = serde_json::from_str(json).unwrap();
365
366        assert_eq!(params.message, "Test prompt");
367        assert_eq!(params.default, None);
368    }
369
370    // Tests for HitlHandler integration
371    #[tokio::test]
372    async fn test_prompt_with_hitl_handler_calls_handler() {
373        use crate::runtime::hitl::{HitlError, HitlResponse};
374        use async_trait::async_trait;
375
376        struct MockHandler {
377            response: String,
378        }
379
380        #[async_trait]
381        impl HitlHandler for MockHandler {
382            async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
383                Ok(HitlResponse::new(&self.response))
384            }
385        }
386
387        let handler = Arc::new(MockHandler {
388            response: "user_input".to_string(),
389        });
390        let tool = PromptTool::with_handler(handler);
391        let result = tool
392            .call(r#"{"message": "Enter something"}"#.to_string())
393            .await;
394
395        assert!(result.is_ok());
396        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
397        assert_eq!(response["response"], "user_input");
398        assert_eq!(response["default_used"], false);
399    }
400
401    #[tokio::test]
402    async fn test_prompt_with_hitl_handler_ignores_default() {
403        use crate::runtime::hitl::{HitlError, HitlResponse};
404        use async_trait::async_trait;
405
406        struct MockHandler;
407
408        #[async_trait]
409        impl HitlHandler for MockHandler {
410            async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
411                Ok(HitlResponse::new("handler_response"))
412            }
413        }
414
415        let tool = PromptTool::with_handler(Arc::new(MockHandler));
416        let result = tool
417            .call(r#"{"message": "Confirm?", "default": "ignored_default"}"#.to_string())
418            .await;
419
420        assert!(result.is_ok());
421        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
422        // Handler response takes precedence over default
423        assert_eq!(response["response"], "handler_response");
424        assert_eq!(response["default_used"], false);
425    }
426
427    #[tokio::test]
428    async fn test_prompt_with_hitl_handler_error_propagates() {
429        use crate::runtime::hitl::{HitlError, HitlResponse};
430        use async_trait::async_trait;
431
432        struct ErrorHandler;
433
434        #[async_trait]
435        impl HitlHandler for ErrorHandler {
436            async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
437                Err(HitlError::Cancelled)
438            }
439        }
440
441        let tool = PromptTool::with_handler(Arc::new(ErrorHandler));
442        let result = tool.call(r#"{"message": "Confirm?"}"#.to_string()).await;
443
444        assert!(result.is_err());
445        let err = result.unwrap_err();
446        assert!(err.to_string().contains("HITL handler error"));
447    }
448}