Skip to main content

ai_agents_hitl/
handler.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3
4use super::types::{ApprovalRequest, ApprovalResult};
5
6/// Handler for human-in-the-loop approval requests.
7///
8/// Built-in handlers: `RejectAllHandler`, `AutoApproveHandler`, `CallbackHandler`,
9/// `LocalizedHandler`. For simple cases, use `create_handler()` or
10/// `create_localized_handler()` helpers instead of implementing this directly.
11#[async_trait]
12pub trait ApprovalHandler: Send + Sync {
13    /// Process an approval request and return the decision.
14    async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult;
15
16    /// Language preference for approval messages. Returns `None` by default.
17    fn preferred_language(&self) -> Option<String> {
18        None
19    }
20
21    /// Languages this handler can display. Returns `None` by default.
22    fn supported_languages(&self) -> Option<Vec<String>> {
23        None
24    }
25}
26
27pub struct RejectAllHandler;
28
29impl RejectAllHandler {
30    pub fn new() -> Self {
31        Self
32    }
33}
34
35impl Default for RejectAllHandler {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41#[async_trait]
42impl ApprovalHandler for RejectAllHandler {
43    async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
44        tracing::warn!(
45            "[HITL] Auto-rejecting: {} (no handler configured)",
46            request.message
47        );
48        ApprovalResult::rejected_with_reason("No approval handler configured")
49    }
50}
51
52// This is for testing purposes
53pub struct AutoApproveHandler;
54
55impl AutoApproveHandler {
56    pub fn new() -> Self {
57        Self
58    }
59}
60
61impl Default for AutoApproveHandler {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67#[async_trait]
68impl ApprovalHandler for AutoApproveHandler {
69    async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
70        tracing::info!("[HITL] Auto-approving: {}", request.message);
71        ApprovalResult::approved()
72    }
73}
74
75pub struct CallbackHandler<F>
76where
77    F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync,
78{
79    callback: F,
80}
81
82impl<F> CallbackHandler<F>
83where
84    F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync,
85{
86    pub fn new(callback: F) -> Self {
87        Self { callback }
88    }
89}
90
91#[async_trait]
92impl<F> ApprovalHandler for CallbackHandler<F>
93where
94    F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync,
95{
96    async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
97        (self.callback)(request)
98    }
99}
100
101pub struct LocalizedHandler {
102    inner: Arc<dyn ApprovalHandler>,
103    language: String,
104    supported: Option<Vec<String>>,
105}
106
107impl LocalizedHandler {
108    pub fn new(inner: Arc<dyn ApprovalHandler>, language: impl Into<String>) -> Self {
109        Self {
110            inner,
111            language: language.into(),
112            supported: None,
113        }
114    }
115
116    pub fn with_supported(mut self, languages: Vec<String>) -> Self {
117        self.supported = Some(languages);
118        self
119    }
120}
121
122#[async_trait]
123impl ApprovalHandler for LocalizedHandler {
124    async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
125        self.inner.request_approval(request).await
126    }
127
128    fn preferred_language(&self) -> Option<String> {
129        Some(self.language.clone())
130    }
131
132    fn supported_languages(&self) -> Option<Vec<String>> {
133        self.supported.clone()
134    }
135}
136
137pub fn create_handler<F>(callback: F) -> Arc<dyn ApprovalHandler>
138where
139    F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync + 'static,
140{
141    Arc::new(CallbackHandler::new(callback))
142}
143
144pub fn create_localized_handler<F>(
145    callback: F,
146    language: impl Into<String>,
147) -> Arc<dyn ApprovalHandler>
148where
149    F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync + 'static,
150{
151    Arc::new(LocalizedHandler::new(
152        Arc::new(CallbackHandler::new(callback)),
153        language,
154    ))
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::types::ApprovalTrigger;
161
162    fn create_test_request() -> ApprovalRequest {
163        ApprovalRequest::new(
164            ApprovalTrigger::tool("test_tool", serde_json::json!({})),
165            "Test approval",
166        )
167    }
168
169    #[tokio::test]
170    async fn test_reject_all_handler() {
171        let handler = RejectAllHandler::new();
172        let request = create_test_request();
173        let result = handler.request_approval(request).await;
174        assert!(result.is_rejected());
175        assert!(handler.preferred_language().is_none());
176        assert!(handler.supported_languages().is_none());
177    }
178
179    #[tokio::test]
180    async fn test_auto_approve_handler() {
181        let handler = AutoApproveHandler::new();
182        let request = create_test_request();
183        let result = handler.request_approval(request).await;
184        assert!(result.is_approved());
185    }
186
187    #[tokio::test]
188    async fn test_callback_handler() {
189        let handler = CallbackHandler::new(|_| ApprovalResult::approved());
190        let request = create_test_request();
191        let result = handler.request_approval(request).await;
192        assert!(result.is_approved());
193    }
194
195    #[tokio::test]
196    async fn test_callback_handler_with_rejection() {
197        let handler = CallbackHandler::new(|req| {
198            if req.message.contains("dangerous") {
199                ApprovalResult::rejected_with_reason("Dangerous operation")
200            } else {
201                ApprovalResult::approved()
202            }
203        });
204
205        let safe_request = ApprovalRequest::new(
206            ApprovalTrigger::tool("safe", serde_json::json!({})),
207            "Safe operation",
208        );
209        let result = handler.request_approval(safe_request).await;
210        assert!(result.is_approved());
211
212        let dangerous_request = ApprovalRequest::new(
213            ApprovalTrigger::tool("danger", serde_json::json!({})),
214            "dangerous operation",
215        );
216        let result = handler.request_approval(dangerous_request).await;
217        assert!(result.is_rejected());
218    }
219
220    #[tokio::test]
221    async fn test_create_handler_helper() {
222        let handler = create_handler(|_| ApprovalResult::approved());
223        let request = create_test_request();
224        let result = handler.request_approval(request).await;
225        assert!(result.is_approved());
226    }
227
228    #[tokio::test]
229    async fn test_localized_handler() {
230        let inner = Arc::new(AutoApproveHandler::new());
231        let handler = LocalizedHandler::new(inner, "ko")
232            .with_supported(vec!["ko".to_string(), "en".to_string()]);
233
234        assert_eq!(handler.preferred_language(), Some("ko".to_string()));
235        assert_eq!(
236            handler.supported_languages(),
237            Some(vec!["ko".to_string(), "en".to_string()])
238        );
239
240        let request = create_test_request();
241        let result = handler.request_approval(request).await;
242        assert!(result.is_approved());
243    }
244
245    #[tokio::test]
246    async fn test_create_localized_handler_helper() {
247        let handler = create_localized_handler(|_| ApprovalResult::approved(), "ja");
248
249        assert_eq!(handler.preferred_language(), Some("ja".to_string()));
250
251        let request = create_test_request();
252        let result = handler.request_approval(request).await;
253        assert!(result.is_approved());
254    }
255
256    #[test]
257    fn test_reject_all_default() {
258        let handler = RejectAllHandler::default();
259        assert!(std::mem::size_of_val(&handler) == 0);
260    }
261
262    #[test]
263    fn test_auto_approve_default() {
264        let handler = AutoApproveHandler::default();
265        assert!(std::mem::size_of_val(&handler) == 0);
266    }
267}