#include "sljitLir.h"
#include <stdio.h>
#include <stdlib.h>
#define BF_CELL_SIZE 3000
#define BF_LOOP_LEVEL 256
static int readvalid(FILE *src)
{
int chr;
while ((chr = fgetc(src)) != EOF) {
switch (chr) {
case '+':
case '-':
case '>':
case '<':
case '.':
case ',':
case '[':
case ']':
return chr;
}
}
return chr;
}
static int gettoken(FILE *src, int *ntok)
{
int chr = readvalid(src);
int chr2;
int cnt = 1;
if (chr == EOF)
return EOF;
if (chr == '.' || chr == ',' || chr == '[' || chr == ']') {
*ntok = 1;
return chr;
}
while ((chr2 = readvalid(src)) == chr)
cnt++;
if (chr2 != EOF)
ungetc(chr2, src);
*ntok = cnt;
return chr;
}
struct loop_node_st {
struct sljit_label *loop_start;
struct sljit_jump *loop_end;
};
static struct loop_node_st loop_stack[BF_LOOP_LEVEL];
static int loop_sp;
static int loop_push(struct sljit_label *loop_start, struct sljit_jump *loop_end)
{
if (loop_sp >= BF_LOOP_LEVEL)
return -1;
loop_stack[loop_sp].loop_start = loop_start;
loop_stack[loop_sp].loop_end = loop_end;
loop_sp++;
return 0;
}
static int loop_pop(struct sljit_label **loop_start, struct sljit_jump **loop_end)
{
if (loop_sp <= 0)
return -1;
loop_sp--;
*loop_start = loop_stack[loop_sp].loop_start;
*loop_end = loop_stack[loop_sp].loop_end;
return 0;
}
static void *SLJIT_FUNC my_alloc(size_t size, size_t n)
{
return calloc(size, n);
}
static void SLJIT_FUNC my_putchar(sljit_sw c)
{
putchar((int)c);
}
static sljit_sw SLJIT_FUNC my_getchar(void)
{
return getchar();
}
static void SLJIT_FUNC my_free(void *mem)
{
free(mem);
}
#define loop_empty() (loop_sp == 0)
static void *compile(FILE *src, sljit_uw *lcode)
{
void *code = NULL;
int chr;
int nchr;
struct sljit_compiler *C = sljit_create_compiler(NULL);
struct sljit_jump *end;
struct sljit_label *loop_start;
struct sljit_jump *loop_end;
int SP = SLJIT_S0;
int CELLS = SLJIT_S1;
sljit_emit_enter(C, 0, SLJIT_ARGS2V(W, W), 2, 2, 0);
sljit_emit_op2(C, SLJIT_XOR, SP, 0, SP, 0, SP, 0);
sljit_emit_op1(C, SLJIT_MOV, SLJIT_R0, 0, SLJIT_IMM, BF_CELL_SIZE);
sljit_emit_op1(C, SLJIT_MOV, SLJIT_R1, 0, SLJIT_IMM, 1);
sljit_emit_icall(C, SLJIT_CALL, SLJIT_ARGS2(P, W, W), SLJIT_IMM, SLJIT_FUNC_ADDR(my_alloc));
end = sljit_emit_cmp(C, SLJIT_EQUAL, SLJIT_R0, 0, SLJIT_IMM, 0);
sljit_emit_op1(C, SLJIT_MOV, CELLS, 0, SLJIT_R0, 0);
while ((chr = gettoken(src, &nchr)) != EOF) {
switch (chr) {
case '+':
case '-':
sljit_emit_op1(C, SLJIT_MOV_U8, SLJIT_R0, 0, SLJIT_MEM2(CELLS, SP), 0);
sljit_emit_op2(C, chr == '+' ? SLJIT_ADD : SLJIT_SUB,
SLJIT_R0, 0, SLJIT_R0, 0, SLJIT_IMM, nchr);
sljit_emit_op1(C, SLJIT_MOV_U8, SLJIT_MEM2(CELLS, SP), 0, SLJIT_R0, 0);
break;
case '>':
case '<':
sljit_emit_op2(C, chr == '>' ? SLJIT_ADD : SLJIT_SUB,
SP, 0, SP, 0, SLJIT_IMM, nchr);
break;
case '.':
sljit_emit_op1(C, SLJIT_MOV_U8, SLJIT_R0, 0, SLJIT_MEM2(CELLS, SP), 0);
sljit_emit_icall(C, SLJIT_CALL, SLJIT_ARGS1(W, W), SLJIT_IMM, SLJIT_FUNC_ADDR(my_putchar));
break;
case ',':
sljit_emit_icall(C, SLJIT_CALL, SLJIT_ARGS0(W), SLJIT_IMM, SLJIT_FUNC_ADDR(my_getchar));
sljit_emit_op1(C, SLJIT_MOV_U8, SLJIT_MEM2(CELLS, SP), 0, SLJIT_R0, 0);
break;
case '[':
loop_start = sljit_emit_label(C);
sljit_emit_op1(C, SLJIT_MOV_U8, SLJIT_R0, 0, SLJIT_MEM2(CELLS, SP), 0);
loop_end = sljit_emit_cmp(C, SLJIT_EQUAL, SLJIT_R0, 0, SLJIT_IMM, 0);
if (loop_push(loop_start, loop_end)) {
fprintf(stderr, "Too many loop level\n");
goto compile_failed;
}
break;
case ']':
if (loop_pop(&loop_start, &loop_end)) {
fprintf(stderr, "Unmatch loop ]\n");
goto compile_failed;
}
sljit_set_label(sljit_emit_jump(C, SLJIT_JUMP), loop_start);
sljit_set_label(loop_end, sljit_emit_label(C));
break;
}
}
if (!loop_empty()) {
fprintf(stderr, "Unmatch loop [\n");
goto compile_failed;
}
sljit_emit_op1(C, SLJIT_MOV, SLJIT_R0, 0, CELLS, 0);
sljit_emit_icall(C, SLJIT_CALL, SLJIT_ARGS1(P, P), SLJIT_IMM, SLJIT_FUNC_ADDR(my_free));
sljit_set_label(end, sljit_emit_label(C));
sljit_emit_return_void(C);
code = sljit_generate_code(C, 0, NULL);
if (lcode)
*lcode = sljit_get_generated_code_size(C);
compile_failed:
sljit_free_compiler(C);
return code;
}
typedef void (*bf_entry_t)(void);
int main(int argc, char **argv)
{
void *code;
bf_entry_t entry;
FILE *fp;
if (argc < 2) {
fprintf(stderr, "Usage: %s <brainfuck program>\n", argv[0]);
return -1;
}
fp = fopen(argv[1], "rb");
if (!fp) {
perror("open");
return -1;
}
code = compile(fp, NULL);
fclose(fp);
if (!code) {
fprintf(stderr, "[Fatal]: Compile failed\n");
return -1;
}
entry = (bf_entry_t)code;
entry();
sljit_free_code(code, NULL);
return 0;
}