Skip to main content

autoagents_guardrails/
guard.rs

1use std::{
2    fmt,
3    sync::atomic::{AtomicU64, Ordering},
4    time::SystemTime,
5};
6
7use async_trait::async_trait;
8use autoagents_llm::{
9    ToolCall,
10    chat::{ChatMessage, StructuredOutputFormat, Tool, Usage},
11    completion::CompletionRequest,
12};
13use serde_json::Value;
14
15use crate::policy::{GuardCategory, GuardSeverity};
16
17static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(1);
18pub const DEFAULT_REDACTED_TEXT: &str = "[redacted by guardrails]";
19
20/// Immutable metadata attached to each guardrails evaluation.
21#[derive(Debug, Clone)]
22pub struct GuardContext {
23    pub request_id: u64,
24    pub operation: GuardOperation,
25    pub created_at: SystemTime,
26}
27
28impl GuardContext {
29    pub fn new(operation: GuardOperation) -> Self {
30        Self {
31            request_id: REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed),
32            operation,
33            created_at: SystemTime::now(),
34        }
35    }
36}
37
38/// LLM operation currently evaluated by guardrails.
39#[derive(Debug, Clone, Copy, Eq, PartialEq)]
40pub enum GuardOperation {
41    Chat,
42    ChatWithTools,
43    ChatWithWebSearch,
44    ChatStream,
45    ChatStreamStruct,
46    ChatStreamWithTools,
47    Complete,
48}
49
50impl fmt::Display for GuardOperation {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        let value = match self {
53            GuardOperation::Chat => "chat",
54            GuardOperation::ChatWithTools => "chat_with_tools",
55            GuardOperation::ChatWithWebSearch => "chat_with_web_search",
56            GuardOperation::ChatStream => "chat_stream",
57            GuardOperation::ChatStreamStruct => "chat_stream_struct",
58            GuardOperation::ChatStreamWithTools => "chat_stream_with_tools",
59            GuardOperation::Complete => "complete",
60        };
61        f.write_str(value)
62    }
63}
64
65/// A rule hit returned by a guard implementation.
66#[derive(Debug, Clone)]
67pub struct GuardViolation {
68    pub rule_id: String,
69    pub category: GuardCategory,
70    pub severity: GuardSeverity,
71    pub message: String,
72    pub metadata: Option<Value>,
73}
74
75impl GuardViolation {
76    pub fn new(
77        rule_id: impl Into<String>,
78        category: GuardCategory,
79        severity: GuardSeverity,
80        message: impl Into<String>,
81    ) -> Self {
82        Self {
83            rule_id: rule_id.into(),
84            category,
85            severity,
86            message: message.into(),
87            metadata: None,
88        }
89    }
90
91    pub fn with_metadata(mut self, metadata: Value) -> Self {
92        self.metadata = Some(metadata);
93        self
94    }
95}
96
97/// Decision returned by each guard invocation.
98#[derive(Debug, Clone)]
99pub enum GuardDecision {
100    /// No issue found.
101    Pass,
102    /// Guard mutated payload in-place and wants processing to continue.
103    Modify { violation: Option<GuardViolation> },
104    /// Guard found a violation and wants policy handling.
105    Reject(GuardViolation),
106}
107
108impl GuardDecision {
109    pub fn pass() -> Self {
110        Self::Pass
111    }
112
113    pub fn modify() -> Self {
114        Self::Modify { violation: None }
115    }
116
117    pub fn reject(violation: GuardViolation) -> Self {
118        Self::Reject(violation)
119    }
120}
121
122/// Error emitted by a guard implementation.
123#[derive(Debug, Clone)]
124pub struct GuardError {
125    pub message: String,
126}
127
128impl GuardError {
129    pub fn new(message: impl Into<String>) -> Self {
130        Self {
131            message: message.into(),
132        }
133    }
134}
135
136impl fmt::Display for GuardError {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        f.write_str(&self.message)
139    }
140}
141
142impl std::error::Error for GuardError {}
143
144/// Trait implemented by request/input guardrails.
145#[async_trait]
146pub trait InputGuard: Send + Sync + 'static {
147    /// Stable identifier used for diagnostics.
148    fn name(&self) -> &'static str;
149
150    /// Inspect and optionally mutate input payload.
151    async fn inspect(
152        &self,
153        input: &mut GuardedInput,
154        context: &GuardContext,
155    ) -> Result<GuardDecision, GuardError>;
156}
157
158/// Trait implemented by response/output guardrails.
159#[async_trait]
160pub trait OutputGuard: Send + Sync + 'static {
161    /// Stable identifier used for diagnostics.
162    fn name(&self) -> &'static str;
163
164    /// Inspect and optionally mutate output payload.
165    async fn inspect(
166        &self,
167        output: &mut GuardedOutput,
168        context: &GuardContext,
169    ) -> Result<GuardDecision, GuardError>;
170}
171
172/// Owned chat payload used by input guards.
173#[derive(Debug, Clone)]
174pub struct ChatGuardInput {
175    pub messages: Vec<ChatMessage>,
176    pub tools: Option<Vec<Tool>>,
177    pub json_schema: Option<StructuredOutputFormat>,
178}
179
180/// Owned completion payload used by input guards.
181#[derive(Debug, Clone)]
182pub struct CompletionGuardInput {
183    pub request: CompletionRequest,
184    pub json_schema: Option<StructuredOutputFormat>,
185}
186
187/// Owned web search payload used by input guards.
188#[derive(Debug, Clone)]
189pub struct WebSearchGuardInput {
190    pub input: String,
191}
192
193/// Input payload union passed through input guards.
194#[derive(Debug, Clone)]
195pub enum GuardedInput {
196    Chat(ChatGuardInput),
197    Completion(CompletionGuardInput),
198    WebSearch(WebSearchGuardInput),
199}
200
201impl GuardedInput {
202    /// Redact every text field using the default placeholder.
203    pub fn redact_all(&mut self) {
204        self.redact_with(DEFAULT_REDACTED_TEXT);
205    }
206
207    /// Redact every text field with a custom replacement string.
208    pub fn redact_with(&mut self, replacement: &str) {
209        match self {
210            GuardedInput::Chat(chat) => {
211                for message in &mut chat.messages {
212                    message.content = replacement.to_string();
213                }
214            }
215            GuardedInput::Completion(completion) => {
216                completion.request.prompt = replacement.to_string();
217            }
218            GuardedInput::WebSearch(web) => {
219                web.input = replacement.to_string();
220            }
221        }
222    }
223}
224
225/// Materialized chat output payload used by output guards.
226#[derive(Debug, Clone)]
227pub struct ChatGuardOutput {
228    pub text: Option<String>,
229    pub tool_calls: Option<Vec<ToolCall>>,
230    pub thinking: Option<String>,
231    pub usage: Option<Usage>,
232}
233
234/// Materialized completion output payload used by output guards.
235#[derive(Debug, Clone)]
236pub struct CompletionGuardOutput {
237    pub text: String,
238}
239
240/// Output payload union passed through output guards.
241#[derive(Debug, Clone)]
242pub enum GuardedOutput {
243    Chat(ChatGuardOutput),
244    Completion(CompletionGuardOutput),
245}
246
247impl GuardedOutput {
248    /// Redact output content using the default placeholder and clear optional
249    /// chat-specific metadata.
250    pub fn redact_all(&mut self) {
251        self.redact_with(DEFAULT_REDACTED_TEXT);
252    }
253
254    /// Redact output content with a custom replacement and clear optional
255    /// chat-specific metadata.
256    pub fn redact_with(&mut self, replacement: &str) {
257        match self {
258            GuardedOutput::Chat(chat) => {
259                chat.text = Some(replacement.to_string());
260                chat.thinking = None;
261                chat.tool_calls = None;
262            }
263            GuardedOutput::Completion(completion) => {
264                completion.text = replacement.to_string();
265            }
266        }
267    }
268
269    /// Redact only text fields while preserving non-text chat metadata.
270    pub fn redact_text_only(&mut self) {
271        match self {
272            GuardedOutput::Chat(chat) => {
273                chat.text = Some(DEFAULT_REDACTED_TEXT.to_string());
274            }
275            GuardedOutput::Completion(completion) => {
276                completion.text = DEFAULT_REDACTED_TEXT.to_string();
277            }
278        }
279    }
280}