use bstr::B;
use either::Either;
use nom::{
branch::alt,
bytes::streaming::{escaped_transform, is_not, tag},
character::streaming::{char, digit1, one_of},
combinator::{map, map_res, opt, recognize},
error::context,
number::streaming::recognize_float,
sequence::{preceded, terminated, tuple},
};
use ordered_float::NotNan;
pub type IResult<'a, T> = nom::IResult<&'a [u8], T, crate::error::Error<'a>>;
pub trait FromSql<'a>: Sized {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self>;
}
impl<'a> FromSql<'a> for bool {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context("1 or 0", map(one_of("01"), |b| b == '1'))(s)
}
}
macro_rules! number_impl {
(
$( #[doc = $com:expr] )*
$type_name:ty
$implementation:block
) => {
$( #[doc = $com] )*
impl<'a> FromSql<'a> for $type_name {
fn from_sql(s: &'a [u8]) -> IResult<'a, $type_name> {
context(
concat!("number (", stringify!($type_name), ")"),
map_res($implementation, |num: &[u8]| {
let s = std::str::from_utf8(num).map_err(Either::Right)?;
s.parse().map_err(Either::Left)
}),
)(s)
}
}
};
(
$( #[doc = $com:expr] )*
$type_name:ty
$implementation:block
$further_processing:block
) => {
$( #[doc = $com] )*
impl<'a> FromSql<'a> for $type_name {
fn from_sql(s: &'a [u8]) -> IResult<'a, $type_name> {
context(
concat!("number (", stringify!($type_name), ")"),
map($implementation, $further_processing),
)(s)
}
}
};
}
macro_rules! unsigned_int {
($t:ident) => {
number_impl! { $t { recognize(digit1) } }
};
}
unsigned_int!(u8);
unsigned_int!(u16);
unsigned_int!(u32);
unsigned_int!(u64);
macro_rules! signed_int {
($t:ident) => {
number_impl! { $t { recognize(tuple((opt(char('-')), digit1))) } }
};
}
signed_int!(i8);
signed_int!(i16);
signed_int!(i32);
signed_int!(i64);
macro_rules! float {
($t:ident) => {
number_impl! {
#[doc = concat!("Matches a float literal with [`recognize_float`] and parses it as a [`", stringify!($t), "`].")]
$t { recognize_float }
}
number_impl! {
#[doc = concat!("Parses an [`", stringify!($t), "`] and wraps it with [`NotNan::new_unchecked`].")]
NotNan<$t> {
<$t>::from_sql
} {
|float| unsafe { NotNan::new_unchecked(float) }
}
}
};
}
float!(f32);
float!(f64);
impl<'a> FromSql<'a> for &'a [u8] {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context(
"byte string with no escape sequences",
preceded(
tag("'"),
terminated(
map(opt(is_not(B("'"))), |opt| opt.unwrap_or_else(|| B(""))),
tag("'"),
),
),
)(s)
}
}
impl<'a> FromSql<'a> for &'a str {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context(
"string with no escape sequences",
map_res(<&[u8]>::from_sql, std::str::from_utf8),
)(s)
}
}
impl<'a> FromSql<'a> for String {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context("string", map_res(<Vec<u8>>::from_sql, String::from_utf8))(s)
}
}
impl<'a> FromSql<'a> for Vec<u8> {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context(
"byte string",
preceded(
tag("'"),
terminated(
map(
opt(escaped_transform(
is_not(B("\\\"'")),
'\\',
map(one_of(B(r#"0btnrZ\'""#)), |b| match b {
'0' => B("\0"),
'b' => b"\x08",
't' => b"\t",
'n' => b"\n",
'r' => b"\r",
'Z' => b"\x1A",
'\\' => b"\\",
'\'' => b"'",
'"' => b"\"",
_ => unreachable!(),
}),
)),
|opt| opt.unwrap_or_default(),
),
tag("'"),
),
),
)(s)
}
}
impl<'a> FromSql<'a> for () {
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context("unit type", map(tag("NULL"), |_| ()))(s)
}
}
impl<'a, T> FromSql<'a> for Option<T>
where
T: FromSql<'a>,
{
fn from_sql(s: &'a [u8]) -> IResult<'a, Self> {
context(
"optional type",
alt((
context("“NULL”", map(<()>::from_sql, |_| None)),
map(T::from_sql, Some),
)),
)(s)
}
}