sigmd 0.1.0

Windows API signature metadata
Documentation
//! winnow grammar for SAL expression arguments.
//!
//! Grammar:
//!
//! ```text
//! expression  := add_sub
//! add_sub     := mul_div  ( ('+' | '-') mul_div )*
//! mul_div     := entity   ( ('*' | '/') entity   )*
//! entity      := ws? ( "return" | identifier | parentheses
//!                    | unary_operator | constant ) ws?
//! parentheses := '(' expression ')'
//! unary_operator := '*'    expression
//!                 | "sizeof" parentheses
//! identifier  := alpha_or_underscore (alphanum_or_underscore)*    ;; not "sizeof"/"return"
//! constant    := "0x" HEX | DEC_UINT | DEC_INT_FITTING_U64
//! ```

use winnow::{
    ModalResult, Parser as _,
    ascii::{dec_uint, hex_uint, multispace0},
    combinator::{alt, cut_err, delimited, opt, preceded, repeat, separated_pair},
    error::{
        AddContext, ContextError, FromExternalError, ParseError, ParserError, StrContext,
        StrContextValue,
    },
    stream::AsChar as _,
    token::{one_of, take_while},
};

/// Unary operator recognized by the grammar.
#[derive(Debug, Clone)]
pub enum UnaryOperator {
    /// Pointer dereference (`*expr`).
    Dereference,

    /// `sizeof` operator (`sizeof(IDENT)`).
    SizeOf,
}

/// Binary operator recognized by the grammar.
#[derive(Debug, Clone)]
pub enum BinaryOperator {
    /// Addition.
    Add,

    /// Subtraction.
    Subtract,

    /// Multiplication.
    Multiply,

    /// Division.
    Divide,
}

/// Unary expression in the parser AST.
#[derive(Debug, Clone)]
pub struct UnaryExpression<'a> {
    /// Operator applied to the operand.
    pub operator: UnaryOperator,

    /// Operand the operator applies to.
    pub expression: Box<Expression<'a>>,
}

/// Binary expression in the parser AST.
#[derive(Debug, Clone)]
pub struct BinaryExpression<'a> {
    /// Operator applied to the operands.
    pub operator: BinaryOperator,

    /// Left-hand operand.
    pub lhs: Box<Expression<'a>>,

    /// Right-hand operand.
    pub rhs: Box<Expression<'a>>,
}

/// A parsed SAL expression.
#[derive(Debug, Clone)]
pub enum Expression<'a> {
    /// The `return` keyword, referring to the function return value.
    Return,

    /// Integer literal.
    Constant(u64),

    /// Parameter or type name.
    Identifier(&'a str),

    /// Unary operator applied to an operand.
    UnaryExpression(UnaryExpression<'a>),

    /// Binary operator applied to two operands.
    BinaryExpression(BinaryExpression<'a>),
}

/// Parser input stream.
type Stream<'i> = &'i str;

/// Bundles the winnow error traits used throughout this module.
trait ErrorType<'i>:
    ParserError<Stream<'i>>
    + AddContext<Stream<'i>, StrContext>
    + FromExternalError<Stream<'i>, std::num::ParseIntError>
{
}

impl<'i, T> ErrorType<'i> for T where
    T: ParserError<Stream<'i>>
        + AddContext<Stream<'i>, StrContext>
        + FromExternalError<Stream<'i>, std::num::ParseIntError>
{
}

/// Failure to parse a SAL expression argument.
#[derive(Debug)]
pub struct Error {
    /// Human-readable message from the underlying winnow context.
    message: String,

    /// Byte range in `input` that the diagnostic points at.
    span: std::ops::Range<usize>,

    /// Copy of the input the parser was run against.
    input: String,
}

impl Error {
    /// Converts a winnow [`ParseError`] into a renderable [`Error`].
    fn from_parse(error: ParseError<&str, ContextError>) -> Self {
        Self {
            message: error.inner().to_string(),
            input: error.input().to_string(),
            span: error.char_span(),
        }
    }
}

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        let message = annotate_snippets::Level::ERROR
            .primary_title(&self.message)
            .element(
                annotate_snippets::Snippet::source(&self.input)
                    .annotation(annotate_snippets::AnnotationKind::Primary.span(self.span.clone())),
            );
        let renderer = annotate_snippets::Renderer::plain();
        renderer.render(&[message]).fmt(f)
    }
}

impl std::error::Error for Error {}

/// Parses a hexadecimal (`0x...`) or decimal (`N` or `+N`) integer literal.
fn constant<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    alt((
        preceded("0x", cut_err(hex_uint))
            .context(StrContext::Label("digit"))
            .context(StrContext::Expected(StrContextValue::Description(
                "hexadecimal",
            ))),
        preceded(opt(one_of(['+'])), dec_uint)
            .context(StrContext::Label("digit"))
            .context(StrContext::Expected(StrContextValue::Description(
                "decimal",
            ))),
    ))
    .map(Expression::Constant)
    .parse_next(input)
}

/// Parses the literal `return` keyword.
fn return_<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    "return".value(Expression::Return).parse_next(input)
}

/// Returns `true` if `input` is not a reserved keyword.
fn check_keyword(input: &str) -> bool {
    !matches!(input, "sizeof" | "return")
}

/// Parses an identifier starting with a letter or underscore.
fn identifier<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<&'i str, E> {
    (
        one_of(|c: char| c.is_alpha() || c == '_'),
        take_while(0.., |c: char| c.is_alphanum() || c == '_'),
    )
        .take()
        .parse_next(input)
}

/// Parses an identifier expression, rejecting keywords.
fn ident<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    identifier
        .verify(check_keyword)
        .map(Expression::Identifier)
        .parse_next(input)
}

/// Parses a unary expression: `*expr` or `sizeof(expr)`.
fn unary_operator<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    alt((
        separated_pair(
            "*",
            multispace0,
            cut_err(expression).context(StrContext::Expected(StrContextValue::Description(
                "expression",
            ))),
        ),
        separated_pair(
            "sizeof",
            multispace0,
            cut_err(parentheses).context(StrContext::Expected(StrContextValue::Description(
                "parenthesized expression",
            ))),
        ),
    ))
    .map(|(op, expr)| {
        Expression::UnaryExpression(UnaryExpression {
            operator: match op {
                "*" => UnaryOperator::Dereference,
                "sizeof" => UnaryOperator::SizeOf,
                _ => unreachable!("unknown unary operator"),
            },
            expression: Box::new(expr),
        })
    })
    .parse_next(input)
}

/// Folds a left-hand expression and an `(operator, rhs)` pair into a
/// [`BinaryExpression`].
fn binary_op<'i>(lhs: Expression<'i>, (op, rhs): (char, Expression<'i>)) -> Expression<'i> {
    Expression::BinaryExpression(BinaryExpression {
        operator: match op {
            '+' => BinaryOperator::Add,
            '-' => BinaryOperator::Subtract,
            '*' => BinaryOperator::Multiply,
            '/' => BinaryOperator::Divide,
            _ => unreachable!("unknown operator"),
        },
        lhs: Box::new(lhs),
        rhs: Box::new(rhs),
    })
}

/// Parses an expression enclosed in parentheses.
fn parentheses<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    delimited(
        '(',
        expression,
        cut_err(')').context(StrContext::Expected(StrContextValue::CharLiteral(')'))),
    )
    .parse_next(input)
}

/// Parses a primary expression (the `entity` grammar production) with
/// surrounding whitespace.
fn entity<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    delimited(
        multispace0,
        alt((return_, ident, parentheses, unary_operator, constant)),
        multispace0,
    )
    .context(StrContext::Label("expression"))
    .parse_next(input)
}

/// Parses a left-associative `*` / `/` chain over [`entity`].
fn mul_div<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    let init = entity.parse_next(input)?;
    repeat(0.., (one_of(['*', '/']), entity))
        .fold(move || init.clone(), binary_op)
        .parse_next(input)
}

/// Parses a left-associative `+` / `-` chain over [`mul_div`].
fn add_sub<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    let init = mul_div.parse_next(input)?;
    repeat(0.., (one_of(['+', '-']), mul_div))
        .fold(move || init.clone(), binary_op)
        .parse_next(input)
}

/// Entry into the recursive expression parser.
fn expression<'i, E: ErrorType<'i>>(input: &mut Stream<'i>) -> ModalResult<Expression<'i>, E> {
    add_sub.parse_next(input)
}

/// Parses a SAL expression string into an [`Expression`].
pub fn parse(input: &str) -> Result<Expression<'_>, Error> {
    expression::<ContextError>
        .parse(input)
        .map_err(Error::from_parse)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn small_examples_parse() {
        for input in [
            "return",
            "nSize",
            "*lpNumberOfBytesRead",
            "_ElementSize * _ElementCount",
            "sizeof(WORD) * 3",
            "(sizeof(SID) - sizeof(DWORD) + (SubAuthorityCount) * sizeof(DWORD))",
        ] {
            if let Err(err) = parse(input) {
                panic!("parse `{input}` failed: {err}");
            }
        }
    }

    #[test]
    fn full_corpus_partition() {
        // Read every non-empty line from assets/sal.txt.
        // OK: 757, ERR: 25 against the corpus. Those 25 failures are lines that
        // use unsupported syntax: `->` field access, ternary operators (`?:`),
        // `_Inexpressible_(...)`, struct field paths, and `%` in identifiers.
        // We assert that our OK count matches 757 so that any regression or
        // improvement is caught without requiring all lines to pass.
        const EXPECTED_OK: usize = 757;
        const EXPECTED_ERR: usize = 25;

        let corpus = include_str!("../../../../assets/tests/sal.txt");
        let mut ok_count = 0usize;
        let mut err_lines = Vec::new();

        for (lineno, line) in corpus.lines().enumerate() {
            let line = line.trim();

            if line.is_empty() {
                continue;
            }

            match parse(line) {
                Ok(_) => ok_count += 1,
                Err(err) => err_lines.push(format!("line {}: `{line}`: {err}", lineno + 1)),
            }
        }

        let err_count = err_lines.len();
        assert_eq!(
            ok_count,
            EXPECTED_OK,
            "expected {EXPECTED_OK} OK lines, got {ok_count} (ERR: {err_count})\n\
             failures:\n{}",
            err_lines.join("\n")
        );
        assert_eq!(
            err_count, EXPECTED_ERR,
            "expected {EXPECTED_ERR} ERR lines, got {err_count} (OK: {ok_count})"
        );
    }
}