use crate::utils::{read_compressed_int, read_compressed_int_at};
use crate::Result;
use widestring::U16Str;
pub struct UserStrings<'a> {
data: &'a [u8],
}
impl<'a> UserStrings<'a> {
pub fn from(data: &'a [u8]) -> Result<UserStrings<'a>> {
if data.is_empty() || data[0] != 0 {
return Err(out_of_bounds_error!());
}
Ok(UserStrings { data })
}
pub fn get(&self, index: usize) -> Result<&'a U16Str> {
if index >= self.data.len() {
return Err(out_of_bounds_error!());
}
let (total_bytes, compressed_length_size) = read_compressed_int_at(self.data, index)?;
let data_start = index + compressed_length_size;
if total_bytes == 0 {
return Err(malformed_error!(
"Invalid zero-length string at index {}",
index
));
}
if total_bytes == 1 {
static EMPTY_U16: [u16; 0] = [];
return Ok(U16Str::from_slice(&EMPTY_U16));
}
let utf16_length = total_bytes - 1;
let total_data_end = data_start + total_bytes;
if total_data_end > self.data.len() {
return Err(out_of_bounds_error!());
}
if utf16_length % 2 != 0 {
return Err(malformed_error!("Invalid UTF-16 length at index {}", index));
}
let utf16_data_end = data_start + utf16_length;
let utf16_data = &self.data[data_start..utf16_data_end];
let str_slice = unsafe {
let ptr = utf16_data.as_ptr();
#[allow(clippy::cast_ptr_alignment)]
core::ptr::slice_from_raw_parts(ptr.cast::<u16>(), utf16_data.len() / 2)
.as_ref()
.ok_or_else(|| malformed_error!("null pointer in user string slice conversion"))?
};
Ok(U16Str::from_slice(str_slice))
}
#[must_use]
pub fn iter(&self) -> UserStringsIterator<'_> {
UserStringsIterator::new(self)
}
#[must_use]
pub fn data(&self) -> &[u8] {
self.data
}
#[must_use]
pub fn contains(&self, s: &str) -> bool {
self.iter().any(|(_, value)| value.to_string_lossy() == s)
}
#[must_use]
pub fn find(&self, s: &str) -> Option<u32> {
self.iter()
.find(|(_, value)| value.to_string_lossy() == s)
.and_then(|(offset, _)| u32::try_from(offset).ok())
}
#[must_use]
pub fn find_all(&self, s: &str) -> Vec<u32> {
self.iter()
.filter(|(_, value)| value.to_string_lossy() == s)
.filter_map(|(offset, _)| u32::try_from(offset).ok())
.collect()
}
}
impl<'a> IntoIterator for &'a UserStrings<'a> {
type Item = (usize, &'a U16Str);
type IntoIter = UserStringsIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct UserStringsIterator<'a> {
user_strings: &'a UserStrings<'a>,
position: usize,
}
impl<'a> UserStringsIterator<'a> {
pub(crate) fn new(user_strings: &'a UserStrings<'a>) -> Self {
Self {
user_strings,
position: 1,
}
}
}
impl<'a> Iterator for UserStringsIterator<'a> {
type Item = (usize, &'a U16Str);
fn next(&mut self) -> Option<Self::Item> {
const MAX_RECOVERY_ATTEMPTS: usize = 10;
let mut recovery_attempts = 0;
loop {
if self.position >= self.user_strings.data.len() {
return None;
}
if recovery_attempts >= MAX_RECOVERY_ATTEMPTS {
return None;
}
let start_position = self.position;
let (total_bytes, compressed_length_size) = if let Ok((length, consumed)) =
read_compressed_int(self.user_strings.data, &mut self.position)
{
self.position -= consumed;
(length, consumed)
} else {
self.position += 1;
recovery_attempts += 1;
continue;
};
if total_bytes == 0 {
self.position += compressed_length_size;
recovery_attempts += 1;
continue;
}
let Ok(string) = self.user_strings.get(start_position) else {
self.position += compressed_length_size + total_bytes;
recovery_attempts += 1;
continue;
};
let new_position = self.position + compressed_length_size + total_bytes;
self.position = new_position;
return Some((start_position, string));
}
}
}
#[cfg(test)]
mod tests {
use widestring::u16str;
use super::*;
#[test]
fn crafted() {
#[rustfmt::skip]
let data: [u8; 29] = [
0x00, 0x1b, 0x48, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x6c, 0x00, 0x6f, 0x00, 0x2c, 0x00, 0x20, 0x00, 0x57, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x6c, 0x00, 0x64, 0x00, 0x21, 0x00, 0x00
];
let us_str = UserStrings::from(&data).unwrap();
assert_eq!(us_str.get(1).unwrap(), u16str!("Hello, World!"));
}
#[test]
fn invalid() {
let data_empty = [];
if UserStrings::from(&data_empty).is_ok() {
panic!("This should not be valid!")
}
let data_invalid_first = [
0x22, 0x1b, 0x48, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x6c, 0x00, 0x6f, 0x00, 0x2c, 0x00,
0x20, 0x00, 0x57, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x6c, 0x00, 0x64, 0x00, 0x21, 0x00,
0x00, 0x00, 0x00, 0x00,
];
if UserStrings::from(&data_invalid_first).is_ok() {
panic!("This should not be valid!")
}
let data_invalid_first = [0x00, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC];
let us_str = UserStrings::from(&data_invalid_first).unwrap();
if us_str.get(1).is_ok() {
panic!("This should not be valid!")
}
}
#[test]
fn test_userstrings_iterator_basic() {
let data = [
0x00, 0x05, 0x48, 0x00, 0x69, 0x00, 0x00, ];
let user_strings = UserStrings::from(&data).unwrap();
let mut iter = user_strings.iter();
let first = iter.next().unwrap();
assert_eq!(first.0, 1);
assert_eq!(first.1.to_string_lossy(), "Hi");
assert!(iter.next().is_none());
}
#[test]
fn test_userstrings_iterator_multiple() {
let data = [
0x00, 0x05, 0x48, 0x00, 0x69, 0x00, 0x00, 0x07, 0x42, 0x00, 0x79, 0x00, 0x65, 0x00, 0x00, ];
let user_strings = UserStrings::from(&data).unwrap();
let mut iter = user_strings.iter();
let first = iter.next().unwrap();
assert_eq!(first.0, 1);
assert_eq!(first.1.to_string_lossy(), "Hi");
let second = iter.next().unwrap();
assert_eq!(second.0, 7); assert_eq!(second.1.to_string_lossy(), "Bye");
assert!(iter.next().is_none());
}
#[test]
fn test_userstrings_iterator_empty_string() {
let data = [
0x00, 0x01, 0x00, 0x05, 0x48, 0x00, 0x69, 0x00, 0x00, ];
let user_strings = UserStrings::from(&data).unwrap();
let mut iter = user_strings.iter();
let first = iter.next().unwrap();
assert_eq!(first.0, 1);
assert_eq!(first.1.to_string_lossy(), "");
let second = iter.next().unwrap();
assert_eq!(second.0, 3);
assert_eq!(second.1.to_string_lossy(), "Hi");
assert!(iter.next().is_none());
}
#[test]
fn test_userstrings_iterator_long_string() {
let mut data = vec![0x00];
data.push(0x0B);
for _ in 0..5 {
data.extend_from_slice(&[0x41, 0x00]);
}
data.push(0x00);
let user_strings = UserStrings::from(&data).unwrap();
let mut iter = user_strings.iter();
let first = iter.next().unwrap();
assert_eq!(first.0, 1);
assert_eq!(first.1.to_string_lossy(), "AAAAA");
assert!(iter.next().is_none());
}
#[test]
fn test_userstrings_iterator_truncated_data() {
let data = [0x00, 0x07, 0x48, 0x00, 0x69];
let user_strings = UserStrings::from(&data).unwrap();
let mut iter = user_strings.iter();
assert!(iter.next().is_none());
}
#[test]
fn test_userstrings_iterator_invalid_utf16_length() {
let data = [0x00, 0x04, 0x48, 0x00, 0x69]; let user_strings = UserStrings::from(&data).unwrap();
let mut iter = user_strings.iter();
assert!(iter.next().is_none());
}
}