bitcalc 0.3.0

A calculator with bit operations
mod ast;
mod eval;
mod lexer;
mod parser;

use std::fmt::{Binary, Display, LowerHex};

use clap::{Parser, ValueEnum};
use num_traits::{CheckedRem, CheckedShl, CheckedShr, PrimInt, ToBytes, Zero};
use rustyline::DefaultEditor;
use rustyline::error::ReadlineError;

use crate::{ast::Expr, parser::parse};

mod colors {
    use std::fmt::Display;

    pub const RED: Seq = Seq(31);
    pub const GREEN: Seq = Seq(32);
    pub const BLUE: Seq = Seq(34);
    pub const RESET: Seq = Seq(0);
    pub const GRAY: Seq = Seq(2);
    pub const BOLD: Seq = Seq(1);
    pub const UNDERLINE: Seq = Seq(4);

    pub struct Seq(u8);

    impl Display for Seq {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "\x1B[{}m", self.0)
        }
    }
}

use colors::*;

#[derive(Parser)]
struct Args {
    /// Set the number of bits to operate on
    #[arg(long, value_enum, default_value_t = Mode::U32)]
    bits: Mode,

    /// Set the size of grouping of the numbers
    #[arg(long, value_enum, default_value_t = Grouping::U8)]
    group: Grouping,
}

#[derive(Clone, Copy, Default, ValueEnum)]
enum Mode {
    /// 8 bits
    #[value(name = "8")]
    U8,
    /// 16 bits
    #[value(name = "16")]
    U16,
    /// 32 bits
    #[value(name = "32")]
    U32,
    /// 64 bits
    #[default]
    #[value(name = "64")]
    U64,
}

#[derive(Clone, Copy, Default, ValueEnum)]
enum Grouping {
    /// No grouping
    #[value(name = "no")]
    No,
    /// 4 bits and 1 hex digits grouping
    #[value(name = "4")]
    U4,
    /// 8 bits and 2 hex digits grouping
    #[default]
    #[value(name = "8")]
    U8,
}

fn main() {
    let args = Args::parse();
    let mut line_editor = DefaultEditor::new().unwrap();

    let mut mode = args.bits;
    let mut group = args.group;

    println!("Welcome to bitcalc!");
    println!("Type 'help' for available operations.");

    let mut answers = Vec::new();

    loop {
        println!();
        let prompt = format!("{GRAY}[{:>3}]{RESET} {BOLD}{GREEN}>{RESET} ", answers.len());
        let line = line_editor.readline(&prompt);
        match line {
            Ok(buffer) => {
                if buffer.trim().is_empty() {
                    continue;
                }
                if buffer == "q" || buffer == "quit" {
                    break;
                }
                if buffer == "h" || buffer == "help" {
                    println!("{UNDERLINE}{BOLD}Special commands{RESET}");
                    println!("  {BLUE}q{RESET}, {BLUE}quit{RESET}  quit the application");
                    println!("  {BLUE}h{RESET}, {BLUE}help{RESET}  display this help message");
                    println!(
                        "  {BLUE}bits=N{RESET}   set the number of bits to work on (8, 16, 32, 64)"
                    );
                    println!();
                    println!("{UNDERLINE}{BOLD}History values{RESET}");
                    println!("  {BLUE}_{RESET}        result of the last successful calculation");
                    println!(
                        "  {BLUE}_N{RESET}       get a value from the history with N being the index"
                    );
                    println!();
                    println!("{UNDERLINE}{BOLD}Arithmetic operators{RESET}");
                    println!("  {BLUE}+{RESET}        add");
                    println!("  {BLUE}-{RESET}        subtract");
                    println!("  {BLUE}*{RESET}        multiply");
                    println!("  {BLUE}/{RESET}        divide");
                    println!("  {BLUE}%{RESET}        remainder");
                    println!("  {BLUE}**{RESET}       exponentiation");
                    println!();
                    println!("{UNDERLINE}{BOLD}Bit manipulation operators{RESET}");
                    println!("  {BLUE}^{RESET}        bitwise exclusive or");
                    println!("  {BLUE}|{RESET}        bitwise or");
                    println!("  {BLUE}&{RESET}        bitwise and");
                    println!("  {BLUE}<<{RESET}       shift left");
                    println!("  {BLUE}>>{RESET}       shift right");
                    println!("  {BLUE}rotl{RESET}     rotate left");
                    println!("  {BLUE}rotr{RESET}     rotate right");
                    println!();
                    println!("{UNDERLINE}{BOLD}Order of operations{RESET}");
                    println!("  {BLUE}**{RESET}");
                    println!("  {BLUE}*{RESET}, {BLUE}/{RESET}, {BLUE}%{RESET}");
                    println!("  {BLUE}+{RESET}, {BLUE}-{RESET}");
                    println!(
                        "  {BLUE}<<{RESET}, {BLUE}>>{RESET}, {BLUE}rotl{RESET}, {BLUE}rotr{RESET}"
                    );
                    println!("  {BLUE}&{RESET}");
                    println!("  {BLUE}^{RESET}");
                    println!("  {BLUE}|{RESET}");
                    println!();
                    println!("Parentheses can be used to group operations.");
                    continue;
                }

                if let Some(rest) = buffer.strip_prefix("bits=") {
                    match rest.parse() {
                        Ok(b) => {
                            mode = match b {
                                8 => Mode::U8,
                                16 => Mode::U16,
                                32 => Mode::U32,
                                64 => Mode::U64,
                                _ => {
                                    println!("Error: number of bits must be 8, 16, 32 or 64.");
                                    continue;
                                }
                            };
                        }
                        Err(e) => {
                            println!("Error: {e}");
                        }
                    }
                    continue;
                }

                if let Some(rest) = buffer.strip_prefix("group=") {
                    if rest == "none" {
                        group = Grouping::No;
                        continue;
                    }
                    match rest.parse() {
                        Ok(b) => {
                            group = match b {
                                4 => Grouping::U4,
                                8 => Grouping::U8,
                                _ => {
                                    println!("Error: group must be 4 or 8.");
                                    continue;
                                }
                            };
                        }
                        Err(e) => {
                            println!("Error: {e}");
                        }
                    }
                    continue;
                }

                let parsed = parse(&buffer);
                match parsed {
                    Ok(expr) => {
                        do_eval(&mode, &group, &mut answers, &expr);
                    }
                    Err(err) => {
                        print_error(&buffer, err);
                    }
                }
            }
            Err(ReadlineError::Interrupted | ReadlineError::Eof) => {
                break;
            }
            x => {
                println!("Event: {:?}", x);
            }
        }
    }
}

fn do_eval(mode: &Mode, group: &Grouping, answers: &mut Vec<u64>, expr: &Expr) {
    match mode {
        Mode::U8 => eval_and_print::<u8>(group, answers, expr),
        Mode::U16 => eval_and_print::<u16>(group, answers, expr),
        Mode::U32 => eval_and_print::<u32>(group, answers, expr),
        Mode::U64 => eval_and_print::<u64>(group, answers, expr),
    }
}

fn eval_and_print<
    I: Zero
        + PrimInt
        + TryFrom<u64>
        + TryInto<u64>
        + TryInto<u32>
        + TryInto<usize>
        + CheckedRem
        + CheckedShl
        + CheckedShr
        + ToBytes
        + Display
        + LowerHex
        + Binary,
>(
    group: &Grouping,
    answers: &mut Vec<u64>,
    expr: &Expr,
) {
    let res = eval::eval::<I>(answers, expr);
    match res {
        Ok(x) => {
            let num_bytes = x.to_be_bytes().as_ref().len();
            let bin_width = num_bytes * 8;
            let hex_width = num_bytes * 2;

            let (hex_group, bin_group) = match group {
                Grouping::No => (None, None),
                Grouping::U4 => (Some(1), Some(4)),
                Grouping::U8 => (Some(2), Some(8)),
            };

            let printed_hex = format!("{x:0>hex_width$x}");
            let printed_bin = format!("{x:0>bin_width$b}");

            let printed_hex = if let Some(g) = hex_group {
                printed_hex
                    .as_bytes()
                    .chunks(g)
                    .map(|b| {
                        let s = std::str::from_utf8(b).unwrap();
                        format!("{s:_>width$}", width = g * 4)
                    })
                    .collect::<Vec<_>>()
                    .join("_")
            } else {
                printed_hex
            };

            let printed_bin = if let Some(g) = bin_group {
                printed_bin
                    .as_bytes()
                    .chunks(g)
                    .map(|b| std::str::from_utf8(b).unwrap())
                    .collect::<Vec<_>>()
                    .join("_")
            } else {
                printed_bin
            };

            let (hex_dim, hex_bright) = match printed_hex.find(|c| c != '_' && c != '0') {
                Some(x) => printed_hex.split_at(x),
                None => (printed_hex.as_ref(), ""),
            };

            let (bin_dim, bin_bright) = match printed_bin.find(|c| c != '_' && c != '0') {
                Some(x) => printed_bin.split_at(x),
                None => (printed_bin.as_ref(), ""),
            };

            println!();
            println!("  {GRAY}{expr}{RESET} = {BLUE}{x}{RESET}");
            println!();
            println!("  Hex: {GRAY}0x{hex_dim}{RESET}{BLUE}{hex_bright}{RESET}");
            println!("  Bin: {GRAY}0b{bin_dim}{RESET}{BLUE}{bin_bright}{RESET}");

            answers.push(x.try_into().ok().unwrap());
        }
        Err(e) => {
            println!("{RED}Error:{RESET} {e}");
        }
    }
}

fn print_error(input: &str, err: parser::ParseError) {
    use ariadne::{ColorGenerator, Label, Report, ReportKind, Source};

    let mut colors = ColorGenerator::new();

    // Generate & choose some colours for each of our elements
    let a = colors.next();

    let mut builder =
        Report::build(ReportKind::Error, ("input", 0..0)).with_message(err.kind().to_string());

    if let Some(span) = err.span() {
        builder = builder.with_label(
            Label::new(("input", span.clone()))
                .with_message(err.kind().to_string())
                .with_color(a),
        );
    }

    builder
        .finish()
        .print(("input", Source::from(input)))
        .unwrap();
}