namespace asm_code {
void divide_table(reg_alloc regs, reg_scalar a, reg_scalar b, reg_scalar q, reg_scalar r) {
EXPAND_MACROS_SCOPE;
regs.get_scalar(reg_rax);
m.bind(a, "a");
m.bind(b, "b");
m.bind(q, "q");
assert(r.value==reg_rdx.value);
static bool outputted_table=false;
if (!outputted_table) {
#ifdef CHIAOSX
APPEND_M(str( ".text " ));
#else
APPEND_M(str( ".text 1" ));
#endif
APPEND_M(str( ".balign 64" ));
APPEND_M(str( "divide_table:" ));
const int expected_size=1<<divide_table_index_bits;
const int max_index=bit_sequence(0, divide_table_index_bits);
assert(max_index>=1);
int num=0;
auto add=[&](uint64 v) {
APPEND_M(str( ".quad #", to_hex(v) ));
++num;
};
add(0);
for (int index=1;index<=max_index;++index) {
uint128 v = (~uint128(0)) / uint128(index);
v>>=64;
add(v);
}
assert(num==expected_size);
APPEND_M(str( ".text" ));
outputted_table=true;
}
string b_shift_label=m.alloc_label();
APPEND_M(str( "BSR `q, `b" )); APPEND_M(str( "SUB `q, #", to_hex(divide_table_index_bits-1) )); APPEND_M(str( "JNB #", b_shift_label ));
APPEND_M(str( "XOR `q, `q" )); APPEND_M(str( "#:", b_shift_label ));
APPEND_M(str( "SARX RAX, `b, `q" )); #ifdef CHIAOSX
APPEND_M(str( "LEA RCX, [RIP+divide_table]" )); APPEND_M(str( "MOV RAX, [RCX+RAX*8]"));
#else
APPEND_M(str( "MOV RAX, [divide_table+RAX*8]" )); #endif
APPEND_M(str( "IMUL `a" )); APPEND_M(str( "SARX `q, RDX, `q" ));
string wrong_remainder_label=m.alloc_label();
APPEND_M(str( "MOV RAX, `q" ));
APPEND_M(str( "IMUL RAX, `b" )); APPEND_M(str( "JO #", wrong_remainder_label )); APPEND_M(str( "MOV RDX, `a" ));
APPEND_M(str( "SUB RDX, RAX" )); APPEND_M(str( "JO #", wrong_remainder_label ));
APPEND_M(str( "CMP RDX, `b" ));
APPEND_M(str( "JAE #", wrong_remainder_label ));
string end_label=m.alloc_label();
APPEND_M(str( "JMP #", end_label ));
const bool asm_output_common_case_only=false;
if (!asm_output_common_case_only) {
APPEND_M(str( "#:", wrong_remainder_label ));
APPEND_M(str( "MOV RDX, `a" ));
APPEND_M(str( "SAR RDX, #", to_hex(63) ));
APPEND_M(str( "MOV RAX, `a" ));
APPEND_M(str( "IDIV `b" )); APPEND_M(str( "MOV `q, RAX" ));
APPEND_M(str( "CMP RDX, 0" ));
APPEND_M(str( "JGE #", end_label )); APPEND_M(str( "ADD RDX, `b" )); APPEND_M(str( "DEC `q" ));
}
APPEND_M(str( "#:", end_label ));
}
const array<uint64, 2> gcd_mask_approximate={1ull<<63, 1ull<<63};
const array<uint64, 2> gcd_mask_exact={0, 0};
void gcd_64_iteration(
reg_alloc regs, reg_vector c_gcd_mask, array<reg_scalar, 2> a, array<reg_vector, 2> uv, reg_scalar ab_threshold,
string early_exit_label
) {
EXPAND_MACROS_SCOPE;
m.bind(c_gcd_mask, "c_gcd_mask");
m.bind(a, "a");
m.bind(uv, "uv");
m.bind(ab_threshold, "ab_threshold");
reg_scalar q=regs.bind_scalar(m, "q");
reg_scalar r=regs.bind_scalar(m, "r", reg_rdx);
reg_scalar tmp_a=regs.bind_scalar(m, "tmp_a");
reg_vector new_uv_1=regs.bind_vector(m, "new_uv_1");
reg_vector tmp_1=regs.bind_vector(m, "tmp_1");
reg_vector tmp_2=regs.bind_vector(m, "tmp_2");
APPEND_M(str( "CMP `a_1, `ab_threshold" ));
APPEND_M(str( "JBE #", early_exit_label ));
divide_table(regs, a[0], a[1], q, r);
APPEND_M(str( "MOV `tmp_a, `q" ));
APPEND_M(str( "SHL `tmp_a, #", to_hex(63-gcd_num_quotient_bits) ));
APPEND_M(str( "SAR `tmp_a, #", to_hex(63-gcd_num_quotient_bits) ));
APPEND_M(str( "CMP `tmp_a, `q" ));
APPEND_M(str( "JNE #", early_exit_label ));
APPEND_M(str( "MOV `a_0, `a_1" ));
APPEND_M(str( "MOV `a_1, `r" ));
APPEND_M(str( "VMOVQ `new_uv_1_128, `q" ));
APPEND_M(str( "VPBROADCASTQ `new_uv_1, `new_uv_1_128" ));
APPEND_M(str( "VPMULDQ `new_uv_1, `new_uv_1, `uv_1" )); APPEND_M(str( "VPSUBQ `new_uv_1, `uv_0, `new_uv_1" ));
APPEND_M(str( "VPADDQ `tmp_1, `new_uv_1, #", constant_address_uint64(1ull<<data_size, 1ull<<data_size) ));
APPEND_M(str( "VPTEST `tmp_1, #", constant_address_uint64(carry_mask & (~(1ull<<data_size)), carry_mask & (~(1ull<<data_size))) ));
APPEND_M(str( "JNZ #", early_exit_label ));
{
APPEND_M(str( "VMOVQ `tmp_1_128, `a_0" ));
APPEND_M(str( "VPBROADCASTQ `tmp_1, `tmp_1_128" )); APPEND_M(str( "VPADDQ `tmp_1, `tmp_1, `uv_1" ));
APPEND_M(str( "VMOVQ `tmp_2_128, `a_1" ));
APPEND_M(str( "VPBROADCASTQ `tmp_2, `tmp_2_128" )); APPEND_M(str( "VPADDQ `tmp_2, `tmp_2, `new_uv_1" ));
APPEND_M(str( "VPSUBQ `tmp_1, `tmp_1, `tmp_2" ));
APPEND_M(str( "VPOR `tmp_1, `tmp_1, `tmp_2" ));
APPEND_M(str( "VPTEST `tmp_1, `c_gcd_mask" ));
APPEND_M(str( "JNZ #", early_exit_label ));
}
APPEND_M(str( "VMOVDQU `uv_0, `uv_1" )); APPEND_M(str( "VMOVDQU `uv_1, `new_uv_1" ));
}
}