use crate::types::{PromptInput, SanitizedOutput};
use crate::AimdsError;
use async_trait::async_trait;
#[async_trait]
pub trait SafetyGate: Send + Sync {
async fn inspect(&self, input: &PromptInput) -> Result<SafetyVerdict, AimdsError>;
}
#[derive(Debug, Clone)]
pub enum SafetyVerdict {
Pass,
Block(String),
Redact(SanitizedOutput),
}
impl SafetyVerdict {
pub fn is_forwardable(&self) -> bool {
matches!(self, SafetyVerdict::Pass | SafetyVerdict::Redact(_))
}
pub fn is_blocked(&self) -> bool {
matches!(self, SafetyVerdict::Block(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::SanitizedOutput;
use chrono::Utc;
use uuid::Uuid;
fn fake_sanitized() -> SanitizedOutput {
SanitizedOutput {
original_id: Uuid::nil(),
timestamp: Utc::now(),
sanitized_content: String::new(),
modifications: vec![],
is_safe: true,
}
}
#[test]
fn pass_is_forwardable_and_not_blocked() {
let v = SafetyVerdict::Pass;
assert!(v.is_forwardable());
assert!(!v.is_blocked());
}
#[test]
fn redact_is_forwardable_and_not_blocked() {
let v = SafetyVerdict::Redact(fake_sanitized());
assert!(v.is_forwardable());
assert!(!v.is_blocked());
}
#[test]
fn block_is_blocked_and_not_forwardable() {
let v = SafetyVerdict::Block("test rule".into());
assert!(!v.is_forwardable());
assert!(v.is_blocked());
}
#[test]
fn safety_gate_is_object_safe() {
fn requires_gate<T: SafetyGate>() {}
struct AlwaysPass;
#[async_trait]
impl SafetyGate for AlwaysPass {
async fn inspect(
&self,
_input: &PromptInput,
) -> Result<SafetyVerdict, AimdsError> {
Ok(SafetyVerdict::Pass)
}
}
requires_gate::<AlwaysPass>();
}
}