use crate::{
Assembler, AssemblerData, IMM_REG, OFFSET, REGISTER_LIMIT,
grad_slice::GradSliceAssembler, mmap::Mmap, reg,
};
use dynasmrt::{DynasmApi, DynasmError, DynasmLabelApi, dynasm};
use fidget_core::types::Grad;
const STACK_SIZE: u32 = 0x220;
#[expect(clippy::useless_conversion)]
impl Assembler for GradSliceAssembler {
type Data = Grad;
fn init(mmap: Mmap, slot_count: usize) -> Self {
let mut out = AssemblerData::new(mmap);
out.prepare_stack(slot_count, STACK_SIZE as usize);
dynasm!(out.ops
; stp x29, x30, [sp, 0x0]
; mov x29, sp
; stp d8, d9, [sp, 0x10]
; stp d10, d11, [sp, 0x20]
; stp d12, d13, [sp, 0x30]
; stp d14, d15, [sp, 0x40]
; str x20, [sp, 0x200]
; str x21, [sp, 0x208]
; str x22, [sp, 0x210]
; str x23, [sp, 0x218]
; mov x3, 0
; ->L:
; cmp x2, 0
; b.eq ->E
);
Self(out)
}
fn bytes_per_clause() -> usize {
20
}
fn build_load(&mut self, dst_reg: u8, src_mem: u32) {
assert!((dst_reg as usize) < REGISTER_LIMIT);
let sp_offset = self.0.stack_pos(src_mem) + STACK_SIZE;
assert!(sp_offset < 65536);
dynasm!(self.0.ops
; ldr Q(reg(dst_reg)), [sp, sp_offset]
)
}
fn build_store(&mut self, dst_mem: u32, src_reg: u8) {
assert!((src_reg as usize) < REGISTER_LIMIT);
let sp_offset = self.0.stack_pos(dst_mem) + STACK_SIZE;
assert!(sp_offset < 65536);
dynasm!(self.0.ops
; str Q(reg(src_reg)), [sp, sp_offset]
)
}
fn build_input(&mut self, out_reg: u8, src_arg: u32) {
assert!(src_arg < 16384 / 8);
dynasm!(self.0.ops
; ldr x4, [x0, src_arg * 8]
; add x4, x4, x3 ; eor V(reg(out_reg)).b16, V(reg(out_reg)).b16, V(reg(out_reg)).b16
; ldr Q(reg(out_reg)), [x4]
);
}
fn build_output(&mut self, arg_reg: u8, out_index: u32) {
assert!(out_index < 16384 / 8);
dynasm!(self.0.ops
; ldr x4, [x1, out_index * 8]
; add x4, x4, x3 ; str Q(reg(arg_reg)), [x4] );
}
fn build_sin(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn grad_sin(v: Grad) -> Grad {
v.sin()
}
self.call_fn_unary(out_reg, lhs_reg, grad_sin);
}
fn build_cos(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_cos(f: Grad) -> Grad {
f.cos()
}
self.call_fn_unary(out_reg, lhs_reg, float_cos);
}
fn build_tan(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_tan(f: Grad) -> Grad {
f.tan()
}
self.call_fn_unary(out_reg, lhs_reg, float_tan);
}
fn build_asin(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_asin(f: Grad) -> Grad {
f.asin()
}
self.call_fn_unary(out_reg, lhs_reg, float_asin);
}
fn build_acos(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_acos(f: Grad) -> Grad {
f.acos()
}
self.call_fn_unary(out_reg, lhs_reg, float_acos);
}
fn build_atan(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_atan(f: Grad) -> Grad {
f.atan()
}
self.call_fn_unary(out_reg, lhs_reg, float_atan);
}
fn build_exp(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_exp(f: Grad) -> Grad {
f.exp()
}
self.call_fn_unary(out_reg, lhs_reg, float_exp);
}
fn build_ln(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn float_ln(f: Grad) -> Grad {
f.ln()
}
self.call_fn_unary(out_reg, lhs_reg, float_ln);
}
fn build_copy(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops ; mov V(reg(out_reg)).b16, V(reg(lhs_reg)).b16)
}
fn build_neg(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops ; fneg V(reg(out_reg)).s4, V(reg(lhs_reg)).s4)
}
fn build_abs(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fcmp S(reg(lhs_reg)), 0.0
; b.lt 12 ; mov V(reg(out_reg)).b16, V(reg(lhs_reg)).b16
; b 8 ; fneg V(reg(out_reg)).s4, V(reg(lhs_reg)).s4
)
}
fn build_recip(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fmul s6, S(reg(lhs_reg)), S(reg(lhs_reg))
; fneg s6, s6
; dup v6.s4, v6.s[0]
; fdiv v7.s4, V(reg(lhs_reg)).s4, v6.s4
; fmov s6, 1.0
; fdiv s6, s6, S(reg(lhs_reg))
; mov V(reg(out_reg)).b16, v7.b16
; mov V(reg(out_reg)).s[0], v6.s[0]
)
}
fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fsqrt s6, S(reg(lhs_reg))
; fmov s7, 2.0
; fmul s7, s6, s7
; dup v7.s4, v7.s[0]
; fdiv V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, v7.s4
; mov V(reg(out_reg)).S[0], v6.S[0]
)
}
fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
dynasm!(self.0.ops
; fmov s7, 2.0
; dup v7.s4, v7.s[0]
; fmov s6, 1.0
; mov v7.S[0], v6.S[0]
; fmov w9, S(reg(lhs_reg))
; dup v6.s4, w9
; fmul V(reg(out_reg)).s4, v6.s4, V(reg(lhs_reg)).s4
; fmul V(reg(out_reg)).s4, v7.s4, V(reg(out_reg)).s4
)
}
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn grad_floor(v: Grad) -> Grad {
v.floor()
}
self.call_fn_unary(out_reg, lhs_reg, grad_floor);
}
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn grad_ceil(v: Grad) -> Grad {
v.ceil()
}
self.call_fn_unary(out_reg, lhs_reg, grad_ceil);
}
fn build_round(&mut self, out_reg: u8, lhs_reg: u8) {
extern "C" fn grad_round(v: Grad) -> Grad {
v.round()
}
self.call_fn_unary(out_reg, lhs_reg, grad_round);
}
fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fadd V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fsub V(reg(out_reg)).s4, V(reg(lhs_reg)).s4, V(reg(rhs_reg)).s4
)
}
fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; dup v6.s4, V(reg(lhs_reg)).s[0]
; fmul v5.s4, v6.s4, V(reg(rhs_reg)).s4
; fmov s7, s5
; dup v6.s4, V(reg(rhs_reg)).s[0]
; fmla v5.s4, v6.s4, V(reg(lhs_reg)).s4
; mov V(reg(out_reg)).b16, v5.b16
; mov V(reg(out_reg)).s[0], v7.s[0]
)
}
fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fmov w9, S(reg(rhs_reg))
; dup v6.s4, w9
; fmul v5.s4, v6.s4, V(reg(lhs_reg)).s4
; fmov w9, S(reg(lhs_reg))
; dup v6.s4, w9
; fmls v5.s4, v6.s4, V(reg(rhs_reg)).s4
; fmul s6, S(reg(rhs_reg)), S(reg(rhs_reg))
; fmov w9, s6
; dup v6.s4, w9
; fdiv v5.s4, v5.s4, v6.s4
; fdiv s6, S(reg(lhs_reg)), S(reg(rhs_reg))
; mov V(reg(out_reg)).b16, v5.b16
; mov V(reg(out_reg)).s[0], v6.s[0]
)
}
fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
extern "C" fn grad_atan2(y: Grad, x: Grad) -> Grad {
y.atan2(x)
}
self.call_fn_binary(out_reg, lhs_reg, rhs_reg, grad_atan2);
}
fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fcmp S(reg(lhs_reg)), S(reg(rhs_reg))
; b.vs 24 ; b.gt 12
; mov V(reg(out_reg)).b16, V(reg(rhs_reg)).b16
; b 20
; mov V(reg(out_reg)).b16, V(reg(lhs_reg)).b16
; b 12
; mov w9, f32::NAN.to_bits()
; fmov S(reg(out_reg)), w9
)
}
fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fcmp S(reg(lhs_reg)), S(reg(rhs_reg))
; b.vs 24 ; b.lt 12
; mov V(reg(out_reg)).b16, V(reg(rhs_reg)).b16
; b 20
; mov V(reg(out_reg)).b16, V(reg(lhs_reg)).b16
; b 12
; mov w9, f32::NAN.to_bits()
; fmov S(reg(out_reg)), w9
)
}
fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
extern "C" fn grad_modulo(lhs: Grad, rhs: Grad) -> Grad {
lhs.rem_euclid(rhs)
}
self.call_fn_binary(out_reg, lhs_reg, rhs_reg, grad_modulo);
}
fn build_not(&mut self, out_reg: u8, arg_reg: u8) {
dynasm!(self.0.ops
; fcmeq s6, S(reg(arg_reg)), 0.0
; fmov S(reg(out_reg)), 1.0
; and V(reg(out_reg)).b16, V(reg(out_reg)).b16, v6.b16
);
}
fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq s6, S(reg(lhs_reg)), 0.0
; dup v6.s4, v6.s[0]
; mvn v7.b16, v6.b16
; and v6.b16, v6.b16, V(reg(lhs_reg)).b16
; and v7.b16, v7.b16, V(reg(rhs_reg)).b16
; orr V(reg(out_reg)).b16, v6.b16, v7.b16
);
}
fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq s6, S(reg(lhs_reg)), 0.0
; dup v6.s4, v6.s[0]
; mvn v7.b16, v6.b16
; and v7.b16, v7.b16, V(reg(lhs_reg)).b16
; and v6.b16, v6.b16, V(reg(rhs_reg)).b16
; orr V(reg(out_reg)).b16, v6.b16, v7.b16
);
}
fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8) {
dynasm!(self.0.ops
; fcmeq s6, S(reg(lhs_reg)), S(reg(lhs_reg))
; dup v6.s4, v6.s[0]
; fcmeq s7, S(reg(rhs_reg)), S(reg(rhs_reg))
; dup v7.s4, v7.s[0]
; and v6.b16, v6.b16, v7.b16
; mvn v6.b16, v6.b16
; fcmgt s4, S(reg(rhs_reg)), S(reg(lhs_reg))
; dup v4.s4, v4.s[0]
; fcmgt s5, S(reg(lhs_reg)), S(reg(rhs_reg))
; dup v5.s4, v5.s[0]
; fmov s7, -1.0
; and V(reg(out_reg)).b16, v4.b16, v7.b16
; fmov s7, 1.0
; and v5.B16, v5.B16, v7.B16
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v5.B16
; mov w9, f32::NAN.to_bits()
; fmov s7, w9
; and v7.b16, v7.b16, v6.b16
; orr V(reg(out_reg)).B16, V(reg(out_reg)).B16, v7.b16
)
}
fn load_imm(&mut self, imm: f32) -> u8 {
let imm_u32 = imm.to_bits();
if imm_u32 & 0xFFFF == 0 {
dynasm!(self.0.ops
; movz w9, imm_u32 >> 16, lsl 16
; fmov S(IMM_REG), w9
);
} else if imm_u32 & 0xFFFF_0000 == 0 {
dynasm!(self.0.ops
; movz w9, imm_u32 & 0xFFFF
; fmov S(IMM_REG), w9
);
} else {
dynasm!(self.0.ops
; movz w9, imm_u32 >> 16, lsl 16
; movk w9, imm_u32 & 0xFFFF
; fmov S(IMM_REG), w9
);
}
IMM_REG.wrapping_sub(OFFSET)
}
fn finalize(mut self) -> Result<Mmap, DynasmError> {
dynasm!(self.0.ops
; sub x2, x2, 1
; add x3, x3, 16
; b ->L
; ->E:
; ldp x29, x30, [sp, 0x0]
; ldp d8, d9, [sp, 0x10]
; ldp d10, d11, [sp, 0x20]
; ldp d12, d13, [sp, 0x30]
; ldp d14, d15, [sp, 0x40]
; ldr x20, [sp, 0x200]
; ldr x21, [sp, 0x208]
; ldr x22, [sp, 0x210]
; ldr x23, [sp, 0x218]
);
self.0.finalize()
}
}
#[expect(clippy::useless_conversion)]
impl GradSliceAssembler {
fn call_fn_unary(
&mut self,
out_reg: u8,
arg_reg: u8,
f: extern "C" fn(Grad) -> Grad,
) {
let addr = f as usize;
dynasm!(self.0.ops
; mov x20, x0
; mov x21, x1
; mov x22, x2
; mov x23, x3
; stp q8, q9, [sp, 0x50]
; stp q10, q11, [sp, 0x70]
; stp q12, q13, [sp, 0x90]
; stp q14, q15, [sp, 0xb0]
; stp q16, q17, [sp, 0xd0]
; stp q18, q19, [sp, 0xf0]
; stp q20, q21, [sp, 0x110]
; stp q22, q23, [sp, 0x130]
; stp q24, q25, [sp, 0x150]
; stp q26, q27, [sp, 0x170]
; stp q28, q29, [sp, 0x190]
; stp q30, q31, [sp, 0x1b0]
; movz x0, (addr >> 48) as u32 & 0xFFFF, lsl 48
; movk x0, (addr >> 32) as u32 & 0xFFFF, lsl 32
; movk x0, (addr >> 16) as u32 & 0xFFFF, lsl 16
; movk x0, addr as u32 & 0xFFFF
; mov s0, V(reg(arg_reg)).s[0]
; mov s1, V(reg(arg_reg)).s[1]
; mov s2, V(reg(arg_reg)).s[2]
; mov s3, V(reg(arg_reg)).s[3]
; blr x0
; ldp q8, q9, [sp, 0x50]
; ldp q10, q11, [sp, 0x70]
; ldp q12, q13, [sp, 0x90]
; ldp q14, q15, [sp, 0xb0]
; ldp q16, q17, [sp, 0xd0]
; ldp q18, q19, [sp, 0xf0]
; ldp q20, q21, [sp, 0x110]
; ldp q22, q23, [sp, 0x130]
; ldp q24, q25, [sp, 0x150]
; ldp q26, q27, [sp, 0x170]
; ldp q28, q29, [sp, 0x190]
; ldp q30, q31, [sp, 0x1b0]
; mov V(reg(out_reg)).s[0], v0.s[0]
; mov V(reg(out_reg)).s[1], v1.s[0]
; mov V(reg(out_reg)).s[2], v2.s[0]
; mov V(reg(out_reg)).s[3], v3.s[0]
; mov x0, x20
; mov x1, x21
; mov x2, x22
; mov x3, x23
);
}
fn call_fn_binary(
&mut self,
out_reg: u8,
lhs_reg: u8,
rhs_reg: u8,
f: extern "C" fn(Grad, Grad) -> Grad,
) {
let addr = f as usize;
dynasm!(self.0.ops
; mov x20, x0
; mov x21, x1
; mov x22, x2
; mov x23, x3
; stp q0, q1, [sp, 0x1d0]
; str q2, [sp, 0x1f0]
; stp q8, q9, [sp, 0x50]
; stp q10, q11, [sp, 0x70]
; stp q12, q13, [sp, 0x90]
; stp q14, q15, [sp, 0xb0]
; stp q16, q17, [sp, 0xd0]
; stp q18, q19, [sp, 0xf0]
; stp q20, q21, [sp, 0x110]
; stp q22, q23, [sp, 0x130]
; stp q24, q25, [sp, 0x150]
; stp q26, q27, [sp, 0x170]
; stp q28, q29, [sp, 0x190]
; stp q30, q31, [sp, 0x1b0]
; movz x0, (addr >> 48) as u32 & 0xFFFF, lsl 48
; movk x0, (addr >> 32) as u32 & 0xFFFF, lsl 32
; movk x0, (addr >> 16) as u32 & 0xFFFF, lsl 16
; movk x0, addr as u32 & 0xFFFF
; mov s0, V(reg(lhs_reg)).s[0]
; mov s1, V(reg(lhs_reg)).s[1]
; mov s2, V(reg(lhs_reg)).s[2]
; mov s4, V(reg(rhs_reg)).s[0]
; mov s5, V(reg(rhs_reg)).s[1]
; mov s6, V(reg(rhs_reg)).s[2]
; mov s7, V(reg(rhs_reg)).s[3]
; mov s3, V(reg(lhs_reg)).s[3]
; blr x0
; ldp q8, q9, [sp, 0x50]
; ldp q10, q11, [sp, 0x70]
; ldp q12, q13, [sp, 0x90]
; ldp q14, q15, [sp, 0xb0]
; ldp q16, q17, [sp, 0xd0]
; ldp q18, q19, [sp, 0xf0]
; ldp q20, q21, [sp, 0x110]
; ldp q22, q23, [sp, 0x130]
; ldp q24, q25, [sp, 0x150]
; ldp q26, q27, [sp, 0x170]
; ldp q28, q29, [sp, 0x190]
; ldp q30, q31, [sp, 0x1b0]
; mov V(reg(out_reg)).s[0], v0.s[0]
; mov V(reg(out_reg)).s[1], v1.s[0]
; mov V(reg(out_reg)).s[2], v2.s[0]
; mov V(reg(out_reg)).s[3], v3.s[0]
; ldp q0, q1, [sp, 0x1d0]
; ldr q2, [sp, 0x1f0]
; mov x0, x20
; mov x1, x21
; mov x2, x22
; mov x3, x23
);
}
}