use crate::{
Assembler, AssemblerData, IMM_REG, OFFSET, REGISTER_LIMIT,
float_slice::FloatSliceAssembler, mmap::Mmap, reg,
};
use dynasmrt::{DynasmApi, DynasmError, DynasmLabelApi, dynasm};
pub const SIMD_WIDTH: usize = 4;
const STACK_SIZE: u32 = 0x230;
#[expect(clippy::useless_conversion)]
impl Assembler for FloatSliceAssembler {
type Data = f32;
fn init(mmap: Mmap, slot_count: usize) -> Self {
let mut out = AssemblerData::new(mmap);
out.prepare_stack(slot_count, STACK_SIZE as usize);
dynasm!(out.ops
; stp x29, x30, [sp, 0x0]
; mov x29, sp
; stp d8, d9, [sp, 0x10]
; stp d10, d11, [sp, 0x20]
; stp d12, d13, [sp, 0x30]
; stp d14, d15, [sp, 0x40]
; str x20, [sp, 0x200]
; str x21, [sp, 0x208]
; str x22, [sp, 0x210]
; str x23, [sp, 0x218]
; str x24, [sp, 0x220]
; mov x3, 0
; ->L:
; cmp x2, 0
; b.eq ->E
);
Self(out)
}
fn bytes_per_clause() -> usize {
10
}
fn build_load(&mut self, dst_reg: u8, src_mem: u32) {
assert!((dst_reg as usize) < REGISTER_LIMIT);
let sp_offset = self.0.stack_pos(src_mem) + STACK_SIZE;
assert!(sp_offset < 65536);
dynasm!(self.0.ops
; ldr Q(reg(dst_reg)), [sp, sp_offset]
)
}
fn build_store(&mut self, dst_mem: u32, src_reg: u8) {
assert!((src_reg as usize) < REGISTER_LIMIT);
let sp_offset = self.0.stack_pos(dst_mem) + STACK_SIZE;
assert!(sp_offset < 65536);
dynasm!(self.0.ops
; str Q(reg(src_reg)), [sp, sp_offset]
)
}
fn build_input(&mut self, out_reg: u8, src_arg: u32) {
assert!(src_arg < 16384 / 8);
dynasm!(self.0.ops
; ldr x4, [x0, src_arg * 8]
; add x4, x4, x3 ; ldr Q(reg(out_reg)), [x4]
);
}
fn build_output(&mut self, arg_reg: u8, out_index: u32) {
assert!(out_index < 16384 / 8);
dynasm!(self.0.ops
; ldr x4, [x1, out_index * 8]
; add x4, x4, x3 ; str Q(reg(arg_reg)), [x4] );
}
fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_sin(f: f32) -> f32 {
f.sin()
}
self.call_fn_unary(out_reg, lhs_reg, float_sin);
}
fn build_cos(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_cos(f: f32) -> f32 {
f.cos()
}
self.call_fn_unary(out_reg, lhs_reg, float_cos);
}
fn build_tan(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_tan(f: f32) -> f32 {
f.tan()
}
self.call_fn_unary(out_reg, lhs_reg, float_tan);
}
fn build_asin(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_asin(f: f32) -> f32 {
f.asin()
}
self.call_fn_unary(out_reg, lhs_reg, float_asin);
}
fn build_acos(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_acos(f: f32) -> f32 {
f.acos()
}
self.call_fn_unary(out_reg, lhs_reg, float_acos);
}
fn build_atan(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_atan(f: f32) -> f32 {
f.atan()
}
self.call_fn_unary(out_reg, lhs_reg, float_atan);
}
fn build_exp(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_exp(f: f32) -> f32 {
f.exp()
}
self.call_fn_unary(out_reg, lhs_reg, float_exp);
}
fn build_ln(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_ln(f: f32) -> f32 {
f.ln()
}
self.call_fn_unary(out_reg, lhs_reg, float_ln);
}
fn build_copy(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops ; mov V(reg(out_reg)).b16, V(reg(lhs_reg)).b16)
}
fn build_neg(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops ; fneg V(reg(out_reg)).s4, V(reg(lhs_reg)).s4)
}
fn build_abs(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops ; fabs V(reg(out_reg)).s4, V(reg(lhs_reg)).s4)
}
fn build_recip(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fmov s7, 1.0
; dup v7.s4, v7.s[0]
; fdiv V(reg(out_reg)).s4, v7.s4, V(reg(lhs_reg)).s4
)
}
fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops ; fsqrt V(reg(out_reg)).s4, V(reg(lhs_reg)).s4)
}
fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fmul V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
)
}
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
; mvn v6.b16, v6.b16
; fcvtms V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16
);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
; mvn v6.b16, v6.b16
; fcvtps V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16
);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq v6.s4, V(reg(lhs_reg)).s4, V(reg(lhs_reg)).s4
; mvn v6.b16, v6.b16
; fcvtas V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
; scvtf V(reg(out_reg)).s4, V(reg(out_reg)).s4
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v6.b16
);
}
fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fadd V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fsub V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fmul V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fdiv V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
extern "C" fn float_atan2(y: f32, x: f32) -> f32 {
y.atan2(x)
}
self.call_fn_binary(out_reg, lhs_reg, rhs_reg, float_atan2);
}
fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fmax V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fmin V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fabs v6.s4, V(reg(rhs_reg)).s4
; fdiv v7.s4, V(reg(lhs_reg)).s4, v6.s4
; frintm v7.s4, v7.s4 ; fmul v7.s4, v7.s4, v6.s4
; fsub V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, v7.s4
)
}
fn build_not(&mut self, out_reg: u8, arg_reg: u8) {
dynasm!(self.0.ops
; cmeq v6.s4, V(reg(arg_reg)).s4, 0
; fmov S(reg(out_reg)), 1.0
; dup V(reg(out_reg)).s4, V(reg(out_reg)).s[0]
; and V(reg(out_reg)).b16, V(reg(out_reg)).b16, v6.b16
);
}
fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; cmeq v6.s4, V(reg(lhs_reg)).s4, 0
; mvn v7.b16, v6.b16
; and v6.b16, v6.b16, V(reg(lhs_reg)).b16
; and v7.b16, v7.b16, V(reg(rhs_reg)).b16
; orr V(reg(out_reg)).b16, v6.b16, v7.b16
);
}
fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; cmeq v6.s4, V(reg(lhs_reg)).s4, 0
; mvn v7.b16, v6.b16
; and v7.b16, v7.b16, V(reg(lhs_reg)).b16
; and v6.b16, v6.b16, V(reg(rhs_reg)).b16
; orr V(reg(out_reg)).b16, v6.b16, v7.b16
);
}
fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq v6.S4, V(reg(lhs_reg)).S4, V(reg(lhs_reg)).S4
; fcmeq v7.S4, V(reg(rhs_reg)).S4, V(reg(rhs_reg)).S4
; and v6.b16, v6.b16, v7.b16
; mvn v6.b16, v6.b16
; fcmgt v4.S4, V(reg(rhs_reg)).S4, V(reg(lhs_reg)).S4
; fcmgt v5.S4, V(reg(lhs_reg)).S4, V(reg(rhs_reg)).S4
; fmov s7, -1.0
; dup v7.s4, v7.s[0]
; and V(reg(out_reg)).B16, v4.B16, v7.B16
; fmov s7, 1.0
; dup v7.s4, v7.s[0]
; and v5.B16, v5.B16, v7.B16
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v5.B16
; mov w9, f32::NAN.to_bits()
; dup v7.s4, w9
; and v7.b16, v7.b16, v6.b16
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v7.b16
)
}
fn load_imm(&mut self, imm: f32) -> u8 {
let imm_u32 = imm.to_bits();
if imm_u32 & 0xFFFF == 0 {
dynasm!(self.0.ops
; movz w9, imm_u32 >> 16, lsl 16
; dup V(IMM_REG).s4, w9
);
} else if imm_u32 & 0xFFFF_0000 == 0 {
dynasm!(self.0.ops
; movz w9, imm_u32 & 0xFFFF
; dup V(IMM_REG).s4, w9
);
} else {
dynasm!(self.0.ops
; movz w9, imm_u32 >> 16, lsl 16
; movk w9, imm_u32 & 0xFFFF
; dup V(IMM_REG).s4, w9
);
}
IMM_REG.wrapping_sub(OFFSET)
}
fn finalize(mut self) -> Result<Mmap, DynasmError> {
dynasm!(self.0.ops
; sub x2, x2, 4
; add x3, x3, 16
; b ->L
; ->E:
; ldp x29, x30, [sp, 0x0]
; ldp d8, d9, [sp, 0x10]
; ldp d10, d11, [sp, 0x20]
; ldp d12, d13, [sp, 0x30]
; ldp d14, d15, [sp, 0x40]
; ldr x20, [sp, 0x200]
; ldr x21, [sp, 0x208]
; ldr x22, [sp, 0x210]
; ldr x23, [sp, 0x218]
; ldr x24, [sp, 0x220]
);
self.0.finalize()
}
}
#[expect(clippy::useless_conversion)]
impl FloatSliceAssembler {
fn call_fn_unary(
&mut self,
out_reg: u8,
arg_reg: u8,
f: extern "C" fn(f32) -> f32,
) {
let addr = f as usize;
dynasm!(self.0.ops
; mov x20, x0
; mov x21, x1
; mov x22, x2
; mov x23, x3
; stp q8, q9, [sp, 0x50]
; stp q10, q11, [sp, 0x70]
; stp q12, q13, [sp, 0x90]
; stp q14, q15, [sp, 0xb0]
; stp q16, q17, [sp, 0xd0]
; stp q18, q19, [sp, 0xf0]
; stp q20, q21, [sp, 0x110]
; stp q22, q23, [sp, 0x130]
; stp q24, q25, [sp, 0x150]
; stp q26, q27, [sp, 0x170]
; stp q28, q29, [sp, 0x190]
; stp q30, q31, [sp, 0x1b0]
; movz x24, (addr >> 48) as u32 & 0xFFFF, lsl 48
; movk x24, (addr >> 32) as u32 & 0xFFFF, lsl 32
; movk x24, (addr >> 16) as u32 & 0xFFFF, lsl 16
; movk x24, addr as u32 & 0xFFFF
; mov v0.b16, V(reg(arg_reg)).b16
; mov d8, v0.d[0]
; mov d9, v0.d[1]
; mov s0, v8.s[0]
; blr x24
; mov v8.s[0], v0.s[0]
; mov s0, v8.s[1]
; blr x24
; mov v8.s[1], v0.s[0]
; mov s0, v9.s[0]
; blr x24
; mov v9.s[0], v0.s[0]
; mov s0, v9.s[1]
; blr x24
; mov v9.s[1], v0.s[0]
; mov v0.d[0], v8.d[0]
; mov v0.d[1], v9.d[0]
; ldp q8, q9, [sp, 0x50]
; ldp q10, q11, [sp, 0x70]
; ldp q12, q13, [sp, 0x90]
; ldp q14, q15, [sp, 0xb0]
; ldp q16, q17, [sp, 0xd0]
; ldp q18, q19, [sp, 0xf0]
; ldp q20, q21, [sp, 0x110]
; ldp q22, q23, [sp, 0x130]
; ldp q24, q25, [sp, 0x150]
; ldp q26, q27, [sp, 0x170]
; ldp q28, q29, [sp, 0x190]
; ldp q30, q31, [sp, 0x1b0]
; mov V(reg(out_reg)).b16, v0.b16
; mov x0, x20
; mov x1, x21
; mov x2, x22
; mov x3, x23
);
}
fn call_fn_binary(
&mut self,
out_reg: u8,
lhs_reg: u8,
rhs_reg: u8,
f: extern "C" fn(f32, f32) -> f32,
) {
let addr = f as usize;
dynasm!(self.0.ops
; mov x20, x0
; mov x21, x1
; mov x22, x2
; mov x23, x3
; stp q8, q9, [sp, 0x50]
; stp q10, q11, [sp, 0x70]
; stp q12, q13, [sp, 0x90]
; stp q14, q15, [sp, 0xb0]
; stp q16, q17, [sp, 0xd0]
; stp q18, q19, [sp, 0xf0]
; stp q20, q21, [sp, 0x110]
; stp q22, q23, [sp, 0x130]
; stp q24, q25, [sp, 0x150]
; stp q26, q27, [sp, 0x170]
; stp q28, q29, [sp, 0x190]
; stp q30, q31, [sp, 0x1b0]
; movz x24, (addr >> 48) as u32 & 0xFFFF, lsl 48
; movk x24, (addr >> 32) as u32 & 0xFFFF, lsl 32
; movk x24, (addr >> 16) as u32 & 0xFFFF, lsl 16
; movk x24, addr as u32 & 0xFFFF
; mov v0.b16, V(reg(lhs_reg)).b16
; mov v1.b16, V(reg(rhs_reg)).b16
; mov d8, v0.d[0]
; mov d9, v0.d[1]
; mov d10, v1.d[0]
; mov d11, v1.d[1]
; mov s0, v8.s[0]
; mov s1, v10.s[0]
; blr x24
; mov v8.s[0], v0.s[0]
; mov s0, v8.s[1]
; mov s1, v10.s[1]
; blr x24
; mov v8.s[1], v0.s[0]
; mov s0, v9.s[0]
; mov s1, v11.s[0]
; blr x24
; mov v9.s[0], v0.s[0]
; mov s0, v9.s[1]
; mov s1, v11.s[1]
; blr x24
; mov v9.s[1], v0.s[0]
; mov v0.d[0], v8.d[0]
; mov v0.d[1], v9.d[0]
; ldp q8, q9, [sp, 0x50]
; ldp q10, q11, [sp, 0x70]
; ldp q12, q13, [sp, 0x90]
; ldp q14, q15, [sp, 0xb0]
; ldp q16, q17, [sp, 0xd0]
; ldp q18, q19, [sp, 0xf0]
; ldp q20, q21, [sp, 0x110]
; ldp q22, q23, [sp, 0x130]
; ldp q24, q25, [sp, 0x150]
; ldp q26, q27, [sp, 0x170]
; ldp q28, q29, [sp, 0x190]
; ldp q30, q31, [sp, 0x1b0]
; mov V(reg(out_reg)).b16, v0.b16
; mov x0, x20
; mov x1, x21
; mov x2, x22
; mov x3, x23
);
}
}