adk_guardrail/
executor.rs

1use crate::{Guardrail, GuardrailError, GuardrailResult, Result, Severity};
2use adk_core::Content;
3use futures::future::join_all;
4use std::sync::Arc;
5
6/// A set of guardrails to run together
7pub struct GuardrailSet {
8    guardrails: Vec<Arc<dyn Guardrail>>,
9}
10
11impl GuardrailSet {
12    pub fn new() -> Self {
13        Self { guardrails: Vec::new() }
14    }
15
16    pub fn with(mut self, guardrail: impl Guardrail + 'static) -> Self {
17        self.guardrails.push(Arc::new(guardrail));
18        self
19    }
20
21    pub fn with_arc(mut self, guardrail: Arc<dyn Guardrail>) -> Self {
22        self.guardrails.push(guardrail);
23        self
24    }
25
26    pub fn guardrails(&self) -> &[Arc<dyn Guardrail>] {
27        &self.guardrails
28    }
29
30    pub fn is_empty(&self) -> bool {
31        self.guardrails.is_empty()
32    }
33}
34
35impl Default for GuardrailSet {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41/// Result of running a guardrail set
42#[derive(Debug)]
43pub struct ExecutionResult {
44    pub passed: bool,
45    pub transformed_content: Option<Content>,
46    pub failures: Vec<(String, String, Severity)>, // (name, reason, severity)
47}
48
49/// Executor for running guardrails in parallel
50pub struct GuardrailExecutor;
51
52impl GuardrailExecutor {
53    /// Run all guardrails in parallel, with early exit on critical failures
54    pub async fn run(guardrails: &GuardrailSet, content: &Content) -> Result<ExecutionResult> {
55        if guardrails.is_empty() {
56            return Ok(ExecutionResult {
57                passed: true,
58                transformed_content: None,
59                failures: vec![],
60            });
61        }
62
63        // Separate parallel and sequential guardrails
64        let (parallel, sequential): (Vec<_>, Vec<_>) =
65            guardrails.guardrails().iter().partition(|g| g.run_parallel());
66
67        let mut current_content = content.clone();
68        let mut all_failures = Vec::new();
69
70        // Run parallel guardrails
71        if !parallel.is_empty() {
72            let futures: Vec<_> = parallel
73                .iter()
74                .map(|g| Self::run_single(Arc::clone(g), &current_content))
75                .collect();
76
77            let results = join_all(futures).await;
78
79            for (guardrail, result) in parallel.iter().zip(results) {
80                match result {
81                    GuardrailResult::Pass => {}
82                    GuardrailResult::Fail { reason, severity } => {
83                        all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
84                        // Early exit on critical
85                        if severity == Severity::Critical && guardrail.fail_fast() {
86                            return Err(GuardrailError::ValidationFailed {
87                                name: guardrail.name().to_string(),
88                                reason,
89                                severity,
90                            });
91                        }
92                    }
93                    GuardrailResult::Transform { new_content, reason } => {
94                        tracing::debug!(
95                            guardrail = guardrail.name(),
96                            reason = %reason,
97                            "Content transformed"
98                        );
99                        current_content = new_content;
100                    }
101                }
102            }
103        }
104
105        // Run sequential guardrails
106        for guardrail in sequential {
107            let result = Self::run_single(Arc::clone(guardrail), &current_content).await;
108            match result {
109                GuardrailResult::Pass => {}
110                GuardrailResult::Fail { reason, severity } => {
111                    all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
112                    if severity == Severity::Critical && guardrail.fail_fast() {
113                        return Err(GuardrailError::ValidationFailed {
114                            name: guardrail.name().to_string(),
115                            reason,
116                            severity,
117                        });
118                    }
119                }
120                GuardrailResult::Transform { new_content, reason } => {
121                    tracing::debug!(
122                        guardrail = guardrail.name(),
123                        reason = %reason,
124                        "Content transformed"
125                    );
126                    current_content = new_content;
127                }
128            }
129        }
130
131        let passed =
132            all_failures.is_empty() || all_failures.iter().all(|(_, _, s)| *s == Severity::Low);
133
134        // Check if content was transformed by comparing serialized forms
135        let was_transformed =
136            serde_json::to_string(&current_content).ok() != serde_json::to_string(content).ok();
137        let transformed = if was_transformed { Some(current_content) } else { None };
138
139        Ok(ExecutionResult { passed, transformed_content: transformed, failures: all_failures })
140    }
141
142    async fn run_single(guardrail: Arc<dyn Guardrail>, content: &Content) -> GuardrailResult {
143        guardrail.validate(content).await
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    struct PassGuardrail;
152
153    #[async_trait::async_trait]
154    impl Guardrail for PassGuardrail {
155        fn name(&self) -> &str {
156            "pass"
157        }
158        async fn validate(&self, _: &Content) -> GuardrailResult {
159            GuardrailResult::Pass
160        }
161    }
162
163    struct FailGuardrail {
164        severity: Severity,
165    }
166
167    #[async_trait::async_trait]
168    impl Guardrail for FailGuardrail {
169        fn name(&self) -> &str {
170            "fail"
171        }
172        async fn validate(&self, _: &Content) -> GuardrailResult {
173            GuardrailResult::Fail { reason: "test failure".into(), severity: self.severity }
174        }
175    }
176
177    #[tokio::test]
178    async fn test_empty_guardrails_pass() {
179        let set = GuardrailSet::new();
180        let content = Content::new("user").with_text("hello");
181        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
182        assert!(result.passed);
183    }
184
185    #[tokio::test]
186    async fn test_pass_guardrail() {
187        let set = GuardrailSet::new().with(PassGuardrail);
188        let content = Content::new("user").with_text("hello");
189        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
190        assert!(result.passed);
191    }
192
193    #[tokio::test]
194    async fn test_fail_guardrail_low_severity() {
195        let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Low });
196        let content = Content::new("user").with_text("hello");
197        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
198        assert!(result.passed); // Low severity doesn't fail
199        assert_eq!(result.failures.len(), 1);
200    }
201
202    #[tokio::test]
203    async fn test_fail_guardrail_high_severity() {
204        let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::High });
205        let content = Content::new("user").with_text("hello");
206        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
207        assert!(!result.passed);
208    }
209
210    #[tokio::test]
211    async fn test_critical_early_exit() {
212        let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Critical });
213        let content = Content::new("user").with_text("hello");
214        let result = GuardrailExecutor::run(&set, &content).await;
215        assert!(result.is_err());
216    }
217}