use std::sync::Arc;
use async_trait::async_trait;
pub trait OutputGuardrail: Send + Sync {
fn id(&self) -> &str;
fn arm(&self, ctx: &OutputGuardrailContext<'_>) -> Option<Box<dyn OutputGuardrailRun>>;
}
pub trait OutputGuardrailRun: Send {
fn check(&mut self, accumulated: &str, delta: &str) -> GuardrailDecision;
}
pub struct OutputGuardrailContext<'a> {
pub system_prompt: &'a str,
pub config: &'a serde_json::Value,
}
#[derive(Debug, Clone)]
pub enum GuardrailDecision {
Pass,
Block(GuardrailBlock),
}
#[derive(Debug, Clone)]
pub struct GuardrailBlock {
pub reason_code: String,
pub replacement: String,
}
impl GuardrailDecision {
pub fn block(reason_code: impl Into<String>, replacement: impl Into<String>) -> Self {
GuardrailDecision::Block(GuardrailBlock {
reason_code: reason_code.into(),
replacement: replacement.into(),
})
}
}
pub struct ArmedGuardrail {
pub capability_id: String,
pub guardrail_id: String,
pub run: Box<dyn OutputGuardrailRun>,
}
pub fn evaluate_guardrails(
runs: &mut [ArmedGuardrail],
accumulated: &str,
delta: &str,
) -> Option<TrippedGuardrail> {
for armed in runs.iter_mut() {
match armed.run.check(accumulated, delta) {
GuardrailDecision::Pass => continue,
GuardrailDecision::Block(block) => {
return Some(TrippedGuardrail {
capability_id: armed.capability_id.clone(),
guardrail_id: armed.guardrail_id.clone(),
block,
});
}
}
}
None
}
#[derive(Debug, Clone)]
pub struct TrippedGuardrail {
pub capability_id: String,
pub guardrail_id: String,
pub block: GuardrailBlock,
}
pub fn arm_guardrails(
providers: &[(String, Arc<dyn OutputGuardrail>)],
ctx: &OutputGuardrailContext<'_>,
) -> Vec<ArmedGuardrail> {
providers
.iter()
.filter_map(|(cap_id, p)| {
let guardrail_id = p.id().to_string();
p.arm(ctx).map(|run| ArmedGuardrail {
capability_id: cap_id.clone(),
guardrail_id,
run,
})
})
.collect()
}
#[async_trait]
pub trait PostGenerationOutputGuardrail: Send + Sync {
fn id(&self) -> &str;
async fn check_message(&self, ctx: &PostGenerationOutputContext<'_>) -> GuardrailDecision;
}
pub struct PostGenerationOutputContext<'a> {
pub system_prompt: &'a str,
pub message_text: &'a str,
pub utility_llm_service: Option<&'a Arc<dyn crate::UtilityLlmService>>,
}
pub struct PostGenerationProvider {
pub capability_id: String,
pub provider: Arc<dyn PostGenerationOutputGuardrail>,
}
pub async fn evaluate_post_generation_guardrails(
providers: &[PostGenerationProvider],
ctx: &PostGenerationOutputContext<'_>,
) -> Option<TrippedGuardrail> {
for p in providers {
match p.provider.check_message(ctx).await {
GuardrailDecision::Pass => continue,
GuardrailDecision::Block(block) => {
return Some(TrippedGuardrail {
capability_id: p.capability_id.clone(),
guardrail_id: p.provider.id().to_string(),
block,
});
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
struct AlwaysBlock;
impl OutputGuardrailRun for AlwaysBlock {
fn check(&mut self, _accumulated: &str, _delta: &str) -> GuardrailDecision {
GuardrailDecision::block("test_block", "[blocked]")
}
}
struct NeverBlock;
impl OutputGuardrailRun for NeverBlock {
fn check(&mut self, _accumulated: &str, _delta: &str) -> GuardrailDecision {
GuardrailDecision::Pass
}
}
fn armed(cap: &str, guard: &str, run: Box<dyn OutputGuardrailRun>) -> ArmedGuardrail {
ArmedGuardrail {
capability_id: cap.to_string(),
guardrail_id: guard.to_string(),
run,
}
}
#[test]
fn evaluate_returns_first_block_in_order() {
let mut runs = vec![
armed("cap_a", "g_a", Box::new(NeverBlock)),
armed("cap_b", "g_b", Box::new(AlwaysBlock)),
armed("cap_c", "g_c", Box::new(AlwaysBlock)),
];
let tripped = evaluate_guardrails(&mut runs, "any text", "delta").expect("blocked");
assert_eq!(tripped.capability_id, "cap_b");
assert_eq!(tripped.guardrail_id, "g_b");
assert_eq!(tripped.block.reason_code, "test_block");
assert_eq!(tripped.block.replacement, "[blocked]");
}
#[test]
fn evaluate_returns_none_when_all_pass() {
let mut runs = vec![
armed("cap_a", "g_a", Box::new(NeverBlock)),
armed("cap_b", "g_b", Box::new(NeverBlock)),
];
assert!(evaluate_guardrails(&mut runs, "txt", "d").is_none());
}
struct PostPass;
#[async_trait]
impl PostGenerationOutputGuardrail for PostPass {
fn id(&self) -> &str {
"pass"
}
async fn check_message(&self, _ctx: &PostGenerationOutputContext<'_>) -> GuardrailDecision {
GuardrailDecision::Pass
}
}
struct PostBlock;
#[async_trait]
impl PostGenerationOutputGuardrail for PostBlock {
fn id(&self) -> &str {
"block"
}
async fn check_message(&self, _ctx: &PostGenerationOutputContext<'_>) -> GuardrailDecision {
GuardrailDecision::block("guardrail.moderation", "[removed]")
}
}
fn post_ctx<'a>(text: &'a str) -> PostGenerationOutputContext<'a> {
PostGenerationOutputContext {
system_prompt: "",
message_text: text,
utility_llm_service: None,
}
}
#[tokio::test]
async fn post_generation_returns_first_block_in_order() {
let providers = vec![
PostGenerationProvider {
capability_id: "cap_a".to_string(),
provider: Arc::new(PostPass),
},
PostGenerationProvider {
capability_id: "cap_b".to_string(),
provider: Arc::new(PostBlock),
},
];
let ctx = post_ctx("hello");
let tripped = evaluate_post_generation_guardrails(&providers, &ctx)
.await
.expect("blocked");
assert_eq!(tripped.capability_id, "cap_b");
assert_eq!(tripped.guardrail_id, "block");
assert_eq!(tripped.block.reason_code, "guardrail.moderation");
assert_eq!(tripped.block.replacement, "[removed]");
}
#[tokio::test]
async fn post_generation_returns_none_when_all_pass() {
let providers = vec![PostGenerationProvider {
capability_id: "cap_a".to_string(),
provider: Arc::new(PostPass),
}];
let ctx = post_ctx("hello");
assert!(
evaluate_post_generation_guardrails(&providers, &ctx)
.await
.is_none()
);
}
}