use crate::error::Error;
use crate::guardrail::{Guardrail, GuardrailOutcome};
use crate::llm::{ChatRequest, ChatResponse};
use std::sync::Arc;
use tracing::debug;
pub(crate) async fn check_pre_llm(
guardrails: &[Arc<dyn Guardrail>],
req: &ChatRequest,
) -> Result<(), Error> {
for g in guardrails {
match g.pre_llm(req).await {
GuardrailOutcome::Pass => {}
GuardrailOutcome::Refuse { reason } => {
debug!(guardrail = g.name(), %reason, "pre-LLM guardrail refused");
return Err(Error::Refused { reason });
}
GuardrailOutcome::Handoff { agent, reason } => {
debug!(guardrail = g.name(), %agent, %reason, "pre-LLM guardrail handoff");
return Err(Error::Handoff { agent, reason });
}
}
}
Ok(())
}
pub(crate) async fn check_post_llm(
guardrails: &[Arc<dyn Guardrail>],
req: &ChatRequest,
resp: &ChatResponse,
) -> Result<(), Error> {
for g in guardrails {
match g.post_llm(req, resp).await {
GuardrailOutcome::Pass => {}
GuardrailOutcome::Refuse { reason } => {
debug!(guardrail = g.name(), %reason, "post-LLM guardrail refused");
return Err(Error::Refused { reason });
}
GuardrailOutcome::Handoff { agent, reason } => {
debug!(guardrail = g.name(), %agent, %reason, "post-LLM guardrail handoff");
return Err(Error::Handoff { agent, reason });
}
}
}
Ok(())
}