use base64::{Engine as _, engine::general_purpose};
use std::io::Write as _;
use crate::error::EncodeError;
use wafrift_types::hash::{FNV_OFFSET_64, FNV_PRIME_64};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChunkedBody {
pub body: Vec<u8>,
pub required_headers: Vec<(String, String)>,
}
pub fn null_byte_inject(payload: impl AsRef<[u8]>) -> Result<String, EncodeError> {
let payload = payload.as_ref();
let payload_str = std::str::from_utf8(payload).map_err(|_| EncodeError::InvalidUtf8)?;
if payload.contains(&b'.') {
Ok(format!("{payload_str}%00.jpg"))
} else {
Ok(format!("{payload_str}%00"))
}
}
pub fn overlong_utf8(payload: impl AsRef<[u8]>) -> Result<String, EncodeError> {
let text = std::str::from_utf8(payload.as_ref()).map_err(|_| EncodeError::InvalidUtf8)?;
Ok(text
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_string()
} else if ch.is_ascii() {
let byte = ch as u8;
format!("%{:02X}%{:02X}", 0xC0 | (byte >> 6), 0x80 | (byte & 0x3F))
} else {
ch.to_string()
}
})
.collect())
}
pub fn overlong_utf8_more(payload: impl AsRef<[u8]>) -> Result<String, EncodeError> {
let text = std::str::from_utf8(payload.as_ref()).map_err(|_| EncodeError::InvalidUtf8)?;
Ok(text
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() {
ch.to_string()
} else if ch.is_ascii() {
let byte = ch as u8;
let cont1 = 0x80 | (byte >> 6);
let cont2 = 0x80 | (byte & 0x3F);
format!("%E0%{cont1:02X}%{cont2:02X}")
} else {
ch.to_string()
}
})
.collect())
}
pub fn chunked_split(
payload: impl AsRef<[u8]>,
chunk_size: usize,
) -> Result<ChunkedBody, EncodeError> {
let payload = payload.as_ref();
if payload.is_empty() {
return Ok(ChunkedBody {
body: Vec::new(),
required_headers: vec![("Transfer-Encoding".to_string(), "chunked".to_string())],
});
}
let chunk_size = chunk_size.max(1);
let mut result: Vec<u8> = Vec::with_capacity(payload.len() + 64);
for chunk in payload.chunks(chunk_size) {
let _ = write!(&mut result, "{:x}\r\n", chunk.len());
result.extend_from_slice(chunk);
result.extend_from_slice(b"\r\n");
}
result.extend_from_slice(b"0\r\n\r\n");
Ok(ChunkedBody {
body: result,
required_headers: vec![("Transfer-Encoding".to_string(), "chunked".to_string())],
})
}
pub fn parameter_pollute(payload: impl AsRef<[u8]>) -> Result<String, EncodeError> {
let payload = payload.as_ref();
let payload_str = std::str::from_utf8(payload).map_err(|_| EncodeError::InvalidUtf8)?;
if let Some(eq_pos) = payload.iter().position(|byte| *byte == b'=') {
let key = std::str::from_utf8(&payload[..eq_pos]).map_err(|_| EncodeError::InvalidUtf8)?;
Ok(format!("{key}=safe&{payload_str}"))
} else {
let mut h: u64 = FNV_OFFSET_64;
for &b in payload {
h ^= u64::from(b);
h = h.wrapping_mul(FNV_PRIME_64);
}
let decoy: String = (0..8)
.map(|i| (b'a' + (((h >> (i * 8)) as u8) % 26)) as char)
.collect();
Ok(format!("{decoy}=1&{payload_str}"))
}
}
pub fn base64_encode(payload: impl AsRef<[u8]>) -> String {
general_purpose::STANDARD.encode(payload)
}
pub fn base64_url_encode(payload: impl AsRef<[u8]>) -> String {
general_purpose::URL_SAFE_NO_PAD.encode(payload)
}
pub fn hex_encode(payload: impl AsRef<[u8]>) -> String {
hex::encode(payload)
}
pub use wafrift_types::utf7::{utf7_decode, utf7_encode};
pub fn gzip_encode(payload: impl AsRef<[u8]>) -> Result<String, EncodeError> {
let payload = payload.as_ref();
let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
encoder
.write_all(payload)
.map_err(|e| EncodeError::InvalidConfig(format!("gzip failed: {e}")))?;
let bytes = encoder
.finish()
.map_err(|e| EncodeError::InvalidConfig(format!("gzip failed: {e}")))?;
Ok(general_purpose::STANDARD.encode(bytes))
}
pub fn deflate_encode(payload: impl AsRef<[u8]>) -> Result<String, EncodeError> {
let payload = payload.as_ref();
let mut encoder =
flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default());
encoder
.write_all(payload)
.map_err(|e| EncodeError::InvalidConfig(format!("deflate failed: {e}")))?;
let bytes = encoder
.finish()
.map_err(|e| EncodeError::InvalidConfig(format!("deflate failed: {e}")))?;
Ok(general_purpose::STANDARD.encode(bytes))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn null_byte_with_extension() {
assert_eq!(null_byte_inject("file.php").unwrap(), "file.php%00.jpg");
}
#[test]
fn null_byte_without_extension() {
assert_eq!(null_byte_inject("payload").unwrap(), "payload%00");
}
#[test]
fn overlong_utf8_slash() {
let result = overlong_utf8("/").unwrap();
assert_eq!(result, "%C0%AF");
}
#[test]
fn overlong_utf8_more_slash() {
let result = overlong_utf8_more("/").unwrap();
assert_eq!(result, "%E0%80%AF");
}
#[test]
fn overlong_utf8_more_punctuation_above_0x40_uses_valid_continuation_bytes() {
for ch in ['@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] {
let s = ch.to_string();
let encoded = overlong_utf8_more(&s).unwrap();
assert!(
encoded.starts_with("%E0%"),
"{ch:?} should use 3-byte form, got: {encoded}"
);
let bytes: Vec<u8> = encoded
.split('%')
.filter(|s| !s.is_empty())
.map(|s| u8::from_str_radix(s, 16).unwrap())
.collect();
assert_eq!(bytes.len(), 3, "expected 3 bytes for {ch:?}");
assert_eq!(bytes[0], 0xE0, "lead byte wrong for {ch:?}");
assert!(
(0x80..=0xBF).contains(&bytes[1]),
"{ch:?} 2nd byte 0x{:02X} outside valid continuation range",
bytes[1]
);
assert!(
(0x80..=0xBF).contains(&bytes[2]),
"{ch:?} 3rd byte 0x{:02X} outside valid continuation range",
bytes[2]
);
let codepoint = ((bytes[1] & 0x3F) as u32) << 6 | (bytes[2] & 0x3F) as u32;
assert_eq!(
codepoint, ch as u32,
"decoded codepoint 0x{codepoint:X} != original 0x{:X}",
ch as u32
);
}
}
#[test]
fn overlong_utf8_more_preserves_alphanumerics_verbatim() {
assert_eq!(overlong_utf8_more("abc123").unwrap(), "abc123");
}
#[test]
fn chunked_split_produces_valid_chunks() {
let result = chunked_split("SELECT * FROM users", 3).unwrap();
let body = String::from_utf8(result.body.clone()).unwrap();
assert!(body.contains("\r\n"));
assert!(body.ends_with("0\r\n\r\n"));
assert_eq!(
result.required_headers,
vec![("Transfer-Encoding".to_string(), "chunked".to_string())]
);
}
#[test]
fn chunked_split_byte_lengths_correct() {
let payload = b"abc\x80\x81defgh";
let result = chunked_split(payload, 3).unwrap();
let mut i = 0;
let mut chunk_count = 0;
let expected_chunk_sizes = [3_usize, 3, 3, 1];
while i < result.body.len() {
let size_end = result.body[i..]
.windows(2)
.position(|w| w == b"\r\n")
.unwrap_or(result.body.len() - i)
+ i;
let size_str = std::str::from_utf8(&result.body[i..size_end]).unwrap();
if size_str == "0" {
break;
}
let size = usize::from_str_radix(size_str, 16).unwrap();
assert_eq!(size, expected_chunk_sizes[chunk_count]);
let data_start = size_end + 2;
let data_end = data_start + size;
assert_eq!(
&result.body[data_start..data_end],
&payload[chunk_count * 3..chunk_count * 3 + size]
);
i = data_end + 2;
chunk_count += 1;
}
assert_eq!(chunk_count, 4);
}
#[test]
fn chunked_split_empty() {
let result = chunked_split("", 3).unwrap();
assert!(result.body.is_empty());
}
#[test]
fn parameter_pollution_with_key_value() {
let result = parameter_pollute("user=' OR 1=1--").unwrap();
assert!(result.starts_with("user=safe&"));
assert!(result.contains("user=' OR 1=1--"));
}
#[test]
fn parameter_pollution_without_equals() {
let result = parameter_pollute("payload").unwrap();
assert!(result.ends_with("&payload"));
assert!(!result.contains("_wafrift_decoy"));
let decoy = result
.strip_suffix("=1&payload")
.expect("decoy=1&payload shape");
assert_eq!(decoy.len(), 8, "decoy must be 8 chars: {result}");
assert!(
decoy.bytes().all(|b| b.is_ascii_lowercase()),
"decoy must be [a-z]{{8}}: {result}"
);
assert_eq!(result, parameter_pollute("payload").unwrap());
assert_ne!(result, parameter_pollute("payloae").unwrap());
}
#[test]
fn base64_standard() {
assert_eq!(base64_encode("hello"), "aGVsbG8=");
}
#[test]
fn base64_url_safe() {
assert_eq!(base64_url_encode("hello+++"), "aGVsbG8rKys");
}
#[test]
fn hex_encode_basic() {
assert_eq!(hex_encode("ABC"), "414243");
}
#[test]
fn utf7_rfc2152_basic() {
assert_eq!(utf7_encode("Hello"), "Hello");
assert_eq!(utf7_encode("A+B"), "A+-B");
assert!(utf7_encode("日本語").starts_with('+'));
}
#[test]
fn utf7_rfc2152_decodeable() {
let encoded = utf7_encode("日本語");
assert!(encoded.contains('+'));
assert!(encoded.contains('-'));
}
#[test]
fn gzip_roundtrip() {
let original = b"SELECT * FROM users";
let encoded = gzip_encode(original).unwrap();
assert!(!encoded.is_empty());
let decoded = general_purpose::STANDARD.decode(&encoded).unwrap();
let mut decoder = flate2::read::GzDecoder::new(&decoded[..]);
let mut decompressed = Vec::new();
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn deflate_roundtrip() {
let original = b"SELECT * FROM users";
let encoded = deflate_encode(original).unwrap();
assert!(!encoded.is_empty());
let decoded = general_purpose::STANDARD.decode(&encoded).unwrap();
let mut decoder = flate2::read::DeflateDecoder::new(&decoded[..]);
let mut decompressed = Vec::new();
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
}