use winnow::{
ModalResult, Parser as _,
ascii::{dec_uint, hex_uint, multispace0},
combinator::{alt, cut_err, delimited, opt, preceded, repeat, separated_pair},
error::{
AddContext, ContextError, FromExternalError, ParseError, ParserError, StrContext,
StrContextValue,
},
stream::AsChar as _,
token::{one_of, take_while},
};
#[derive(Debug, Clone)]
pub enum UnaryOperator {
Dereference,
SizeOf,
}
#[derive(Debug, Clone)]
pub enum BinaryOperator {
Add,
Subtract,
Multiply,
Divide,
}
#[derive(Debug, Clone)]
pub struct UnaryExpression<'a> {
pub operator: UnaryOperator,
pub expression: Box<Expression<'a>>,
}
#[derive(Debug, Clone)]
pub struct BinaryExpression<'a> {
pub operator: BinaryOperator,
pub lhs: Box<Expression<'a>>,
pub rhs: Box<Expression<'a>>,
}
#[derive(Debug, Clone)]
pub enum Expression<'a> {
Return,
Constant(u64),
Identifier(&'a str),
UnaryExpression(UnaryExpression<'a>),
BinaryExpression(BinaryExpression<'a>),
}
type Stream<'i> = &'i str;
trait ErrorType<'i>:
ParserError<Stream<'i>>
+ AddContext<Stream<'i>, StrContext>
+ FromExternalError<Stream<'i>, std::num::ParseIntError>
{
}
impl<'i, T> ErrorType<'i> for T where
T: ParserError<Stream<'i>>
+ AddContext<Stream<'i>, StrContext>
+ FromExternalError<Stream<'i>, std::num::ParseIntError>
{
}
#[derive(Debug)]
pub struct Error {
message: String,
span: std::ops::Range<usize>,
input: String,
}
impl Error {
fn from_parse(error: ParseError<&str, ContextError>) -> Self {
Self {
message: error.inner().to_string(),
input: error.input().to_string(),
span: error.char_span(),
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let message = annotate_snippets::Level::ERROR
.primary_title(&self.message)
.element(
annotate_snippets::Snippet::source(&self.input)
.annotation(annotate_snippets::AnnotationKind::Primary.span(self.span.clone())),
);
let renderer = annotate_snippets::Renderer::plain();
renderer.render(&[message]).fmt(f)
}
}
impl std::error::Error for Error {}
fn constant<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
alt((
preceded("0x", cut_err(hex_uint))
.context(StrContext::Label("digit"))
.context(StrContext::Expected(StrContextValue::Description(
"hexadecimal",
))),
preceded(opt(one_of(['+'])), dec_uint)
.context(StrContext::Label("digit"))
.context(StrContext::Expected(StrContextValue::Description(
"decimal",
))),
))
.map(Expression::Constant)
.parse_next(input)
}
fn return_<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
"return".value(Expression::Return).parse_next(input)
}
fn check_keyword(input: &str) -> bool {
!matches!(input, "sizeof" | "return")
}
fn identifier<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<&'i str, E> {
(
one_of(|c: char| c.is_alpha() || c == '_'),
take_while(0.., |c: char| c.is_alphanum() || c == '_'),
)
.take()
.parse_next(input)
}
fn ident<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
identifier
.verify(check_keyword)
.map(Expression::Identifier)
.parse_next(input)
}
fn unary_operator<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
alt((
separated_pair(
"*",
multispace0,
cut_err(expression).context(StrContext::Expected(StrContextValue::Description(
"expression",
))),
),
separated_pair(
"sizeof",
multispace0,
cut_err(parentheses).context(StrContext::Expected(StrContextValue::Description(
"parenthesized expression",
))),
),
))
.map(|(op, expr)| {
Expression::UnaryExpression(UnaryExpression {
operator: match op {
"*" => UnaryOperator::Dereference,
"sizeof" => UnaryOperator::SizeOf,
_ => unreachable!("unknown unary operator"),
},
expression: Box::new(expr),
})
})
.parse_next(input)
}
fn binary_op<'i>(lhs: Expression<'i>, (op, rhs): (char, Expression<'i>)) -> Expression<'i> {
Expression::BinaryExpression(BinaryExpression {
operator: match op {
'+' => BinaryOperator::Add,
'-' => BinaryOperator::Subtract,
'*' => BinaryOperator::Multiply,
'/' => BinaryOperator::Divide,
_ => unreachable!("unknown operator"),
},
lhs: Box::new(lhs),
rhs: Box::new(rhs),
})
}
fn parentheses<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
delimited(
'(',
expression,
cut_err(')').context(StrContext::Expected(StrContextValue::CharLiteral(')'))),
)
.parse_next(input)
}
fn entity<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
delimited(
multispace0,
alt((return_, ident, parentheses, unary_operator, constant)),
multispace0,
)
.context(StrContext::Label("expression"))
.parse_next(input)
}
fn mul_div<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
let init = entity.parse_next(input)?;
repeat(0.., (one_of(['*', '/']), entity))
.fold(move || init.clone(), binary_op)
.parse_next(input)
}
fn add_sub<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
let init = mul_div.parse_next(input)?;
repeat(0.., (one_of(['+', '-']), mul_div))
.fold(move || init.clone(), binary_op)
.parse_next(input)
}
fn expression<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
add_sub.parse_next(input)
}
pub fn parse(input: &str) -> Result<Expression<'_>, Error> {
expression::<ContextError>
.parse(input)
.map_err(Error::from_parse)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn small_examples_parse() {
for input in [
"return",
"nSize",
"*lpNumberOfBytesRead",
"_ElementSize * _ElementCount",
"sizeof(WORD) * 3",
"(sizeof(SID) - sizeof(DWORD) + (SubAuthorityCount) * sizeof(DWORD))",
] {
if let Err(err) = parse(input) {
panic!("parse `{input}` failed: {err}");
}
}
}
#[test]
fn full_corpus_partition() {
const EXPECTED_OK: usize = 757;
const EXPECTED_ERR: usize = 25;
let corpus = include_str!("../../../../assets/tests/sal.txt");
let mut ok_count = 0usize;
let mut err_lines = Vec::new();
for (lineno, line) in corpus.lines().enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
match parse(line) {
Ok(_) => ok_count += 1,
Err(err) => err_lines.push(format!("line {}: `{line}`: {err}", lineno + 1)),
}
}
let err_count = err_lines.len();
assert_eq!(
ok_count,
EXPECTED_OK,
"expected {EXPECTED_OK} OK lines, got {ok_count} (ERR: {err_count})\n\
failures:\n{}",
err_lines.join("\n")
);
assert_eq!(
err_count, EXPECTED_ERR,
"expected {EXPECTED_ERR} ERR lines, got {err_count} (OK: {ok_count})"
);
}
}