use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::error::Result;
use crate::hooks::RunContext;
use crate::message::Message;
#[derive(Debug, Clone)]
pub struct GuardrailOutput {
pub tripwire_triggered: bool,
pub output_info: Value,
}
impl GuardrailOutput {
#[must_use]
pub const fn pass() -> Self {
Self {
tripwire_triggered: false,
output_info: Value::Null,
}
}
#[must_use]
pub fn tripwire(info: impl Into<Value>) -> Self {
Self {
tripwire_triggered: true,
output_info: info.into(),
}
}
#[must_use]
pub fn pass_with_info(info: impl Into<Value>) -> Self {
Self {
tripwire_triggered: false,
output_info: info.into(),
}
}
#[must_use]
pub const fn is_triggered(&self) -> bool {
self.tripwire_triggered
}
}
#[async_trait]
pub trait InputGuardrailCheck: Send + Sync {
async fn check(
&self,
context: &RunContext,
agent_name: &str,
input: &[Message],
) -> Result<GuardrailOutput>;
}
#[derive(Clone)]
pub struct InputGuardrail {
name: String,
run_in_parallel: bool,
check: Arc<dyn InputGuardrailCheck>,
}
impl InputGuardrail {
#[must_use]
pub fn new(name: impl Into<String>, check: impl InputGuardrailCheck + 'static) -> Self {
Self {
name: name.into(),
run_in_parallel: true,
check: Arc::new(check),
}
}
#[must_use]
pub const fn run_in_parallel(mut self, parallel: bool) -> Self {
self.run_in_parallel = parallel;
self
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub const fn is_parallel(&self) -> bool {
self.run_in_parallel
}
pub async fn run(
&self,
context: &RunContext,
agent_name: &str,
input: &[Message],
) -> Result<InputGuardrailResult> {
let output = self.check.check(context, agent_name, input).await?;
Ok(InputGuardrailResult {
guardrail_name: self.name.clone(),
output,
})
}
}
impl std::fmt::Debug for InputGuardrail {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InputGuardrail")
.field("name", &self.name)
.field("run_in_parallel", &self.run_in_parallel)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct InputGuardrailResult {
pub guardrail_name: String,
pub output: GuardrailOutput,
}
impl InputGuardrailResult {
#[must_use]
pub const fn is_triggered(&self) -> bool {
self.output.tripwire_triggered
}
}
#[async_trait]
pub trait OutputGuardrailCheck: Send + Sync {
async fn check(
&self,
context: &RunContext,
agent_name: &str,
output: &Value,
) -> Result<GuardrailOutput>;
}
#[derive(Clone)]
pub struct OutputGuardrail {
name: String,
check: Arc<dyn OutputGuardrailCheck>,
}
impl OutputGuardrail {
#[must_use]
pub fn new(name: impl Into<String>, check: impl OutputGuardrailCheck + 'static) -> Self {
Self {
name: name.into(),
check: Arc::new(check),
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
pub async fn run(
&self,
context: &RunContext,
agent_name: &str,
output: &Value,
) -> Result<OutputGuardrailResult> {
let guardrail_output = self.check.check(context, agent_name, output).await?;
Ok(OutputGuardrailResult {
guardrail_name: self.name.clone(),
output: guardrail_output,
})
}
}
impl std::fmt::Debug for OutputGuardrail {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OutputGuardrail")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct OutputGuardrailResult {
pub guardrail_name: String,
pub output: GuardrailOutput,
}
impl OutputGuardrailResult {
#[must_use]
pub const fn is_triggered(&self) -> bool {
self.output.tripwire_triggered
}
}