use crate::mmap::{Mmap, MmapWriter};
use fidget_core::{
Error,
compiler::RegOp,
context::{Context, Node},
eval::{
BulkEvaluator, BulkOutput, Function, MathFunction, Tape,
TracingEvaluator,
},
render::{RenderHints, TileSizes},
types::{Grad, Interval},
var::VarMap,
vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace},
};
use dynasmrt::{
AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi,
TargetKind, components::PatchLoc, dynasm,
};
use std::sync::Arc;
mod mmap;
mod permit;
pub(crate) use permit::WritePermit;
mod float_slice;
mod grad_slice;
mod interval;
mod point;
#[cfg(not(any(
target_os = "linux",
target_os = "macos",
target_os = "windows"
)))]
compile_error!(
"The `jit` module only builds on Linux, macOS, and Windows; \
please disable the `jit` feature"
);
#[cfg(target_arch = "aarch64")]
mod aarch64;
#[cfg(target_arch = "aarch64")]
use aarch64 as arch;
#[cfg(target_arch = "x86_64")]
mod x86_64;
#[cfg(target_arch = "x86_64")]
use x86_64 as arch;
const REGISTER_LIMIT: usize = arch::REGISTER_LIMIT;
const OFFSET: u8 = arch::OFFSET;
const IMM_REG: u8 = arch::IMM_REG;
fn reg(r: u8) -> u8 {
let out = r.wrapping_add(OFFSET);
assert!(out < 32);
out
}
const CHOICE_LEFT: u32 = Choice::Left as u32;
const CHOICE_RIGHT: u32 = Choice::Right as u32;
const CHOICE_BOTH: u32 = Choice::Both as u32;
trait Assembler {
type Data;
fn init(m: Mmap, slot_count: usize) -> Self;
fn bytes_per_clause() -> usize {
8 }
fn build_load(&mut self, dst_reg: u8, src_mem: u32);
fn build_store(&mut self, dst_mem: u32, src_reg: u8);
fn build_input(&mut self, out_reg: u8, src_arg: u32);
fn build_output(&mut self, arg_reg: u8, out_index: u32);
fn build_copy(&mut self, out_reg: u8, lhs_reg: u8);
fn build_neg(&mut self, out_reg: u8, lhs_reg: u8);
fn build_abs(&mut self, out_reg: u8, lhs_reg: u8);
fn build_recip(&mut self, out_reg: u8, lhs_reg: u8);
fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8);
fn build_sin(&mut self, out_reg: u8, lhs_reg: u8);
fn build_cos(&mut self, out_reg: u8, lhs_reg: u8);
fn build_tan(&mut self, out_reg: u8, lhs_reg: u8);
fn build_asin(&mut self, out_reg: u8, lhs_reg: u8);
fn build_acos(&mut self, out_reg: u8, lhs_reg: u8);
fn build_atan(&mut self, out_reg: u8, lhs_reg: u8);
fn build_exp(&mut self, out_reg: u8, lhs_reg: u8);
fn build_ln(&mut self, out_reg: u8, lhs_reg: u8);
fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
self.build_mul(out_reg, lhs_reg, lhs_reg)
}
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8);
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8);
fn build_round(&mut self, out_reg: u8, lhs_reg: u8);
fn build_not(&mut self, out_reg: u8, lhs_reg: u8);
fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_add_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_add(out_reg, lhs_reg, imm);
}
fn build_sub_imm_reg(&mut self, out_reg: u8, arg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_sub(out_reg, imm, arg);
}
fn build_sub_reg_imm(&mut self, out_reg: u8, arg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_sub(out_reg, arg, imm);
}
fn build_mul_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_mul(out_reg, lhs_reg, imm);
}
fn load_imm(&mut self, imm: f32) -> u8;
fn finalize(self) -> Result<Mmap, DynasmError>;
}
pub trait SimdSize {
const SIMD_SIZE: usize;
}
pub(crate) struct AssemblerData<T> {
ops: MmapAssembler,
mem_offset: usize,
saved_callee_regs: bool,
_p: std::marker::PhantomData<*const T>,
}
impl<T> AssemblerData<T> {
fn new(mmap: Mmap) -> Self {
Self {
ops: MmapAssembler::from(mmap),
mem_offset: 0,
saved_callee_regs: false,
_p: std::marker::PhantomData,
}
}
fn prepare_stack(&mut self, slot_count: usize, stack_size: usize) {
let mem = slot_count.saturating_sub(REGISTER_LIMIT)
* std::mem::size_of::<T>()
+ stack_size;
self.mem_offset = mem.next_multiple_of(16);
self.push_stack();
}
fn stack_pos(&self, slot: u32) -> u32 {
assert!(slot >= REGISTER_LIMIT as u32);
(slot - REGISTER_LIMIT as u32) * std::mem::size_of::<T>() as u32
}
}
#[cfg(target_arch = "x86_64")]
impl<T> AssemblerData<T> {
fn push_stack(&mut self) {
dynasm!(self.ops
; sub rsp, self.mem_offset as i32
);
}
fn finalize(mut self) -> Result<Mmap, DynasmError> {
dynasm!(self.ops
; add rsp, self.mem_offset as i32
; pop rbp
; vzeroupper
; ret
);
self.ops.finalize()
}
}
#[cfg(target_arch = "aarch64")]
#[allow(clippy::unnecessary_cast)] impl<T> AssemblerData<T> {
fn push_stack(&mut self) {
if self.mem_offset < 4096 {
dynasm!(self.ops
; sub sp, sp, self.mem_offset as u32
);
} else if self.mem_offset < 65536 {
dynasm!(self.ops
; mov w28, self.mem_offset as u32
; sub sp, sp, w28
);
} else {
panic!("invalid mem offset: {} is too large", self.mem_offset);
}
}
fn finalize(mut self) -> Result<Mmap, DynasmError> {
if self.mem_offset < 4096 {
dynasm!(self.ops
; add sp, sp, self.mem_offset as u32
);
} else if self.mem_offset < 65536 {
dynasm!(self.ops
; mov w9, self.mem_offset as u32
; add sp, sp, w9
);
} else {
panic!("invalid mem offset: {}", self.mem_offset);
}
dynasm!(self.ops
; ret
);
self.ops.finalize()
}
}
#[cfg(target_arch = "x86_64")]
type Relocation = dynasmrt::x64::X64Relocation;
#[cfg(target_arch = "aarch64")]
type Relocation = dynasmrt::aarch64::Aarch64Relocation;
struct MmapAssembler {
mmap: MmapWriter,
global_labels: [Option<AssemblyOffset>; 26],
local_labels: [Option<AssemblyOffset>; 26],
global_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 2>,
local_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 8>,
}
impl Extend<u8> for MmapAssembler {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = u8>,
{
for c in iter.into_iter() {
self.push(c);
}
}
}
impl<'a> Extend<&'a u8> for MmapAssembler {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = &'a u8>,
{
for c in iter.into_iter() {
self.push(*c);
}
}
}
impl DynasmApi for MmapAssembler {
#[inline(always)]
fn offset(&self) -> AssemblyOffset {
AssemblyOffset(self.mmap.len())
}
#[inline(always)]
fn push(&mut self, byte: u8) {
self.mmap.push(byte);
}
#[inline(always)]
fn align(&mut self, alignment: usize, with: u8) {
let offset = self.offset().0 % alignment;
if offset != 0 {
for _ in offset..alignment {
self.push(with);
}
}
}
#[inline(always)]
fn push_u32(&mut self, value: u32) {
for b in value.to_le_bytes() {
self.mmap.push(b);
}
}
}
impl DynasmLabelApi for MmapAssembler {
type Relocation = Relocation;
fn local_label(&mut self, name: &'static str) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.local_labels[c as usize].is_some() {
panic!("duplicate local label {name}");
}
self.local_labels[c as usize] = Some(self.offset());
}
fn global_label(&mut self, name: &'static str) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.global_labels[c as usize].is_some() {
panic!("duplicate global label {name}");
}
self.global_labels[c as usize] = Some(self.offset());
}
fn dynamic_label(&mut self, _id: DynamicLabel) {
panic!("dynamic labels are not supported");
}
fn global_relocation(
&mut self,
name: &'static str,
target_offset: isize,
field_offset: u8,
ref_offset: u8,
kind: Relocation,
) {
let location = self.offset();
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
self.global_relocs.push((
PatchLoc::new(
location,
target_offset,
field_offset,
ref_offset,
kind,
),
c,
));
}
fn dynamic_relocation(
&mut self,
_id: DynamicLabel,
_target_offset: isize,
_field_offset: u8,
_ref_offset: u8,
_kind: Relocation,
) {
panic!("dynamic relocations are not supported");
}
fn forward_relocation(
&mut self,
name: &'static str,
target_offset: isize,
field_offset: u8,
ref_offset: u8,
kind: Relocation,
) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.local_labels[c as usize].is_some() {
panic!("invalid forward relocation: {name} already exists!");
}
let location = self.offset();
self.local_relocs.push((
PatchLoc::new(
location,
target_offset,
field_offset,
ref_offset,
kind,
),
c,
));
}
fn backward_relocation(
&mut self,
name: &'static str,
target_offset: isize,
field_offset: u8,
ref_offset: u8,
kind: Relocation,
) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.local_labels[c as usize].is_none() {
panic!("invalid backward relocation: {name} does not exist");
}
let location = self.offset();
self.local_relocs.push((
PatchLoc::new(
location,
target_offset,
field_offset,
ref_offset,
kind,
),
c,
));
}
fn value_relocation(
&mut self,
_target: usize,
_field_offset: u8,
_ref_offset: u8,
_kind: Relocation,
) {
panic!("bare relocations not implemented");
}
}
impl MmapAssembler {
fn commit_local(&mut self) -> Result<(), DynasmError> {
let baseaddr = self.mmap.as_ptr() as usize;
for (loc, label) in self.local_relocs.take() {
let target =
self.local_labels[label as usize].expect("invalid local label");
let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
if loc.patch(buf, baseaddr, target.0).is_err() {
return Err(DynasmError::ImpossibleRelocation(
TargetKind::Local("oh no"),
));
}
}
self.local_labels = [None; 26];
Ok(())
}
fn finalize(mut self) -> Result<Mmap, DynasmError> {
self.commit_local()?;
let baseaddr = self.mmap.as_ptr() as usize;
for (loc, label) in self.global_relocs.take() {
let target =
self.global_labels.get(label as usize).unwrap().unwrap();
let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
if loc.patch(buf, baseaddr, target.0).is_err() {
return Err(DynasmError::ImpossibleRelocation(
TargetKind::Global("oh no"),
));
}
}
Ok(self.mmap.finalize())
}
}
impl From<Mmap> for MmapAssembler {
fn from(mmap: Mmap) -> Self {
Self {
mmap: MmapWriter::from(mmap),
global_labels: [None; 26],
local_labels: [None; 26],
global_relocs: Default::default(),
local_relocs: Default::default(),
}
}
}
fn build_asm_fn_with_storage<A: Assembler>(
t: &VmData<REGISTER_LIMIT>,
mut s: Mmap,
) -> Mmap {
let size_estimate = t.len() * A::bytes_per_clause();
if size_estimate > 2 * s.capacity() {
s = Mmap::new(size_estimate).expect("failed to build mmap")
}
let mut asm = A::init(s, t.slot_count());
for op in t.iter_asm() {
match op {
RegOp::Load(reg, mem) => {
asm.build_load(reg, mem);
}
RegOp::Store(reg, mem) => {
asm.build_store(mem, reg);
}
RegOp::Input(out, i) => {
asm.build_input(out, i);
}
RegOp::Output(arg, i) => {
asm.build_output(arg, i);
}
RegOp::NegReg(out, arg) => {
asm.build_neg(out, arg);
}
RegOp::AbsReg(out, arg) => {
asm.build_abs(out, arg);
}
RegOp::RecipReg(out, arg) => {
asm.build_recip(out, arg);
}
RegOp::SqrtReg(out, arg) => {
asm.build_sqrt(out, arg);
}
RegOp::SinReg(out, arg) => {
asm.build_sin(out, arg);
}
RegOp::CosReg(out, arg) => {
asm.build_cos(out, arg);
}
RegOp::TanReg(out, arg) => {
asm.build_tan(out, arg);
}
RegOp::AsinReg(out, arg) => {
asm.build_asin(out, arg);
}
RegOp::AcosReg(out, arg) => {
asm.build_acos(out, arg);
}
RegOp::AtanReg(out, arg) => {
asm.build_atan(out, arg);
}
RegOp::ExpReg(out, arg) => {
asm.build_exp(out, arg);
}
RegOp::LnReg(out, arg) => {
asm.build_ln(out, arg);
}
RegOp::CopyReg(out, arg) => {
asm.build_copy(out, arg);
}
RegOp::SquareReg(out, arg) => {
asm.build_square(out, arg);
}
RegOp::FloorReg(out, arg) => {
asm.build_floor(out, arg);
}
RegOp::CeilReg(out, arg) => {
asm.build_ceil(out, arg);
}
RegOp::RoundReg(out, arg) => {
asm.build_round(out, arg);
}
RegOp::NotReg(out, arg) => {
asm.build_not(out, arg);
}
RegOp::AddRegReg(out, lhs, rhs) => {
asm.build_add(out, lhs, rhs);
}
RegOp::MulRegReg(out, lhs, rhs) => {
asm.build_mul(out, lhs, rhs);
}
RegOp::DivRegReg(out, lhs, rhs) => {
asm.build_div(out, lhs, rhs);
}
RegOp::AtanRegReg(out, lhs, rhs) => {
asm.build_atan2(out, lhs, rhs);
}
RegOp::SubRegReg(out, lhs, rhs) => {
asm.build_sub(out, lhs, rhs);
}
RegOp::MinRegReg(out, lhs, rhs) => {
asm.build_min(out, lhs, rhs);
}
RegOp::MaxRegReg(out, lhs, rhs) => {
asm.build_max(out, lhs, rhs);
}
RegOp::AddRegImm(out, arg, imm) => {
asm.build_add_imm(out, arg, imm);
}
RegOp::MulRegImm(out, arg, imm) => {
asm.build_mul_imm(out, arg, imm);
}
RegOp::DivRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_div(out, arg, reg);
}
RegOp::DivImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_div(out, reg, arg);
}
RegOp::AtanRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_atan2(out, arg, reg);
}
RegOp::AtanImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_atan2(out, reg, arg);
}
RegOp::SubImmReg(out, arg, imm) => {
asm.build_sub_imm_reg(out, arg, imm);
}
RegOp::SubRegImm(out, arg, imm) => {
asm.build_sub_reg_imm(out, arg, imm);
}
RegOp::MinRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_min(out, arg, reg);
}
RegOp::MaxRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_max(out, arg, reg);
}
RegOp::ModRegReg(out, lhs, rhs) => {
asm.build_mod(out, lhs, rhs);
}
RegOp::ModRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_mod(out, arg, reg);
}
RegOp::ModImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_mod(out, reg, arg);
}
RegOp::AndRegReg(out, lhs, rhs) => {
asm.build_and(out, lhs, rhs);
}
RegOp::AndRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_and(out, arg, reg);
}
RegOp::OrRegReg(out, lhs, rhs) => {
asm.build_or(out, lhs, rhs);
}
RegOp::OrRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_or(out, arg, reg);
}
RegOp::CopyImm(out, imm) => {
let reg = asm.load_imm(imm);
asm.build_copy(out, reg);
}
RegOp::CompareRegReg(out, lhs, rhs) => {
asm.build_compare(out, lhs, rhs);
}
RegOp::CompareRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_compare(out, arg, reg);
}
RegOp::CompareImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_compare(out, reg, arg);
}
}
}
asm.finalize().expect("failed to build JIT function")
}
#[derive(Clone)]
pub struct JitFunction(GenericVmFunction<REGISTER_LIMIT>);
impl JitFunction {
fn tracing_tape<A: Assembler>(
&self,
storage: Mmap,
) -> JitTracingFn<A::Data> {
let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
let ptr = f.as_ptr();
JitTracingFn {
mmap: f.into(),
vars: self.0.data().vars.clone(),
choice_count: self.0.choice_count(),
output_count: self.0.output_count(),
fn_trace: unsafe {
std::mem::transmute::<
*const std::ffi::c_void,
JitTracingFnPointer<A::Data>,
>(ptr)
},
}
}
fn bulk_tape<A: Assembler>(&self, storage: Mmap) -> JitBulkFn<A::Data> {
let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
let ptr = f.as_ptr();
JitBulkFn {
mmap: f.into(),
output_count: self.0.output_count(),
vars: self.0.data().vars.clone(),
fn_bulk: unsafe {
std::mem::transmute::<
*const std::ffi::c_void,
JitBulkFnPointer<A::Data>,
>(ptr)
},
}
}
}
impl Function for JitFunction {
type Trace = VmTrace;
type Storage = VmData<REGISTER_LIMIT>;
type Workspace = VmWorkspace<REGISTER_LIMIT>;
type TapeStorage = Mmap;
type IntervalEval = JitIntervalEval;
type PointEval = JitPointEval;
type FloatSliceEval = JitFloatSliceEval;
type GradSliceEval = JitGradSliceEval;
#[inline]
fn point_tape(&self, storage: Mmap) -> JitTracingFn<f32> {
self.tracing_tape::<point::PointAssembler>(storage)
}
#[inline]
fn interval_tape(&self, storage: Mmap) -> JitTracingFn<Interval> {
self.tracing_tape::<interval::IntervalAssembler>(storage)
}
#[inline]
fn float_slice_tape(&self, storage: Mmap) -> JitBulkFn<f32> {
self.bulk_tape::<float_slice::FloatSliceAssembler>(storage)
}
#[inline]
fn grad_slice_tape(&self, storage: Mmap) -> JitBulkFn<Grad> {
self.bulk_tape::<grad_slice::GradSliceAssembler>(storage)
}
#[inline]
fn simplify(
&self,
trace: &Self::Trace,
storage: Self::Storage,
workspace: &mut Self::Workspace,
) -> Result<Self, Error> {
self.0.simplify(trace, storage, workspace).map(JitFunction)
}
#[inline]
fn recycle(self) -> Option<Self::Storage> {
self.0.recycle()
}
#[inline]
fn size(&self) -> usize {
self.0.size()
}
#[inline]
fn vars(&self) -> &VarMap {
self.0.vars()
}
#[inline]
fn can_simplify(&self) -> bool {
self.0.choice_count() > 0
}
}
impl RenderHints for JitFunction {
fn tile_sizes_3d() -> TileSizes {
TileSizes::new(&[64, 16, 8]).unwrap()
}
fn tile_sizes_2d() -> TileSizes {
TileSizes::new(&[128, 16]).unwrap()
}
fn simplify_tree_during_meshing(d: usize) -> bool {
d % 8 == 4
}
}
impl MathFunction for JitFunction {
fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
GenericVmFunction::new(ctx, nodes).map(JitFunction)
}
}
impl From<GenericVmFunction<REGISTER_LIMIT>> for JitFunction {
fn from(v: GenericVmFunction<REGISTER_LIMIT>) -> Self {
Self(v)
}
}
impl<'a> From<&'a JitFunction> for &'a GenericVmFunction<REGISTER_LIMIT> {
fn from(v: &'a JitFunction) -> Self {
&v.0
}
}
#[cfg(target_arch = "x86_64")]
macro_rules! jit_fn {
(unsafe fn($($args:tt)*)) => {
unsafe extern "sysv64" fn($($args)*)
};
}
#[cfg(target_arch = "aarch64")]
macro_rules! jit_fn {
(unsafe fn($($args:tt)*)) => {
unsafe extern "C" fn($($args)*)
};
}
struct JitTracingEval<T> {
choices: VmTrace,
out: Vec<T>,
}
impl<T> Default for JitTracingEval<T> {
fn default() -> Self {
Self {
choices: VmTrace::default(),
out: Vec::default(),
}
}
}
pub type JitTracingFnPointer<T> = jit_fn!(
unsafe fn(
*const T, *mut u8, *mut u8, *mut T, )
);
#[derive(Clone)]
pub struct JitTracingFn<T> {
mmap: Arc<Mmap>,
choice_count: usize,
output_count: usize,
vars: Arc<VarMap>,
fn_trace: JitTracingFnPointer<T>,
}
impl<T: Clone> Tape for JitTracingFn<T> {
type Storage = Mmap;
fn recycle(self) -> Option<Self::Storage> {
Arc::into_inner(self.mmap)
}
fn vars(&self) -> &VarMap {
&self.vars
}
fn output_count(&self) -> usize {
self.output_count
}
}
unsafe impl<T> Send for JitTracingFn<T> {}
unsafe impl<T> Sync for JitTracingFn<T> {}
impl<T: From<f32> + Clone> JitTracingEval<T> {
fn eval(
&mut self,
tape: &JitTracingFn<T>,
vars: &[T],
) -> (&[T], Option<&VmTrace>) {
let mut simplify = 0;
self.choices.resize(tape.choice_count, Choice::Unknown);
self.choices.fill(Choice::Unknown);
self.out.resize(tape.output_count, f32::NAN.into());
self.out.fill(f32::NAN.into());
unsafe {
(tape.fn_trace)(
vars.as_ptr(),
self.choices.as_mut_ptr() as *mut u8,
&mut simplify,
self.out.as_mut_ptr(),
)
};
(
&self.out,
if simplify != 0 {
Some(&self.choices)
} else {
None
},
)
}
}
#[derive(Default)]
pub struct JitIntervalEval(JitTracingEval<Interval>);
impl TracingEvaluator for JitIntervalEval {
type Data = Interval;
type Tape = JitTracingFn<Interval>;
type Trace = VmTrace;
type TapeStorage = Mmap;
#[inline]
fn eval(
&mut self,
tape: &Self::Tape,
vars: &[Self::Data],
) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
tape.vars().check_tracing_arguments(vars)?;
Ok(self.0.eval(tape, vars))
}
}
#[derive(Default)]
pub struct JitPointEval(JitTracingEval<f32>);
impl TracingEvaluator for JitPointEval {
type Data = f32;
type Tape = JitTracingFn<f32>;
type Trace = VmTrace;
type TapeStorage = Mmap;
#[inline]
fn eval(
&mut self,
tape: &Self::Tape,
vars: &[Self::Data],
) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
tape.vars().check_tracing_arguments(vars)?;
Ok(self.0.eval(tape, vars))
}
}
pub type JitBulkFnPointer<T> = jit_fn!(
unsafe fn(
*const *const T, *const *mut T, u64, )
);
#[derive(Clone)]
pub struct JitBulkFn<T> {
mmap: Arc<Mmap>,
vars: Arc<VarMap>,
output_count: usize,
fn_bulk: JitBulkFnPointer<T>,
}
impl<T: Clone> Tape for JitBulkFn<T> {
type Storage = Mmap;
fn recycle(self) -> Option<Self::Storage> {
Arc::into_inner(self.mmap)
}
fn vars(&self) -> &VarMap {
&self.vars
}
fn output_count(&self) -> usize {
self.output_count
}
}
const MAX_SIMD_WIDTH: usize = 8;
struct JitBulkEval<T> {
input_ptrs: Vec<*const T>,
output_ptrs: Vec<*mut T>,
scratch: Vec<[T; MAX_SIMD_WIDTH]>,
out: Vec<Vec<T>>,
}
unsafe impl<T> Sync for JitBulkEval<T> {}
unsafe impl<T> Send for JitBulkEval<T> {}
impl<T> Default for JitBulkEval<T> {
fn default() -> Self {
Self {
out: vec![],
scratch: vec![],
input_ptrs: vec![],
output_ptrs: vec![],
}
}
}
unsafe impl<T> Send for JitBulkFn<T> {}
unsafe impl<T> Sync for JitBulkFn<T> {}
impl<T: From<f32> + Copy + SimdSize> JitBulkEval<T> {
fn eval<V: std::ops::Deref<Target = [T]>>(
&mut self,
tape: &JitBulkFn<T>,
vars: &[V],
) -> BulkOutput<'_, T> {
let n = vars.first().map(|v| v.deref().len()).unwrap_or(0);
self.out.resize_with(tape.output_count(), Vec::new);
for o in &mut self.out {
o.resize(n.max(T::SIMD_SIZE), f32::NAN.into());
o.fill(f32::NAN.into());
}
if n < T::SIMD_SIZE {
assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH);
self.scratch
.resize(vars.len(), [f32::NAN.into(); MAX_SIMD_WIDTH]);
for (v, t) in vars.iter().zip(self.scratch.iter_mut()) {
t[0..n].copy_from_slice(v);
}
self.input_ptrs.clear();
self.input_ptrs
.extend(self.scratch[..vars.len()].iter().map(|t| t.as_ptr()));
self.output_ptrs.clear();
self.output_ptrs
.extend(self.out.iter_mut().map(|t| t.as_mut_ptr()));
unsafe {
(tape.fn_bulk)(
self.input_ptrs.as_ptr(),
self.output_ptrs.as_ptr(),
T::SIMD_SIZE as u64,
);
}
} else {
let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; self.input_ptrs.clear();
self.input_ptrs.extend(vars.iter().map(|v| v.as_ptr()));
self.output_ptrs.clear();
self.output_ptrs
.extend(self.out.iter_mut().map(|v| v.as_mut_ptr()));
unsafe {
(tape.fn_bulk)(
self.input_ptrs.as_ptr(),
self.output_ptrs.as_ptr(),
m as u64,
);
}
if n != m {
self.input_ptrs.clear();
self.output_ptrs.clear();
unsafe {
self.input_ptrs.extend(
vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)),
);
self.output_ptrs.extend(
self.out
.iter_mut()
.map(|v| v.as_mut_ptr().add(n - T::SIMD_SIZE)),
);
(tape.fn_bulk)(
self.input_ptrs.as_ptr(),
self.output_ptrs.as_ptr(),
T::SIMD_SIZE as u64,
);
}
}
}
BulkOutput::new(&self.out, n)
}
}
#[derive(Default)]
pub struct JitFloatSliceEval(JitBulkEval<f32>);
impl BulkEvaluator for JitFloatSliceEval {
type Data = f32;
type Tape = JitBulkFn<Self::Data>;
type TapeStorage = Mmap;
#[inline]
fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
&mut self,
tape: &Self::Tape,
vars: &[V],
) -> Result<BulkOutput<'_, f32>, Error> {
tape.vars().check_bulk_arguments(vars)?;
Ok(self.0.eval(tape, vars))
}
}
#[derive(Default)]
pub struct JitGradSliceEval(JitBulkEval<Grad>);
impl BulkEvaluator for JitGradSliceEval {
type Data = Grad;
type Tape = JitBulkFn<Self::Data>;
type TapeStorage = Mmap;
#[inline]
fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
&mut self,
tape: &Self::Tape,
vars: &[V],
) -> Result<BulkOutput<'_, Grad>, Error> {
tape.vars().check_bulk_arguments(vars)?;
Ok(self.0.eval(tape, vars))
}
}
pub type JitShape = fidget_core::shape::Shape<JitFunction>;
#[cfg(test)]
mod test {
use super::*;
fidget_core::grad_slice_tests!(JitFunction);
fidget_core::interval_tests!(JitFunction);
fidget_core::float_slice_tests!(JitFunction);
fidget_core::point_tests!(JitFunction);
#[test]
fn test_mmap_expansion() {
let mmap = Mmap::new(0).unwrap();
let mut asm = MmapAssembler::from(mmap);
const COUNT: u32 = 23456;
for i in 0..COUNT {
asm.push_u32(i);
}
let mmap = asm.finalize().unwrap();
let ptr = mmap.as_ptr() as *const u32;
for i in 0..COUNT {
let v = unsafe { *ptr.add(i as usize) };
assert_eq!(v, i);
}
}
}