use crate::num::NonZeroChar;
use std::ffi::CStr;
use std::num::NonZeroU8;
use std::ops::Range;
use std::str::Chars;
#[derive(Debug, PartialEq, Eq)]
pub enum EscapeError {
ZeroChars,
MoreThanOneChar,
LoneSlash,
InvalidEscape,
BareCarriageReturn,
BareCarriageReturnInRawString,
EscapeOnlyChar,
TooShortHexEscape,
InvalidCharInHexEscape,
OutOfRangeHexEscape,
NoBraceInUnicodeEscape,
InvalidCharInUnicodeEscape,
EmptyUnicodeEscape,
UnclosedUnicodeEscape,
LeadingUnderscoreUnicodeEscape,
OverlongUnicodeEscape,
LoneSurrogateUnicodeEscape,
OutOfRangeUnicodeEscape,
UnicodeEscapeInByte,
NonAsciiCharInByte,
NulInCStr,
UnskippedWhitespaceWarning,
MultipleSkippedLinesWarning,
}
impl EscapeError {
pub fn is_fatal(&self) -> bool {
!matches!(
self,
EscapeError::UnskippedWhitespaceWarning | EscapeError::MultipleSkippedLinesWarning
)
}
}
pub fn check_raw_str(src: &str, callback: impl FnMut(Range<usize>, Result<char, EscapeError>)) {
str::check_raw(src, callback);
}
pub fn check_raw_byte_str(src: &str, callback: impl FnMut(Range<usize>, Result<u8, EscapeError>)) {
<[u8]>::check_raw(src, callback);
}
pub fn check_raw_c_str(
src: &str,
callback: impl FnMut(Range<usize>, Result<NonZeroChar, EscapeError>),
) {
CStr::check_raw(src, callback);
}
trait CheckRaw {
type RawUnit;
fn char2raw_unit(c: char) -> Result<Self::RawUnit, EscapeError>;
fn check_raw(
src: &str,
mut callback: impl FnMut(Range<usize>, Result<Self::RawUnit, EscapeError>),
) {
let mut chars = src.chars();
while let Some(c) = chars.next() {
let start = src.len() - chars.as_str().len() - c.len_utf8();
let res = match c {
'\r' => Err(EscapeError::BareCarriageReturnInRawString),
_ => Self::char2raw_unit(c),
};
let end = src.len() - chars.as_str().len();
callback(start..end, res);
}
}
}
impl CheckRaw for str {
type RawUnit = char;
#[inline]
fn char2raw_unit(c: char) -> Result<Self::RawUnit, EscapeError> {
Ok(c)
}
}
impl CheckRaw for [u8] {
type RawUnit = u8;
#[inline]
fn char2raw_unit(c: char) -> Result<Self::RawUnit, EscapeError> {
char2byte(c)
}
}
#[inline]
fn char2byte(c: char) -> Result<u8, EscapeError> {
if c.is_ascii() {
Ok(c as u8)
} else {
Err(EscapeError::NonAsciiCharInByte)
}
}
impl CheckRaw for CStr {
type RawUnit = NonZeroChar;
#[inline]
fn char2raw_unit(c: char) -> Result<Self::RawUnit, EscapeError> {
NonZeroChar::new(c).ok_or(EscapeError::NulInCStr)
}
}
#[inline]
pub fn unescape_char(src: &str) -> Result<char, EscapeError> {
str::unescape_single(&mut src.chars())
}
#[inline]
pub fn unescape_byte(src: &str) -> Result<u8, EscapeError> {
<[u8]>::unescape_single(&mut src.chars())
}
pub fn unescape_str(src: &str, callback: impl FnMut(Range<usize>, Result<char, EscapeError>)) {
str::unescape(src, callback)
}
pub fn unescape_byte_str(src: &str, callback: impl FnMut(Range<usize>, Result<u8, EscapeError>)) {
<[u8]>::unescape(src, callback)
}
pub fn unescape_c_str(
src: &str,
callback: impl FnMut(Range<usize>, Result<MixedUnit, EscapeError>),
) {
CStr::unescape(src, callback)
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum MixedUnit {
Char(NonZeroChar),
HighByte(NonZeroU8),
}
impl From<NonZeroChar> for MixedUnit {
#[inline]
fn from(c: NonZeroChar) -> Self {
MixedUnit::Char(c)
}
}
impl From<NonZeroU8> for MixedUnit {
#[inline]
fn from(byte: NonZeroU8) -> Self {
if byte.get().is_ascii() {
MixedUnit::Char(NonZeroChar::new(byte.get() as char).unwrap())
} else {
MixedUnit::HighByte(byte)
}
}
}
impl TryFrom<char> for MixedUnit {
type Error = EscapeError;
#[inline]
fn try_from(c: char) -> Result<Self, EscapeError> {
NonZeroChar::new(c)
.map(MixedUnit::Char)
.ok_or(EscapeError::NulInCStr)
}
}
impl TryFrom<u8> for MixedUnit {
type Error = EscapeError;
#[inline]
fn try_from(byte: u8) -> Result<Self, EscapeError> {
NonZeroU8::new(byte)
.map(From::from)
.ok_or(EscapeError::NulInCStr)
}
}
trait Unescape {
type Unit;
const ZERO_RESULT: Result<Self::Unit, EscapeError>;
fn nonzero_byte2unit(b: NonZeroU8) -> Self::Unit;
fn char2unit(c: char) -> Result<Self::Unit, EscapeError>;
fn hex2unit(b: u8) -> Result<Self::Unit, EscapeError>;
fn unicode2unit(r: Result<char, EscapeError>) -> Result<Self::Unit, EscapeError>;
fn unescape_single(chars: &mut Chars<'_>) -> Result<Self::Unit, EscapeError> {
let res = match chars.next().ok_or(EscapeError::ZeroChars)? {
'\\' => Self::unescape_1(chars),
'\n' | '\t' | '\'' => Err(EscapeError::EscapeOnlyChar),
'\r' => Err(EscapeError::BareCarriageReturn),
c => Self::char2unit(c),
}?;
if chars.next().is_some() {
return Err(EscapeError::MoreThanOneChar);
}
Ok(res)
}
fn unescape_1(chars: &mut Chars<'_>) -> Result<Self::Unit, EscapeError> {
let c = chars.next().ok_or(EscapeError::LoneSlash)?;
if c == '0' {
Self::ZERO_RESULT
} else {
simple_escape(c)
.map(|b| Self::nonzero_byte2unit(b))
.or_else(|c| match c {
'x' => Self::hex2unit(hex_escape(chars)?),
'u' => Self::unicode2unit({
let value = unicode_escape(chars)?;
if value > char::MAX as u32 {
Err(EscapeError::OutOfRangeUnicodeEscape)
} else {
char::from_u32(value).ok_or(EscapeError::LoneSurrogateUnicodeEscape)
}
}),
_ => Err(EscapeError::InvalidEscape),
})
}
}
fn unescape(
src: &str,
mut callback: impl FnMut(Range<usize>, Result<Self::Unit, EscapeError>),
) {
let mut chars = src.chars();
while let Some(c) = chars.next() {
let start = src.len() - chars.as_str().len() - c.len_utf8();
let res = match c {
'\\' => {
if let Some(b'\n') = chars.as_str().as_bytes().first() {
let _ = chars.next();
let callback_err = |range, err| callback(range, Err(err));
skip_ascii_whitespace(&mut chars, start, callback_err);
continue;
} else {
Self::unescape_1(&mut chars)
}
}
'"' => Err(EscapeError::EscapeOnlyChar),
'\r' => Err(EscapeError::BareCarriageReturn),
c => Self::char2unit(c),
};
let end = src.len() - chars.as_str().len();
callback(start..end, res);
}
}
}
#[inline] fn simple_escape(c: char) -> Result<NonZeroU8, char> {
Ok(NonZeroU8::new(match c {
'"' => b'"',
'n' => b'\n',
'r' => b'\r',
't' => b'\t',
'\\' => b'\\',
'\'' => b'\'',
_ => Err(c)?,
})
.unwrap())
}
#[inline] fn hex_escape(chars: &mut impl Iterator<Item = char>) -> Result<u8, EscapeError> {
let hi = chars.next().ok_or(EscapeError::TooShortHexEscape)?;
let hi = hi.to_digit(16).ok_or(EscapeError::InvalidCharInHexEscape)?;
let lo = chars.next().ok_or(EscapeError::TooShortHexEscape)?;
let lo = lo.to_digit(16).ok_or(EscapeError::InvalidCharInHexEscape)?;
Ok((hi * 16 + lo) as u8)
}
#[inline] fn unicode_escape(chars: &mut impl Iterator<Item = char>) -> Result<u32, EscapeError> {
if chars.next() != Some('{') {
return Err(EscapeError::NoBraceInUnicodeEscape);
}
let mut value: u32 = match chars.next().ok_or(EscapeError::UnclosedUnicodeEscape)? {
'_' => return Err(EscapeError::LeadingUnderscoreUnicodeEscape),
'}' => return Err(EscapeError::EmptyUnicodeEscape),
c => c
.to_digit(16)
.ok_or(EscapeError::InvalidCharInUnicodeEscape)?,
};
let mut n_digits = 1;
loop {
match chars.next() {
None => return Err(EscapeError::UnclosedUnicodeEscape),
Some('_') => continue,
Some('}') => {
return if n_digits > 6 {
Err(EscapeError::OverlongUnicodeEscape)
} else {
Ok(value)
};
}
Some(c) => {
let digit: u32 = c
.to_digit(16)
.ok_or(EscapeError::InvalidCharInUnicodeEscape)?;
n_digits += 1;
if n_digits > 6 {
continue;
}
value = value * 16 + digit;
}
};
}
}
#[inline] fn skip_ascii_whitespace(
chars: &mut Chars<'_>,
start: usize,
mut callback: impl FnMut(Range<usize>, EscapeError),
) {
let rest = chars.as_str();
let first_non_space = rest
.bytes()
.position(|b| b != b' ' && b != b'\t' && b != b'\n' && b != b'\r')
.unwrap_or(rest.len());
let (space, rest) = rest.split_at(first_non_space);
let end = start + 2 + first_non_space;
if space.contains('\n') {
callback(start..end, EscapeError::MultipleSkippedLinesWarning);
}
*chars = rest.chars();
if let Some(c) = chars.clone().next() {
if c.is_whitespace() {
callback(
start..end + c.len_utf8(),
EscapeError::UnskippedWhitespaceWarning,
);
}
}
}
impl Unescape for str {
type Unit = char;
const ZERO_RESULT: Result<Self::Unit, EscapeError> = Ok('\0');
#[inline]
fn nonzero_byte2unit(b: NonZeroU8) -> Self::Unit {
b.get().into()
}
#[inline]
fn char2unit(c: char) -> Result<Self::Unit, EscapeError> {
Ok(c)
}
#[inline]
fn hex2unit(b: u8) -> Result<Self::Unit, EscapeError> {
if b.is_ascii() {
Ok(b as char)
} else {
Err(EscapeError::OutOfRangeHexEscape)
}
}
#[inline]
fn unicode2unit(r: Result<char, EscapeError>) -> Result<Self::Unit, EscapeError> {
r
}
}
impl Unescape for [u8] {
type Unit = u8;
const ZERO_RESULT: Result<Self::Unit, EscapeError> = Ok(b'\0');
#[inline]
fn nonzero_byte2unit(b: NonZeroU8) -> Self::Unit {
b.get()
}
#[inline]
fn char2unit(c: char) -> Result<Self::Unit, EscapeError> {
char2byte(c)
}
#[inline]
fn hex2unit(b: u8) -> Result<Self::Unit, EscapeError> {
Ok(b)
}
#[inline]
fn unicode2unit(_r: Result<char, EscapeError>) -> Result<Self::Unit, EscapeError> {
Err(EscapeError::UnicodeEscapeInByte)
}
}
impl Unescape for CStr {
type Unit = MixedUnit;
const ZERO_RESULT: Result<Self::Unit, EscapeError> = Err(EscapeError::NulInCStr);
#[inline]
fn nonzero_byte2unit(b: NonZeroU8) -> Self::Unit {
b.into()
}
#[inline]
fn char2unit(c: char) -> Result<Self::Unit, EscapeError> {
c.try_into()
}
#[inline]
fn hex2unit(byte: u8) -> Result<Self::Unit, EscapeError> {
byte.try_into()
}
#[inline]
fn unicode2unit(r: Result<char, EscapeError>) -> Result<Self::Unit, EscapeError> {
Self::char2unit(r?)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Mode {
Char,
Byte,
Str,
RawStr,
ByteStr,
RawByteStr,
CStr,
RawCStr,
}
impl Mode {
pub fn in_double_quotes(self) -> bool {
match self {
Mode::Str
| Mode::RawStr
| Mode::ByteStr
| Mode::RawByteStr
| Mode::CStr
| Mode::RawCStr => true,
Mode::Char | Mode::Byte => false,
}
}
pub fn prefix_noraw(self) -> &'static str {
match self {
Mode::Char | Mode::Str | Mode::RawStr => "",
Mode::Byte | Mode::ByteStr | Mode::RawByteStr => "b",
Mode::CStr | Mode::RawCStr => "c",
}
}
}
pub fn check_for_errors(
src: &str,
mode: Mode,
mut error_callback: impl FnMut(Range<usize>, EscapeError),
) {
match mode {
Mode::Char => {
let mut chars = src.chars();
if let Err(e) = str::unescape_single(&mut chars) {
error_callback(0..(src.len() - chars.as_str().len()), e);
}
}
Mode::Byte => {
let mut chars = src.chars();
if let Err(e) = <[u8]>::unescape_single(&mut chars) {
error_callback(0..(src.len() - chars.as_str().len()), e);
}
}
Mode::Str => unescape_str(src, |range, res| {
if let Err(e) = res {
error_callback(range, e);
}
}),
Mode::ByteStr => unescape_byte_str(src, |range, res| {
if let Err(e) = res {
error_callback(range, e);
}
}),
Mode::CStr => unescape_c_str(src, |range, res| {
if let Err(e) = res {
error_callback(range, e);
}
}),
Mode::RawStr => check_raw_str(src, |range, res| {
if let Err(e) = res {
error_callback(range, e);
}
}),
Mode::RawByteStr => check_raw_byte_str(src, |range, res| {
if let Err(e) = res {
error_callback(range, e);
}
}),
Mode::RawCStr => check_raw_c_str(src, |range, res| {
if let Err(e) = res {
error_callback(range, e);
}
}),
}
}