use unicode_normalization::UnicodeNormalization;
#[derive(Debug, Clone)]
pub struct Normalized {
pub text: String,
pub offsets: Vec<usize>,
pub orig_len: usize,
}
impl Normalized {
pub fn map_span(&self, start: usize, end: usize) -> (usize, usize) {
let o_start = self.offsets.get(start).copied().unwrap_or(self.orig_len);
let o_end = if end >= self.offsets.len() {
self.orig_len
} else {
self.offsets[end]
};
(o_start.min(o_end), o_start.max(o_end))
}
}
fn is_stripped(c: char) -> bool {
matches!(
c,
'\u{200B}' | '\u{200C}' | '\u{200D}' | '\u{FEFF}' | '\u{2060}' | '\u{00AD}'
) || (c.is_control() && c != '\n' && c != '\t' && c != '\r')
}
fn fold_homoglyph(c: char) -> char {
match c {
'\u{0430}' => 'a', '\u{0435}' => 'e', '\u{043E}' => 'o', '\u{0440}' => 'p', '\u{0441}' => 'c', '\u{0445}' => 'x', '\u{0455}' => 's', '\u{0456}' => 'i', _ => c,
}
}
pub fn normalize(input: &str) -> Normalized {
let mut text = String::with_capacity(input.len());
let mut offsets: Vec<usize> = Vec::with_capacity(input.len());
for (byte_idx, ch) in input.char_indices() {
if is_stripped(ch) {
continue;
}
let folded = fold_homoglyph(ch);
for nch in folded.to_string().nfkc() {
let lower = nch.to_lowercase();
for lch in lower {
let mut buf = [0u8; 4];
let encoded = lch.encode_utf8(&mut buf);
text.push_str(encoded);
for _ in 0..encoded.len() {
offsets.push(byte_idx);
}
}
}
}
surface_base64(input, &mut text, &mut offsets);
Normalized {
text,
offsets,
orig_len: input.len(),
}
}
fn surface_base64(input: &str, text: &mut String, offsets: &mut Vec<usize>) {
use base64::Engine as _;
let bytes = input.as_bytes();
let is_b64 = |b: u8| b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'=';
let mut i = 0;
while i < bytes.len() {
if !is_b64(bytes[i]) {
i += 1;
continue;
}
let start = i;
while i < bytes.len() && is_b64(bytes[i]) {
i += 1;
}
let run = &input[start..i];
if run.len() < 24 {
continue;
}
if let Ok(decoded) = base64::engine::general_purpose::STANDARD
.decode(run.trim_end_matches('='))
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(run))
&& let Ok(s) = String::from_utf8(decoded)
&& s.chars()
.filter(|c| c.is_ascii_graphic() || *c == ' ')
.count()
* 2
>= s.len()
{
text.push('\n');
offsets.push(start);
let lowered = s.to_lowercase();
text.push_str(&lowered);
for _ in 0..lowered.len() {
offsets.push(start);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine as _;
#[test]
fn strips_zero_width_and_maps_back() {
let original = "ig\u{200B}nore previous";
let n = normalize(original);
assert!(n.text.contains("ignore previous"));
let pos = n.text.find("ignore previous").unwrap();
let (s, e) = n.map_span(pos, pos + "ignore previous".len());
let recovered = &original[s..e];
assert!(recovered.starts_with("ig"));
assert!(recovered.contains("nore previous"));
}
#[test]
fn folds_cyrillic_homoglyphs() {
let original = "\u{0456}gn\u{043E}re";
let n = normalize(original);
assert!(n.text.contains("ignore"), "got: {:?}", n.text);
}
#[test]
fn nfkc_normalizes_fullwidth() {
let original = "IGNORE"; let n = normalize(original);
assert!(n.text.contains("ignore"), "got: {:?}", n.text);
}
#[test]
fn lowercases_for_case_insensitive_match() {
let n = normalize("IGNORE Previous");
assert!(n.text.contains("ignore previous"));
}
#[test]
fn surfaces_base64_block() {
let b64 =
base64::engine::general_purpose::STANDARD.encode("ignore all previous instructions");
let original = format!("prefix {b64} suffix");
let n = normalize(&original);
assert!(n.text.contains("ignore all previous instructions"));
let pos = n.text.find("ignore all previous").unwrap();
let (s, _e) = n.map_span(pos, pos + 5);
assert!(original[s..].starts_with(&b64[..1]) || original[s..].starts_with(&b64));
}
#[test]
fn offsets_len_matches_text_len() {
let n = normalize("hello world");
assert_eq!(n.offsets.len(), n.text.len());
}
#[test]
fn offset_invariant_holds_with_non_ascii() {
let original = "héllo ignore previous";
let n = normalize(original);
assert_eq!(
n.offsets.len(),
n.text.len(),
"offset map desynced on non-ascii input"
);
let pos = n
.text
.find("ignore previous")
.expect("phrase present after normalize");
let (s, e) = n.map_span(pos, pos + "ignore previous".len());
assert_eq!(
&original[s..e],
"ignore previous",
"span mis-mapped: got {:?}",
&original[s..e]
);
}
}