Skip to main content

adk_guardrail/
executor.rs

1use crate::{Guardrail, GuardrailError, GuardrailResult, Result, Severity};
2use adk_core::Content;
3use futures::future::join_all;
4use std::sync::Arc;
5
6/// A collection of guardrails to execute together.
7///
8/// Use the builder-style [`with`](Self::with) method to add guardrails.
9pub struct GuardrailSet {
10    guardrails: Vec<Arc<dyn Guardrail>>,
11}
12
13impl GuardrailSet {
14    /// Create an empty guardrail set.
15    pub fn new() -> Self {
16        Self { guardrails: Vec::new() }
17    }
18
19    /// Add a guardrail (by value, automatically wrapped in `Arc`).
20    pub fn with(mut self, guardrail: impl Guardrail + 'static) -> Self {
21        self.guardrails.push(Arc::new(guardrail));
22        self
23    }
24
25    /// Add a pre-wrapped guardrail.
26    pub fn with_arc(mut self, guardrail: Arc<dyn Guardrail>) -> Self {
27        self.guardrails.push(guardrail);
28        self
29    }
30
31    /// Get a reference to the registered guardrails.
32    pub fn guardrails(&self) -> &[Arc<dyn Guardrail>] {
33        &self.guardrails
34    }
35
36    /// Returns `true` if no guardrails have been added.
37    pub fn is_empty(&self) -> bool {
38        self.guardrails.is_empty()
39    }
40}
41
42impl Default for GuardrailSet {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48/// Result of running a [`GuardrailSet`].
49#[derive(Debug)]
50pub struct ExecutionResult {
51    /// `true` if all guardrails passed (no critical failures).
52    pub passed: bool,
53    /// Content after guardrail transformations, or `None` if unchanged.
54    pub transformed_content: Option<Content>,
55    /// List of failures as `(guardrail_name, reason, severity)`.
56    pub failures: Vec<(String, String, Severity)>,
57}
58
59/// Executor for running guardrails in parallel
60pub struct GuardrailExecutor;
61
62impl GuardrailExecutor {
63    /// Run all guardrails in parallel, with early exit on critical failures
64    pub async fn run(guardrails: &GuardrailSet, content: &Content) -> Result<ExecutionResult> {
65        if guardrails.is_empty() {
66            return Ok(ExecutionResult {
67                passed: true,
68                transformed_content: None,
69                failures: vec![],
70            });
71        }
72
73        // Separate parallel and sequential guardrails
74        let (parallel, sequential): (Vec<_>, Vec<_>) =
75            guardrails.guardrails().iter().partition(|g| g.run_parallel());
76
77        let mut current_content = content.clone();
78        let mut all_failures = Vec::new();
79
80        // Run parallel guardrails
81        if !parallel.is_empty() {
82            let futures: Vec<_> = parallel
83                .iter()
84                .map(|g| Self::run_single(Arc::clone(g), &current_content))
85                .collect();
86
87            let results = join_all(futures).await;
88
89            for (guardrail, result) in parallel.iter().zip(results) {
90                match result {
91                    GuardrailResult::Pass => {}
92                    GuardrailResult::Fail { reason, severity } => {
93                        all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
94                        // Early exit on critical
95                        if severity == Severity::Critical && guardrail.fail_fast() {
96                            return Err(GuardrailError::ValidationFailed {
97                                name: guardrail.name().to_string(),
98                                reason,
99                                severity,
100                            });
101                        }
102                    }
103                    GuardrailResult::Transform { new_content, reason } => {
104                        tracing::debug!(
105                            guardrail = guardrail.name(),
106                            reason = %reason,
107                            "Content transformed"
108                        );
109                        current_content = new_content;
110                    }
111                }
112            }
113        }
114
115        // Run sequential guardrails
116        for guardrail in sequential {
117            let result = Self::run_single(Arc::clone(guardrail), &current_content).await;
118            match result {
119                GuardrailResult::Pass => {}
120                GuardrailResult::Fail { reason, severity } => {
121                    all_failures.push((guardrail.name().to_string(), reason.clone(), severity));
122                    if severity == Severity::Critical && guardrail.fail_fast() {
123                        return Err(GuardrailError::ValidationFailed {
124                            name: guardrail.name().to_string(),
125                            reason,
126                            severity,
127                        });
128                    }
129                }
130                GuardrailResult::Transform { new_content, reason } => {
131                    tracing::debug!(
132                        guardrail = guardrail.name(),
133                        reason = %reason,
134                        "Content transformed"
135                    );
136                    current_content = new_content;
137                }
138            }
139        }
140
141        let passed =
142            all_failures.is_empty() || all_failures.iter().all(|(_, _, s)| *s == Severity::Low);
143
144        // Check if content was transformed by comparing serialized forms
145        let was_transformed =
146            serde_json::to_string(&current_content).ok() != serde_json::to_string(content).ok();
147        let transformed = if was_transformed { Some(current_content) } else { None };
148
149        Ok(ExecutionResult { passed, transformed_content: transformed, failures: all_failures })
150    }
151
152    async fn run_single(guardrail: Arc<dyn Guardrail>, content: &Content) -> GuardrailResult {
153        guardrail.validate(content).await
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    struct PassGuardrail;
162
163    #[async_trait::async_trait]
164    impl Guardrail for PassGuardrail {
165        fn name(&self) -> &str {
166            "pass"
167        }
168        async fn validate(&self, _: &Content) -> GuardrailResult {
169            GuardrailResult::Pass
170        }
171    }
172
173    struct FailGuardrail {
174        severity: Severity,
175    }
176
177    #[async_trait::async_trait]
178    impl Guardrail for FailGuardrail {
179        fn name(&self) -> &str {
180            "fail"
181        }
182        async fn validate(&self, _: &Content) -> GuardrailResult {
183            GuardrailResult::Fail { reason: "test failure".into(), severity: self.severity }
184        }
185    }
186
187    #[tokio::test]
188    async fn test_empty_guardrails_pass() {
189        let set = GuardrailSet::new();
190        let content = Content::new("user").with_text("hello");
191        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
192        assert!(result.passed);
193    }
194
195    #[tokio::test]
196    async fn test_pass_guardrail() {
197        let set = GuardrailSet::new().with(PassGuardrail);
198        let content = Content::new("user").with_text("hello");
199        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
200        assert!(result.passed);
201    }
202
203    #[tokio::test]
204    async fn test_fail_guardrail_low_severity() {
205        let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Low });
206        let content = Content::new("user").with_text("hello");
207        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
208        assert!(result.passed); // Low severity doesn't fail
209        assert_eq!(result.failures.len(), 1);
210    }
211
212    #[tokio::test]
213    async fn test_fail_guardrail_high_severity() {
214        let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::High });
215        let content = Content::new("user").with_text("hello");
216        let result = GuardrailExecutor::run(&set, &content).await.unwrap();
217        assert!(!result.passed);
218    }
219
220    #[tokio::test]
221    async fn test_critical_early_exit() {
222        let set = GuardrailSet::new().with(FailGuardrail { severity: Severity::Critical });
223        let content = Content::new("user").with_text("hello");
224        let result = GuardrailExecutor::run(&set, &content).await;
225        assert!(result.is_err());
226    }
227}