1use std::{
2 fmt,
3 sync::atomic::{AtomicU64, Ordering},
4 time::SystemTime,
5};
6
7use async_trait::async_trait;
8use autoagents_llm::{
9 ToolCall,
10 chat::{ChatMessage, StructuredOutputFormat, Tool, Usage},
11 completion::CompletionRequest,
12};
13use serde_json::Value;
14
15use crate::policy::{GuardCategory, GuardSeverity};
16
17static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(1);
18pub const DEFAULT_REDACTED_TEXT: &str = "[redacted by guardrails]";
19
20#[derive(Debug, Clone)]
22pub struct GuardContext {
23 pub request_id: u64,
24 pub operation: GuardOperation,
25 pub created_at: SystemTime,
26}
27
28impl GuardContext {
29 pub fn new(operation: GuardOperation) -> Self {
30 Self {
31 request_id: REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed),
32 operation,
33 created_at: SystemTime::now(),
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, Eq, PartialEq)]
40pub enum GuardOperation {
41 Chat,
42 ChatWithTools,
43 ChatWithWebSearch,
44 ChatStream,
45 ChatStreamStruct,
46 ChatStreamWithTools,
47 Complete,
48}
49
50impl fmt::Display for GuardOperation {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let value = match self {
53 GuardOperation::Chat => "chat",
54 GuardOperation::ChatWithTools => "chat_with_tools",
55 GuardOperation::ChatWithWebSearch => "chat_with_web_search",
56 GuardOperation::ChatStream => "chat_stream",
57 GuardOperation::ChatStreamStruct => "chat_stream_struct",
58 GuardOperation::ChatStreamWithTools => "chat_stream_with_tools",
59 GuardOperation::Complete => "complete",
60 };
61 f.write_str(value)
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct GuardViolation {
68 pub rule_id: String,
69 pub category: GuardCategory,
70 pub severity: GuardSeverity,
71 pub message: String,
72 pub metadata: Option<Value>,
73}
74
75impl GuardViolation {
76 pub fn new(
77 rule_id: impl Into<String>,
78 category: GuardCategory,
79 severity: GuardSeverity,
80 message: impl Into<String>,
81 ) -> Self {
82 Self {
83 rule_id: rule_id.into(),
84 category,
85 severity,
86 message: message.into(),
87 metadata: None,
88 }
89 }
90
91 pub fn with_metadata(mut self, metadata: Value) -> Self {
92 self.metadata = Some(metadata);
93 self
94 }
95}
96
97#[derive(Debug, Clone)]
99pub enum GuardDecision {
100 Pass,
102 Modify { violation: Option<GuardViolation> },
104 Reject(GuardViolation),
106}
107
108impl GuardDecision {
109 pub fn pass() -> Self {
110 Self::Pass
111 }
112
113 pub fn modify() -> Self {
114 Self::Modify { violation: None }
115 }
116
117 pub fn reject(violation: GuardViolation) -> Self {
118 Self::Reject(violation)
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct GuardError {
125 pub message: String,
126}
127
128impl GuardError {
129 pub fn new(message: impl Into<String>) -> Self {
130 Self {
131 message: message.into(),
132 }
133 }
134}
135
136impl fmt::Display for GuardError {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.write_str(&self.message)
139 }
140}
141
142impl std::error::Error for GuardError {}
143
144#[async_trait]
146pub trait InputGuard: Send + Sync + 'static {
147 fn name(&self) -> &'static str;
149
150 async fn inspect(
152 &self,
153 input: &mut GuardedInput,
154 context: &GuardContext,
155 ) -> Result<GuardDecision, GuardError>;
156}
157
158#[async_trait]
160pub trait OutputGuard: Send + Sync + 'static {
161 fn name(&self) -> &'static str;
163
164 async fn inspect(
166 &self,
167 output: &mut GuardedOutput,
168 context: &GuardContext,
169 ) -> Result<GuardDecision, GuardError>;
170}
171
172#[derive(Debug, Clone)]
174pub struct ChatGuardInput {
175 pub messages: Vec<ChatMessage>,
176 pub tools: Option<Vec<Tool>>,
177 pub json_schema: Option<StructuredOutputFormat>,
178}
179
180#[derive(Debug, Clone)]
182pub struct CompletionGuardInput {
183 pub request: CompletionRequest,
184 pub json_schema: Option<StructuredOutputFormat>,
185}
186
187#[derive(Debug, Clone)]
189pub struct WebSearchGuardInput {
190 pub input: String,
191}
192
193#[derive(Debug, Clone)]
195pub enum GuardedInput {
196 Chat(ChatGuardInput),
197 Completion(CompletionGuardInput),
198 WebSearch(WebSearchGuardInput),
199}
200
201impl GuardedInput {
202 pub fn redact_all(&mut self) {
204 self.redact_with(DEFAULT_REDACTED_TEXT);
205 }
206
207 pub fn redact_with(&mut self, replacement: &str) {
209 match self {
210 GuardedInput::Chat(chat) => {
211 for message in &mut chat.messages {
212 message.content = replacement.to_string();
213 }
214 }
215 GuardedInput::Completion(completion) => {
216 completion.request.prompt = replacement.to_string();
217 }
218 GuardedInput::WebSearch(web) => {
219 web.input = replacement.to_string();
220 }
221 }
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct ChatGuardOutput {
228 pub text: Option<String>,
229 pub tool_calls: Option<Vec<ToolCall>>,
230 pub thinking: Option<String>,
231 pub usage: Option<Usage>,
232}
233
234#[derive(Debug, Clone)]
236pub struct CompletionGuardOutput {
237 pub text: String,
238}
239
240#[derive(Debug, Clone)]
242pub enum GuardedOutput {
243 Chat(ChatGuardOutput),
244 Completion(CompletionGuardOutput),
245}
246
247impl GuardedOutput {
248 pub fn redact_all(&mut self) {
251 self.redact_with(DEFAULT_REDACTED_TEXT);
252 }
253
254 pub fn redact_with(&mut self, replacement: &str) {
257 match self {
258 GuardedOutput::Chat(chat) => {
259 chat.text = Some(replacement.to_string());
260 chat.thinking = None;
261 chat.tool_calls = None;
262 }
263 GuardedOutput::Completion(completion) => {
264 completion.text = replacement.to_string();
265 }
266 }
267 }
268
269 pub fn redact_text_only(&mut self) {
271 match self {
272 GuardedOutput::Chat(chat) => {
273 chat.text = Some(DEFAULT_REDACTED_TEXT.to_string());
274 }
275 GuardedOutput::Completion(completion) => {
276 completion.text = DEFAULT_REDACTED_TEXT.to_string();
277 }
278 }
279 }
280}