use crate::{
compiler::RegOp,
context::{Context, Node},
eval::{
BulkEvaluator, MathShape, Shape, Tape, TracingEvaluator,
TransformedShape,
},
jit::mmap::Mmap,
shape::RenderHints,
types::{Grad, Interval},
vm::{Choice, GenericVmShape, VmData, VmTrace, VmWorkspace},
Error,
};
use dynasmrt::{
components::PatchLoc, dynasm, AssemblyOffset, DynamicLabel, DynasmApi,
DynasmError, DynasmLabelApi, TargetKind,
};
use nalgebra::Matrix4;
mod mmap;
mod float_slice;
mod grad_slice;
mod interval;
mod point;
#[cfg(not(any(
target_os = "linux",
target_os = "macos",
target_os = "windows"
)))]
compile_error!(
"The `jit` module only builds on Linux, macOS, and Windows; \
please disable the `jit` feature"
);
#[cfg(target_arch = "aarch64")]
mod aarch64;
#[cfg(target_arch = "aarch64")]
use aarch64 as arch;
#[cfg(target_arch = "x86_64")]
mod x86_64;
#[cfg(target_arch = "x86_64")]
use x86_64 as arch;
const REGISTER_LIMIT: usize = arch::REGISTER_LIMIT;
const OFFSET: u8 = arch::OFFSET;
const IMM_REG: u8 = arch::IMM_REG;
#[cfg(target_arch = "aarch64")]
type RegIndex = u32;
#[cfg(target_arch = "x86_64")]
type RegIndex = u8;
fn reg(r: u8) -> RegIndex {
let out = r.wrapping_add(OFFSET) as RegIndex;
assert!(out < 32);
out
}
const CHOICE_LEFT: u32 = Choice::Left as u32;
const CHOICE_RIGHT: u32 = Choice::Right as u32;
const CHOICE_BOTH: u32 = Choice::Both as u32;
trait Assembler {
type Data;
fn init(m: Mmap, slot_count: usize) -> Self;
fn bytes_per_clause() -> usize {
8 }
fn build_load(&mut self, dst_reg: u8, src_mem: u32);
fn build_store(&mut self, dst_mem: u32, src_reg: u8);
fn build_input(&mut self, out_reg: u8, src_arg: u8);
fn build_copy(&mut self, out_reg: u8, lhs_reg: u8);
fn build_neg(&mut self, out_reg: u8, lhs_reg: u8);
fn build_abs(&mut self, out_reg: u8, lhs_reg: u8);
fn build_recip(&mut self, out_reg: u8, lhs_reg: u8);
fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8);
fn build_sin(&mut self, out_reg: u8, lhs_reg: u8);
fn build_cos(&mut self, out_reg: u8, lhs_reg: u8);
fn build_tan(&mut self, out_reg: u8, lhs_reg: u8);
fn build_asin(&mut self, out_reg: u8, lhs_reg: u8);
fn build_acos(&mut self, out_reg: u8, lhs_reg: u8);
fn build_atan(&mut self, out_reg: u8, lhs_reg: u8);
fn build_exp(&mut self, out_reg: u8, lhs_reg: u8);
fn build_ln(&mut self, out_reg: u8, lhs_reg: u8);
fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
self.build_mul(out_reg, lhs_reg, lhs_reg)
}
fn build_floor(&mut self, out_reg: u8, lhs_reg: u8);
fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8);
fn build_round(&mut self, out_reg: u8, lhs_reg: u8);
fn build_not(&mut self, out_reg: u8, lhs_reg: u8);
fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
fn build_add_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_add(out_reg, lhs_reg, imm);
}
fn build_sub_imm_reg(&mut self, out_reg: u8, arg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_sub(out_reg, imm, arg);
}
fn build_sub_reg_imm(&mut self, out_reg: u8, arg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_sub(out_reg, arg, imm);
}
fn build_mul_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
let imm = self.load_imm(imm);
self.build_mul(out_reg, lhs_reg, imm);
}
fn load_imm(&mut self, imm: f32) -> u8;
fn finalize(self, out_reg: u8) -> Result<Mmap, Error>;
}
pub trait SimdSize {
const SIMD_SIZE: usize;
}
pub(crate) struct AssemblerData<T> {
ops: MmapAssembler,
mem_offset: usize,
saved_callee_regs: bool,
_p: std::marker::PhantomData<*const T>,
}
impl<T> AssemblerData<T> {
fn new(mmap: Mmap) -> Self {
Self {
ops: MmapAssembler::from(mmap),
mem_offset: 0,
saved_callee_regs: false,
_p: std::marker::PhantomData,
}
}
fn prepare_stack(&mut self, slot_count: usize, stack_size: usize) {
let mem = slot_count.saturating_sub(REGISTER_LIMIT)
* std::mem::size_of::<T>()
+ stack_size;
self.mem_offset = ((mem + 15) / 16) * 16;
self.push_stack();
}
#[cfg(target_arch = "aarch64")]
fn push_stack(&mut self) {
assert!(self.mem_offset < 4096);
dynasm!(self.ops
; sub sp, sp, self.mem_offset as u32
);
}
#[cfg(target_arch = "x86_64")]
fn push_stack(&mut self) {
dynasm!(self.ops
; sub rsp, self.mem_offset as i32
);
}
fn stack_pos(&self, slot: u32) -> u32 {
assert!(slot >= REGISTER_LIMIT as u32);
(slot - REGISTER_LIMIT as u32) * std::mem::size_of::<T>() as u32
}
}
#[cfg(target_arch = "x86_64")]
type Relocation = dynasmrt::x64::X64Relocation;
#[cfg(target_arch = "aarch64")]
type Relocation = dynasmrt::aarch64::Aarch64Relocation;
struct MmapAssembler {
mmap: Mmap,
len: usize,
global_labels: [Option<AssemblyOffset>; 26],
local_labels: [Option<AssemblyOffset>; 26],
global_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 2>,
local_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 8>,
}
impl Extend<u8> for MmapAssembler {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = u8>,
{
for c in iter.into_iter() {
self.push(c);
}
}
}
impl<'a> Extend<&'a u8> for MmapAssembler {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = &'a u8>,
{
for c in iter.into_iter() {
self.push(*c);
}
}
}
impl DynasmApi for MmapAssembler {
#[inline(always)]
fn offset(&self) -> AssemblyOffset {
AssemblyOffset(self.len)
}
#[inline(always)]
fn push(&mut self, byte: u8) {
if self.len >= self.mmap.len() {
self.expand_mmap();
}
self.mmap.write(self.len, byte);
self.len += 1;
}
#[inline(always)]
fn align(&mut self, alignment: usize, with: u8) {
let offset = self.offset().0 % alignment;
if offset != 0 {
for _ in offset..alignment {
self.push(with);
}
}
}
#[inline(always)]
fn push_u32(&mut self, value: u32) {
if self.len + 3 >= self.mmap.len() {
self.expand_mmap();
}
for (i, b) in value.to_le_bytes().iter().enumerate() {
self.mmap.write(self.len + i, *b);
}
self.len += 4;
}
}
impl DynasmLabelApi for MmapAssembler {
type Relocation = Relocation;
fn local_label(&mut self, name: &'static str) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.local_labels[c as usize].is_some() {
panic!("duplicate local label {name}");
}
self.local_labels[c as usize] = Some(self.offset());
}
fn global_label(&mut self, name: &'static str) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.global_labels[c as usize].is_some() {
panic!("duplicate global label {name}");
}
self.global_labels[c as usize] = Some(self.offset());
}
fn dynamic_label(&mut self, _id: DynamicLabel) {
panic!("dynamic labels are not supported");
}
fn global_relocation(
&mut self,
name: &'static str,
target_offset: isize,
field_offset: u8,
ref_offset: u8,
kind: Relocation,
) {
let location = self.offset();
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
self.global_relocs.push((
PatchLoc::new(
location,
target_offset,
field_offset,
ref_offset,
kind,
),
c,
));
}
fn dynamic_relocation(
&mut self,
_id: DynamicLabel,
_target_offset: isize,
_field_offset: u8,
_ref_offset: u8,
_kind: Relocation,
) {
panic!("dynamic relocations are not supported");
}
fn forward_relocation(
&mut self,
name: &'static str,
target_offset: isize,
field_offset: u8,
ref_offset: u8,
kind: Relocation,
) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.local_labels[c as usize].is_some() {
panic!("invalid forward relocation: {name} already exists!");
}
let location = self.offset();
self.local_relocs.push((
PatchLoc::new(
location,
target_offset,
field_offset,
ref_offset,
kind,
),
c,
));
}
fn backward_relocation(
&mut self,
name: &'static str,
target_offset: isize,
field_offset: u8,
ref_offset: u8,
kind: Relocation,
) {
if name.len() != 1 {
panic!("local label must be a single character");
}
let c = name.as_bytes()[0].wrapping_sub(b'A');
if c >= 26 {
panic!("Invalid label {name}, must be A-Z");
}
if self.local_labels[c as usize].is_none() {
panic!("invalid backward relocation: {name} does not exist");
}
let location = self.offset();
self.local_relocs.push((
PatchLoc::new(
location,
target_offset,
field_offset,
ref_offset,
kind,
),
c,
));
}
fn bare_relocation(
&mut self,
_target: usize,
_field_offset: u8,
_ref_offset: u8,
_kind: Relocation,
) {
panic!("bare relocations not implemented");
}
}
impl MmapAssembler {
fn commit_local(&mut self) -> Result<(), Error> {
let baseaddr = self.mmap.as_ptr() as usize;
for (loc, label) in self.local_relocs.take() {
let target =
self.local_labels[label as usize].expect("invalid local label");
let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
if loc.patch(buf, baseaddr, target.0).is_err() {
return Err(DynasmError::ImpossibleRelocation(
TargetKind::Local("oh no"),
)
.into());
}
}
self.local_labels = [None; 26];
Ok(())
}
fn finalize(mut self) -> Result<Mmap, Error> {
self.commit_local()?;
let baseaddr = self.mmap.as_ptr() as usize;
for (loc, label) in self.global_relocs.take() {
let target =
self.global_labels.get(label as usize).unwrap().unwrap();
let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
if loc.patch(buf, baseaddr, target.0).is_err() {
return Err(DynasmError::ImpossibleRelocation(
TargetKind::Global("oh no"),
)
.into());
}
}
self.mmap.finalize(self.len);
Ok(self.mmap)
}
fn expand_mmap(&mut self) {
let mut next = Mmap::new(self.mmap.len() * 2).unwrap();
next.as_mut_slice()[0..self.len].copy_from_slice(self.mmap.as_slice());
std::mem::swap(&mut self.mmap, &mut next);
}
}
impl From<Mmap> for MmapAssembler {
fn from(mmap: Mmap) -> Self {
Self {
mmap,
len: 0,
global_labels: [None; 26],
local_labels: [None; 26],
global_relocs: Default::default(),
local_relocs: Default::default(),
}
}
}
fn build_asm_fn_with_storage<A: Assembler>(
t: &VmData<REGISTER_LIMIT>,
mut s: Mmap,
) -> Mmap {
#[cfg(target_os = "macos")]
let _guard = Mmap::thread_mode_write();
let size_estimate = t.len() * A::bytes_per_clause();
if size_estimate > 2 * s.len() {
s = Mmap::new(size_estimate).expect("failed to build mmap")
}
let mut asm = A::init(s, t.slot_count());
for op in t.iter_asm() {
match op {
RegOp::Load(reg, mem) => {
asm.build_load(reg, mem);
}
RegOp::Store(reg, mem) => {
asm.build_store(mem, reg);
}
RegOp::Input(out, i) => {
asm.build_input(out, i);
}
RegOp::NegReg(out, arg) => {
asm.build_neg(out, arg);
}
RegOp::AbsReg(out, arg) => {
asm.build_abs(out, arg);
}
RegOp::RecipReg(out, arg) => {
asm.build_recip(out, arg);
}
RegOp::SqrtReg(out, arg) => {
asm.build_sqrt(out, arg);
}
RegOp::SinReg(out, arg) => {
asm.build_sin(out, arg);
}
RegOp::CosReg(out, arg) => {
asm.build_cos(out, arg);
}
RegOp::TanReg(out, arg) => {
asm.build_tan(out, arg);
}
RegOp::AsinReg(out, arg) => {
asm.build_asin(out, arg);
}
RegOp::AcosReg(out, arg) => {
asm.build_acos(out, arg);
}
RegOp::AtanReg(out, arg) => {
asm.build_atan(out, arg);
}
RegOp::ExpReg(out, arg) => {
asm.build_exp(out, arg);
}
RegOp::LnReg(out, arg) => {
asm.build_ln(out, arg);
}
RegOp::CopyReg(out, arg) => {
asm.build_copy(out, arg);
}
RegOp::SquareReg(out, arg) => {
asm.build_square(out, arg);
}
RegOp::FloorReg(out, arg) => {
asm.build_floor(out, arg);
}
RegOp::CeilReg(out, arg) => {
asm.build_ceil(out, arg);
}
RegOp::RoundReg(out, arg) => {
asm.build_round(out, arg);
}
RegOp::NotReg(out, arg) => {
asm.build_not(out, arg);
}
RegOp::AddRegReg(out, lhs, rhs) => {
asm.build_add(out, lhs, rhs);
}
RegOp::MulRegReg(out, lhs, rhs) => {
asm.build_mul(out, lhs, rhs);
}
RegOp::DivRegReg(out, lhs, rhs) => {
asm.build_div(out, lhs, rhs);
}
RegOp::AtanRegReg(out, lhs, rhs) => {
asm.build_atan2(out, lhs, rhs);
}
RegOp::SubRegReg(out, lhs, rhs) => {
asm.build_sub(out, lhs, rhs);
}
RegOp::MinRegReg(out, lhs, rhs) => {
asm.build_min(out, lhs, rhs);
}
RegOp::MaxRegReg(out, lhs, rhs) => {
asm.build_max(out, lhs, rhs);
}
RegOp::AddRegImm(out, arg, imm) => {
asm.build_add_imm(out, arg, imm);
}
RegOp::MulRegImm(out, arg, imm) => {
asm.build_mul_imm(out, arg, imm);
}
RegOp::DivRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_div(out, arg, reg);
}
RegOp::DivImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_div(out, reg, arg);
}
RegOp::AtanRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_atan2(out, arg, reg);
}
RegOp::AtanImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_atan2(out, reg, arg);
}
RegOp::SubImmReg(out, arg, imm) => {
asm.build_sub_imm_reg(out, arg, imm);
}
RegOp::SubRegImm(out, arg, imm) => {
asm.build_sub_reg_imm(out, arg, imm);
}
RegOp::MinRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_min(out, arg, reg);
}
RegOp::MaxRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_max(out, arg, reg);
}
RegOp::ModRegReg(out, lhs, rhs) => {
asm.build_mod(out, lhs, rhs);
}
RegOp::ModRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_mod(out, arg, reg);
}
RegOp::ModImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_mod(out, reg, arg);
}
RegOp::AndRegReg(out, lhs, rhs) => {
asm.build_and(out, lhs, rhs);
}
RegOp::AndRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_and(out, arg, reg);
}
RegOp::OrRegReg(out, lhs, rhs) => {
asm.build_or(out, lhs, rhs);
}
RegOp::OrRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_or(out, arg, reg);
}
RegOp::CopyImm(out, imm) => {
let reg = asm.load_imm(imm);
asm.build_copy(out, reg);
}
RegOp::CompareRegReg(out, lhs, rhs) => {
asm.build_compare(out, lhs, rhs);
}
RegOp::CompareRegImm(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_compare(out, arg, reg);
}
RegOp::CompareImmReg(out, arg, imm) => {
let reg = asm.load_imm(imm);
asm.build_compare(out, reg, arg);
}
}
}
asm.finalize(0).expect("failed to build JIT function")
}
#[derive(Clone)]
pub struct JitShape(GenericVmShape<REGISTER_LIMIT>);
impl JitShape {
fn tracing_tape<A: Assembler>(
&self,
storage: Mmap,
) -> JitTracingFn<A::Data> {
let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
let ptr = f.as_ptr();
JitTracingFn {
mmap: f,
var_count: self.0.var_count(),
choice_count: self.0.choice_count(),
fn_trace: unsafe { std::mem::transmute(ptr) },
}
}
fn bulk_tape<A: Assembler>(&self, storage: Mmap) -> JitBulkFn<A::Data> {
let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
let ptr = f.as_ptr();
JitBulkFn {
mmap: f,
var_count: self.0.data().var_count(),
fn_bulk: unsafe { std::mem::transmute(ptr) },
}
}
}
impl Shape for JitShape {
type Trace = VmTrace;
type Storage = VmData<REGISTER_LIMIT>;
type Workspace = VmWorkspace<REGISTER_LIMIT>;
type TapeStorage = Mmap;
type IntervalEval = JitIntervalEval;
type PointEval = JitPointEval;
type FloatSliceEval = JitFloatSliceEval;
type GradSliceEval = JitGradSliceEval;
fn point_tape(&self, storage: Mmap) -> JitTracingFn<f32> {
self.tracing_tape::<point::PointAssembler>(storage)
}
fn interval_tape(&self, storage: Mmap) -> JitTracingFn<Interval> {
self.tracing_tape::<interval::IntervalAssembler>(storage)
}
fn float_slice_tape(&self, storage: Mmap) -> JitBulkFn<f32> {
self.bulk_tape::<float_slice::FloatSliceAssembler>(storage)
}
fn grad_slice_tape(&self, storage: Mmap) -> JitBulkFn<Grad> {
self.bulk_tape::<grad_slice::GradSliceAssembler>(storage)
}
fn simplify(
&self,
trace: &Self::Trace,
storage: Self::Storage,
workspace: &mut Self::Workspace,
) -> Result<Self, Error> {
self.0
.simplify_inner(trace.as_slice(), storage, workspace)
.map(JitShape)
}
fn recycle(self) -> Option<Self::Storage> {
self.0.recycle()
}
fn size(&self) -> usize {
self.0.size()
}
type TransformedShape = TransformedShape<Self>;
fn apply_transform(self, mat: Matrix4<f32>) -> Self::TransformedShape {
TransformedShape::new(self, mat)
}
}
impl RenderHints for JitShape {
fn tile_sizes_3d() -> &'static [usize] {
&[64, 16, 8]
}
fn tile_sizes_2d() -> &'static [usize] {
&[128, 16]
}
fn simplify_tree_during_meshing(d: usize) -> bool {
d % 8 == 4
}
}
#[cfg(target_arch = "x86_64")]
macro_rules! jit_fn {
(unsafe fn($($args:tt)*) -> $($out:tt)*) => {
unsafe extern "sysv64" fn($($args)*) -> $($out)*
};
}
#[cfg(target_arch = "aarch64")]
macro_rules! jit_fn {
(unsafe fn($($args:tt)*) -> $($out:tt)*) => {
unsafe extern "C" fn($($args)*) -> $($out)*
};
}
#[derive(Default)]
struct JitTracingEval {
choices: VmTrace,
}
pub struct JitTracingFn<T> {
#[allow(unused)]
mmap: Mmap,
choice_count: usize,
var_count: usize,
fn_trace: jit_fn!(
unsafe fn(
*const T, *mut u8, *mut u8, ) -> T
),
}
impl<T> Tape for JitTracingFn<T> {
type Storage = Mmap;
fn recycle(self) -> Self::Storage {
self.mmap
}
}
unsafe impl<T> Send for JitTracingFn<T> {}
unsafe impl<T> Sync for JitTracingFn<T> {}
impl JitTracingEval {
fn eval<T: From<f32>, F: Into<T>>(
&mut self,
tape: &JitTracingFn<T>,
x: F,
y: F,
z: F,
) -> (T, Option<&VmTrace>) {
let x = x.into();
let y = y.into();
let z = z.into();
let mut simplify = 0;
self.choices.resize(tape.choice_count, Choice::Unknown);
assert!(tape.var_count <= 3);
self.choices.fill(Choice::Unknown);
let vars = [x, y, z];
let out = unsafe {
(tape.fn_trace)(
vars.as_ptr(),
self.choices.as_mut_ptr() as *mut u8,
&mut simplify,
)
};
(
out,
if simplify != 0 {
Some(&self.choices)
} else {
None
},
)
}
}
#[derive(Default)]
pub struct JitIntervalEval(JitTracingEval);
impl TracingEvaluator for JitIntervalEval {
type Data = Interval;
type Tape = JitTracingFn<Interval>;
type Trace = VmTrace;
type TapeStorage = Mmap;
fn eval<F: Into<Self::Data>>(
&mut self,
tape: &Self::Tape,
x: F,
y: F,
z: F,
) -> Result<(Self::Data, Option<&Self::Trace>), Error> {
Ok(self.0.eval(tape, x, y, z))
}
}
#[derive(Default)]
pub struct JitPointEval(JitTracingEval);
impl TracingEvaluator for JitPointEval {
type Data = f32;
type Tape = JitTracingFn<f32>;
type Trace = VmTrace;
type TapeStorage = Mmap;
fn eval<F: Into<Self::Data>>(
&mut self,
tape: &Self::Tape,
x: F,
y: F,
z: F,
) -> Result<(Self::Data, Option<&Self::Trace>), Error> {
Ok(self.0.eval(tape, x, y, z))
}
}
pub struct JitBulkFn<T> {
#[allow(unused)]
mmap: Mmap,
var_count: usize,
fn_bulk: jit_fn!(
unsafe fn(
*const *const T, *mut T, u64, ) -> T
),
}
impl<T> Tape for JitBulkFn<T> {
type Storage = Mmap;
fn recycle(self) -> Self::Storage {
self.mmap
}
}
struct JitBulkEval<T> {
out: Vec<T>,
}
impl<T> Default for JitBulkEval<T> {
fn default() -> Self {
Self { out: vec![] }
}
}
unsafe impl<T> Send for JitBulkFn<T> {}
unsafe impl<T> Sync for JitBulkFn<T> {}
impl<T: From<f32> + Copy + SimdSize> JitBulkEval<T> {
fn eval(
&mut self,
tape: &JitBulkFn<T>,
xs: &[T],
ys: &[T],
zs: &[T],
) -> &[T] {
assert!(tape.var_count <= 3);
let n = xs.len();
self.out.resize(n, f32::NAN.into());
self.out.fill(f32::NAN.into());
if n < T::SIMD_SIZE {
const MAX_SIMD_WIDTH: usize = 8;
let mut x = [T::from(0.0); MAX_SIMD_WIDTH];
let mut y = [T::from(0.0); MAX_SIMD_WIDTH];
let mut z = [T::from(0.0); MAX_SIMD_WIDTH];
assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH);
x[0..n].copy_from_slice(xs);
y[0..n].copy_from_slice(ys);
z[0..n].copy_from_slice(zs);
let mut tmp = [f32::NAN.into(); MAX_SIMD_WIDTH];
let vars = [x.as_ptr(), y.as_ptr(), z.as_ptr()];
unsafe {
(tape.fn_bulk)(
vars.as_ptr(),
tmp.as_mut_ptr(),
T::SIMD_SIZE as u64,
);
}
self.out.copy_from_slice(&tmp[0..n]);
} else {
let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; let vars = [xs.as_ptr(), ys.as_ptr(), zs.as_ptr()];
unsafe {
(tape.fn_bulk)(vars.as_ptr(), self.out.as_mut_ptr(), m as u64);
}
if n != m {
unsafe {
let vars = [
xs.as_ptr().add(n - T::SIMD_SIZE),
ys.as_ptr().add(n - T::SIMD_SIZE),
zs.as_ptr().add(n - T::SIMD_SIZE),
];
(tape.fn_bulk)(
vars.as_ptr(),
self.out.as_mut_ptr().add(n - T::SIMD_SIZE),
T::SIMD_SIZE as u64,
);
}
}
}
&self.out
}
}
#[derive(Default)]
pub struct JitFloatSliceEval(JitBulkEval<f32>);
impl BulkEvaluator for JitFloatSliceEval {
type Data = f32;
type Tape = JitBulkFn<Self::Data>;
type TapeStorage = Mmap;
fn eval(
&mut self,
tape: &Self::Tape,
xs: &[f32],
ys: &[f32],
zs: &[f32],
) -> Result<&[Self::Data], Error> {
self.check_arguments(xs, ys, zs, tape.var_count)?;
Ok(self.0.eval(tape, xs, ys, zs))
}
}
#[derive(Default)]
pub struct JitGradSliceEval(JitBulkEval<Grad>);
impl BulkEvaluator for JitGradSliceEval {
type Data = Grad;
type Tape = JitBulkFn<Self::Data>;
type TapeStorage = Mmap;
fn eval(
&mut self,
tape: &Self::Tape,
xs: &[Self::Data],
ys: &[Self::Data],
zs: &[Self::Data],
) -> Result<&[Self::Data], Error> {
self.check_arguments(xs, ys, zs, tape.var_count)?;
Ok(self.0.eval(tape, xs, ys, zs))
}
}
impl MathShape for JitShape {
fn new(ctx: &Context, node: Node) -> Result<Self, Error> {
GenericVmShape::new(ctx, node).map(JitShape)
}
}
#[cfg(test)]
mod test {
use super::*;
crate::grad_slice_tests!(JitShape);
crate::interval_tests!(JitShape);
crate::float_slice_tests!(JitShape);
crate::point_tests!(JitShape);
}