use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::ops::Range;
use core::str;
use crate::allocator::*;
#[cfg(any(debug_assertions, test))]
use crate::error::{ValidationError, ValidationResult};
use crate::{Offset, StringId, StringIndex};
pub type StringPool<O = u32, I = u16, const NULL_PADDED: bool = false, A = Global> =
StringTable<O, I, NULL_PADDED, A>;
pub type StringPoolIter<'a, O = u32, const NULL_PADDED: bool = false> =
StringTableIter<'a, O, NULL_PADDED>;
pub struct StringTable<
O = u32,
I = u16,
const NULL_PADDED: bool = false,
A: Allocator + Clone = Global,
> where
O: Offset,
I: StringIndex,
{
bytes: Box<[u8], A>,
offsets: Box<[O], A>,
_id: PhantomData<I>,
}
impl StringTable<u32, u16, false, Global> {
#[inline]
pub fn empty() -> Self {
Self::empty_in(Global)
}
}
impl<O: Offset, I: StringIndex, const NULL_PADDED: bool, A: Allocator + Clone>
StringTable<O, I, NULL_PADDED, A>
{
pub fn empty_in(allocator: A) -> Self {
let bytes = Vec::new_in(allocator.clone()).into_boxed_slice();
let mut offsets = Vec::with_capacity_in(1, allocator);
offsets.push(zero_offset::<O>());
Self {
bytes,
offsets: offsets.into_boxed_slice(),
_id: PhantomData,
}
}
#[inline]
pub(crate) fn from_parts_unchecked(bytes: Box<[u8], A>, offsets: Box<[O], A>) -> Self {
Self {
bytes,
offsets,
_id: PhantomData,
}
}
#[inline]
pub fn len(&self) -> usize {
self.offsets.len().saturating_sub(1)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn get(&self, id: StringId<I>) -> Option<&str> {
let index = id.into_usize();
if index < self.len() {
let start = unsafe { self.offsets.get_unchecked(index) }.to_usize();
let end = unsafe { self.offsets.get_unchecked(index + 1) }.to_usize();
let logical_end = end.saturating_sub(usize::from(NULL_PADDED));
debug_assert!(logical_end >= start);
let bytes = unsafe { self.bytes.get_unchecked(start..logical_end) };
Some(unsafe { str::from_utf8_unchecked(bytes) })
} else {
None
}
}
#[inline]
pub unsafe fn get_unchecked(&self, id: StringId<I>) -> &str {
let index = id.into_usize();
let start = unsafe { self.offsets.get_unchecked(index) }.to_usize();
let end = unsafe { self.offsets.get_unchecked(index + 1) }.to_usize();
let logical_end = end.saturating_sub(usize::from(NULL_PADDED));
debug_assert!(logical_end >= start);
let bytes = unsafe { self.bytes.get_unchecked(start..logical_end) };
unsafe { str::from_utf8_unchecked(bytes) }
}
#[inline]
pub fn iter(&self) -> StringTableIter<'_, O, NULL_PADDED> {
let offsets = &self.offsets;
let strings = offsets.len().saturating_sub(1);
let cur_offset = offsets.as_ptr();
StringTableIter {
bytes: &self.bytes,
cur_offset,
max_offset: unsafe { cur_offset.add(strings) },
remaining: strings,
_offsets: PhantomData,
}
}
#[inline]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
#[inline]
pub fn contains(&self, value: &str) -> bool {
self.iter().any(|item| item == value)
}
#[inline]
pub fn offsets(&self) -> &[O] {
&self.offsets
}
#[inline]
pub fn byte_range(&self, id: StringId<I>) -> Option<Range<usize>> {
let index = id.into_usize();
if index < self.len() {
let start = unsafe { self.offsets.get_unchecked(index) }.to_usize();
let end = unsafe { self.offsets.get_unchecked(index + 1) }.to_usize();
let logical_end = end.saturating_sub(usize::from(NULL_PADDED));
debug_assert!(logical_end >= start);
Some(start..logical_end)
} else {
None
}
}
#[cfg(any(debug_assertions, test))]
pub(crate) fn validate(&self) -> ValidationResult<()> {
let bytes_len = self.bytes.len();
if O::try_from_usize(bytes_len).is_none() {
return Err(ValidationError::TooManyBytesForOffsetType {
bytes: bytes_len,
offset_type: O::TYPE_NAME,
});
}
let strings = self.len();
if strings > 0 && I::try_from_usize(strings - 1).is_none() {
return Err(ValidationError::TooManyStrings {
strings,
id_type: I::TYPE_NAME,
});
}
let offsets = &self.offsets;
if offsets.is_empty() {
return Err(ValidationError::MissingSentinelOffset);
}
let last_index = offsets.len() - 1;
let found_last = offsets[last_index].to_usize();
if found_last != bytes_len {
return Err(ValidationError::LastOffsetMismatch {
found: found_last,
expected: bytes_len,
});
}
let mut previous = 0usize;
for (index, &offset) in offsets.iter().enumerate() {
let current = offset.to_usize();
if current > bytes_len {
return Err(ValidationError::OffsetOutOfBounds {
index,
offset: current,
bytes_len,
});
}
if index == 0 {
previous = current;
continue;
}
if current < previous {
return Err(ValidationError::OffsetsNotMonotonic {
index,
previous,
current,
});
}
if NULL_PADDED {
if current == previous {
return Err(ValidationError::NullPaddedStringMissingTerminatorByte {
index: index - 1,
});
}
let terminator_index = current - 1;
if self.bytes[terminator_index] != 0 {
return Err(ValidationError::NullPaddedStringMissingTrailingNul {
index: index - 1,
});
}
if str::from_utf8(&self.bytes[previous..terminator_index]).is_err() {
return Err(ValidationError::InvalidUtf8 { index: index - 1 });
}
} else if str::from_utf8(&self.bytes[previous..current]).is_err() {
return Err(ValidationError::InvalidUtf8 { index: index - 1 });
}
previous = current;
}
Ok(())
}
}
pub struct StringTableIter<'a, O: Offset = u32, const NULL_PADDED: bool = false> {
bytes: &'a [u8],
cur_offset: *const O,
max_offset: *const O,
remaining: usize,
_offsets: PhantomData<&'a [O]>,
}
impl<'a, O: Offset, const NULL_PADDED: bool> Iterator for StringTableIter<'a, O, NULL_PADDED> {
type Item = &'a str;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.cur_offset != self.max_offset {
let start = unsafe { (*self.cur_offset).to_usize() };
self.cur_offset = unsafe { self.cur_offset.add(1) };
let end = unsafe { (*self.cur_offset).to_usize() };
self.remaining -= 1;
let logical_end = end.saturating_sub(usize::from(NULL_PADDED));
debug_assert!(logical_end >= start);
let bytes = unsafe { self.bytes.get_unchecked(start..logical_end) };
Some(unsafe { str::from_utf8_unchecked(bytes) })
} else {
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.len();
(remaining, Some(remaining))
}
}
impl<O: Offset, const NULL_PADDED: bool> ExactSizeIterator for StringTableIter<'_, O, NULL_PADDED> {
#[inline]
fn len(&self) -> usize {
self.remaining
}
}
impl<O: Offset, const NULL_PADDED: bool> FusedIterator for StringTableIter<'_, O, NULL_PADDED> {}
#[inline]
fn zero_offset<O: Offset>() -> O {
unsafe { O::try_from_usize(0).unwrap_unchecked() }
}
#[cfg(test)]
mod tests {
use crate::allocator::{Global, Vec};
use crate::error::{ValidationError, ValidationResult};
use crate::{Offset, StringId, StringIndex, StringTable};
fn validate_parts<O: Offset, I: StringIndex, const NULL_PADDED: bool>(
bytes: Vec<u8, Global>,
offsets: Vec<O, Global>,
) -> ValidationResult<()> {
let table = StringTable::<O, I, NULL_PADDED, Global>::from_parts_unchecked(
bytes.into_boxed_slice(),
offsets.into_boxed_slice(),
);
table.validate()
}
#[test]
fn validate_rejects_missing_sentinel() {
let mut bytes = Vec::new_in(Global);
bytes.extend_from_slice(b"hello");
let mut offsets = Vec::new_in(Global);
offsets.push(0u32);
let result = validate_parts::<u32, u32, false>(bytes, offsets);
assert!(matches!(
result,
Err(ValidationError::LastOffsetMismatch { .. })
));
}
#[test]
fn validate_rejects_non_monotonic_offsets() {
let mut bytes = Vec::new_in(Global);
bytes.extend_from_slice(b"abcd");
let mut offsets = Vec::new_in(Global);
offsets.push(0u32);
offsets.push(3u32);
offsets.push(2u32);
offsets.push(4u32);
let result = validate_parts::<u32, u32, false>(bytes, offsets);
assert!(matches!(
result,
Err(ValidationError::OffsetsNotMonotonic { .. })
));
}
#[test]
fn validate_rejects_invalid_utf8() {
let mut bytes = Vec::new_in(Global);
bytes.push(0xFF);
let mut offsets = Vec::new_in(Global);
offsets.push(0u32);
offsets.push(1u32);
let result = validate_parts::<u32, u32, false>(bytes, offsets);
assert!(matches!(result, Err(ValidationError::InvalidUtf8 { .. })));
}
#[test]
fn validate_null_padded_accepts_trailing_nul() {
let mut bytes = Vec::new_in(Global);
bytes.extend_from_slice(b"hello\0");
let mut offsets = Vec::new_in(Global);
offsets.push(0u32);
offsets.push(6u32);
let result = validate_parts::<u32, u32, true>(bytes, offsets);
assert!(result.is_ok());
}
#[test]
fn validate_null_padded_rejects_missing_trailing_nul() {
let mut bytes = Vec::new_in(Global);
bytes.extend_from_slice(b"hello");
let mut offsets = Vec::new_in(Global);
offsets.push(0u32);
offsets.push(5u32);
let result = validate_parts::<u32, u32, true>(bytes, offsets);
assert!(matches!(
result,
Err(ValidationError::NullPaddedStringMissingTrailingNul { .. })
));
}
#[test]
fn validate_rejects_offset_type_overflow() {
let mut bytes = Vec::new_in(Global);
bytes.extend_from_slice(b"abc");
let mut offsets = Vec::new_in(Global);
offsets.push(0u8);
offsets.push(3u8);
let result = validate_parts::<u8, u32, false>(bytes, offsets);
assert!(result.is_ok());
let mut too_big = Vec::new_in(Global);
too_big.extend_from_slice(&[0u8; 300]);
let mut offsets = Vec::new_in(Global);
offsets.push(0u8);
offsets.push(u8::MAX);
let result = validate_parts::<u8, u32, false>(too_big, offsets);
assert!(matches!(
result,
Err(ValidationError::TooManyBytesForOffsetType { .. })
));
}
#[test]
fn validate_rejects_id_type_overflow() {
let bytes = Vec::new_in(Global).into_boxed_slice();
let mut offsets = Vec::new_in(Global);
for _ in 0..258 {
offsets.push(0u32);
}
let table = StringTable::<u32, u8>::from_parts_unchecked(bytes, offsets.into_boxed_slice());
let result = table.validate();
assert!(matches!(
result,
Err(ValidationError::TooManyStrings {
strings: 257,
id_type: "u8"
})
));
}
#[test]
fn get_returns_none_for_invalid_id() {
let table = StringTable::empty();
assert_eq!(table.get(StringId::new(0)), None);
}
}