use crate::error::Result;
use crate::guardrails::traits::{GuardrailResult, InputGuardrail, OutputGuardrail};
use crate::model::SharedModel;
use crate::model::types::{ChatRequest, ChatResponse, Message};
pub struct ContentPolicyGuardrail {
model: SharedModel,
policy: String,
}
impl ContentPolicyGuardrail {
pub fn new(model: SharedModel, policy: impl Into<String>) -> Self {
Self {
model,
policy: policy.into(),
}
}
async fn evaluate(&self, content: &str) -> Result<GuardrailResult> {
let prompt = format!(
"You are a content safety classifier. Evaluate whether the following content \
violates the content policy.\n\n\
Policy: {}\n\n\
Content to evaluate:\n{}\n\n\
Respond with EXACTLY one of:\n\
PASS\n\
BLOCK: <brief reason>\n\n\
Your response:",
self.policy, content
);
let request = ChatRequest::new(vec![Message::user(&prompt)]);
let response = self.model.generate_erased(&request).await?;
let text = response.text().trim().to_string();
let upper = text.to_uppercase();
if upper.starts_with("PASS") {
Ok(GuardrailResult::Pass)
} else if upper.starts_with("BLOCK") {
let reason = text
.get(5..)
.map(|s| s.trim_start_matches(':').trim())
.filter(|s| !s.is_empty())
.unwrap_or("content policy violation")
.to_string();
Ok(GuardrailResult::Block(reason))
} else {
Ok(GuardrailResult::Block(format!(
"content policy classifier returned ambiguous result: {text}"
)))
}
}
}
impl InputGuardrail for ContentPolicyGuardrail {
async fn check(&self, input: &str, _messages: &[Message]) -> Result<GuardrailResult> {
self.evaluate(input).await
}
}
impl OutputGuardrail for ContentPolicyGuardrail {
async fn check(&self, response: &ChatResponse) -> Result<GuardrailResult> {
self.evaluate(response.text()).await
}
}