ai_agents_hitl/
handler.rs1use async_trait::async_trait;
2use std::sync::Arc;
3
4use super::types::{ApprovalRequest, ApprovalResult};
5
6#[async_trait]
12pub trait ApprovalHandler: Send + Sync {
13 async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult;
15
16 fn preferred_language(&self) -> Option<String> {
18 None
19 }
20
21 fn supported_languages(&self) -> Option<Vec<String>> {
23 None
24 }
25}
26
27pub struct RejectAllHandler;
28
29impl RejectAllHandler {
30 pub fn new() -> Self {
31 Self
32 }
33}
34
35impl Default for RejectAllHandler {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41#[async_trait]
42impl ApprovalHandler for RejectAllHandler {
43 async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
44 tracing::warn!(
45 "[HITL] Auto-rejecting: {} (no handler configured)",
46 request.message
47 );
48 ApprovalResult::rejected_with_reason("No approval handler configured")
49 }
50}
51
52pub struct AutoApproveHandler;
54
55impl AutoApproveHandler {
56 pub fn new() -> Self {
57 Self
58 }
59}
60
61impl Default for AutoApproveHandler {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[async_trait]
68impl ApprovalHandler for AutoApproveHandler {
69 async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
70 tracing::info!("[HITL] Auto-approving: {}", request.message);
71 ApprovalResult::approved()
72 }
73}
74
75pub struct CallbackHandler<F>
76where
77 F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync,
78{
79 callback: F,
80}
81
82impl<F> CallbackHandler<F>
83where
84 F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync,
85{
86 pub fn new(callback: F) -> Self {
87 Self { callback }
88 }
89}
90
91#[async_trait]
92impl<F> ApprovalHandler for CallbackHandler<F>
93where
94 F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync,
95{
96 async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
97 (self.callback)(request)
98 }
99}
100
101pub struct LocalizedHandler {
102 inner: Arc<dyn ApprovalHandler>,
103 language: String,
104 supported: Option<Vec<String>>,
105}
106
107impl LocalizedHandler {
108 pub fn new(inner: Arc<dyn ApprovalHandler>, language: impl Into<String>) -> Self {
109 Self {
110 inner,
111 language: language.into(),
112 supported: None,
113 }
114 }
115
116 pub fn with_supported(mut self, languages: Vec<String>) -> Self {
117 self.supported = Some(languages);
118 self
119 }
120}
121
122#[async_trait]
123impl ApprovalHandler for LocalizedHandler {
124 async fn request_approval(&self, request: ApprovalRequest) -> ApprovalResult {
125 self.inner.request_approval(request).await
126 }
127
128 fn preferred_language(&self) -> Option<String> {
129 Some(self.language.clone())
130 }
131
132 fn supported_languages(&self) -> Option<Vec<String>> {
133 self.supported.clone()
134 }
135}
136
137pub fn create_handler<F>(callback: F) -> Arc<dyn ApprovalHandler>
138where
139 F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync + 'static,
140{
141 Arc::new(CallbackHandler::new(callback))
142}
143
144pub fn create_localized_handler<F>(
145 callback: F,
146 language: impl Into<String>,
147) -> Arc<dyn ApprovalHandler>
148where
149 F: Fn(ApprovalRequest) -> ApprovalResult + Send + Sync + 'static,
150{
151 Arc::new(LocalizedHandler::new(
152 Arc::new(CallbackHandler::new(callback)),
153 language,
154 ))
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::types::ApprovalTrigger;
161
162 fn create_test_request() -> ApprovalRequest {
163 ApprovalRequest::new(
164 ApprovalTrigger::tool("test_tool", serde_json::json!({})),
165 "Test approval",
166 )
167 }
168
169 #[tokio::test]
170 async fn test_reject_all_handler() {
171 let handler = RejectAllHandler::new();
172 let request = create_test_request();
173 let result = handler.request_approval(request).await;
174 assert!(result.is_rejected());
175 assert!(handler.preferred_language().is_none());
176 assert!(handler.supported_languages().is_none());
177 }
178
179 #[tokio::test]
180 async fn test_auto_approve_handler() {
181 let handler = AutoApproveHandler::new();
182 let request = create_test_request();
183 let result = handler.request_approval(request).await;
184 assert!(result.is_approved());
185 }
186
187 #[tokio::test]
188 async fn test_callback_handler() {
189 let handler = CallbackHandler::new(|_| ApprovalResult::approved());
190 let request = create_test_request();
191 let result = handler.request_approval(request).await;
192 assert!(result.is_approved());
193 }
194
195 #[tokio::test]
196 async fn test_callback_handler_with_rejection() {
197 let handler = CallbackHandler::new(|req| {
198 if req.message.contains("dangerous") {
199 ApprovalResult::rejected_with_reason("Dangerous operation")
200 } else {
201 ApprovalResult::approved()
202 }
203 });
204
205 let safe_request = ApprovalRequest::new(
206 ApprovalTrigger::tool("safe", serde_json::json!({})),
207 "Safe operation",
208 );
209 let result = handler.request_approval(safe_request).await;
210 assert!(result.is_approved());
211
212 let dangerous_request = ApprovalRequest::new(
213 ApprovalTrigger::tool("danger", serde_json::json!({})),
214 "dangerous operation",
215 );
216 let result = handler.request_approval(dangerous_request).await;
217 assert!(result.is_rejected());
218 }
219
220 #[tokio::test]
221 async fn test_create_handler_helper() {
222 let handler = create_handler(|_| ApprovalResult::approved());
223 let request = create_test_request();
224 let result = handler.request_approval(request).await;
225 assert!(result.is_approved());
226 }
227
228 #[tokio::test]
229 async fn test_localized_handler() {
230 let inner = Arc::new(AutoApproveHandler::new());
231 let handler = LocalizedHandler::new(inner, "ko")
232 .with_supported(vec!["ko".to_string(), "en".to_string()]);
233
234 assert_eq!(handler.preferred_language(), Some("ko".to_string()));
235 assert_eq!(
236 handler.supported_languages(),
237 Some(vec!["ko".to_string(), "en".to_string()])
238 );
239
240 let request = create_test_request();
241 let result = handler.request_approval(request).await;
242 assert!(result.is_approved());
243 }
244
245 #[tokio::test]
246 async fn test_create_localized_handler_helper() {
247 let handler = create_localized_handler(|_| ApprovalResult::approved(), "ja");
248
249 assert_eq!(handler.preferred_language(), Some("ja".to_string()));
250
251 let request = create_test_request();
252 let result = handler.request_approval(request).await;
253 assert!(result.is_approved());
254 }
255
256 #[test]
257 fn test_reject_all_default() {
258 let handler = RejectAllHandler::default();
259 assert!(std::mem::size_of_val(&handler) == 0);
260 }
261
262 #[test]
263 fn test_auto_approve_default() {
264 let handler = AutoApproveHandler::default();
265 assert!(std::mem::size_of_val(&handler) == 0);
266 }
267}