nika_engine/runtime/
hitl.rs1use async_trait::async_trait;
25use std::time::Duration;
26use thiserror::Error;
27
28#[derive(Debug, Error)]
30pub enum HitlError {
31 #[error("User cancelled prompt")]
33 Cancelled,
34
35 #[error("Prompt timed out after {0:?}")]
37 Timeout(Duration),
38
39 #[error("HITL handler not available: {0}")]
41 NotAvailable(String),
42
43 #[error("HITL error: {0}")]
45 Other(String),
46}
47
48#[derive(Debug, Clone)]
50pub struct HitlRequest {
51 pub message: String,
53 pub default: Option<String>,
55 pub timeout: Option<Duration>,
57 pub choices: Option<Vec<String>>,
59}
60
61impl HitlRequest {
62 pub fn new(message: impl Into<String>) -> Self {
64 Self {
65 message: message.into(),
66 default: None,
67 timeout: None,
68 choices: None,
69 }
70 }
71
72 pub fn with_default(mut self, default: impl Into<String>) -> Self {
74 self.default = Some(default.into());
75 self
76 }
77
78 pub fn with_timeout(mut self, timeout: Duration) -> Self {
80 self.timeout = Some(timeout);
81 self
82 }
83
84 pub fn with_choices(mut self, choices: Vec<String>) -> Self {
86 self.choices = Some(choices);
87 self
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct HitlResponse {
94 pub response: String,
96 pub default_used: bool,
98}
99
100impl HitlResponse {
101 pub fn new(response: impl Into<String>) -> Self {
103 Self {
104 response: response.into(),
105 default_used: false,
106 }
107 }
108
109 pub fn from_default(default: impl Into<String>) -> Self {
111 Self {
112 response: default.into(),
113 default_used: true,
114 }
115 }
116}
117
118#[async_trait]
124pub trait HitlHandler: Send + Sync {
125 async fn prompt(&self, request: HitlRequest) -> Result<HitlResponse, HitlError>;
139}
140
141#[derive(Debug, Default)]
145pub struct DefaultHitlHandler;
146
147#[async_trait]
148impl HitlHandler for DefaultHitlHandler {
149 async fn prompt(&self, request: HitlRequest) -> Result<HitlResponse, HitlError> {
150 match request.default {
151 Some(default) => Ok(HitlResponse::from_default(default)),
152 None => Err(HitlError::NotAvailable(
153 "No default provided and running in headless mode".to_string(),
154 )),
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[tokio::test]
164 async fn test_hitl_request_builder() {
165 let request = HitlRequest::new("Enter your name")
166 .with_default("Anonymous")
167 .with_timeout(Duration::from_secs(30))
168 .with_choices(vec!["Alice".to_string(), "Bob".to_string()]);
169
170 assert_eq!(request.message, "Enter your name");
171 assert_eq!(request.default, Some("Anonymous".to_string()));
172 assert_eq!(request.timeout, Some(Duration::from_secs(30)));
173 assert_eq!(
174 request.choices,
175 Some(vec!["Alice".to_string(), "Bob".to_string()])
176 );
177 }
178
179 #[tokio::test]
180 async fn test_hitl_response_new() {
181 let response = HitlResponse::new("user input");
182 assert_eq!(response.response, "user input");
183 assert!(!response.default_used);
184 }
185
186 #[tokio::test]
187 async fn test_hitl_response_from_default() {
188 let response = HitlResponse::from_default("default value");
189 assert_eq!(response.response, "default value");
190 assert!(response.default_used);
191 }
192
193 #[tokio::test]
194 async fn test_default_handler_uses_default() {
195 let handler = DefaultHitlHandler;
196 let request = HitlRequest::new("Test prompt").with_default("default");
197
198 let response = handler.prompt(request).await.unwrap();
199 assert_eq!(response.response, "default");
200 assert!(response.default_used);
201 }
202
203 #[tokio::test]
204 async fn test_default_handler_errors_without_default() {
205 let handler = DefaultHitlHandler;
206 let request = HitlRequest::new("Test prompt");
207
208 let result = handler.prompt(request).await;
209 assert!(result.is_err());
210 assert!(matches!(result.unwrap_err(), HitlError::NotAvailable(_)));
211 }
212
213 #[tokio::test]
214 async fn test_hitl_error_display() {
215 let err = HitlError::Cancelled;
216 assert_eq!(err.to_string(), "User cancelled prompt");
217
218 let err = HitlError::Timeout(Duration::from_secs(30));
219 assert!(err.to_string().contains("30"));
220
221 let err = HitlError::NotAvailable("test".to_string());
222 assert!(err.to_string().contains("test"));
223 }
224
225 #[tokio::test]
227 async fn test_custom_hitl_handler() {
228 struct CustomHandler {
229 fixed_response: String,
230 }
231
232 #[async_trait]
233 impl HitlHandler for CustomHandler {
234 async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
235 Ok(HitlResponse::new(&self.fixed_response))
236 }
237 }
238
239 let handler = CustomHandler {
240 fixed_response: "custom_response".to_string(),
241 };
242 let request = HitlRequest::new("Test");
243 let response = handler.prompt(request).await.unwrap();
244
245 assert_eq!(response.response, "custom_response");
246 assert!(!response.default_used);
247 }
248}