#[cfg(feature = "guard")]
pub mod content;
#[cfg(feature = "guard")]
pub mod llm;
#[cfg(feature = "guard")]
pub mod rule;
use crate::error::Result;
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GuardDirection {
Input,
Output,
ToolInput,
ToolOutput,
}
impl std::fmt::Display for GuardDirection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GuardDirection::Input => write!(f, "input"),
GuardDirection::Output => write!(f, "output"),
GuardDirection::ToolInput => write!(f, "tool_input"),
GuardDirection::ToolOutput => write!(f, "tool_output"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GuardResult {
Pass,
Block {
reason: String,
},
Warn {
reasons: Vec<String>,
},
}
impl GuardResult {
pub fn is_blocked(&self) -> bool {
matches!(self, GuardResult::Block { .. })
}
}
pub trait Guard: Send + Sync {
fn name(&self) -> &str;
fn check<'a>(
&'a self,
content: &'a str,
direction: GuardDirection,
) -> BoxFuture<'a, Result<GuardResult>>;
}
pub struct GuardManager {
guards: Vec<Arc<dyn Guard>>,
}
impl Default for GuardManager {
fn default() -> Self {
Self::new()
}
}
impl GuardManager {
pub fn new() -> Self {
Self { guards: Vec::new() }
}
pub fn add(&mut self, guard: Arc<dyn Guard>) {
self.guards.push(guard);
}
pub fn from_guards(guards: Vec<Arc<dyn Guard>>) -> Self {
Self { guards }
}
pub fn is_empty(&self) -> bool {
self.guards.is_empty()
}
pub async fn check_all(&self, content: &str, direction: GuardDirection) -> Result<GuardResult> {
if self.guards.is_empty() {
return Ok(GuardResult::Pass);
}
let semaphore = Arc::new(tokio::sync::Semaphore::new(16));
let cancel = CancellationToken::new();
let mut handles = Vec::with_capacity(self.guards.len());
for guard in &self.guards {
let guard = guard.clone();
let content = content.to_string();
let cancel_child = cancel.clone();
let permit = semaphore.clone().acquire_owned().await;
handles.push(tokio::spawn(async move {
let _permit = permit; let result = tokio::select! {
_ = cancel_child.cancelled() => {
return (guard.name().to_string(), Ok(GuardResult::Pass));
}
r = guard.check(&content, direction) => r,
};
(guard.name().to_string(), result)
}));
}
let mut warnings = Vec::new();
for (i, handle) in handles.into_iter().enumerate() {
let (guard_name, result) = handle.await.map_err(|e| {
crate::error::ReactError::Other(format!("Guard task {} panicked: {}", i, e))
})?;
match result {
Ok(GuardResult::Block { reason }) => {
cancel.cancel(); tracing::warn!(
guard = guard_name,
direction = %direction,
reason = %reason,
"Guard blocked content"
);
return Ok(GuardResult::Block { reason });
}
Ok(GuardResult::Warn { reasons }) => {
warnings.extend(reasons);
}
Ok(GuardResult::Pass) => {}
Err(e) => {
tracing::error!(guard = guard_name, error = %e, "Guard check error");
warnings.push(format!("{} error: {}", guard_name, e));
}
}
}
if !warnings.is_empty() {
Ok(GuardResult::Warn { reasons: warnings })
} else {
Ok(GuardResult::Pass)
}
}
}