Skip to main content

fidget_jit/
lib.rs

1//! Compilation down to native machine code
2//!
3//! Users are unlikely to use anything in this module other than [`JitFunction`],
4//! which is a [`Function`] that uses JIT evaluation.
5//!
6//! ```
7//! use fidget_core::{
8//!     context::Tree,
9//!     shape::EzShape,
10//! };
11//! use fidget_jit::JitShape;
12//!
13//! let tree = Tree::x() + Tree::y();
14//! let shape = JitShape::from(tree);
15//!
16//! // Generate machine code to execute the tape
17//! let tape = shape.ez_point_tape();
18//! let mut eval = JitShape::new_point_eval();
19//!
20//! // This calls directly into that machine code!
21//! let (r, _trace) = eval.eval(&tape, 0.1, 0.3, 0.0)?;
22//! assert_eq!(r, 0.1 + 0.3);
23//! # Ok::<(), fidget_core::Error>(())
24//! ```
25
26use crate::mmap::{Mmap, MmapWriter};
27use fidget_core::{
28    Error,
29    compiler::RegOp,
30    context::{Context, Node},
31    eval::{
32        BulkEvaluator, BulkOutput, Function, MathFunction, Tape,
33        TracingEvaluator,
34    },
35    render::{RenderHints, TileSizes},
36    types::{Grad, Interval},
37    var::VarMap,
38    vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace},
39};
40
41use dynasmrt::{
42    AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi,
43    TargetKind, components::PatchLoc, dynasm,
44};
45use std::sync::Arc;
46
47mod mmap;
48mod permit;
49pub(crate) use permit::WritePermit;
50
51// Evaluators
52mod float_slice;
53mod grad_slice;
54mod interval;
55mod point;
56
57#[cfg(not(any(
58    target_os = "linux",
59    target_os = "macos",
60    target_os = "windows"
61)))]
62compile_error!(
63    "The `jit` module only builds on Linux, macOS, and Windows; \
64    please disable the `jit` feature"
65);
66
67#[cfg(target_arch = "aarch64")]
68mod aarch64;
69#[cfg(target_arch = "aarch64")]
70use aarch64 as arch;
71
72#[cfg(target_arch = "x86_64")]
73mod x86_64;
74#[cfg(target_arch = "x86_64")]
75use x86_64 as arch;
76
77/// Number of registers available when executing natively
78const REGISTER_LIMIT: usize = arch::REGISTER_LIMIT;
79
80/// Offset before the first useable register
81const OFFSET: u8 = arch::OFFSET;
82
83/// Register written to by `CopyImm`
84///
85/// It is the responsibility of functions to avoid writing to `IMM_REG` in cases
86/// where it could be one of their arguments (i.e. all functions of 2 or more
87/// arguments).
88const IMM_REG: u8 = arch::IMM_REG;
89
90/// Converts from a tape-local register to a hardware register
91///
92/// Tape-local registers are in the range `0..REGISTER_LIMIT`, while ARM
93/// registers have an offset (based on calling convention).
94///
95/// This uses `wrapping_add` to support immediates, which are loaded into a
96/// register below [`OFFSET`] (which is "negative" from the perspective of this
97/// function).
98fn reg(r: u8) -> u8 {
99    let out = r.wrapping_add(OFFSET);
100    assert!(out < 32);
101    out
102}
103
104const CHOICE_LEFT: u32 = Choice::Left as u32;
105const CHOICE_RIGHT: u32 = Choice::Right as u32;
106const CHOICE_BOTH: u32 = Choice::Both as u32;
107
108/// Trait for generating machine assembly
109trait Assembler {
110    /// Data type used during evaluation.
111    ///
112    /// This should be a `repr(C)` type, so it can be passed around directly.
113    type Data;
114
115    /// Initializes the assembler with the given slot count
116    ///
117    /// This will likely construct a function prelude and reserve space on the
118    /// stack for slot spills.
119    fn init(m: Mmap, slot_count: usize) -> Self;
120
121    /// Returns an approximate bytes per clause value, used for preallocation
122    fn bytes_per_clause() -> usize {
123        8 // probably wrong!
124    }
125
126    /// Builds a load from memory to a register
127    fn build_load(&mut self, dst_reg: u8, src_mem: u32);
128
129    /// Builds a store from a register to a memory location
130    fn build_store(&mut self, dst_mem: u32, src_reg: u8);
131
132    /// Copies the given input to `out_reg`
133    fn build_input(&mut self, out_reg: u8, src_arg: u32);
134
135    /// Writes the argument register to the output
136    fn build_output(&mut self, arg_reg: u8, out_index: u32);
137
138    /// Copies a register
139    fn build_copy(&mut self, out_reg: u8, lhs_reg: u8);
140
141    /// Unary negation
142    fn build_neg(&mut self, out_reg: u8, lhs_reg: u8);
143
144    /// Absolute value
145    fn build_abs(&mut self, out_reg: u8, lhs_reg: u8);
146
147    /// Reciprocal (1 / `lhs_reg`)
148    fn build_recip(&mut self, out_reg: u8, lhs_reg: u8);
149
150    /// Square root
151    fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8);
152
153    /// Sine
154    fn build_sin(&mut self, out_reg: u8, lhs_reg: u8);
155
156    /// Cosine
157    fn build_cos(&mut self, out_reg: u8, lhs_reg: u8);
158
159    /// Tangent
160    fn build_tan(&mut self, out_reg: u8, lhs_reg: u8);
161
162    /// Arcsine
163    fn build_asin(&mut self, out_reg: u8, lhs_reg: u8);
164
165    /// Arccosine
166    fn build_acos(&mut self, out_reg: u8, lhs_reg: u8);
167
168    /// Arctangent
169    fn build_atan(&mut self, out_reg: u8, lhs_reg: u8);
170
171    /// Exponent
172    fn build_exp(&mut self, out_reg: u8, lhs_reg: u8);
173
174    /// Natural log
175    fn build_ln(&mut self, out_reg: u8, lhs_reg: u8);
176
177    /// Less than
178    fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
179
180    /// Square
181    ///
182    /// This has a default implementation, but can be overloaded for efficiency;
183    /// for example, in interval arithmetic, we benefit from knowing that both
184    /// values are the same.
185    fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
186        self.build_mul(out_reg, lhs_reg, lhs_reg)
187    }
188
189    /// Arithmetic floor
190    fn build_floor(&mut self, out_reg: u8, lhs_reg: u8);
191
192    /// Arithmetic ceiling
193    fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8);
194
195    /// Rounding
196    fn build_round(&mut self, out_reg: u8, lhs_reg: u8);
197
198    /// Logical not
199    fn build_not(&mut self, out_reg: u8, lhs_reg: u8);
200
201    /// Logical and (short-circuiting)
202    fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
203
204    /// Logical or (short-circuiting)
205    fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
206
207    /// Addition
208    fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
209
210    /// Subtraction
211    fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
212
213    /// Multiplication
214    fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
215
216    /// Division
217    fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
218
219    /// Four-quadrant arctangent
220    fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
221
222    /// Maximum of two values
223    ///
224    /// In a tracing evaluator, this function must also write to the `choices`
225    /// array and may set `simplify` if one branch is always taken.
226    fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
227
228    /// Minimum of two values
229    ///
230    /// In a tracing evaluator, this function must also write to the `choices`
231    /// array and may set `simplify` if one branch is always taken.
232    fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
233
234    /// Modulo of two values (least non-negative remainder)
235    fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
236
237    // Special-case functions for immediates.  In some cases, you can be more
238    // efficient if you know that an argument is an immediate (for example, both
239    // values in the interval will be the same, and it will have no gradients).
240
241    /// Builds a addition (immediate + register)
242    ///
243    /// This has a default implementation, but can be overloaded for efficiency
244    fn build_add_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
245        let imm = self.load_imm(imm);
246        self.build_add(out_reg, lhs_reg, imm);
247    }
248    /// Builds a subtraction (immediate − register)
249    ///
250    /// This has a default implementation, but can be overloaded for efficiency
251    fn build_sub_imm_reg(&mut self, out_reg: u8, arg: u8, imm: f32) {
252        let imm = self.load_imm(imm);
253        self.build_sub(out_reg, imm, arg);
254    }
255    /// Builds a subtraction (register − immediate)
256    ///
257    /// This has a default implementation, but can be overloaded for efficiency
258    fn build_sub_reg_imm(&mut self, out_reg: u8, arg: u8, imm: f32) {
259        let imm = self.load_imm(imm);
260        self.build_sub(out_reg, arg, imm);
261    }
262    /// Builds a multiplication (register × immediate)
263    ///
264    /// This has a default implementation, but can be overloaded for efficiency
265    fn build_mul_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
266        let imm = self.load_imm(imm);
267        self.build_mul(out_reg, lhs_reg, imm);
268    }
269
270    /// Loads an immediate into a register, returning that register
271    fn load_imm(&mut self, imm: f32) -> u8;
272
273    /// Finalize the assembly code, returning a memory-mapped region
274    fn finalize(self) -> Result<Mmap, DynasmError>;
275}
276
277/// Trait defining SIMD width
278pub trait SimdSize {
279    /// Number of elements processed in a single iteration
280    ///
281    /// This value is used when checking array sizes, as we want to be sure to
282    /// pass the JIT code an appropriately sized array.
283    const SIMD_SIZE: usize;
284}
285
286/////////////////////////////////////////////////////////////////////////////////////////
287
288pub(crate) struct AssemblerData<T> {
289    ops: MmapAssembler,
290
291    /// Current offset of the stack pointer, in bytes
292    mem_offset: usize,
293
294    /// Set to true if we have saved certain callee-saved registers
295    ///
296    /// These registers are only modified in function calls, so normally we
297    /// don't save them.
298    saved_callee_regs: bool,
299
300    _p: std::marker::PhantomData<*const T>,
301}
302
303impl<T> AssemblerData<T> {
304    fn new(mmap: Mmap) -> Self {
305        Self {
306            ops: MmapAssembler::from(mmap),
307            mem_offset: 0,
308            saved_callee_regs: false,
309            _p: std::marker::PhantomData,
310        }
311    }
312
313    fn prepare_stack(&mut self, slot_count: usize, stack_size: usize) {
314        // We always use the stack, if only to store callee-saved registers
315        let mem = slot_count.saturating_sub(REGISTER_LIMIT)
316            * std::mem::size_of::<T>()
317            + stack_size;
318
319        // Round up to the nearest multiple of 16 bytes, for alignment
320        self.mem_offset = mem.next_multiple_of(16);
321        self.push_stack();
322    }
323
324    fn stack_pos(&self, slot: u32) -> u32 {
325        assert!(slot >= REGISTER_LIMIT as u32);
326        (slot - REGISTER_LIMIT as u32) * std::mem::size_of::<T>() as u32
327    }
328}
329
330#[cfg(target_arch = "x86_64")]
331impl<T> AssemblerData<T> {
332    fn push_stack(&mut self) {
333        dynasm!(self.ops
334            ; sub rsp, self.mem_offset as i32
335        );
336    }
337
338    fn finalize(mut self) -> Result<Mmap, DynasmError> {
339        dynasm!(self.ops
340            ; add rsp, self.mem_offset as i32
341            ; pop rbp
342            ; vzeroupper
343            ; ret
344        );
345        self.ops.finalize()
346    }
347}
348
349#[cfg(target_arch = "aarch64")]
350#[allow(clippy::unnecessary_cast)] // dynasm-rs#106
351impl<T> AssemblerData<T> {
352    fn push_stack(&mut self) {
353        if self.mem_offset < 4096 {
354            dynasm!(self.ops
355                ; sub sp, sp, self.mem_offset as u32
356            );
357        } else if self.mem_offset < 65536 {
358            dynasm!(self.ops
359                ; mov w28, self.mem_offset as u32
360                ; sub sp, sp, w28
361            );
362        } else {
363            panic!("invalid mem offset: {} is too large", self.mem_offset);
364        }
365    }
366
367    fn finalize(mut self) -> Result<Mmap, DynasmError> {
368        // Fix up the stack
369        if self.mem_offset < 4096 {
370            dynasm!(self.ops
371                ; add sp, sp, self.mem_offset as u32
372            );
373        } else if self.mem_offset < 65536 {
374            dynasm!(self.ops
375                ; mov w9, self.mem_offset as u32
376                ; add sp, sp, w9
377            );
378        } else {
379            panic!("invalid mem offset: {}", self.mem_offset);
380        }
381
382        dynasm!(self.ops
383            ; ret
384        );
385        self.ops.finalize()
386    }
387}
388
389////////////////////////////////////////////////////////////////////////////////
390
391#[cfg(target_arch = "x86_64")]
392type Relocation = dynasmrt::x64::X64Relocation;
393
394#[cfg(target_arch = "aarch64")]
395type Relocation = dynasmrt::aarch64::Aarch64Relocation;
396
397struct MmapAssembler {
398    mmap: MmapWriter,
399
400    global_labels: [Option<AssemblyOffset>; 26],
401    local_labels: [Option<AssemblyOffset>; 26],
402
403    global_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 2>,
404    local_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 8>,
405}
406
407impl Extend<u8> for MmapAssembler {
408    fn extend<T>(&mut self, iter: T)
409    where
410        T: IntoIterator<Item = u8>,
411    {
412        for c in iter.into_iter() {
413            self.push(c);
414        }
415    }
416}
417
418impl<'a> Extend<&'a u8> for MmapAssembler {
419    fn extend<T>(&mut self, iter: T)
420    where
421        T: IntoIterator<Item = &'a u8>,
422    {
423        for c in iter.into_iter() {
424            self.push(*c);
425        }
426    }
427}
428
429impl DynasmApi for MmapAssembler {
430    #[inline(always)]
431    fn offset(&self) -> AssemblyOffset {
432        AssemblyOffset(self.mmap.len())
433    }
434
435    #[inline(always)]
436    fn push(&mut self, byte: u8) {
437        self.mmap.push(byte);
438    }
439
440    #[inline(always)]
441    fn align(&mut self, alignment: usize, with: u8) {
442        let offset = self.offset().0 % alignment;
443        if offset != 0 {
444            for _ in offset..alignment {
445                self.push(with);
446            }
447        }
448    }
449
450    #[inline(always)]
451    fn push_u32(&mut self, value: u32) {
452        for b in value.to_le_bytes() {
453            self.mmap.push(b);
454        }
455    }
456}
457
458/// This is a very limited implementation of the labels API.  Compared to the
459/// standard labels API, it has the following limitations:
460///
461/// - Labels must be a single character
462/// - Local labels must be committed before they're reused, using `commit_local`
463/// - Only 8 local jumps are available at any given time; this is reset when
464///   `commit_local` is called.  (if this becomes problematic, it can be
465///   increased by tweaking the size of `local_relocs: ArrayVec<..., 8>`.
466///
467/// In exchange for these limitations, it allocates no memory at runtime, and all
468/// label lookups are done in constant time.
469///
470/// However, it still has overhead compared to computing the jumps by hand;
471/// this overhead was roughly 5% in one unscientific test.
472impl DynasmLabelApi for MmapAssembler {
473    type Relocation = Relocation;
474
475    fn local_label(&mut self, name: &'static str) {
476        if name.len() != 1 {
477            panic!("local label must be a single character");
478        }
479        let c = name.as_bytes()[0].wrapping_sub(b'A');
480        if c >= 26 {
481            panic!("Invalid label {name}, must be A-Z");
482        }
483        if self.local_labels[c as usize].is_some() {
484            panic!("duplicate local label {name}");
485        }
486
487        self.local_labels[c as usize] = Some(self.offset());
488    }
489    fn global_label(&mut self, name: &'static str) {
490        if name.len() != 1 {
491            panic!("local label must be a single character");
492        }
493        let c = name.as_bytes()[0].wrapping_sub(b'A');
494        if c >= 26 {
495            panic!("Invalid label {name}, must be A-Z");
496        }
497        if self.global_labels[c as usize].is_some() {
498            panic!("duplicate global label {name}");
499        }
500
501        self.global_labels[c as usize] = Some(self.offset());
502    }
503    fn dynamic_label(&mut self, _id: DynamicLabel) {
504        panic!("dynamic labels are not supported");
505    }
506    fn global_relocation(
507        &mut self,
508        name: &'static str,
509        target_offset: isize,
510        field_offset: u8,
511        ref_offset: u8,
512        kind: Relocation,
513    ) {
514        let location = self.offset();
515        if name.len() != 1 {
516            panic!("local label must be a single character");
517        }
518        let c = name.as_bytes()[0].wrapping_sub(b'A');
519        if c >= 26 {
520            panic!("Invalid label {name}, must be A-Z");
521        }
522        self.global_relocs.push((
523            PatchLoc::new(
524                location,
525                target_offset,
526                field_offset,
527                ref_offset,
528                kind,
529            ),
530            c,
531        ));
532    }
533    fn dynamic_relocation(
534        &mut self,
535        _id: DynamicLabel,
536        _target_offset: isize,
537        _field_offset: u8,
538        _ref_offset: u8,
539        _kind: Relocation,
540    ) {
541        panic!("dynamic relocations are not supported");
542    }
543    fn forward_relocation(
544        &mut self,
545        name: &'static str,
546        target_offset: isize,
547        field_offset: u8,
548        ref_offset: u8,
549        kind: Relocation,
550    ) {
551        if name.len() != 1 {
552            panic!("local label must be a single character");
553        }
554        let c = name.as_bytes()[0].wrapping_sub(b'A');
555        if c >= 26 {
556            panic!("Invalid label {name}, must be A-Z");
557        }
558        if self.local_labels[c as usize].is_some() {
559            panic!("invalid forward relocation: {name} already exists!");
560        }
561        let location = self.offset();
562        self.local_relocs.push((
563            PatchLoc::new(
564                location,
565                target_offset,
566                field_offset,
567                ref_offset,
568                kind,
569            ),
570            c,
571        ));
572    }
573    fn backward_relocation(
574        &mut self,
575        name: &'static str,
576        target_offset: isize,
577        field_offset: u8,
578        ref_offset: u8,
579        kind: Relocation,
580    ) {
581        if name.len() != 1 {
582            panic!("local label must be a single character");
583        }
584        let c = name.as_bytes()[0].wrapping_sub(b'A');
585        if c >= 26 {
586            panic!("Invalid label {name}, must be A-Z");
587        }
588        if self.local_labels[c as usize].is_none() {
589            panic!("invalid backward relocation: {name} does not exist");
590        }
591        let location = self.offset();
592        self.local_relocs.push((
593            PatchLoc::new(
594                location,
595                target_offset,
596                field_offset,
597                ref_offset,
598                kind,
599            ),
600            c,
601        ));
602    }
603    fn bare_relocation(
604        &mut self,
605        _target: usize,
606        _field_offset: u8,
607        _ref_offset: u8,
608        _kind: Relocation,
609    ) {
610        panic!("bare relocations not implemented");
611    }
612}
613
614impl MmapAssembler {
615    /// Applies all local relocations, clearing the `local_relocs` array
616    ///
617    /// This should be called after any function which uses local labels.
618    fn commit_local(&mut self) -> Result<(), DynasmError> {
619        let baseaddr = self.mmap.as_ptr() as usize;
620
621        for (loc, label) in self.local_relocs.take() {
622            let target =
623                self.local_labels[label as usize].expect("invalid local label");
624            let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
625            if loc.patch(buf, baseaddr, target.0).is_err() {
626                return Err(DynasmError::ImpossibleRelocation(
627                    TargetKind::Local("oh no"),
628                ));
629            }
630        }
631        self.local_labels = [None; 26];
632        Ok(())
633    }
634
635    fn finalize(mut self) -> Result<Mmap, DynasmError> {
636        self.commit_local()?;
637
638        let baseaddr = self.mmap.as_ptr() as usize;
639        for (loc, label) in self.global_relocs.take() {
640            let target =
641                self.global_labels.get(label as usize).unwrap().unwrap();
642            let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
643            if loc.patch(buf, baseaddr, target.0).is_err() {
644                return Err(DynasmError::ImpossibleRelocation(
645                    TargetKind::Global("oh no"),
646                ));
647            }
648        }
649
650        Ok(self.mmap.finalize())
651    }
652}
653
654impl From<Mmap> for MmapAssembler {
655    fn from(mmap: Mmap) -> Self {
656        Self {
657            mmap: MmapWriter::from(mmap),
658            global_labels: [None; 26],
659            local_labels: [None; 26],
660            global_relocs: Default::default(),
661            local_relocs: Default::default(),
662        }
663    }
664}
665
666/////////////////////////////////////////////////////////////////////////////////////////
667
668fn build_asm_fn_with_storage<A: Assembler>(
669    t: &VmData<REGISTER_LIMIT>,
670    mut s: Mmap,
671) -> Mmap {
672    let size_estimate = t.len() * A::bytes_per_clause();
673    if size_estimate > 2 * s.capacity() {
674        s = Mmap::new(size_estimate).expect("failed to build mmap")
675    }
676
677    let mut asm = A::init(s, t.slot_count());
678
679    for op in t.iter_asm() {
680        match op {
681            RegOp::Load(reg, mem) => {
682                asm.build_load(reg, mem);
683            }
684            RegOp::Store(reg, mem) => {
685                asm.build_store(mem, reg);
686            }
687            RegOp::Input(out, i) => {
688                asm.build_input(out, i);
689            }
690            RegOp::Output(arg, i) => {
691                asm.build_output(arg, i);
692            }
693            RegOp::NegReg(out, arg) => {
694                asm.build_neg(out, arg);
695            }
696            RegOp::AbsReg(out, arg) => {
697                asm.build_abs(out, arg);
698            }
699            RegOp::RecipReg(out, arg) => {
700                asm.build_recip(out, arg);
701            }
702            RegOp::SqrtReg(out, arg) => {
703                asm.build_sqrt(out, arg);
704            }
705            RegOp::SinReg(out, arg) => {
706                asm.build_sin(out, arg);
707            }
708            RegOp::CosReg(out, arg) => {
709                asm.build_cos(out, arg);
710            }
711            RegOp::TanReg(out, arg) => {
712                asm.build_tan(out, arg);
713            }
714            RegOp::AsinReg(out, arg) => {
715                asm.build_asin(out, arg);
716            }
717            RegOp::AcosReg(out, arg) => {
718                asm.build_acos(out, arg);
719            }
720            RegOp::AtanReg(out, arg) => {
721                asm.build_atan(out, arg);
722            }
723            RegOp::ExpReg(out, arg) => {
724                asm.build_exp(out, arg);
725            }
726            RegOp::LnReg(out, arg) => {
727                asm.build_ln(out, arg);
728            }
729            RegOp::CopyReg(out, arg) => {
730                asm.build_copy(out, arg);
731            }
732            RegOp::SquareReg(out, arg) => {
733                asm.build_square(out, arg);
734            }
735            RegOp::FloorReg(out, arg) => {
736                asm.build_floor(out, arg);
737            }
738            RegOp::CeilReg(out, arg) => {
739                asm.build_ceil(out, arg);
740            }
741            RegOp::RoundReg(out, arg) => {
742                asm.build_round(out, arg);
743            }
744            RegOp::NotReg(out, arg) => {
745                asm.build_not(out, arg);
746            }
747            RegOp::AddRegReg(out, lhs, rhs) => {
748                asm.build_add(out, lhs, rhs);
749            }
750            RegOp::MulRegReg(out, lhs, rhs) => {
751                asm.build_mul(out, lhs, rhs);
752            }
753            RegOp::DivRegReg(out, lhs, rhs) => {
754                asm.build_div(out, lhs, rhs);
755            }
756            RegOp::AtanRegReg(out, lhs, rhs) => {
757                asm.build_atan2(out, lhs, rhs);
758            }
759            RegOp::SubRegReg(out, lhs, rhs) => {
760                asm.build_sub(out, lhs, rhs);
761            }
762            RegOp::MinRegReg(out, lhs, rhs) => {
763                asm.build_min(out, lhs, rhs);
764            }
765            RegOp::MaxRegReg(out, lhs, rhs) => {
766                asm.build_max(out, lhs, rhs);
767            }
768            RegOp::AddRegImm(out, arg, imm) => {
769                asm.build_add_imm(out, arg, imm);
770            }
771            RegOp::MulRegImm(out, arg, imm) => {
772                asm.build_mul_imm(out, arg, imm);
773            }
774            RegOp::DivRegImm(out, arg, imm) => {
775                let reg = asm.load_imm(imm);
776                asm.build_div(out, arg, reg);
777            }
778            RegOp::DivImmReg(out, arg, imm) => {
779                let reg = asm.load_imm(imm);
780                asm.build_div(out, reg, arg);
781            }
782            RegOp::AtanRegImm(out, arg, imm) => {
783                let reg = asm.load_imm(imm);
784                asm.build_atan2(out, arg, reg);
785            }
786            RegOp::AtanImmReg(out, arg, imm) => {
787                let reg = asm.load_imm(imm);
788                asm.build_atan2(out, reg, arg);
789            }
790            RegOp::SubImmReg(out, arg, imm) => {
791                asm.build_sub_imm_reg(out, arg, imm);
792            }
793            RegOp::SubRegImm(out, arg, imm) => {
794                asm.build_sub_reg_imm(out, arg, imm);
795            }
796            RegOp::MinRegImm(out, arg, imm) => {
797                let reg = asm.load_imm(imm);
798                asm.build_min(out, arg, reg);
799            }
800            RegOp::MaxRegImm(out, arg, imm) => {
801                let reg = asm.load_imm(imm);
802                asm.build_max(out, arg, reg);
803            }
804            RegOp::ModRegReg(out, lhs, rhs) => {
805                asm.build_mod(out, lhs, rhs);
806            }
807            RegOp::ModRegImm(out, arg, imm) => {
808                let reg = asm.load_imm(imm);
809                asm.build_mod(out, arg, reg);
810            }
811            RegOp::ModImmReg(out, arg, imm) => {
812                let reg = asm.load_imm(imm);
813                asm.build_mod(out, reg, arg);
814            }
815            RegOp::AndRegReg(out, lhs, rhs) => {
816                asm.build_and(out, lhs, rhs);
817            }
818            RegOp::AndRegImm(out, arg, imm) => {
819                let reg = asm.load_imm(imm);
820                asm.build_and(out, arg, reg);
821            }
822            RegOp::OrRegReg(out, lhs, rhs) => {
823                asm.build_or(out, lhs, rhs);
824            }
825            RegOp::OrRegImm(out, arg, imm) => {
826                let reg = asm.load_imm(imm);
827                asm.build_or(out, arg, reg);
828            }
829            RegOp::CopyImm(out, imm) => {
830                let reg = asm.load_imm(imm);
831                asm.build_copy(out, reg);
832            }
833            RegOp::CompareRegReg(out, lhs, rhs) => {
834                asm.build_compare(out, lhs, rhs);
835            }
836            RegOp::CompareRegImm(out, arg, imm) => {
837                let reg = asm.load_imm(imm);
838                asm.build_compare(out, arg, reg);
839            }
840            RegOp::CompareImmReg(out, arg, imm) => {
841                let reg = asm.load_imm(imm);
842                asm.build_compare(out, reg, arg);
843            }
844        }
845    }
846
847    asm.finalize().expect("failed to build JIT function")
848    // JIT execute mode is restored here when the _guard is dropped
849}
850
851/// Function for use with a JIT evaluator
852#[derive(Clone)]
853pub struct JitFunction(GenericVmFunction<REGISTER_LIMIT>);
854
855impl JitFunction {
856    fn tracing_tape<A: Assembler>(
857        &self,
858        storage: Mmap,
859    ) -> JitTracingFn<A::Data> {
860        let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
861        let ptr = f.as_ptr();
862        JitTracingFn {
863            mmap: f.into(),
864            vars: self.0.data().vars.clone(),
865            choice_count: self.0.choice_count(),
866            output_count: self.0.output_count(),
867            fn_trace: unsafe {
868                std::mem::transmute::<
869                    *const std::ffi::c_void,
870                    JitTracingFnPointer<A::Data>,
871                >(ptr)
872            },
873        }
874    }
875    fn bulk_tape<A: Assembler>(&self, storage: Mmap) -> JitBulkFn<A::Data> {
876        let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
877        let ptr = f.as_ptr();
878        JitBulkFn {
879            mmap: f.into(),
880            output_count: self.0.output_count(),
881            vars: self.0.data().vars.clone(),
882            fn_bulk: unsafe {
883                std::mem::transmute::<
884                    *const std::ffi::c_void,
885                    JitBulkFnPointer<A::Data>,
886                >(ptr)
887            },
888        }
889    }
890}
891
892impl Function for JitFunction {
893    type Trace = VmTrace;
894    type Storage = VmData<REGISTER_LIMIT>;
895    type Workspace = VmWorkspace<REGISTER_LIMIT>;
896
897    type TapeStorage = Mmap;
898
899    type IntervalEval = JitIntervalEval;
900    type PointEval = JitPointEval;
901    type FloatSliceEval = JitFloatSliceEval;
902    type GradSliceEval = JitGradSliceEval;
903
904    #[inline]
905    fn point_tape(&self, storage: Mmap) -> JitTracingFn<f32> {
906        self.tracing_tape::<point::PointAssembler>(storage)
907    }
908
909    #[inline]
910    fn interval_tape(&self, storage: Mmap) -> JitTracingFn<Interval> {
911        self.tracing_tape::<interval::IntervalAssembler>(storage)
912    }
913
914    #[inline]
915    fn float_slice_tape(&self, storage: Mmap) -> JitBulkFn<f32> {
916        self.bulk_tape::<float_slice::FloatSliceAssembler>(storage)
917    }
918
919    #[inline]
920    fn grad_slice_tape(&self, storage: Mmap) -> JitBulkFn<Grad> {
921        self.bulk_tape::<grad_slice::GradSliceAssembler>(storage)
922    }
923
924    #[inline]
925    fn simplify(
926        &self,
927        trace: &Self::Trace,
928        storage: Self::Storage,
929        workspace: &mut Self::Workspace,
930    ) -> Result<Self, Error> {
931        self.0.simplify(trace, storage, workspace).map(JitFunction)
932    }
933
934    #[inline]
935    fn recycle(self) -> Option<Self::Storage> {
936        self.0.recycle()
937    }
938
939    #[inline]
940    fn size(&self) -> usize {
941        self.0.size()
942    }
943
944    #[inline]
945    fn vars(&self) -> &VarMap {
946        self.0.vars()
947    }
948
949    #[inline]
950    fn can_simplify(&self) -> bool {
951        self.0.choice_count() > 0
952    }
953}
954
955impl RenderHints for JitFunction {
956    fn tile_sizes_3d() -> TileSizes {
957        TileSizes::new(&[64, 16, 8]).unwrap()
958    }
959
960    fn tile_sizes_2d() -> TileSizes {
961        TileSizes::new(&[128, 16]).unwrap()
962    }
963
964    fn simplify_tree_during_meshing(d: usize) -> bool {
965        // Unscientifically selected, but similar to tile_sizes_3d
966        d % 8 == 4
967    }
968}
969
970impl MathFunction for JitFunction {
971    fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
972        GenericVmFunction::new(ctx, nodes).map(JitFunction)
973    }
974}
975
976impl From<GenericVmFunction<REGISTER_LIMIT>> for JitFunction {
977    fn from(v: GenericVmFunction<REGISTER_LIMIT>) -> Self {
978        Self(v)
979    }
980}
981
982impl<'a> From<&'a JitFunction> for &'a GenericVmFunction<REGISTER_LIMIT> {
983    fn from(v: &'a JitFunction) -> Self {
984        &v.0
985    }
986}
987
988////////////////////////////////////////////////////////////////////////////////
989
990// Selects the calling convention based on platform; this is forward-looking for
991// eventual x86 Windows support, where we still want to use the sysv64 calling
992// convention.
993/// Macro to build a function type with a `extern "sysv64"` calling convention
994///
995/// This is selected at compile time, based on `target_arch`
996#[cfg(target_arch = "x86_64")]
997macro_rules! jit_fn {
998    (unsafe fn($($args:tt)*)) => {
999        unsafe extern "sysv64" fn($($args)*)
1000    };
1001}
1002
1003/// Macro to build a function type with the `extern "C"` calling convention
1004///
1005/// This is selected at compile time, based on `target_arch`
1006#[cfg(target_arch = "aarch64")]
1007macro_rules! jit_fn {
1008    (unsafe fn($($args:tt)*)) => {
1009        unsafe extern "C" fn($($args)*)
1010    };
1011}
1012
1013////////////////////////////////////////////////////////////////////////////////
1014
1015/// Evaluator for a JIT-compiled tracing function
1016///
1017/// Users are unlikely to use this directly, but it's public because it's an
1018/// associated type on [`JitFunction`].
1019struct JitTracingEval<T> {
1020    choices: VmTrace,
1021    out: Vec<T>,
1022}
1023
1024impl<T> Default for JitTracingEval<T> {
1025    fn default() -> Self {
1026        Self {
1027            choices: VmTrace::default(),
1028            out: Vec::default(),
1029        }
1030    }
1031}
1032
1033/// Typedef for a tracing function pointer
1034pub type JitTracingFnPointer<T> = jit_fn!(
1035    unsafe fn(
1036        *const T, // vars
1037        *mut u8,  // choices
1038        *mut u8,  // simplify (single boolean)
1039        *mut T,   // output (array)
1040    )
1041);
1042
1043/// Handle to an owned function pointer for tracing evaluation
1044#[derive(Clone)]
1045pub struct JitTracingFn<T> {
1046    mmap: Arc<Mmap>,
1047    choice_count: usize,
1048    output_count: usize,
1049    vars: Arc<VarMap>,
1050    fn_trace: JitTracingFnPointer<T>,
1051}
1052
1053impl<T: Clone> Tape for JitTracingFn<T> {
1054    type Storage = Mmap;
1055    fn recycle(self) -> Option<Self::Storage> {
1056        Arc::into_inner(self.mmap)
1057    }
1058
1059    fn vars(&self) -> &VarMap {
1060        &self.vars
1061    }
1062
1063    fn output_count(&self) -> usize {
1064        self.output_count
1065    }
1066}
1067
1068// SAFETY: there is no mutable state in a `JitTracingFn`, and the pointer
1069// inside of it points to its own `Mmap`, which is owned by an `Arc`
1070unsafe impl<T> Send for JitTracingFn<T> {}
1071unsafe impl<T> Sync for JitTracingFn<T> {}
1072
1073impl<T: From<f32> + Clone> JitTracingEval<T> {
1074    /// Evaluates a single point, capturing an evaluation trace
1075    fn eval(
1076        &mut self,
1077        tape: &JitTracingFn<T>,
1078        vars: &[T],
1079    ) -> (&[T], Option<&VmTrace>) {
1080        let mut simplify = 0;
1081        self.choices.resize(tape.choice_count, Choice::Unknown);
1082        self.choices.fill(Choice::Unknown);
1083        self.out.resize(tape.output_count, f32::NAN.into());
1084        self.out.fill(f32::NAN.into());
1085        unsafe {
1086            (tape.fn_trace)(
1087                vars.as_ptr(),
1088                self.choices.as_mut_ptr() as *mut u8,
1089                &mut simplify,
1090                self.out.as_mut_ptr(),
1091            )
1092        };
1093
1094        (
1095            &self.out,
1096            if simplify != 0 {
1097                Some(&self.choices)
1098            } else {
1099                None
1100            },
1101        )
1102    }
1103}
1104
1105/// JIT-based tracing evaluator for interval values
1106#[derive(Default)]
1107pub struct JitIntervalEval(JitTracingEval<Interval>);
1108impl TracingEvaluator for JitIntervalEval {
1109    type Data = Interval;
1110    type Tape = JitTracingFn<Interval>;
1111    type Trace = VmTrace;
1112    type TapeStorage = Mmap;
1113
1114    #[inline]
1115    fn eval(
1116        &mut self,
1117        tape: &Self::Tape,
1118        vars: &[Self::Data],
1119    ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1120        tape.vars().check_tracing_arguments(vars)?;
1121        Ok(self.0.eval(tape, vars))
1122    }
1123}
1124
1125/// JIT-based tracing evaluator for point values
1126#[derive(Default)]
1127pub struct JitPointEval(JitTracingEval<f32>);
1128impl TracingEvaluator for JitPointEval {
1129    type Data = f32;
1130    type Tape = JitTracingFn<f32>;
1131    type Trace = VmTrace;
1132    type TapeStorage = Mmap;
1133
1134    #[inline]
1135    fn eval(
1136        &mut self,
1137        tape: &Self::Tape,
1138        vars: &[Self::Data],
1139    ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1140        tape.vars().check_tracing_arguments(vars)?;
1141        Ok(self.0.eval(tape, vars))
1142    }
1143}
1144
1145////////////////////////////////////////////////////////////////////////////////
1146
1147/// Typedef for a bulk function pointer
1148pub type JitBulkFnPointer<T> = jit_fn!(
1149    unsafe fn(
1150        *const *const T, // vars
1151        *const *mut T,   // out
1152        u64,             // size
1153    )
1154);
1155
1156/// Handle to an owned function pointer for bulk evaluation
1157#[derive(Clone)]
1158pub struct JitBulkFn<T> {
1159    mmap: Arc<Mmap>,
1160    vars: Arc<VarMap>,
1161    output_count: usize,
1162    fn_bulk: JitBulkFnPointer<T>,
1163}
1164
1165impl<T: Clone> Tape for JitBulkFn<T> {
1166    type Storage = Mmap;
1167    fn recycle(self) -> Option<Self::Storage> {
1168        Arc::into_inner(self.mmap)
1169    }
1170
1171    fn vars(&self) -> &VarMap {
1172        &self.vars
1173    }
1174
1175    fn output_count(&self) -> usize {
1176        self.output_count
1177    }
1178}
1179
1180/// Maximum SIMD width for any type, checked at runtime (alas)
1181///
1182/// We can't use `T::SIMD_SIZE` directly here due to Rust limitations. Instead we
1183/// hard-code a maximum SIMD size along with an assertion that should be
1184/// optimized out; we can't use a constant assertion here due to the same
1185/// compiler limitations.
1186const MAX_SIMD_WIDTH: usize = 8;
1187
1188/// Bulk evaluator for JIT functions
1189struct JitBulkEval<T> {
1190    /// Array of pointers used when calling into the JIT function
1191    input_ptrs: Vec<*const T>,
1192
1193    /// Array of pointers used when calling into the JIT function
1194    output_ptrs: Vec<*mut T>,
1195
1196    /// Scratch array for evaluation of less-than-SIMD-size slices
1197    scratch: Vec<[T; MAX_SIMD_WIDTH]>,
1198
1199    /// Output arrays, written to during evaluation
1200    out: Vec<Vec<T>>,
1201}
1202
1203// SAFETY: the pointers in `JitBulkEval` are transient and only scoped to a
1204// single evaluation.
1205unsafe impl<T> Sync for JitBulkEval<T> {}
1206unsafe impl<T> Send for JitBulkEval<T> {}
1207
1208impl<T> Default for JitBulkEval<T> {
1209    fn default() -> Self {
1210        Self {
1211            out: vec![],
1212            scratch: vec![],
1213            input_ptrs: vec![],
1214            output_ptrs: vec![],
1215        }
1216    }
1217}
1218
1219// SAFETY: there is no mutable state in a `JitBulkFn`, and the pointer
1220// inside of it points to its own `Mmap`, which is owned by an `Arc`
1221unsafe impl<T> Send for JitBulkFn<T> {}
1222unsafe impl<T> Sync for JitBulkFn<T> {}
1223
1224impl<T: From<f32> + Copy + SimdSize> JitBulkEval<T> {
1225    /// Evaluate multiple points
1226    fn eval<V: std::ops::Deref<Target = [T]>>(
1227        &mut self,
1228        tape: &JitBulkFn<T>,
1229        vars: &[V],
1230    ) -> BulkOutput<'_, T> {
1231        let n = vars.first().map(|v| v.deref().len()).unwrap_or(0);
1232
1233        self.out.resize_with(tape.output_count(), Vec::new);
1234        for o in &mut self.out {
1235            o.resize(n.max(T::SIMD_SIZE), f32::NAN.into());
1236            o.fill(f32::NAN.into());
1237        }
1238
1239        // Special case for when we have fewer items than the native SIMD size,
1240        // in which case the input slices can't be used as workspace (because
1241        // they are not valid for the entire range of values read in assembly)
1242        if n < T::SIMD_SIZE {
1243            assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH);
1244
1245            self.scratch
1246                .resize(vars.len(), [f32::NAN.into(); MAX_SIMD_WIDTH]);
1247            for (v, t) in vars.iter().zip(self.scratch.iter_mut()) {
1248                t[0..n].copy_from_slice(v);
1249            }
1250
1251            self.input_ptrs.clear();
1252            self.input_ptrs
1253                .extend(self.scratch[..vars.len()].iter().map(|t| t.as_ptr()));
1254
1255            self.output_ptrs.clear();
1256            self.output_ptrs
1257                .extend(self.out.iter_mut().map(|t| t.as_mut_ptr()));
1258
1259            unsafe {
1260                (tape.fn_bulk)(
1261                    self.input_ptrs.as_ptr(),
1262                    self.output_ptrs.as_ptr(),
1263                    T::SIMD_SIZE as u64,
1264                );
1265            }
1266        } else {
1267            // Our vectorized function only accepts sets of a particular width,
1268            // so we'll find the biggest multiple, then do an extra operation to
1269            // process any remainders.
1270            let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; // Round down
1271            self.input_ptrs.clear();
1272            self.input_ptrs.extend(vars.iter().map(|v| v.as_ptr()));
1273
1274            self.output_ptrs.clear();
1275            self.output_ptrs
1276                .extend(self.out.iter_mut().map(|v| v.as_mut_ptr()));
1277            unsafe {
1278                (tape.fn_bulk)(
1279                    self.input_ptrs.as_ptr(),
1280                    self.output_ptrs.as_ptr(),
1281                    m as u64,
1282                );
1283            }
1284            // If we weren't given an even multiple of vector width, then we'll
1285            // handle the remaining items by simply evaluating the *last* full
1286            // vector in the array again.
1287            if n != m {
1288                self.input_ptrs.clear();
1289                self.output_ptrs.clear();
1290                unsafe {
1291                    self.input_ptrs.extend(
1292                        vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)),
1293                    );
1294                    self.output_ptrs.extend(
1295                        self.out
1296                            .iter_mut()
1297                            .map(|v| v.as_mut_ptr().add(n - T::SIMD_SIZE)),
1298                    );
1299                    (tape.fn_bulk)(
1300                        self.input_ptrs.as_ptr(),
1301                        self.output_ptrs.as_ptr(),
1302                        T::SIMD_SIZE as u64,
1303                    );
1304                }
1305            }
1306        }
1307        BulkOutput::new(&self.out, n)
1308    }
1309}
1310
1311/// JIT-based bulk evaluator for arrays of points, yielding point values
1312#[derive(Default)]
1313pub struct JitFloatSliceEval(JitBulkEval<f32>);
1314impl BulkEvaluator for JitFloatSliceEval {
1315    type Data = f32;
1316    type Tape = JitBulkFn<Self::Data>;
1317    type TapeStorage = Mmap;
1318
1319    #[inline]
1320    fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1321        &mut self,
1322        tape: &Self::Tape,
1323        vars: &[V],
1324    ) -> Result<BulkOutput<'_, f32>, Error> {
1325        tape.vars().check_bulk_arguments(vars)?;
1326        Ok(self.0.eval(tape, vars))
1327    }
1328}
1329
1330/// JIT-based bulk evaluator for arrays of points, yielding gradient values
1331#[derive(Default)]
1332pub struct JitGradSliceEval(JitBulkEval<Grad>);
1333impl BulkEvaluator for JitGradSliceEval {
1334    type Data = Grad;
1335    type Tape = JitBulkFn<Self::Data>;
1336    type TapeStorage = Mmap;
1337
1338    #[inline]
1339    fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1340        &mut self,
1341        tape: &Self::Tape,
1342        vars: &[V],
1343    ) -> Result<BulkOutput<'_, Grad>, Error> {
1344        tape.vars().check_bulk_arguments(vars)?;
1345        Ok(self.0.eval(tape, vars))
1346    }
1347}
1348
1349/// A [`Shape`](fidget_core::shape::Shape) which uses the JIT evaluator
1350pub type JitShape = fidget_core::shape::Shape<JitFunction>;
1351
1352////////////////////////////////////////////////////////////////////////////////
1353
1354#[cfg(test)]
1355mod test {
1356    use super::*;
1357    fidget_core::grad_slice_tests!(JitFunction);
1358    fidget_core::interval_tests!(JitFunction);
1359    fidget_core::float_slice_tests!(JitFunction);
1360    fidget_core::point_tests!(JitFunction);
1361
1362    #[test]
1363    fn test_mmap_expansion() {
1364        let mmap = Mmap::new(0).unwrap();
1365
1366        let mut asm = MmapAssembler::from(mmap);
1367        const COUNT: u32 = 23456; // larger than 1 page (4 KiB)
1368
1369        for i in 0..COUNT {
1370            asm.push_u32(i);
1371        }
1372        let mmap = asm.finalize().unwrap();
1373        let ptr = mmap.as_ptr() as *const u32;
1374        for i in 0..COUNT {
1375            let v = unsafe { *ptr.add(i as usize) };
1376            assert_eq!(v, i);
1377        }
1378    }
1379}