adk_guardrail/
executor.rs1use crate::{Guardrail, GuardrailError, GuardrailResult, Result, Severity};
2use adk_core::Content;
3use futures::future::join_all;
4use std::sync::Arc;
5
6pub struct GuardrailSet {
8 guardrails: Vec<Arc<dyn Guardrail>>,
9}
10
11impl GuardrailSet {
12 pub fn new() -> Self {
13 Self { guardrails: Vec::new() }
14 }
15
16 pub fn with(mut self, guardrail: impl Guardrail + 'static) -> Self {
17 self.guardrails.push(Arc::new(guardrail));
18 self
19 }
20
21 pub fn with_arc(mut self, guardrail: Arc<dyn Guardrail>) -> Self {
22 self.guardrails.push(guardrail);
23 self
24 }
25
26 pub fn guardrails(&self) -> &[Arc<dyn Guardrail>] {
27 &self.guardrails
28 }
29
30 pub fn is_empty(&self) -> bool {
31 self.guardrails.is_empty()
32 }
33}
34
35impl Default for GuardrailSet {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41#[derive(Debug)]
43pub struct ExecutionResult {
44 pub passed: bool,
45 pub transformed_content: Option<Content>,
46 pub failures: Vec<(String, String, Severity)>, }
48
49pub struct GuardrailExecutor;
51
52impl GuardrailExecutor {
53 pub async fn run(guardrails: &GuardrailSet, content: &Content) -> Result<ExecutionResult> {
55 if guardrails.is_empty() {
56 return Ok(ExecutionResult {
57 passed: true,
58 transformed_content: None,
59 failures: vec![],
60 });
61 }
62
63 let (parallel, sequential): (Vec<_>, Vec<_>) =
65 guardrails.guardrails().iter().partition(|g| g.run_parallel());
66
67 let mut current_content = content.clone();
68 let mut all_failures = Vec::new();
69
70 if !parallel.is_empty() {
72 let futures: Vec<_> = parallel
73 .iter()
74 .map(|g| Self::run_single(Arc::clone(g), ¤t_content))
75 .collect();
76
77 let results = join_all(futures).await;
78
79 for (guardrail, result) in parallel.iter().zip(results) {
80 match result {
81 GuardrailResult::Pass => {}
82 GuardrailResult::Fail { reason, severity } => {
83 all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
84 if severity == Severity::Critical && guardrail.fail_fast() {
86 return Err(GuardrailError::ValidationFailed {
87 name: guardrail.name().to_string(),
88 reason,
89 severity,
90 });
91 }
92 }
93 GuardrailResult::Transform { new_content, reason } => {
94 tracing::debug!(
95 guardrail = guardrail.name(),
96 reason = %reason,
97 "Content transformed"
98 );
99 current_content = new_content;
100 }
101 }
102 }
103 }
104
105 for guardrail in sequential {
107 let result = Self::run_single(Arc::clone(guardrail), ¤t_content).await;
108 match result {
109 GuardrailResult::Pass => {}
110 GuardrailResult::Fail { reason, severity } => {
111 all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
112 if severity == Severity::Critical && guardrail.fail_fast() {
113 return Err(GuardrailError::ValidationFailed {
114 name: guardrail.name().to_string(),
115 reason,
116 severity,
117 });
118 }
119 }
120 GuardrailResult::Transform { new_content, reason } => {
121 tracing::debug!(
122 guardrail = guardrail.name(),
123 reason = %reason,
124 "Content transformed"
125 );
126 current_content = new_content;
127 }
128 }
129 }
130
131 let passed =
132 all_failures.is_empty() || all_failures.iter().all(|(_, _, s)| *s == Severity::Low);
133
134 let was_transformed =
136 serde_json::to_string(¤t_content).ok() != serde_json::to_string(content).ok();
137 let transformed = if was_transformed { Some(current_content) } else { None };
138
139 Ok(ExecutionResult { passed, transformed_content: transformed, failures: all_failures })
140 }
141
142 async fn run_single(guardrail: Arc<dyn Guardrail>, content: &Content) -> GuardrailResult {
143 guardrail.validate(content).await
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 struct PassGuardrail;
152
153 #[async_trait::async_trait]
154 impl Guardrail for PassGuardrail {
155 fn name(&self) -> &str {
156 "pass"
157 }
158 async fn validate(&self, _: &Content) -> GuardrailResult {
159 GuardrailResult::Pass
160 }
161 }
162
163 struct FailGuardrail {
164 severity: Severity,
165 }
166
167 #[async_trait::async_trait]
168 impl Guardrail for FailGuardrail {
169 fn name(&self) -> &str {
170 "fail"
171 }
172 async fn validate(&self, _: &Content) -> GuardrailResult {
173 GuardrailResult::Fail { reason: "test failure".into(), severity: self.severity }
174 }
175 }
176
177 #[tokio::test]
178 async fn test_empty_guardrails_pass() {
179 let set = GuardrailSet::new();
180 let content = Content::new("user").with_text("hello");
181 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
182 assert!(result.passed);
183 }
184
185 #[tokio::test]
186 async fn test_pass_guardrail() {
187 let set = GuardrailSet::new().with(PassGuardrail);
188 let content = Content::new("user").with_text("hello");
189 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
190 assert!(result.passed);
191 }
192
193 #[tokio::test]
194 async fn test_fail_guardrail_low_severity() {
195 let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Low });
196 let content = Content::new("user").with_text("hello");
197 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
198 assert!(result.passed); assert_eq!(result.failures.len(), 1);
200 }
201
202 #[tokio::test]
203 async fn test_fail_guardrail_high_severity() {
204 let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::High });
205 let content = Content::new("user").with_text("hello");
206 let result = GuardrailExecutor::run(&set, &content).await.unwrap();
207 assert!(!result.passed);
208 }
209
210 #[tokio::test]
211 async fn test_critical_early_exit() {
212 let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Critical });
213 let content = Content::new("user").with_text("hello");
214 let result = GuardrailExecutor::run(&set, &content).await;
215 assert!(result.is_err());
216 }
217}