bitcalc 0.3.0

A calculator with bit operations
use std::{fmt::Display, num::ParseIntError, ops::Range};

use crate::{
    ast::{Expr, Op, Val},
    lexer::{Lexer, Token},
};

pub fn parse(s: &str) -> Result<Expr, ParseError> {
    let mut lex = Lexer::new(s);
    let expression = expr(&mut lex)?;
    let res = next(&mut lex);
    match res {
        Ok((tok, span)) => Err(ParseError {
            kind: ParseErrorKind::Unexpected(tok.to_string()),
            span: Some(span),
        }),
        Err(err) if err.kind() == &ParseErrorKind::EndOfInput => Ok(expression),
        Err(err) => Err(err),
    }
}

fn expr(lex: &mut Lexer) -> Result<Expr, ParseError> {
    bit_or(lex)
}

fn bit_or(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = bit_xor(lex)?;
    if next_is(lex, Token::Pipe) {
        let right = bit_or(lex)?;
        Ok(Expr::Op(Op::BitOr(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn bit_xor(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = bit_and(lex)?;
    if next_is(lex, Token::Hat) {
        let right = bit_xor(lex)?;
        Ok(Expr::Op(Op::BitXor(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn bit_and(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = shift(lex)?;
    if next_is(lex, Token::Amp) {
        let right = bit_and(lex)?;
        Ok(Expr::Op(Op::BitAnd(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn shift(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = sum(lex)?;
    if next_is(lex, Token::ShiftLeft) {
        let right = shift(lex)?;
        Ok(Expr::Op(Op::ShiftLeft(Box::new(left), Box::new(right))))
    } else if next_is(lex, Token::ShiftRight) {
        let right = shift(lex)?;
        Ok(Expr::Op(Op::ShiftRight(Box::new(left), Box::new(right))))
    } else if next_is(lex, Token::RotLeft) {
        let right = shift(lex)?;
        Ok(Expr::Op(Op::RotLeft(Box::new(left), Box::new(right))))
    } else if next_is(lex, Token::RotRight) {
        let right = shift(lex)?;
        Ok(Expr::Op(Op::RotRight(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn sum(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = term(lex)?;
    if next_is(lex, Token::Plus) {
        let right = sum(lex)?;
        Ok(Expr::Op(Op::Add(Box::new(left), Box::new(right))))
    } else if next_is(lex, Token::Hyphen) {
        let right = sum(lex)?;
        Ok(Expr::Op(Op::Sub(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn term(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = power(lex)?;
    if next_is(lex, Token::Star) {
        let right = term(lex)?;
        Ok(Expr::Op(Op::Mul(Box::new(left), Box::new(right))))
    } else if next_is(lex, Token::Slash) {
        let right = term(lex)?;
        Ok(Expr::Op(Op::Div(Box::new(left), Box::new(right))))
    } else if next_is(lex, Token::Percent) {
        let right = term(lex)?;
        Ok(Expr::Op(Op::Mod(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn power(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let left = negation(lex)?;
    if next_is(lex, Token::StarStar) {
        let right = power(lex)?;
        Ok(Expr::Op(Op::Pow(Box::new(left), Box::new(right))))
    } else {
        Ok(left)
    }
}

fn negation(lex: &mut Lexer) -> Result<Expr, ParseError> {
    atom(lex)
}

fn atom(lex: &mut Lexer) -> Result<Expr, ParseError> {
    if next_is(lex, Token::RoundLeft) {
        let e = expr(lex)?;
        take(lex, Token::RoundRight)?;
        Ok(e)
    } else {
        value(lex)
    }
}

fn value(lex: &mut Lexer) -> Result<Expr, ParseError> {
    let (tok, span) = next(lex)?;
    let val = match tok {
        Token::Integer(s) => Val::Int(parse_int(s, 10, span)?),
        Token::Hex(s) => Val::Hex(parse_int(&s[2..], 16, span)?),
        Token::Binary(s) => Val::Bin(parse_int(&s[2..], 2, span)?),
        Token::Underscore => Val::LastAnswer,
        Token::History(s) => {
            Val::History(s[1..].parse().map_err(|e: ParseIntError| ParseError {
                kind: ParseErrorKind::InvalidNumber(e.to_string()),
                span: Some(span),
            })?)
        }
        _ => {
            return Err(ParseError {
                kind: ParseErrorKind::Expected("value".into()),
                span: Some(span),
            });
        }
    };
    Ok(Expr::Val(val))
}

fn parse_int(s: &str, base: u32, span: Range<usize>) -> Result<u64, ParseError> {
    let s = s.replace('_', "");
    u64::from_str_radix(&s, base).map_err(|e| ParseError {
        kind: ParseErrorKind::InvalidNumber(e.to_string()),
        span: Some(span),
    })
}

/// Move the lexer forward and return the token
fn next<'source>(lex: &mut Lexer<'source>) -> Result<(Token<'source>, Range<usize>), ParseError> {
    match lex.next() {
        None => Err(ParseError {
            kind: ParseErrorKind::EndOfInput,
            span: None,
        }),
        Some((Err(()), span)) => Err(ParseError {
            kind: ParseErrorKind::InvalidToken,
            span: Some(span),
        }),
        Some((Ok(token), span)) => Ok((token, span)),
    }
}

fn next_is(lex: &mut Lexer, tok: Token<'_>) -> bool {
    if peek_is(lex, tok) {
        next(lex).unwrap();
        true
    } else {
        false
    }
}

/// Peek the next token
fn peek<'a, 'source>(lex: &'a mut Lexer<'source>) -> Option<&'a Token<'source>> {
    match lex.peek() {
        Some((Ok(token), _span)) => Some(token),
        _ => None,
    }
}

/// Peek the next token and return whether it matches the given token
fn peek_is(lex: &mut Lexer, token: Token) -> bool {
    let Some(lexed_token) = peek(lex) else {
        return false;
    };

    &token == lexed_token
}

/// Move the lexer forward and assert that it matches the token
fn take(lex: &mut Lexer, token: Token) -> Result<Range<usize>, ParseError> {
    let (next, span) = next(lex)?;
    if next == token {
        Ok(span)
    } else {
        Err(ParseError {
            kind: ParseErrorKind::Expected(token.to_string()),
            span: Some(span),
        })
    }
}

#[derive(Debug)]
pub struct ParseError {
    kind: ParseErrorKind,
    span: Option<Range<usize>>,
}

#[derive(Debug, PartialEq, Eq)]
pub enum ParseErrorKind {
    Expected(String),
    InvalidNumber(String),
    Unexpected(String),
    InvalidToken,
    EndOfInput,
}

impl Display for ParseErrorKind {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ParseErrorKind::Expected(s) => write!(f, "expected {s}"),
            ParseErrorKind::InvalidNumber(s) => write!(f, "invalid number `{s}`"),
            ParseErrorKind::Unexpected(s) => write!(f, "unexpected {s}"),
            ParseErrorKind::InvalidToken => write!(f, "invalid token"),
            ParseErrorKind::EndOfInput => write!(f, "unexpected end of input"),
        }
    }
}

impl ParseError {
    pub fn kind(&self) -> &ParseErrorKind {
        &self.kind
    }

    pub fn span(&self) -> &Option<Range<usize>> {
        &self.span
    }
}