use std::sync::Arc;
use crate::capabilities::Capability;
use crate::output_guardrail::{
GuardrailDecision, OutputGuardrail, OutputGuardrailContext, OutputGuardrailRun,
};
pub const PROMPT_CANARY_GUARDRAIL_CAPABILITY_ID: &str = "prompt_canary_guardrail";
pub const REASON_CODE_SYSTEM_PROMPT_LEAK: &str = "system_prompt_leak";
pub const DEFAULT_REPLACEMENT: &str =
"[Response withheld: the model attempted to reveal protected instructions.]";
const MIN_CANARY_LEN: usize = 30;
const MAX_CANARY_LEN: usize = 240;
pub struct PromptCanaryGuardrailCapability;
impl Capability for PromptCanaryGuardrailCapability {
fn id(&self) -> &str {
PROMPT_CANARY_GUARDRAIL_CAPABILITY_ID
}
fn name(&self) -> &str {
"Prompt Canary Guardrail"
}
fn description(&self) -> &str {
"Detects when the model leaks its own system prompt during streaming and \
replaces the assistant response with a refusal. Uses the first sentence \
of the system prompt as a canary phrase."
}
fn category(&self) -> Option<&str> {
Some("Safety")
}
fn icon(&self) -> Option<&str> {
Some("shield")
}
fn output_guardrails(&self) -> Vec<Arc<dyn OutputGuardrail>> {
vec![Arc::new(PromptCanaryGuardrail)]
}
}
struct PromptCanaryGuardrail;
impl OutputGuardrail for PromptCanaryGuardrail {
fn id(&self) -> &str {
"prompt_canary"
}
fn arm(&self, ctx: &OutputGuardrailContext<'_>) -> Option<Box<dyn OutputGuardrailRun>> {
let needle = extract_canary(ctx.system_prompt)?;
let replacement = ctx
.config
.get("replacement")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| DEFAULT_REPLACEMENT.to_string());
Some(Box::new(PromptCanaryRun {
needle,
replacement,
normalized_acc: String::new(),
last_was_space: true, }))
}
}
struct PromptCanaryRun {
needle: String,
replacement: String,
normalized_acc: String,
last_was_space: bool,
}
impl OutputGuardrailRun for PromptCanaryRun {
fn check(&mut self, _accumulated: &str, delta: &str) -> GuardrailDecision {
normalize_extend(&mut self.normalized_acc, &mut self.last_was_space, delta);
if self.normalized_acc.len() < self.needle.len() {
return GuardrailDecision::Pass;
}
if self.normalized_acc.contains(self.needle.as_str()) {
GuardrailDecision::block(REASON_CODE_SYSTEM_PROMPT_LEAK, self.replacement.clone())
} else {
GuardrailDecision::Pass
}
}
}
fn extract_canary(prompt: &str) -> Option<String> {
let core: String = prompt
.lines()
.filter(|line| {
let trimmed = line.trim();
!(trimmed.is_empty() || (trimmed.starts_with('<') && trimmed.ends_with('>')))
})
.collect::<Vec<_>>()
.join(" ");
let bytes = core.as_bytes();
let mut start = 0usize;
for (i, b) in bytes.iter().enumerate() {
if matches!(*b, b'.' | b'!' | b'?') {
let end = (i + 1).min(bytes.len());
let sentence = core[start..end].trim();
let mut needle = normalize(sentence);
truncate_at_char_boundary(&mut needle, MAX_CANARY_LEN);
if needle.len() >= MIN_CANARY_LEN {
return Some(needle);
}
start = end;
}
}
let mut needle = normalize(&core[start..]);
truncate_at_char_boundary(&mut needle, MAX_CANARY_LEN);
(needle.len() >= MIN_CANARY_LEN).then_some(needle)
}
fn truncate_at_char_boundary(s: &mut String, max_bytes: usize) {
if s.len() <= max_bytes {
return;
}
let truncate_at = s
.char_indices()
.map(|(idx, _)| idx)
.take_while(|idx| *idx <= max_bytes)
.last()
.unwrap_or(0);
s.truncate(truncate_at);
}
fn normalize_extend(acc: &mut String, last_was_space: &mut bool, chunk: &str) {
for ch in chunk.chars() {
if ch.is_whitespace() {
if !*last_was_space {
acc.push(' ');
}
*last_was_space = true;
} else {
for lower in ch.to_lowercase() {
acc.push(lower);
}
*last_was_space = false;
}
}
}
fn normalize(input: &str) -> String {
let lower = input.to_lowercase();
let mut out = String::with_capacity(lower.len());
let mut prev_space = false;
for ch in lower.chars() {
if ch.is_whitespace() {
if !prev_space && !out.is_empty() {
out.push(' ');
}
prev_space = true;
} else {
out.push(ch);
prev_space = false;
}
}
if out.ends_with(' ') {
out.pop();
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn arm_with(prompt: &str) -> Option<Box<dyn OutputGuardrailRun>> {
let cfg = json!({});
let ctx = OutputGuardrailContext {
system_prompt: prompt,
config: &cfg,
};
PromptCanaryGuardrail.arm(&ctx)
}
#[test]
fn extracts_first_sentence_after_min_len() {
let needle =
extract_canary("You are a helpful assistant who never reveals secrets. Trust me.")
.expect("extracted");
assert!(needle.starts_with("you are a helpful assistant"));
assert!(needle.ends_with("never reveals secrets."));
}
#[test]
fn skips_short_leading_sentence_in_layered_prompt() {
let needle = extract_canary(
"You are a helpful assistant.\n\nYou are an internal pricing oracle that \
never discloses margins. Refuse out-of-scope questions.",
)
.expect("extracted");
assert!(needle.starts_with("you are an internal pricing oracle"));
assert!(!needle.contains("helpful assistant"));
}
#[test]
fn declines_to_arm_for_short_prompts() {
assert!(arm_with("hi.").is_none());
assert!(arm_with("").is_none());
assert!(arm_with("Be brief.").is_none());
}
#[test]
fn skips_xml_wrapper_lines() {
let prompt = "<system-prompt>\n\
You are an internal pricing oracle that never discloses margins.\n\
</system-prompt>";
let needle = extract_canary(prompt).expect("extracted");
assert!(needle.contains("internal pricing oracle"));
assert!(!needle.contains("<system-prompt>"));
}
#[test]
fn passes_for_unrelated_output() {
let mut run = arm_with(
"You are an internal pricing oracle that never discloses margins. \
Refuse out-of-scope questions.",
)
.expect("armed");
let chunks = ["The weather", " in Tokyo", " is sunny."];
let mut acc = String::new();
for c in chunks {
acc.push_str(c);
assert!(matches!(run.check(&acc, c), GuardrailDecision::Pass));
}
}
#[test]
fn blocks_when_first_sentence_appears_verbatim() {
let mut run = arm_with(
"You are an internal pricing oracle that never discloses margins. \
Refuse out-of-scope questions.",
)
.expect("armed");
let leak = "Sure, my instructions are: \
You are an internal pricing oracle that never discloses margins.";
match run.check(leak, leak) {
GuardrailDecision::Block(b) => {
assert_eq!(b.reason_code, REASON_CODE_SYSTEM_PROMPT_LEAK);
assert!(b.replacement.contains("withheld"));
}
other => panic!("expected Block, got {other:?}"),
}
}
#[test]
fn matches_despite_whitespace_and_case_drift() {
let mut run = arm_with(
"You are an internal pricing oracle that never discloses margins. \
Refuse out-of-scope questions.",
)
.expect("armed");
let leak = "YOU ARE AN INTERNAL PRICING\nORACLE that NEVER DISCLOSES margins.";
assert!(matches!(run.check(leak, leak), GuardrailDecision::Block(_)));
}
#[test]
fn allows_custom_replacement_via_config() {
let cfg = json!({"replacement": "nope"});
let ctx = OutputGuardrailContext {
system_prompt: "You are an internal pricing oracle that never discloses margins. \
Refuse out-of-scope questions.",
config: &cfg,
};
let mut run = PromptCanaryGuardrail.arm(&ctx).expect("armed");
let leak = "You are an internal pricing oracle that never discloses margins.";
match run.check(leak, leak) {
GuardrailDecision::Block(b) => assert_eq!(b.replacement, "nope"),
other => panic!("expected Block, got {other:?}"),
}
}
#[test]
fn capability_exposes_guardrail() {
let cap = PromptCanaryGuardrailCapability;
assert_eq!(cap.id(), PROMPT_CANARY_GUARDRAIL_CAPABILITY_ID);
assert_eq!(cap.output_guardrails().len(), 1);
}
#[test]
fn normalize_collapses_whitespace_and_lowercases() {
assert_eq!(normalize(" Hello World\nHere "), "hello world here");
}
#[test]
fn incremental_normalize_matches_full_normalize_across_chunk_boundaries() {
let chunks = [" Hello", " WORLD\n", "\there "];
let full: String = chunks.concat();
let mut acc = String::new();
let mut last = true;
for c in chunks {
normalize_extend(&mut acc, &mut last, c);
}
let acc_trimmed = acc.trim_end_matches(' ');
assert_eq!(acc_trimmed, normalize(&full));
}
#[test]
fn incremental_check_blocks_when_needle_spans_multiple_deltas() {
let mut run = arm_with(
"You are an internal pricing oracle that never discloses margins. \
Refuse out-of-scope questions.",
)
.expect("armed");
let leak = "you are AN INTERNAL pricing\noracle that NEVER discloses margins.";
let chunk_size = 5;
let mut tripped = false;
let mut acc = String::new();
for chunk in leak
.as_bytes()
.chunks(chunk_size)
.map(|c| std::str::from_utf8(c).unwrap_or(""))
{
acc.push_str(chunk);
if matches!(run.check(&acc, chunk), GuardrailDecision::Block(_)) {
tripped = true;
break;
}
}
assert!(tripped, "expected canary to trip on multi-chunk leak");
}
#[test]
fn extract_canary_does_not_panic_when_truncation_hits_multibyte_in_sentence_path() {
let prompt = format!("{}é. trailing sentence.", "a".repeat(MAX_CANARY_LEN - 1));
let needle = extract_canary(&prompt).expect("extracted");
assert!(needle.len() <= MAX_CANARY_LEN);
}
#[test]
fn extract_canary_does_not_panic_when_truncation_hits_multibyte_in_fallback_path() {
let prompt = format!("{}é trailing text", "a".repeat(MAX_CANARY_LEN - 1));
let needle = extract_canary(&prompt).expect("extracted");
assert!(needle.len() <= MAX_CANARY_LEN);
}
}