use async_trait::async_trait;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum HitlError {
#[error("User cancelled prompt")]
Cancelled,
#[error("Prompt timed out after {0:?}")]
Timeout(Duration),
#[error("HITL handler not available: {0}")]
NotAvailable(String),
#[error("HITL error: {0}")]
Other(String),
}
#[derive(Debug, Clone)]
pub struct HitlRequest {
pub message: String,
pub default: Option<String>,
pub timeout: Option<Duration>,
pub choices: Option<Vec<String>>,
}
impl HitlRequest {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
default: None,
timeout: None,
choices: None,
}
}
pub fn with_default(mut self, default: impl Into<String>) -> Self {
self.default = Some(default.into());
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_choices(mut self, choices: Vec<String>) -> Self {
self.choices = Some(choices);
self
}
}
#[derive(Debug, Clone)]
pub struct HitlResponse {
pub response: String,
pub default_used: bool,
}
impl HitlResponse {
pub fn new(response: impl Into<String>) -> Self {
Self {
response: response.into(),
default_used: false,
}
}
pub fn from_default(default: impl Into<String>) -> Self {
Self {
response: default.into(),
default_used: true,
}
}
}
#[async_trait]
pub trait HitlHandler: Send + Sync {
async fn prompt(&self, request: HitlRequest) -> Result<HitlResponse, HitlError>;
}
#[derive(Debug, Default)]
pub struct DefaultHitlHandler;
#[async_trait]
impl HitlHandler for DefaultHitlHandler {
async fn prompt(&self, request: HitlRequest) -> Result<HitlResponse, HitlError> {
match request.default {
Some(default) => Ok(HitlResponse::from_default(default)),
None => Err(HitlError::NotAvailable(
"No default provided and running in headless mode".to_string(),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hitl_request_builder() {
let request = HitlRequest::new("Enter your name")
.with_default("Anonymous")
.with_timeout(Duration::from_secs(30))
.with_choices(vec!["Alice".to_string(), "Bob".to_string()]);
assert_eq!(request.message, "Enter your name");
assert_eq!(request.default, Some("Anonymous".to_string()));
assert_eq!(request.timeout, Some(Duration::from_secs(30)));
assert_eq!(
request.choices,
Some(vec!["Alice".to_string(), "Bob".to_string()])
);
}
#[tokio::test]
async fn test_hitl_response_new() {
let response = HitlResponse::new("user input");
assert_eq!(response.response, "user input");
assert!(!response.default_used);
}
#[tokio::test]
async fn test_hitl_response_from_default() {
let response = HitlResponse::from_default("default value");
assert_eq!(response.response, "default value");
assert!(response.default_used);
}
#[tokio::test]
async fn test_default_handler_uses_default() {
let handler = DefaultHitlHandler;
let request = HitlRequest::new("Test prompt").with_default("default");
let response = handler.prompt(request).await.unwrap();
assert_eq!(response.response, "default");
assert!(response.default_used);
}
#[tokio::test]
async fn test_default_handler_errors_without_default() {
let handler = DefaultHitlHandler;
let request = HitlRequest::new("Test prompt");
let result = handler.prompt(request).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), HitlError::NotAvailable(_)));
}
#[tokio::test]
async fn test_hitl_error_display() {
let err = HitlError::Cancelled;
assert_eq!(err.to_string(), "User cancelled prompt");
let err = HitlError::Timeout(Duration::from_secs(30));
assert!(err.to_string().contains("30"));
let err = HitlError::NotAvailable("test".to_string());
assert!(err.to_string().contains("test"));
}
#[tokio::test]
async fn test_custom_hitl_handler() {
struct CustomHandler {
fixed_response: String,
}
#[async_trait]
impl HitlHandler for CustomHandler {
async fn prompt(&self, _request: HitlRequest) -> Result<HitlResponse, HitlError> {
Ok(HitlResponse::new(&self.fixed_response))
}
}
let handler = CustomHandler {
fixed_response: "custom_response".to_string(),
};
let request = HitlRequest::new("Test");
let response = handler.prompt(request).await.unwrap();
assert_eq!(response.response, "custom_response");
assert!(!response.default_used);
}
}