adk_guardrail/
executor.rs1use crate::{Guardrail, GuardrailError, GuardrailResult, Result, Severity};
2use adk_core::Content;
3use futures::future::join_all;
4use std::sync::Arc;
5
6pub struct GuardrailSet {
10 guardrails: Vec<Arc<dyn Guardrail>>,
11}
12
13impl GuardrailSet {
14 pub fn new() -> Self {
16 Self { guardrails: Vec::new() }
17 }
18
19 pub fn with(mut self, guardrail: impl Guardrail + 'static) -> Self {
21 self.guardrails.push(Arc::new(guardrail));
22 self
23 }
24
25 pub fn with_arc(mut self, guardrail: Arc<dyn Guardrail>) -> Self {
27 self.guardrails.push(guardrail);
28 self
29 }
30
31 pub fn guardrails(&self) -> &[Arc<dyn Guardrail>] {
33 &self.guardrails
34 }
35
36 pub fn is_empty(&self) -> bool {
38 self.guardrails.is_empty()
39 }
40}
41
42impl Default for GuardrailSet {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48#[derive(Debug)]
50pub struct ExecutionResult {
51 pub passed: bool,
53 pub transformed_content: Option<Content>,
55 pub failures: Vec<(String, String, Severity)>,
57}
58
59pub struct GuardrailExecutor;
61
62impl GuardrailExecutor {
63 pub async fn run(guardrails: &GuardrailSet, content: &Content) -> Result<ExecutionResult> {
65 if guardrails.is_empty() {
66 return Ok(ExecutionResult {
67 passed: true,
68 transformed_content: None,
69 failures: vec![],
70 });
71 }
72
73 let (parallel, sequential): (Vec<_>, Vec<_>) =
75 guardrails.guardrails().iter().partition(|g| g.run_parallel());
76
77 let mut current_content = content.clone();
78 let mut all_failures = Vec::new();
79
80 if !parallel.is_empty() {
82 let futures: Vec<_> = parallel
83 .iter()
84 .map(|g| Self::run_single(Arc::clone(g), ¤t_content))
85 .collect();
86
87 let results = join_all(futures).await;
88
89 for (guardrail, result) in parallel.iter().zip(results) {
90 match result {
91 GuardrailResult::Pass => {}
92 GuardrailResult::Fail { reason, severity } => {
93 all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
94 if severity == Severity::Critical && guardrail.fail_fast() {
96 return Err(GuardrailError::ValidationFailed {
97 name: guardrail.name().to_string(),
98 reason,
99 severity,
100 });
101 }
102 }
103 GuardrailResult::Transform { new_content, reason } => {
104 tracing::debug!(
105 guardrail = guardrail.name(),
106 reason = %reason,
107 "Content transformed"
108 );
109 current_content = new_content;
110 }
111 }
112 }
113 }
114
115 for guardrail in sequential {
117 let result = Self::run_single(Arc::clone(guardrail), ¤t_content).await;
118 match result {
119 GuardrailResult::Pass => {}
120 GuardrailResult::Fail { reason, severity } => {
121 all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
122 if severity == Severity::Critical && guardrail.fail_fast() {
123 return Err(GuardrailError::ValidationFailed {
124 name: guardrail.name().to_string(),
125 reason,
126 severity,
127 });
128 }
129 }
130 GuardrailResult::Transform { new_content, reason } => {
131 tracing::debug!(
132 guardrail = guardrail.name(),
133 reason = %reason,
134 "Content transformed"
135 );
136 current_content = new_content;
137 }
138 }
139 }
140
141 let passed =
142 all_failures.is_empty() || all_failures.iter().all(|(_, _, s)| *s == Severity::Low);
143
144 let was_transformed =
146 serde_json::to_string(¤t_content).ok() != serde_json::to_string(content).ok();
147 let transformed = if was_transformed { Some(current_content) } else { None };
148
149 Ok(ExecutionResult { passed, transformed_content: transformed, failures: all_failures })
150 }
151
152 async fn run_single(guardrail: Arc<dyn Guardrail>, content: &Content) -> GuardrailResult {
153 guardrail.validate(content).await
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 struct PassGuardrail;
162
163 #[async_trait::async_trait]
164 impl Guardrail for PassGuardrail {
165 fn name(&self) -> &str {
166 "pass"
167 }
168 async fn validate(&self, _: &Content) -> GuardrailResult {
169 GuardrailResult::Pass
170 }
171 }
172
173 struct FailGuardrail {
174 severity: Severity,
175 }
176
177 #[async_trait::async_trait]
178 impl Guardrail for FailGuardrail {
179 fn name(&self) -> &str {
180 "fail"
181 }
182 async fn validate(&self, _: &Content) -> GuardrailResult {
183 GuardrailResult::Fail { reason: "test failure".into(), severity: self.severity }
184 }
185 }
186
187 #[tokio::test]
188 async fn test_empty_guardrails_pass() {
189 let set = GuardrailSet::new();
190 let content = Content::new("user").with_text("hello");
191 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
192 assert!(result.passed);
193 }
194
195 #[tokio::test]
196 async fn test_pass_guardrail() {
197 let set = GuardrailSet::new().with(PassGuardrail);
198 let content = Content::new("user").with_text("hello");
199 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
200 assert!(result.passed);
201 }
202
203 #[tokio::test]
204 async fn test_fail_guardrail_low_severity() {
205 let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Low });
206 let content = Content::new("user").with_text("hello");
207 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
208 assert!(result.passed); assert_eq!(result.failures.len(), 1);
210 }
211
212 #[tokio::test]
213 async fn test_fail_guardrail_high_severity() {
214 let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::High });
215 let content = Content::new("user").with_text("hello");
216 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
217 assert!(!result.passed);
218 }
219
220 #[tokio::test]
221 async fn test_critical_early_exit() {
222 let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Critical });
223 let content = Content::new("user").with_text("hello");
224 let result = GuardrailExecutor::run(&set, &content).await;
225 assert!(result.is_err());
226 }
227}