use super::*;
use core::{cmp::Ordering, fmt::Write, marker::PhantomData, ptr::NonNull};
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct ZStr<'a> {
pub(crate) nn: NonNull<u8>,
pub(crate) life: PhantomData<&'a [u8]>,
}
impl<'a> ZStr<'a> {
#[inline]
#[track_caller]
pub const fn from_lit(s: &'static str) -> ZStr<'static> {
let bytes = s.as_bytes();
let mut tail_index = bytes.len() - 1;
while bytes[tail_index] == 0 {
tail_index -= 1;
}
assert!(tail_index < bytes.len() - 1, "No trailing nulls.");
let mut i = 0;
while i < tail_index {
if bytes[i] == 0 {
panic!("Input contains interior null.");
}
i += 1;
}
ZStr {
nn: unsafe { NonNull::new_unchecked(s.as_ptr() as *mut u8) },
life: PhantomData,
}
}
#[inline]
pub fn bytes(self) -> impl Iterator<Item = u8> + 'a {
unsafe { ConstPtrIter::read_until_default(self.nn.as_ptr()) }
}
#[inline]
pub fn chars(self) -> impl Iterator<Item = char> + 'a {
CharDecoder::from(self.bytes())
}
#[inline]
#[must_use]
pub const fn as_ptr(self) -> *const u8 {
self.nn.as_ptr()
}
}
impl<'a> TryFrom<&'a str> for ZStr<'a> {
type Error = ZStringError;
#[inline]
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
let trimmed = value.trim_end_matches('\0');
if value.len() == trimmed.len() {
Err(ZStringError::NoTrailingNulls)
} else if trimmed.contains('\0') {
Err(ZStringError::InteriorNulls)
} else {
Ok(Self {
nn: NonNull::new(value.as_ptr() as *mut u8).unwrap(),
life: PhantomData,
})
}
}
}
impl core::fmt::Display for ZStr<'_> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
for ch in self.chars() {
write!(f, "{ch}")?;
}
Ok(())
}
}
impl core::fmt::Debug for ZStr<'_> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_char('"')?;
core::fmt::Display::fmt(self, f)?;
f.write_char('"')?;
Ok(())
}
}
impl core::fmt::Pointer for ZStr<'_> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Pointer::fmt(&self.nn, f)
}
}
impl PartialEq<ZStr<'_>> for ZStr<'_> {
#[inline]
#[must_use]
fn eq(&self, other: &ZStr<'_>) -> bool {
if self.nn == other.nn {
true
} else {
self.bytes().eq(other.bytes())
}
}
}
impl PartialOrd<ZStr<'_>> for ZStr<'_> {
#[inline]
#[must_use]
fn partial_cmp(&self, other: &ZStr<'_>) -> Option<core::cmp::Ordering> {
if self.nn == other.nn {
Some(Ordering::Equal)
} else {
Some(self.bytes().cmp(other.bytes()))
}
}
}
impl PartialEq<&str> for ZStr<'_> {
#[inline]
#[must_use]
fn eq(&self, other: &&str) -> bool {
self.bytes().eq(other.as_bytes().iter().copied())
}
}
impl PartialOrd<&str> for ZStr<'_> {
#[inline]
#[must_use]
fn partial_cmp(&self, other: &&str) -> Option<core::cmp::Ordering> {
Some(self.bytes().cmp(other.as_bytes().iter().copied()))
}
}
#[cfg(feature = "alloc")]
impl PartialEq<ZString> for ZStr<'_> {
#[inline]
#[must_use]
fn eq(&self, other: &ZString) -> bool {
self.eq(&other.as_zstr())
}
}
#[cfg(feature = "alloc")]
impl PartialOrd<ZString> for ZStr<'_> {
#[inline]
#[must_use]
fn partial_cmp(&self, other: &ZString) -> Option<core::cmp::Ordering> {
self.partial_cmp(&other.as_zstr())
}
}
impl core::hash::Hash for ZStr<'_> {
#[inline]
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
for b in self.bytes() {
state.write_u8(b)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ZStringError {
NoTrailingNulls,
InteriorNulls,
}