use alloc::vec::Vec;
use core::borrow::{Borrow, BorrowMut};
use core::error::Error;
use core::fmt;
use core::hash::{Hash, Hasher};
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use crate::cstr::CStr;
use crate::encoding::{AlwaysValid, Encoding, NullTerminable, ValidateError};
use crate::str::Str;
use crate::string::{OwnValidateError, String};
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum CStringErrorCause {
Invalid(ValidateError),
HasNull {
idx: usize,
},
}
impl CStringErrorCause {
fn write_cause(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CStringErrorCause::Invalid(_) => write!(f, "validation failed"),
CStringErrorCause::HasNull { .. } => write!(f, "null byte encountered"),
}
}
}
#[derive(Debug, PartialEq)]
pub struct CStringError {
bytes: Vec<u8>,
cause: CStringErrorCause,
}
impl fmt::Display for CStringError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Error while creating `CString`: ")?;
self.cause.write_cause(f)
}
}
impl Error for CStringError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.cause {
CStringErrorCause::Invalid(err) => Some(err),
CStringErrorCause::HasNull { .. } => None,
}
}
}
impl CStringError {
pub fn cause(&self) -> &CStringErrorCause {
&self.cause
}
pub fn into_vec(self) -> Vec<u8> {
self.bytes
}
}
#[derive(Debug, PartialEq)]
pub struct NulError {
bytes: Vec<u8>,
nul_pos: usize,
}
impl NulError {
pub fn nul_position(&self) -> usize {
self.nul_pos
}
pub fn into_vec(self) -> Vec<u8> {
self.bytes
}
}
impl fmt::Display for NulError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Error while creating `CString` from `String`: null byte encountered"
)
}
}
impl Error for NulError {}
pub struct CString<E>(PhantomData<E>, Vec<u8>);
impl<E: Encoding + NullTerminable> CString<E> {
pub unsafe fn from_vec_unchecked(mut bytes: Vec<u8>) -> CString<E> {
bytes.push(0);
CString(PhantomData, bytes)
}
pub fn new<T>(bytes: T) -> Result<CString<E>, CStringError>
where
T: Into<Vec<u8>>,
{
let bytes = bytes.into();
let nul_pos = bytes.iter().position(|b| *b == 0);
if let Some(idx) = nul_pos {
return Err(CStringError {
bytes,
cause: CStringErrorCause::HasNull { idx },
});
}
if let Err(e) = E::validate(&bytes) {
return Err(CStringError {
bytes,
cause: CStringErrorCause::Invalid(e),
});
}
Ok(unsafe { Self::from_vec_unchecked(bytes) })
}
pub fn into_string(self) -> String<E> {
self.into()
}
pub fn into_bytes(mut self) -> Vec<u8> {
self.1.pop();
self.1
}
pub fn into_bytes_with_nul(self) -> Vec<u8> {
self.1
}
pub fn from_std(value: alloc::ffi::CString) -> Result<Self, OwnValidateError> {
let bytes = value.into_bytes();
match E::validate(&bytes) {
Ok(_) => Ok(unsafe { CString::from_vec_unchecked(bytes) }),
Err(err) => Err(OwnValidateError::new(err, bytes)),
}
}
pub fn into_std(self) -> alloc::ffi::CString {
let bytes = self.into_bytes();
unsafe { alloc::ffi::CString::from_vec_unchecked(bytes) }
}
}
impl<E: NullTerminable + AlwaysValid> CString<E> {
pub fn new_valid<T>(bytes: T) -> Result<CString<E>, NulError>
where
T: Into<Vec<u8>>,
{
String::from_bytes_infallible(bytes.into()).try_into()
}
}
impl<E: NullTerminable> fmt::Debug for CString<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<CStr<E> as fmt::Debug>::fmt(self, f)
}
}
impl<E: NullTerminable> Default for CString<E> {
fn default() -> Self {
unsafe { CString::from_vec_unchecked(Vec::new()) }
}
}
impl<E: NullTerminable> PartialEq for CString<E> {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
}
}
impl<E: NullTerminable> Eq for CString<E> {}
impl<E: NullTerminable> Hash for CString<E> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_bytes().hash(state)
}
}
impl<E: NullTerminable> Deref for CString<E> {
type Target = CStr<E>;
fn deref(&self) -> &Self::Target {
unsafe { CStr::from_bytes_with_nul_unchecked(&self.1) }
}
}
impl<E: NullTerminable> DerefMut for CString<E> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { CStr::from_bytes_with_nul_unchecked_mut(&mut self.1) }
}
}
impl<E: NullTerminable> AsRef<CStr<E>> for CString<E> {
fn as_ref(&self) -> &CStr<E> {
self
}
}
impl<E: NullTerminable> AsMut<CStr<E>> for CString<E> {
fn as_mut(&mut self) -> &mut CStr<E> {
self
}
}
impl<E: NullTerminable> AsRef<Str<E>> for CString<E> {
fn as_ref(&self) -> &Str<E> {
self
}
}
impl<E: NullTerminable> Borrow<CStr<E>> for CString<E> {
fn borrow(&self) -> &CStr<E> {
self
}
}
impl<E: NullTerminable> BorrowMut<CStr<E>> for CString<E> {
fn borrow_mut(&mut self) -> &mut CStr<E> {
self
}
}
impl<E: NullTerminable> TryFrom<String<E>> for CString<E> {
type Error = NulError;
fn try_from(value: String<E>) -> Result<Self, Self::Error> {
let bytes = value.into_bytes();
if let Some(nul_pos) = bytes.iter().position(|b| *b == 0) {
return Err(NulError { bytes, nul_pos });
}
Ok(unsafe { CString::from_vec_unchecked(bytes) })
}
}
impl<E: NullTerminable> TryFrom<alloc::ffi::CString> for CString<E> {
type Error = OwnValidateError;
fn try_from(value: alloc::ffi::CString) -> Result<Self, Self::Error> {
Self::from_std(value)
}
}
impl<E: NullTerminable> From<CString<E>> for alloc::ffi::CString {
fn from(value: CString<E>) -> Self {
value.into_std()
}
}