use std::{
fmt,
sync::atomic::{AtomicU64, Ordering},
time::SystemTime,
};
use async_trait::async_trait;
use autoagents_llm::{
ToolCall,
chat::{ChatMessage, StructuredOutputFormat, Tool, Usage},
completion::CompletionRequest,
};
use serde_json::Value;
use crate::policy::{GuardCategory, GuardSeverity};
static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(1);
pub const DEFAULT_REDACTED_TEXT: &str = "[redacted by guardrails]";
#[derive(Debug, Clone)]
pub struct GuardContext {
pub request_id: u64,
pub operation: GuardOperation,
pub created_at: SystemTime,
}
impl GuardContext {
pub fn new(operation: GuardOperation) -> Self {
Self {
request_id: REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed),
operation,
created_at: SystemTime::now(),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum GuardOperation {
Chat,
ChatWithTools,
ChatWithWebSearch,
ChatStream,
ChatStreamStruct,
ChatStreamWithTools,
Complete,
}
impl fmt::Display for GuardOperation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let value = match self {
GuardOperation::Chat => "chat",
GuardOperation::ChatWithTools => "chat_with_tools",
GuardOperation::ChatWithWebSearch => "chat_with_web_search",
GuardOperation::ChatStream => "chat_stream",
GuardOperation::ChatStreamStruct => "chat_stream_struct",
GuardOperation::ChatStreamWithTools => "chat_stream_with_tools",
GuardOperation::Complete => "complete",
};
f.write_str(value)
}
}
#[derive(Debug, Clone)]
pub struct GuardViolation {
pub rule_id: String,
pub category: GuardCategory,
pub severity: GuardSeverity,
pub message: String,
pub metadata: Option<Value>,
}
impl GuardViolation {
pub fn new(
rule_id: impl Into<String>,
category: GuardCategory,
severity: GuardSeverity,
message: impl Into<String>,
) -> Self {
Self {
rule_id: rule_id.into(),
category,
severity,
message: message.into(),
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = Some(metadata);
self
}
}
#[derive(Debug, Clone)]
pub enum GuardDecision {
Pass,
Modify { violation: Option<GuardViolation> },
Reject(GuardViolation),
}
impl GuardDecision {
pub fn pass() -> Self {
Self::Pass
}
pub fn modify() -> Self {
Self::Modify { violation: None }
}
pub fn reject(violation: GuardViolation) -> Self {
Self::Reject(violation)
}
}
#[derive(Debug, Clone)]
pub struct GuardError {
pub message: String,
}
impl GuardError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
}
impl fmt::Display for GuardError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for GuardError {}
#[async_trait]
pub trait InputGuard: Send + Sync + 'static {
fn name(&self) -> &'static str;
async fn inspect(
&self,
input: &mut GuardedInput,
context: &GuardContext,
) -> Result<GuardDecision, GuardError>;
}
#[async_trait]
pub trait OutputGuard: Send + Sync + 'static {
fn name(&self) -> &'static str;
async fn inspect(
&self,
output: &mut GuardedOutput,
context: &GuardContext,
) -> Result<GuardDecision, GuardError>;
}
#[derive(Debug, Clone)]
pub struct ChatGuardInput {
pub messages: Vec<ChatMessage>,
pub tools: Option<Vec<Tool>>,
pub json_schema: Option<StructuredOutputFormat>,
}
#[derive(Debug, Clone)]
pub struct CompletionGuardInput {
pub request: CompletionRequest,
pub json_schema: Option<StructuredOutputFormat>,
}
#[derive(Debug, Clone)]
pub struct WebSearchGuardInput {
pub input: String,
}
#[derive(Debug, Clone)]
pub enum GuardedInput {
Chat(ChatGuardInput),
Completion(CompletionGuardInput),
WebSearch(WebSearchGuardInput),
}
impl GuardedInput {
pub fn redact_all(&mut self) {
self.redact_with(DEFAULT_REDACTED_TEXT);
}
pub fn redact_with(&mut self, replacement: &str) {
match self {
GuardedInput::Chat(chat) => {
for message in &mut chat.messages {
message.content = replacement.to_string();
}
}
GuardedInput::Completion(completion) => {
completion.request.prompt = replacement.to_string();
}
GuardedInput::WebSearch(web) => {
web.input = replacement.to_string();
}
}
}
}
#[derive(Debug, Clone)]
pub struct ChatGuardOutput {
pub text: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub thinking: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone)]
pub struct CompletionGuardOutput {
pub text: String,
}
#[derive(Debug, Clone)]
pub enum GuardedOutput {
Chat(ChatGuardOutput),
Completion(CompletionGuardOutput),
}
impl GuardedOutput {
pub fn redact_all(&mut self) {
self.redact_with(DEFAULT_REDACTED_TEXT);
}
pub fn redact_with(&mut self, replacement: &str) {
match self {
GuardedOutput::Chat(chat) => {
chat.text = Some(replacement.to_string());
chat.thinking = None;
chat.tool_calls = None;
}
GuardedOutput::Completion(completion) => {
completion.text = replacement.to_string();
}
}
}
pub fn redact_text_only(&mut self) {
match self {
GuardedOutput::Chat(chat) => {
chat.text = Some(DEFAULT_REDACTED_TEXT.to_string());
}
GuardedOutput::Completion(completion) => {
completion.text = DEFAULT_REDACTED_TEXT.to_string();
}
}
}
}