#include <iomanip>
#include <fstream>
#include <iomanip>
#include <stack>
#include <asmjit/asmjit.h>
#include <vector>
#include "parser.h"
#include "utils.h"
constexpr int MEMORY_SIZE = 30000;
namespace {
enum class BfOpKind {
INVALID_OP = 0,
INC_PTR,
DEC_PTR,
INC_DATA,
DEC_DATA,
READ_STDIN,
WRITE_STDOUT,
LOOP_SET_TO_ZERO,
LOOP_MOVE_PTR,
LOOP_MOVE_DATA,
JUMP_IF_DATA_ZERO,
JUMP_IF_DATA_NOT_ZERO
};
struct BfOp {
BfOp(BfOpKind kind_param, int64_t argument_param)
: kind(kind_param), argument(argument_param) {}
BfOpKind kind = BfOpKind::INVALID_OP;
int64_t argument = 0;
};
const char* BfOpKind_name(BfOpKind kind) {
switch (kind) {
case BfOpKind::INC_PTR:
return "INC_PTR";
case BfOpKind::DEC_PTR:
return "DEC_PTR";
case BfOpKind::INC_DATA:
return "INC_DATA";
case BfOpKind::DEC_DATA:
return "DEC_DATA";
case BfOpKind::READ_STDIN:
return "READ_STDIN";
case BfOpKind::WRITE_STDOUT:
return "WRITE_STDOUT";
case BfOpKind::LOOP_SET_TO_ZERO:
return "LOOP_SET_TO_ZERO";
case BfOpKind::LOOP_MOVE_PTR:
return "LOOP_MOVE_PTR";
case BfOpKind::LOOP_MOVE_DATA:
return "LOOP_MOVE_DATA";
case BfOpKind::JUMP_IF_DATA_ZERO:
return "JUMP_IF_DATA_ZERO";
case BfOpKind::JUMP_IF_DATA_NOT_ZERO:
return "JUMP_IF_DATA_NOT_ZERO";
case BfOpKind::INVALID_OP:
return "INVALID_OP";
}
return nullptr;
}
std::vector<BfOp> optimize_loop(const std::vector<BfOp>& ops,
size_t loop_start) {
std::vector<BfOp> new_ops;
if (ops.size() - loop_start == 2) {
BfOp repeated_op = ops[loop_start + 1];
if (repeated_op.kind == BfOpKind::INC_DATA ||
repeated_op.kind == BfOpKind::DEC_DATA) {
new_ops.push_back(BfOp(BfOpKind::LOOP_SET_TO_ZERO, 0));
} else if (repeated_op.kind == BfOpKind::INC_PTR ||
repeated_op.kind == BfOpKind::DEC_PTR) {
new_ops.push_back(
BfOp(BfOpKind::LOOP_MOVE_PTR, repeated_op.kind == BfOpKind::INC_PTR
? repeated_op.argument
: -repeated_op.argument));
}
} else if (ops.size() - loop_start == 5) {
if (ops[loop_start + 1].kind == BfOpKind::DEC_DATA &&
ops[loop_start + 3].kind == BfOpKind::INC_DATA &&
ops[loop_start + 1].argument == 1 &&
ops[loop_start + 3].argument == 1) {
std::string s;
if (ops[loop_start + 2].kind == BfOpKind::INC_PTR &&
ops[loop_start + 4].kind == BfOpKind::DEC_PTR &&
ops[loop_start + 2].argument == ops[loop_start + 4].argument) {
new_ops.push_back(
BfOp(BfOpKind::LOOP_MOVE_DATA, ops[loop_start + 2].argument));
} else if (ops[loop_start + 2].kind == BfOpKind::DEC_PTR &&
ops[loop_start + 4].kind == BfOpKind::INC_PTR &&
ops[loop_start + 2].argument == ops[loop_start + 4].argument) {
new_ops.push_back(
BfOp(BfOpKind::LOOP_MOVE_DATA, -ops[loop_start + 2].argument));
}
}
}
return new_ops;
}
std::vector<BfOp> translate_program(const Program& p) {
size_t pc = 0;
size_t program_size = p.instructions.size();
std::vector<BfOp> ops;
std::stack<size_t> open_bracket_stack;
while (pc < program_size) {
char instruction = p.instructions[pc];
if (instruction == '[') {
open_bracket_stack.push(ops.size());
ops.push_back(BfOp(BfOpKind::JUMP_IF_DATA_ZERO, 0));
pc++;
} else if (instruction == ']') {
if (open_bracket_stack.empty()) {
DIE << "unmatched closing ']' at pc=" << pc;
}
size_t open_bracket_offset = open_bracket_stack.top();
open_bracket_stack.pop();
std::vector<BfOp> optimized_loop =
optimize_loop(ops, open_bracket_offset);
if (optimized_loop.empty()) {
ops[open_bracket_offset].argument = ops.size();
ops.push_back(
BfOp(BfOpKind::JUMP_IF_DATA_NOT_ZERO, open_bracket_offset));
} else {
ops.erase(ops.begin() + open_bracket_offset, ops.end());
ops.insert(ops.end(), optimized_loop.begin(), optimized_loop.end());
}
pc++;
} else {
size_t start = pc++;
while (pc < program_size && p.instructions[pc] == instruction) {
pc++;
}
size_t num_repeats = pc - start;
BfOpKind kind = BfOpKind::INVALID_OP;
switch (instruction) {
case '>':
kind = BfOpKind::INC_PTR;
break;
case '<':
kind = BfOpKind::DEC_PTR;
break;
case '+':
kind = BfOpKind::INC_DATA;
break;
case '-':
kind = BfOpKind::DEC_DATA;
break;
case ',':
kind = BfOpKind::READ_STDIN;
break;
case '.':
kind = BfOpKind::WRITE_STDOUT;
break;
default: { DIE << "bad char '" << instruction << "' at pc=" << start; }
}
ops.push_back(BfOp(kind, num_repeats));
}
}
return ops;
}
void myputchar(uint8_t c) {
putchar(c);
}
uint8_t mygetchar() {
return getchar();
}
struct BracketLabels {
BracketLabels(const asmjit::Label& ol, const asmjit::Label& cl)
: open_label(ol), close_label(cl) {}
asmjit::Label open_label;
asmjit::Label close_label;
};
}
void optasmjit(const Program& p, bool verbose) {
std::vector<uint8_t> memory(MEMORY_SIZE, 0);
std::stack<BracketLabels> open_bracket_stack;
const std::vector<BfOp> ops = translate_program(p);
if (verbose) {
std::cout << "==== OPS ====\n";
for (size_t i = 0; i < ops.size(); ++i) {
std::cout << std::setw(4) << std::left << i << " ";
std::cout << BfOpKind_name(ops[i].kind) << " " << ops[i].argument << "\n";
}
std::cout << "=============\n";
}
asmjit::JitRuntime jit_runtime;
asmjit::CodeHolder code;
code.init(jit_runtime.getCodeInfo());
asmjit::X86Assembler assm(&code);
asmjit::X86Gp dataptr = asmjit::x86::r13;
assm.mov(dataptr, asmjit::x86::rdi);
for (size_t pc = 0; pc < ops.size(); ++pc) {
BfOp op = ops[pc];
switch (op.kind) {
case BfOpKind::INC_PTR:
assm.add(dataptr, op.argument);
break;
case BfOpKind::DEC_PTR:
assm.sub(dataptr, op.argument);
break;
case BfOpKind::INC_DATA:
assm.add(asmjit::x86::byte_ptr(dataptr), op.argument);
break;
case BfOpKind::DEC_DATA:
assm.sub(asmjit::x86::byte_ptr(dataptr), op.argument);
break;
case BfOpKind::WRITE_STDOUT:
for (int i = 0; i < op.argument; ++i) {
assm.movzx(asmjit::x86::rdi, asmjit::x86::byte_ptr(dataptr));
assm.call(asmjit::imm_ptr(myputchar));
}
break;
case BfOpKind::READ_STDIN:
for (int i = 0; i < op.argument; ++i) {
assm.call(asmjit::imm_ptr(mygetchar));
assm.mov(asmjit::x86::byte_ptr(dataptr), asmjit::x86::al);
}
break;
case BfOpKind::LOOP_SET_TO_ZERO:
assm.mov(asmjit::x86::byte_ptr(dataptr), 0);
break;
case BfOpKind::LOOP_MOVE_PTR: {
asmjit::Label loop = assm.newLabel();
asmjit::Label endloop = assm.newLabel();
assm.bind(loop);
assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
assm.jz(endloop);
if (op.argument < 0) {
assm.sub(dataptr, -op.argument);
} else {
assm.add(dataptr, op.argument);
}
assm.jmp(loop);
assm.bind(endloop);
break;
}
case BfOpKind::LOOP_MOVE_DATA: {
asmjit::Label skip_move = assm.newLabel();
assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
assm.jz(skip_move);
assm.mov(asmjit::x86::r14, dataptr);
if (op.argument < 0) {
assm.sub(asmjit::x86::r14, -op.argument);
} else {
assm.add(asmjit::x86::r14, op.argument);
}
assm.movzx(asmjit::x86::rax, asmjit::x86::byte_ptr(dataptr));
assm.add(asmjit::x86::byte_ptr(asmjit::x86::r14), asmjit::x86::al);
assm.mov(asmjit::x86::byte_ptr(dataptr), 0);
assm.bind(skip_move);
break;
}
case BfOpKind::JUMP_IF_DATA_ZERO: {
assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
asmjit::Label open_label = assm.newLabel();
asmjit::Label close_label = assm.newLabel();
assm.jz(close_label);
assm.bind(open_label);
open_bracket_stack.push(BracketLabels(open_label, close_label));
break;
}
case BfOpKind::JUMP_IF_DATA_NOT_ZERO: {
if (open_bracket_stack.empty()) {
DIE << "unmatched closing ']' at pc=" << pc;
}
BracketLabels labels = open_bracket_stack.top();
open_bracket_stack.pop();
assm.cmp(asmjit::x86::byte_ptr(dataptr), 0);
assm.jnz(labels.open_label);
assm.bind(labels.close_label);
break;
}
case BfOpKind::INVALID_OP:
DIE << "INVALID_OP encountered on pc=" << pc;
break;
}
}
assm.ret();
if (assm.isInErrorState()) {
DIE << "asmjit error: "
<< asmjit::DebugUtils::errorAsString(assm.getLastError());
}
code.sync();
asmjit::CodeBuffer& buf = code.getSectionEntry(0)->getBuffer();
std::vector<uint8_t> emitted_code(buf.getLength());
memcpy(emitted_code.data(), buf.getData(), buf.getLength());
using JittedFunc = void (*)(uint64_t);
JittedFunc func;
asmjit::Error err = jit_runtime.add(&func, &code);
if (err) {
DIE << "error calling jit_runtime.add";
}
Timer texec;
func((uint64_t)memory.data());
if (verbose) {
std::cout << "[-] Execution took: " << texec.elapsed() << "s)\n";
}
if (verbose) {
const char* filename = "/tmp/bjout.bin";
FILE* outfile = fopen(filename, "wb");
if (outfile) {
size_t n = emitted_code.size();
if (fwrite(emitted_code.data(), 1, n, outfile) == n) {
std::cout << "* emitted code to " << filename << "\n";
}
fclose(outfile);
}
std::cout << "* Memory nonzero locations:\n";
for (size_t i = 0, pcount = 0; i < memory.size(); ++i) {
if (memory[i]) {
std::cout << std::right << "[" << std::setw(3) << i
<< "] = " << std::setw(3) << std::left
<< static_cast<int32_t>(memory[i]) << " ";
pcount++;
if (pcount > 0 && pcount % 4 == 0) {
std::cout << "\n";
}
}
}
std::cout << "\n";
}
}
int main(int argc, const char** argv) {
bool verbose = false;
std::string bf_file_path;
parse_command_line(argc, argv, &bf_file_path, &verbose);
Timer t1;
std::ifstream file(bf_file_path);
if (!file) {
DIE << "unable to open file " << bf_file_path;
}
Program program = parse_from_stream(file);
if (verbose) {
std::cout << "Parsing took: " << t1.elapsed() << "s\n";
std::cout << "Length of program: " << program.instructions.size() << "\n";
std::cout << "Program:\n" << program.instructions << "\n";
}
if (verbose) {
std::cout << "[>] Running optasmjit:\n";
}
Timer t2;
optasmjit(program, verbose);
if (verbose) {
std::cout << "[<] Done (elapsed: " << t2.elapsed() << "s)\n";
}
return 0;
}