#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub struct SimdEscaper {
simd_enabled: bool,
}
impl SimdEscaper {
pub fn new() -> Self {
Self {
simd_enabled: Self::detect_simd_support(),
}
}
#[cfg(target_arch = "x86_64")]
fn detect_simd_support() -> bool {
is_x86_feature_detected!("sse4.2") && is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
fn detect_simd_support() -> bool {
false
}
pub fn is_simd_enabled(&self) -> bool {
self.simd_enabled
}
pub fn escape_literal(&self, value: &str) -> String {
if self.simd_enabled && value.len() >= 16 {
self.escape_literal_simd(value)
} else {
self.escape_literal_scalar(value)
}
}
pub fn is_valid_iri(&self, iri: &str) -> bool {
if self.simd_enabled && iri.len() >= 16 {
self.is_valid_iri_simd(iri)
} else {
self.is_valid_iri_scalar(iri)
}
}
pub fn needs_escaping(&self, value: &str) -> bool {
if self.simd_enabled && value.len() >= 16 {
self.needs_escaping_simd(value)
} else {
self.needs_escaping_scalar(value)
}
}
pub fn count_escaped_chars(&self, value: &str) -> usize {
if self.simd_enabled && value.len() >= 16 {
self.count_escaped_chars_simd(value)
} else {
self.count_escaped_chars_scalar(value)
}
}
fn escape_literal_scalar(&self, value: &str) -> String {
value
.replace('\\', "\\\\")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t")
.replace('"', "\\\"")
}
fn is_valid_iri_scalar(&self, iri: &str) -> bool {
!iri.chars().any(|c| {
matches!(
c,
'\0'..='\x1F' | ' ' | '<' | '>' | '"' | '{' | '}' | '|' | '\\' | '^' | '`'
)
})
}
fn needs_escaping_scalar(&self, value: &str) -> bool {
value
.chars()
.any(|c| matches!(c, '\\' | '\n' | '\r' | '\t' | '"'))
}
fn count_escaped_chars_scalar(&self, value: &str) -> usize {
value
.chars()
.filter(|c| matches!(c, '\\' | '\n' | '\r' | '\t' | '"'))
.count()
}
#[allow(dead_code)]
fn needs_escaping_bytes(&self, bytes: &[u8]) -> bool {
bytes
.iter()
.any(|&b| matches!(b, b'\\' | b'\n' | b'\r' | b'\t' | b'"'))
}
#[allow(dead_code)]
fn count_escaped_chars_bytes(&self, bytes: &[u8]) -> usize {
bytes
.iter()
.filter(|&&b| matches!(b, b'\\' | b'\n' | b'\r' | b'\t' | b'"'))
.count()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2,avx2")]
unsafe fn escape_literal_simd_impl(&self, value: &str) -> String {
let bytes = value.as_bytes();
let len = bytes.len();
let escaped_count = self.count_escaped_chars_simd(value);
let mut result = Vec::with_capacity(len + escaped_count);
let mut offset = 0;
let backslash = _mm_set1_epi8(b'\\' as i8);
let newline = _mm_set1_epi8(b'\n' as i8);
let carriage = _mm_set1_epi8(b'\r' as i8);
let tab = _mm_set1_epi8(b'\t' as i8);
let quote = _mm_set1_epi8(b'"' as i8);
while offset + 16 <= len {
let chunk = _mm_loadu_si128(bytes.as_ptr().add(offset) as *const __m128i);
let m1 = _mm_cmpeq_epi8(chunk, backslash);
let m2 = _mm_cmpeq_epi8(chunk, newline);
let m3 = _mm_cmpeq_epi8(chunk, carriage);
let m4 = _mm_cmpeq_epi8(chunk, tab);
let m5 = _mm_cmpeq_epi8(chunk, quote);
let match1 = _mm_or_si128(m1, m2);
let match2 = _mm_or_si128(m3, m4);
let match3 = _mm_or_si128(match1, match2);
let matches = _mm_or_si128(match3, m5);
let mask = _mm_movemask_epi8(matches) as u32;
if mask == 0 {
result.extend_from_slice(&bytes[offset..offset + 16]);
offset += 16;
} else {
for i in 0..16 {
let b = bytes[offset + i];
match b {
b'\\' => result.extend_from_slice(b"\\\\"),
b'\n' => result.extend_from_slice(b"\\n"),
b'\r' => result.extend_from_slice(b"\\r"),
b'\t' => result.extend_from_slice(b"\\t"),
b'"' => result.extend_from_slice(b"\\\""),
_ => result.push(b),
}
}
offset += 16;
}
}
for &b in &bytes[offset..] {
match b {
b'\\' => result.extend_from_slice(b"\\\\"),
b'\n' => result.extend_from_slice(b"\\n"),
b'\r' => result.extend_from_slice(b"\\r"),
b'\t' => result.extend_from_slice(b"\\t"),
b'"' => result.extend_from_slice(b"\\\""),
_ => result.push(b),
}
}
String::from_utf8_unchecked(result)
}
#[cfg(target_arch = "x86_64")]
fn escape_literal_simd(&self, value: &str) -> String {
unsafe { self.escape_literal_simd_impl(value) }
}
#[cfg(not(target_arch = "x86_64"))]
fn escape_literal_simd(&self, value: &str) -> String {
self.escape_literal_scalar(value)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2,avx2")]
unsafe fn is_valid_iri_simd_impl(&self, iri: &str) -> bool {
let bytes = iri.as_bytes();
let len = bytes.len();
let mut offset = 0;
while offset + 16 <= len {
let chunk = _mm_loadu_si128(bytes.as_ptr().add(offset) as *const __m128i);
let control_threshold = _mm_set1_epi8(0x20);
let is_control = _mm_cmplt_epi8(chunk, control_threshold);
let space = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(b' ' as i8));
let lt = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(b'<' as i8));
let gt = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(b'>' as i8));
let quote = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(b'"' as i8));
let backslash = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(b'\\' as i8));
let invalid1 = _mm_or_si128(is_control, space);
let invalid2 = _mm_or_si128(lt, gt);
let invalid3 = _mm_or_si128(quote, backslash);
let invalid4 = _mm_or_si128(invalid1, invalid2);
let invalid = _mm_or_si128(invalid4, invalid3);
let mask = _mm_movemask_epi8(invalid);
if mask != 0 {
return false;
}
offset += 16;
}
for b in &bytes[offset..] {
if *b < 0x20
|| matches!(
*b,
b' ' | b'<' | b'>' | b'"' | b'{' | b'}' | b'|' | b'\\' | b'^' | b'`'
)
{
return false;
}
}
true
}
#[cfg(target_arch = "x86_64")]
fn is_valid_iri_simd(&self, iri: &str) -> bool {
unsafe { self.is_valid_iri_simd_impl(iri) }
}
#[cfg(not(target_arch = "x86_64"))]
fn is_valid_iri_simd(&self, iri: &str) -> bool {
self.is_valid_iri_scalar(iri)
}
#[cfg(target_arch = "x86_64")]
fn needs_escaping_simd(&self, value: &str) -> bool {
let bytes = value.as_bytes();
let len = bytes.len();
let mut offset = 0;
unsafe {
let backslash = _mm_set1_epi8(b'\\' as i8);
let newline = _mm_set1_epi8(b'\n' as i8);
let carriage = _mm_set1_epi8(b'\r' as i8);
let tab = _mm_set1_epi8(b'\t' as i8);
let quote = _mm_set1_epi8(b'"' as i8);
while offset + 16 <= len {
let chunk = _mm_loadu_si128(bytes.as_ptr().add(offset) as *const __m128i);
let m1 = _mm_cmpeq_epi8(chunk, backslash);
let m2 = _mm_cmpeq_epi8(chunk, newline);
let m3 = _mm_cmpeq_epi8(chunk, carriage);
let m4 = _mm_cmpeq_epi8(chunk, tab);
let m5 = _mm_cmpeq_epi8(chunk, quote);
let match1 = _mm_or_si128(m1, m2);
let match2 = _mm_or_si128(m3, m4);
let match3 = _mm_or_si128(match1, match2);
let matches = _mm_or_si128(match3, m5);
let mask = _mm_movemask_epi8(matches);
if mask != 0 {
return true;
}
offset += 16;
}
}
self.needs_escaping_bytes(&bytes[offset..])
}
#[cfg(not(target_arch = "x86_64"))]
fn needs_escaping_simd(&self, value: &str) -> bool {
self.needs_escaping_scalar(value)
}
#[cfg(target_arch = "x86_64")]
fn count_escaped_chars_simd(&self, value: &str) -> usize {
let bytes = value.as_bytes();
let len = bytes.len();
let mut count = 0;
let mut offset = 0;
unsafe {
let backslash = _mm_set1_epi8(b'\\' as i8);
let newline = _mm_set1_epi8(b'\n' as i8);
let carriage = _mm_set1_epi8(b'\r' as i8);
let tab = _mm_set1_epi8(b'\t' as i8);
let quote = _mm_set1_epi8(b'"' as i8);
while offset + 16 <= len {
let chunk = _mm_loadu_si128(bytes.as_ptr().add(offset) as *const __m128i);
let m1 = _mm_cmpeq_epi8(chunk, backslash);
let m2 = _mm_cmpeq_epi8(chunk, newline);
let m3 = _mm_cmpeq_epi8(chunk, carriage);
let m4 = _mm_cmpeq_epi8(chunk, tab);
let m5 = _mm_cmpeq_epi8(chunk, quote);
let match1 = _mm_or_si128(m1, m2);
let match2 = _mm_or_si128(m3, m4);
let match3 = _mm_or_si128(match1, match2);
let matches = _mm_or_si128(match3, m5);
let mask = _mm_movemask_epi8(matches) as u32;
count += mask.count_ones() as usize;
offset += 16;
}
}
count + self.count_escaped_chars_bytes(&bytes[offset..])
}
#[cfg(not(target_arch = "x86_64"))]
fn count_escaped_chars_simd(&self, value: &str) -> usize {
self.count_escaped_chars_scalar(value)
}
}
impl Default for SimdEscaper {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_escaper_creation() {
let escaper = SimdEscaper::new();
assert!(escaper.is_simd_enabled() || !escaper.is_simd_enabled());
}
#[test]
fn test_escape_literal_simple() {
let escaper = SimdEscaper::new();
assert_eq!(escaper.escape_literal("hello"), "hello");
assert_eq!(escaper.escape_literal("hello world"), "hello world");
}
#[test]
fn test_escape_literal_special_chars() {
let escaper = SimdEscaper::new();
assert_eq!(escaper.escape_literal("hello\\world"), "hello\\\\world");
assert_eq!(escaper.escape_literal("hello\nworld"), "hello\\nworld");
assert_eq!(escaper.escape_literal("hello\rworld"), "hello\\rworld");
assert_eq!(escaper.escape_literal("hello\tworld"), "hello\\tworld");
assert_eq!(escaper.escape_literal("hello\"world"), "hello\\\"world");
}
#[test]
fn test_escape_literal_multiple() {
let escaper = SimdEscaper::new();
let input = "Line 1\nLine 2\r\nTab:\t\"Quote\"\\Backslash";
let expected = "Line 1\\nLine 2\\r\\nTab:\\t\\\"Quote\\\"\\\\Backslash";
assert_eq!(escaper.escape_literal(input), expected);
}
#[test]
fn test_escape_literal_long() {
let escaper = SimdEscaper::new();
let long_input = "a".repeat(100) + "\n" + &"b".repeat(100);
let expected = "a".repeat(100) + "\\n" + &"b".repeat(100);
assert_eq!(escaper.escape_literal(&long_input), expected);
}
#[test]
fn test_is_valid_iri() {
let escaper = SimdEscaper::new();
assert!(escaper.is_valid_iri("http://example.org/resource"));
assert!(escaper.is_valid_iri("https://example.org/path/to/resource"));
assert!(escaper.is_valid_iri("urn:isbn:0451450523"));
assert!(!escaper.is_valid_iri("http://example.org/<invalid>"));
assert!(!escaper.is_valid_iri("http://example.org/\"invalid\""));
assert!(!escaper.is_valid_iri("http://example.org/ space"));
assert!(!escaper.is_valid_iri("http://example.org/\ncontrol"));
}
#[test]
fn test_needs_escaping() {
let escaper = SimdEscaper::new();
assert!(!escaper.needs_escaping("hello world"));
assert!(escaper.needs_escaping("hello\nworld"));
assert!(escaper.needs_escaping("hello\"world"));
assert!(escaper.needs_escaping("hello\\world"));
assert!(escaper.needs_escaping("hello\tworld"));
assert!(escaper.needs_escaping("hello\rworld"));
}
#[test]
fn test_count_escaped_chars() {
let escaper = SimdEscaper::new();
assert_eq!(escaper.count_escaped_chars("hello"), 0);
assert_eq!(escaper.count_escaped_chars("hello\nworld"), 1);
assert_eq!(escaper.count_escaped_chars("\"hello\""), 2);
assert_eq!(escaper.count_escaped_chars("a\\b\nc\td\"e"), 4);
}
#[test]
fn test_simd_vs_scalar_consistency() {
let escaper = SimdEscaper::new();
let long_string = "long ".repeat(50) + "\n" + &"string".repeat(50);
let test_cases = vec![
"simple string",
"string with\nnewline",
"string with\ttab",
"string with\"quote",
"string with\\backslash",
"mixed\n\r\t\"\\all",
&long_string,
];
for case in test_cases {
let escaped = escaper.escape_literal(case);
let scalar_escaped = escaper.escape_literal_scalar(case);
assert_eq!(escaped, scalar_escaped, "Mismatch for: {}", case);
let needs_esc = escaper.needs_escaping(case);
let scalar_needs = escaper.needs_escaping_scalar(case);
assert_eq!(needs_esc, scalar_needs, "needs_escaping mismatch: {}", case);
let count = escaper.count_escaped_chars(case);
let scalar_count = escaper.count_escaped_chars_scalar(case);
assert_eq!(count, scalar_count, "count mismatch: {}", case);
}
}
#[test]
fn test_iri_validation_consistency() {
let escaper = SimdEscaper::new();
let long_iri = format!("http://example.org/{}", "a".repeat(100));
let test_iris = vec![
"http://example.org/valid",
"https://example.org/path",
"urn:isbn:0451450523",
"http://example.org/<invalid>",
"http://example.org/ space",
"http://example.org/\ncontrol",
&long_iri,
];
for iri in test_iris {
let valid = escaper.is_valid_iri(iri);
let scalar_valid = escaper.is_valid_iri_scalar(iri);
assert_eq!(valid, scalar_valid, "Mismatch for IRI: {}", iri);
}
}
}