fidget_jit/
lib.rs

1//! Compilation down to native machine code
2//!
3//! Users are unlikely to use anything in this module other than [`JitFunction`],
4//! which is a [`Function`] that uses JIT evaluation.
5//!
6//! ```
7//! use fidget_core::{
8//!     context::Tree,
9//!     shape::EzShape,
10//! };
11//! use fidget_jit::JitShape;
12//!
13//! let tree = Tree::x() + Tree::y();
14//! let shape = JitShape::from(tree);
15//!
16//! // Generate machine code to execute the tape
17//! let tape = shape.ez_point_tape();
18//! let mut eval = JitShape::new_point_eval();
19//!
20//! // This calls directly into that machine code!
21//! let (r, _trace) = eval.eval(&tape, 0.1, 0.3, 0.0)?;
22//! assert_eq!(r, 0.1 + 0.3);
23//! # Ok::<(), fidget_core::Error>(())
24//! ```
25
26use crate::mmap::{Mmap, MmapWriter};
27use fidget_core::{
28    Error,
29    compiler::RegOp,
30    context::{Context, Node},
31    eval::{
32        BulkEvaluator, BulkOutput, Function, MathFunction, Tape,
33        TracingEvaluator,
34    },
35    render::{RenderHints, TileSizes},
36    types::{Grad, Interval},
37    var::VarMap,
38    vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace},
39};
40
41use dynasmrt::{
42    AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi,
43    TargetKind, components::PatchLoc, dynasm,
44};
45use std::sync::Arc;
46
47mod mmap;
48mod permit;
49pub(crate) use permit::WritePermit;
50
51// Evaluators
52mod float_slice;
53mod grad_slice;
54mod interval;
55mod point;
56
57#[cfg(not(any(
58    target_os = "linux",
59    target_os = "macos",
60    target_os = "windows"
61)))]
62compile_error!(
63    "The `jit` module only builds on Linux, macOS, and Windows; \
64    please disable the `jit` feature"
65);
66
67#[cfg(target_arch = "aarch64")]
68mod aarch64;
69#[cfg(target_arch = "aarch64")]
70use aarch64 as arch;
71
72#[cfg(target_arch = "x86_64")]
73mod x86_64;
74#[cfg(target_arch = "x86_64")]
75use x86_64 as arch;
76
77/// Number of registers available when executing natively
78const REGISTER_LIMIT: usize = arch::REGISTER_LIMIT;
79
80/// Offset before the first useable register
81const OFFSET: u8 = arch::OFFSET;
82
83/// Register written to by `CopyImm`
84///
85/// It is the responsibility of functions to avoid writing to `IMM_REG` in cases
86/// where it could be one of their arguments (i.e. all functions of 2 or more
87/// arguments).
88const IMM_REG: u8 = arch::IMM_REG;
89
90/// Type for a register index in `dynasm` code
91#[cfg(target_arch = "aarch64")]
92type RegIndex = u32;
93
94/// Type for a register index in `dynasm` code
95#[cfg(target_arch = "x86_64")]
96type RegIndex = u8;
97
98/// Converts from a tape-local register to a hardware register
99///
100/// Tape-local registers are in the range `0..REGISTER_LIMIT`, while ARM
101/// registers have an offset (based on calling convention).
102///
103/// This uses `wrapping_add` to support immediates, which are loaded into a
104/// register below [`OFFSET`] (which is "negative" from the perspective of this
105/// function).
106fn reg(r: u8) -> RegIndex {
107    let out = r.wrapping_add(OFFSET) as RegIndex;
108    assert!(out < 32);
109    out
110}
111
112const CHOICE_LEFT: u32 = Choice::Left as u32;
113const CHOICE_RIGHT: u32 = Choice::Right as u32;
114const CHOICE_BOTH: u32 = Choice::Both as u32;
115
116/// Trait for generating machine assembly
117trait Assembler {
118    /// Data type used during evaluation.
119    ///
120    /// This should be a `repr(C)` type, so it can be passed around directly.
121    type Data;
122
123    /// Initializes the assembler with the given slot count
124    ///
125    /// This will likely construct a function prelude and reserve space on the
126    /// stack for slot spills.
127    fn init(m: Mmap, slot_count: usize) -> Self;
128
129    /// Returns an approximate bytes per clause value, used for preallocation
130    fn bytes_per_clause() -> usize {
131        8 // probably wrong!
132    }
133
134    /// Builds a load from memory to a register
135    fn build_load(&mut self, dst_reg: u8, src_mem: u32);
136
137    /// Builds a store from a register to a memory location
138    fn build_store(&mut self, dst_mem: u32, src_reg: u8);
139
140    /// Copies the given input to `out_reg`
141    fn build_input(&mut self, out_reg: u8, src_arg: u32);
142
143    /// Writes the argument register to the output
144    fn build_output(&mut self, arg_reg: u8, out_index: u32);
145
146    /// Copies a register
147    fn build_copy(&mut self, out_reg: u8, lhs_reg: u8);
148
149    /// Unary negation
150    fn build_neg(&mut self, out_reg: u8, lhs_reg: u8);
151
152    /// Absolute value
153    fn build_abs(&mut self, out_reg: u8, lhs_reg: u8);
154
155    /// Reciprocal (1 / `lhs_reg`)
156    fn build_recip(&mut self, out_reg: u8, lhs_reg: u8);
157
158    /// Square root
159    fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8);
160
161    /// Sine
162    fn build_sin(&mut self, out_reg: u8, lhs_reg: u8);
163
164    /// Cosine
165    fn build_cos(&mut self, out_reg: u8, lhs_reg: u8);
166
167    /// Tangent
168    fn build_tan(&mut self, out_reg: u8, lhs_reg: u8);
169
170    /// Arcsine
171    fn build_asin(&mut self, out_reg: u8, lhs_reg: u8);
172
173    /// Arccosine
174    fn build_acos(&mut self, out_reg: u8, lhs_reg: u8);
175
176    /// Arctangent
177    fn build_atan(&mut self, out_reg: u8, lhs_reg: u8);
178
179    /// Exponent
180    fn build_exp(&mut self, out_reg: u8, lhs_reg: u8);
181
182    /// Natural log
183    fn build_ln(&mut self, out_reg: u8, lhs_reg: u8);
184
185    /// Less than
186    fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
187
188    /// Square
189    ///
190    /// This has a default implementation, but can be overloaded for efficiency;
191    /// for example, in interval arithmetic, we benefit from knowing that both
192    /// values are the same.
193    fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
194        self.build_mul(out_reg, lhs_reg, lhs_reg)
195    }
196
197    /// Arithmetic floor
198    fn build_floor(&mut self, out_reg: u8, lhs_reg: u8);
199
200    /// Arithmetic ceiling
201    fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8);
202
203    /// Rounding
204    fn build_round(&mut self, out_reg: u8, lhs_reg: u8);
205
206    /// Logical not
207    fn build_not(&mut self, out_reg: u8, lhs_reg: u8);
208
209    /// Logical and (short-circuiting)
210    fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
211
212    /// Logical or (short-circuiting)
213    fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
214
215    /// Addition
216    fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
217
218    /// Subtraction
219    fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
220
221    /// Multiplication
222    fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
223
224    /// Division
225    fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
226
227    /// Four-quadrant arctangent
228    fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
229
230    /// Maximum of two values
231    ///
232    /// In a tracing evaluator, this function must also write to the `choices`
233    /// array and may set `simplify` if one branch is always taken.
234    fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
235
236    /// Minimum of two values
237    ///
238    /// In a tracing evaluator, this function must also write to the `choices`
239    /// array and may set `simplify` if one branch is always taken.
240    fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
241
242    /// Modulo of two values (least non-negative remainder)
243    fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
244
245    // Special-case functions for immediates.  In some cases, you can be more
246    // efficient if you know that an argument is an immediate (for example, both
247    // values in the interval will be the same, and it will have no gradients).
248
249    /// Builds a addition (immediate + register)
250    ///
251    /// This has a default implementation, but can be overloaded for efficiency
252    fn build_add_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
253        let imm = self.load_imm(imm);
254        self.build_add(out_reg, lhs_reg, imm);
255    }
256    /// Builds a subtraction (immediate − register)
257    ///
258    /// This has a default implementation, but can be overloaded for efficiency
259    fn build_sub_imm_reg(&mut self, out_reg: u8, arg: u8, imm: f32) {
260        let imm = self.load_imm(imm);
261        self.build_sub(out_reg, imm, arg);
262    }
263    /// Builds a subtraction (register − immediate)
264    ///
265    /// This has a default implementation, but can be overloaded for efficiency
266    fn build_sub_reg_imm(&mut self, out_reg: u8, arg: u8, imm: f32) {
267        let imm = self.load_imm(imm);
268        self.build_sub(out_reg, arg, imm);
269    }
270    /// Builds a multiplication (register × immediate)
271    ///
272    /// This has a default implementation, but can be overloaded for efficiency
273    fn build_mul_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
274        let imm = self.load_imm(imm);
275        self.build_mul(out_reg, lhs_reg, imm);
276    }
277
278    /// Loads an immediate into a register, returning that register
279    fn load_imm(&mut self, imm: f32) -> u8;
280
281    /// Finalize the assembly code, returning a memory-mapped region
282    fn finalize(self) -> Result<Mmap, DynasmError>;
283}
284
285/// Trait defining SIMD width
286pub trait SimdSize {
287    /// Number of elements processed in a single iteration
288    ///
289    /// This value is used when checking array sizes, as we want to be sure to
290    /// pass the JIT code an appropriately sized array.
291    const SIMD_SIZE: usize;
292}
293
294/////////////////////////////////////////////////////////////////////////////////////////
295
296pub(crate) struct AssemblerData<T> {
297    ops: MmapAssembler,
298
299    /// Current offset of the stack pointer, in bytes
300    mem_offset: usize,
301
302    /// Set to true if we have saved certain callee-saved registers
303    ///
304    /// These registers are only modified in function calls, so normally we
305    /// don't save them.
306    saved_callee_regs: bool,
307
308    _p: std::marker::PhantomData<*const T>,
309}
310
311impl<T> AssemblerData<T> {
312    fn new(mmap: Mmap) -> Self {
313        Self {
314            ops: MmapAssembler::from(mmap),
315            mem_offset: 0,
316            saved_callee_regs: false,
317            _p: std::marker::PhantomData,
318        }
319    }
320
321    fn prepare_stack(&mut self, slot_count: usize, stack_size: usize) {
322        // We always use the stack, if only to store callee-saved registers
323        let mem = slot_count.saturating_sub(REGISTER_LIMIT)
324            * std::mem::size_of::<T>()
325            + stack_size;
326
327        // Round up to the nearest multiple of 16 bytes, for alignment
328        self.mem_offset = mem.next_multiple_of(16);
329        self.push_stack();
330    }
331
332    fn stack_pos(&self, slot: u32) -> u32 {
333        assert!(slot >= REGISTER_LIMIT as u32);
334        (slot - REGISTER_LIMIT as u32) * std::mem::size_of::<T>() as u32
335    }
336}
337
338#[cfg(target_arch = "x86_64")]
339impl<T> AssemblerData<T> {
340    fn push_stack(&mut self) {
341        dynasm!(self.ops
342            ; sub rsp, self.mem_offset as i32
343        );
344    }
345
346    fn finalize(mut self) -> Result<Mmap, DynasmError> {
347        dynasm!(self.ops
348            ; add rsp, self.mem_offset as i32
349            ; pop rbp
350            ; emms
351            ; vzeroall
352            ; ret
353        );
354        self.ops.finalize()
355    }
356}
357
358#[cfg(target_arch = "aarch64")]
359#[allow(clippy::unnecessary_cast)] // dynasm-rs#106
360impl<T> AssemblerData<T> {
361    fn push_stack(&mut self) {
362        if self.mem_offset < 4096 {
363            dynasm!(self.ops
364                ; sub sp, sp, self.mem_offset as u32
365            );
366        } else if self.mem_offset < 65536 {
367            dynasm!(self.ops
368                ; mov w28, self.mem_offset as u32
369                ; sub sp, sp, w28
370            );
371        } else {
372            panic!("invalid mem offset: {} is too large", self.mem_offset);
373        }
374    }
375
376    fn finalize(mut self) -> Result<Mmap, DynasmError> {
377        // Fix up the stack
378        if self.mem_offset < 4096 {
379            dynasm!(self.ops
380                ; add sp, sp, self.mem_offset as u32
381            );
382        } else if self.mem_offset < 65536 {
383            dynasm!(self.ops
384                ; mov w9, self.mem_offset as u32
385                ; add sp, sp, w9
386            );
387        } else {
388            panic!("invalid mem offset: {}", self.mem_offset);
389        }
390
391        dynasm!(self.ops
392            ; ret
393        );
394        self.ops.finalize()
395    }
396}
397
398////////////////////////////////////////////////////////////////////////////////
399
400#[cfg(target_arch = "x86_64")]
401type Relocation = dynasmrt::x64::X64Relocation;
402
403#[cfg(target_arch = "aarch64")]
404type Relocation = dynasmrt::aarch64::Aarch64Relocation;
405
406struct MmapAssembler {
407    mmap: MmapWriter,
408
409    global_labels: [Option<AssemblyOffset>; 26],
410    local_labels: [Option<AssemblyOffset>; 26],
411
412    global_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 2>,
413    local_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 8>,
414}
415
416impl Extend<u8> for MmapAssembler {
417    fn extend<T>(&mut self, iter: T)
418    where
419        T: IntoIterator<Item = u8>,
420    {
421        for c in iter.into_iter() {
422            self.push(c);
423        }
424    }
425}
426
427impl<'a> Extend<&'a u8> for MmapAssembler {
428    fn extend<T>(&mut self, iter: T)
429    where
430        T: IntoIterator<Item = &'a u8>,
431    {
432        for c in iter.into_iter() {
433            self.push(*c);
434        }
435    }
436}
437
438impl DynasmApi for MmapAssembler {
439    #[inline(always)]
440    fn offset(&self) -> AssemblyOffset {
441        AssemblyOffset(self.mmap.len())
442    }
443
444    #[inline(always)]
445    fn push(&mut self, byte: u8) {
446        self.mmap.push(byte);
447    }
448
449    #[inline(always)]
450    fn align(&mut self, alignment: usize, with: u8) {
451        let offset = self.offset().0 % alignment;
452        if offset != 0 {
453            for _ in offset..alignment {
454                self.push(with);
455            }
456        }
457    }
458
459    #[inline(always)]
460    fn push_u32(&mut self, value: u32) {
461        for b in value.to_le_bytes() {
462            self.mmap.push(b);
463        }
464    }
465}
466
467/// This is a very limited implementation of the labels API.  Compared to the
468/// standard labels API, it has the following limitations:
469///
470/// - Labels must be a single character
471/// - Local labels must be committed before they're reused, using `commit_local`
472/// - Only 8 local jumps are available at any given time; this is reset when
473///   `commit_local` is called.  (if this becomes problematic, it can be
474///   increased by tweaking the size of `local_relocs: ArrayVec<..., 8>`.
475///
476/// In exchange for these limitations, it allocates no memory at runtime, and all
477/// label lookups are done in constant time.
478///
479/// However, it still has overhead compared to computing the jumps by hand;
480/// this overhead was roughly 5% in one unscientific test.
481impl DynasmLabelApi for MmapAssembler {
482    type Relocation = Relocation;
483
484    fn local_label(&mut self, name: &'static str) {
485        if name.len() != 1 {
486            panic!("local label must be a single character");
487        }
488        let c = name.as_bytes()[0].wrapping_sub(b'A');
489        if c >= 26 {
490            panic!("Invalid label {name}, must be A-Z");
491        }
492        if self.local_labels[c as usize].is_some() {
493            panic!("duplicate local label {name}");
494        }
495
496        self.local_labels[c as usize] = Some(self.offset());
497    }
498    fn global_label(&mut self, name: &'static str) {
499        if name.len() != 1 {
500            panic!("local label must be a single character");
501        }
502        let c = name.as_bytes()[0].wrapping_sub(b'A');
503        if c >= 26 {
504            panic!("Invalid label {name}, must be A-Z");
505        }
506        if self.global_labels[c as usize].is_some() {
507            panic!("duplicate global label {name}");
508        }
509
510        self.global_labels[c as usize] = Some(self.offset());
511    }
512    fn dynamic_label(&mut self, _id: DynamicLabel) {
513        panic!("dynamic labels are not supported");
514    }
515    fn global_relocation(
516        &mut self,
517        name: &'static str,
518        target_offset: isize,
519        field_offset: u8,
520        ref_offset: u8,
521        kind: Relocation,
522    ) {
523        let location = self.offset();
524        if name.len() != 1 {
525            panic!("local label must be a single character");
526        }
527        let c = name.as_bytes()[0].wrapping_sub(b'A');
528        if c >= 26 {
529            panic!("Invalid label {name}, must be A-Z");
530        }
531        self.global_relocs.push((
532            PatchLoc::new(
533                location,
534                target_offset,
535                field_offset,
536                ref_offset,
537                kind,
538            ),
539            c,
540        ));
541    }
542    fn dynamic_relocation(
543        &mut self,
544        _id: DynamicLabel,
545        _target_offset: isize,
546        _field_offset: u8,
547        _ref_offset: u8,
548        _kind: Relocation,
549    ) {
550        panic!("dynamic relocations are not supported");
551    }
552    fn forward_relocation(
553        &mut self,
554        name: &'static str,
555        target_offset: isize,
556        field_offset: u8,
557        ref_offset: u8,
558        kind: Relocation,
559    ) {
560        if name.len() != 1 {
561            panic!("local label must be a single character");
562        }
563        let c = name.as_bytes()[0].wrapping_sub(b'A');
564        if c >= 26 {
565            panic!("Invalid label {name}, must be A-Z");
566        }
567        if self.local_labels[c as usize].is_some() {
568            panic!("invalid forward relocation: {name} already exists!");
569        }
570        let location = self.offset();
571        self.local_relocs.push((
572            PatchLoc::new(
573                location,
574                target_offset,
575                field_offset,
576                ref_offset,
577                kind,
578            ),
579            c,
580        ));
581    }
582    fn backward_relocation(
583        &mut self,
584        name: &'static str,
585        target_offset: isize,
586        field_offset: u8,
587        ref_offset: u8,
588        kind: Relocation,
589    ) {
590        if name.len() != 1 {
591            panic!("local label must be a single character");
592        }
593        let c = name.as_bytes()[0].wrapping_sub(b'A');
594        if c >= 26 {
595            panic!("Invalid label {name}, must be A-Z");
596        }
597        if self.local_labels[c as usize].is_none() {
598            panic!("invalid backward relocation: {name} does not exist");
599        }
600        let location = self.offset();
601        self.local_relocs.push((
602            PatchLoc::new(
603                location,
604                target_offset,
605                field_offset,
606                ref_offset,
607                kind,
608            ),
609            c,
610        ));
611    }
612    fn bare_relocation(
613        &mut self,
614        _target: usize,
615        _field_offset: u8,
616        _ref_offset: u8,
617        _kind: Relocation,
618    ) {
619        panic!("bare relocations not implemented");
620    }
621}
622
623impl MmapAssembler {
624    /// Applies all local relocations, clearing the `local_relocs` array
625    ///
626    /// This should be called after any function which uses local labels.
627    fn commit_local(&mut self) -> Result<(), DynasmError> {
628        let baseaddr = self.mmap.as_ptr() as usize;
629
630        for (loc, label) in self.local_relocs.take() {
631            let target =
632                self.local_labels[label as usize].expect("invalid local label");
633            let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
634            if loc.patch(buf, baseaddr, target.0).is_err() {
635                return Err(DynasmError::ImpossibleRelocation(
636                    TargetKind::Local("oh no"),
637                ));
638            }
639        }
640        self.local_labels = [None; 26];
641        Ok(())
642    }
643
644    fn finalize(mut self) -> Result<Mmap, DynasmError> {
645        self.commit_local()?;
646
647        let baseaddr = self.mmap.as_ptr() as usize;
648        for (loc, label) in self.global_relocs.take() {
649            let target =
650                self.global_labels.get(label as usize).unwrap().unwrap();
651            let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
652            if loc.patch(buf, baseaddr, target.0).is_err() {
653                return Err(DynasmError::ImpossibleRelocation(
654                    TargetKind::Global("oh no"),
655                ));
656            }
657        }
658
659        Ok(self.mmap.finalize())
660    }
661}
662
663impl From<Mmap> for MmapAssembler {
664    fn from(mmap: Mmap) -> Self {
665        Self {
666            mmap: MmapWriter::from(mmap),
667            global_labels: [None; 26],
668            local_labels: [None; 26],
669            global_relocs: Default::default(),
670            local_relocs: Default::default(),
671        }
672    }
673}
674
675/////////////////////////////////////////////////////////////////////////////////////////
676
677fn build_asm_fn_with_storage<A: Assembler>(
678    t: &VmData<REGISTER_LIMIT>,
679    mut s: Mmap,
680) -> Mmap {
681    let size_estimate = t.len() * A::bytes_per_clause();
682    if size_estimate > 2 * s.capacity() {
683        s = Mmap::new(size_estimate).expect("failed to build mmap")
684    }
685
686    let mut asm = A::init(s, t.slot_count());
687
688    for op in t.iter_asm() {
689        match op {
690            RegOp::Load(reg, mem) => {
691                asm.build_load(reg, mem);
692            }
693            RegOp::Store(reg, mem) => {
694                asm.build_store(mem, reg);
695            }
696            RegOp::Input(out, i) => {
697                asm.build_input(out, i);
698            }
699            RegOp::Output(arg, i) => {
700                asm.build_output(arg, i);
701            }
702            RegOp::NegReg(out, arg) => {
703                asm.build_neg(out, arg);
704            }
705            RegOp::AbsReg(out, arg) => {
706                asm.build_abs(out, arg);
707            }
708            RegOp::RecipReg(out, arg) => {
709                asm.build_recip(out, arg);
710            }
711            RegOp::SqrtReg(out, arg) => {
712                asm.build_sqrt(out, arg);
713            }
714            RegOp::SinReg(out, arg) => {
715                asm.build_sin(out, arg);
716            }
717            RegOp::CosReg(out, arg) => {
718                asm.build_cos(out, arg);
719            }
720            RegOp::TanReg(out, arg) => {
721                asm.build_tan(out, arg);
722            }
723            RegOp::AsinReg(out, arg) => {
724                asm.build_asin(out, arg);
725            }
726            RegOp::AcosReg(out, arg) => {
727                asm.build_acos(out, arg);
728            }
729            RegOp::AtanReg(out, arg) => {
730                asm.build_atan(out, arg);
731            }
732            RegOp::ExpReg(out, arg) => {
733                asm.build_exp(out, arg);
734            }
735            RegOp::LnReg(out, arg) => {
736                asm.build_ln(out, arg);
737            }
738            RegOp::CopyReg(out, arg) => {
739                asm.build_copy(out, arg);
740            }
741            RegOp::SquareReg(out, arg) => {
742                asm.build_square(out, arg);
743            }
744            RegOp::FloorReg(out, arg) => {
745                asm.build_floor(out, arg);
746            }
747            RegOp::CeilReg(out, arg) => {
748                asm.build_ceil(out, arg);
749            }
750            RegOp::RoundReg(out, arg) => {
751                asm.build_round(out, arg);
752            }
753            RegOp::NotReg(out, arg) => {
754                asm.build_not(out, arg);
755            }
756            RegOp::AddRegReg(out, lhs, rhs) => {
757                asm.build_add(out, lhs, rhs);
758            }
759            RegOp::MulRegReg(out, lhs, rhs) => {
760                asm.build_mul(out, lhs, rhs);
761            }
762            RegOp::DivRegReg(out, lhs, rhs) => {
763                asm.build_div(out, lhs, rhs);
764            }
765            RegOp::AtanRegReg(out, lhs, rhs) => {
766                asm.build_atan2(out, lhs, rhs);
767            }
768            RegOp::SubRegReg(out, lhs, rhs) => {
769                asm.build_sub(out, lhs, rhs);
770            }
771            RegOp::MinRegReg(out, lhs, rhs) => {
772                asm.build_min(out, lhs, rhs);
773            }
774            RegOp::MaxRegReg(out, lhs, rhs) => {
775                asm.build_max(out, lhs, rhs);
776            }
777            RegOp::AddRegImm(out, arg, imm) => {
778                asm.build_add_imm(out, arg, imm);
779            }
780            RegOp::MulRegImm(out, arg, imm) => {
781                asm.build_mul_imm(out, arg, imm);
782            }
783            RegOp::DivRegImm(out, arg, imm) => {
784                let reg = asm.load_imm(imm);
785                asm.build_div(out, arg, reg);
786            }
787            RegOp::DivImmReg(out, arg, imm) => {
788                let reg = asm.load_imm(imm);
789                asm.build_div(out, reg, arg);
790            }
791            RegOp::AtanRegImm(out, arg, imm) => {
792                let reg = asm.load_imm(imm);
793                asm.build_atan2(out, arg, reg);
794            }
795            RegOp::AtanImmReg(out, arg, imm) => {
796                let reg = asm.load_imm(imm);
797                asm.build_atan2(out, reg, arg);
798            }
799            RegOp::SubImmReg(out, arg, imm) => {
800                asm.build_sub_imm_reg(out, arg, imm);
801            }
802            RegOp::SubRegImm(out, arg, imm) => {
803                asm.build_sub_reg_imm(out, arg, imm);
804            }
805            RegOp::MinRegImm(out, arg, imm) => {
806                let reg = asm.load_imm(imm);
807                asm.build_min(out, arg, reg);
808            }
809            RegOp::MaxRegImm(out, arg, imm) => {
810                let reg = asm.load_imm(imm);
811                asm.build_max(out, arg, reg);
812            }
813            RegOp::ModRegReg(out, lhs, rhs) => {
814                asm.build_mod(out, lhs, rhs);
815            }
816            RegOp::ModRegImm(out, arg, imm) => {
817                let reg = asm.load_imm(imm);
818                asm.build_mod(out, arg, reg);
819            }
820            RegOp::ModImmReg(out, arg, imm) => {
821                let reg = asm.load_imm(imm);
822                asm.build_mod(out, reg, arg);
823            }
824            RegOp::AndRegReg(out, lhs, rhs) => {
825                asm.build_and(out, lhs, rhs);
826            }
827            RegOp::AndRegImm(out, arg, imm) => {
828                let reg = asm.load_imm(imm);
829                asm.build_and(out, arg, reg);
830            }
831            RegOp::OrRegReg(out, lhs, rhs) => {
832                asm.build_or(out, lhs, rhs);
833            }
834            RegOp::OrRegImm(out, arg, imm) => {
835                let reg = asm.load_imm(imm);
836                asm.build_or(out, arg, reg);
837            }
838            RegOp::CopyImm(out, imm) => {
839                let reg = asm.load_imm(imm);
840                asm.build_copy(out, reg);
841            }
842            RegOp::CompareRegReg(out, lhs, rhs) => {
843                asm.build_compare(out, lhs, rhs);
844            }
845            RegOp::CompareRegImm(out, arg, imm) => {
846                let reg = asm.load_imm(imm);
847                asm.build_compare(out, arg, reg);
848            }
849            RegOp::CompareImmReg(out, arg, imm) => {
850                let reg = asm.load_imm(imm);
851                asm.build_compare(out, reg, arg);
852            }
853        }
854    }
855
856    asm.finalize().expect("failed to build JIT function")
857    // JIT execute mode is restored here when the _guard is dropped
858}
859
860/// Function for use with a JIT evaluator
861#[derive(Clone)]
862pub struct JitFunction(GenericVmFunction<REGISTER_LIMIT>);
863
864impl JitFunction {
865    fn tracing_tape<A: Assembler>(
866        &self,
867        storage: Mmap,
868    ) -> JitTracingFn<A::Data> {
869        let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
870        let ptr = f.as_ptr();
871        JitTracingFn {
872            mmap: f.into(),
873            vars: self.0.data().vars.clone(),
874            choice_count: self.0.choice_count(),
875            output_count: self.0.output_count(),
876            fn_trace: unsafe {
877                std::mem::transmute::<
878                    *const std::ffi::c_void,
879                    JitTracingFnPointer<A::Data>,
880                >(ptr)
881            },
882        }
883    }
884    fn bulk_tape<A: Assembler>(&self, storage: Mmap) -> JitBulkFn<A::Data> {
885        let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
886        let ptr = f.as_ptr();
887        JitBulkFn {
888            mmap: f.into(),
889            output_count: self.0.output_count(),
890            vars: self.0.data().vars.clone(),
891            fn_bulk: unsafe {
892                std::mem::transmute::<
893                    *const std::ffi::c_void,
894                    JitBulkFnPointer<A::Data>,
895                >(ptr)
896            },
897        }
898    }
899}
900
901impl Function for JitFunction {
902    type Trace = VmTrace;
903    type Storage = VmData<REGISTER_LIMIT>;
904    type Workspace = VmWorkspace<REGISTER_LIMIT>;
905
906    type TapeStorage = Mmap;
907
908    type IntervalEval = JitIntervalEval;
909    type PointEval = JitPointEval;
910    type FloatSliceEval = JitFloatSliceEval;
911    type GradSliceEval = JitGradSliceEval;
912
913    #[inline]
914    fn point_tape(&self, storage: Mmap) -> JitTracingFn<f32> {
915        self.tracing_tape::<point::PointAssembler>(storage)
916    }
917
918    #[inline]
919    fn interval_tape(&self, storage: Mmap) -> JitTracingFn<Interval> {
920        self.tracing_tape::<interval::IntervalAssembler>(storage)
921    }
922
923    #[inline]
924    fn float_slice_tape(&self, storage: Mmap) -> JitBulkFn<f32> {
925        self.bulk_tape::<float_slice::FloatSliceAssembler>(storage)
926    }
927
928    #[inline]
929    fn grad_slice_tape(&self, storage: Mmap) -> JitBulkFn<Grad> {
930        self.bulk_tape::<grad_slice::GradSliceAssembler>(storage)
931    }
932
933    #[inline]
934    fn simplify(
935        &self,
936        trace: &Self::Trace,
937        storage: Self::Storage,
938        workspace: &mut Self::Workspace,
939    ) -> Result<Self, Error> {
940        self.0.simplify(trace, storage, workspace).map(JitFunction)
941    }
942
943    #[inline]
944    fn recycle(self) -> Option<Self::Storage> {
945        self.0.recycle()
946    }
947
948    #[inline]
949    fn size(&self) -> usize {
950        self.0.size()
951    }
952
953    #[inline]
954    fn vars(&self) -> &VarMap {
955        self.0.vars()
956    }
957
958    #[inline]
959    fn can_simplify(&self) -> bool {
960        self.0.choice_count() > 0
961    }
962}
963
964impl RenderHints for JitFunction {
965    fn tile_sizes_3d() -> TileSizes {
966        TileSizes::new(&[64, 16, 8]).unwrap()
967    }
968
969    fn tile_sizes_2d() -> TileSizes {
970        TileSizes::new(&[128, 16]).unwrap()
971    }
972
973    fn simplify_tree_during_meshing(d: usize) -> bool {
974        // Unscientifically selected, but similar to tile_sizes_3d
975        d % 8 == 4
976    }
977}
978
979impl MathFunction for JitFunction {
980    fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
981        GenericVmFunction::new(ctx, nodes).map(JitFunction)
982    }
983}
984
985impl From<GenericVmFunction<REGISTER_LIMIT>> for JitFunction {
986    fn from(v: GenericVmFunction<REGISTER_LIMIT>) -> Self {
987        Self(v)
988    }
989}
990
991impl<'a> From<&'a JitFunction> for &'a GenericVmFunction<REGISTER_LIMIT> {
992    fn from(v: &'a JitFunction) -> Self {
993        &v.0
994    }
995}
996
997////////////////////////////////////////////////////////////////////////////////
998
999// Selects the calling convention based on platform; this is forward-looking for
1000// eventual x86 Windows support, where we still want to use the sysv64 calling
1001// convention.
1002/// Macro to build a function type with a `extern "sysv64"` calling convention
1003///
1004/// This is selected at compile time, based on `target_arch`
1005#[cfg(target_arch = "x86_64")]
1006macro_rules! jit_fn {
1007    (unsafe fn($($args:tt)*)) => {
1008        unsafe extern "sysv64" fn($($args)*)
1009    };
1010}
1011
1012/// Macro to build a function type with the `extern "C"` calling convention
1013///
1014/// This is selected at compile time, based on `target_arch`
1015#[cfg(target_arch = "aarch64")]
1016macro_rules! jit_fn {
1017    (unsafe fn($($args:tt)*)) => {
1018        unsafe extern "C" fn($($args)*)
1019    };
1020}
1021
1022////////////////////////////////////////////////////////////////////////////////
1023
1024/// Evaluator for a JIT-compiled tracing function
1025///
1026/// Users are unlikely to use this directly, but it's public because it's an
1027/// associated type on [`JitFunction`].
1028struct JitTracingEval<T> {
1029    choices: VmTrace,
1030    out: Vec<T>,
1031}
1032
1033impl<T> Default for JitTracingEval<T> {
1034    fn default() -> Self {
1035        Self {
1036            choices: VmTrace::default(),
1037            out: Vec::default(),
1038        }
1039    }
1040}
1041
1042/// Typedef for a tracing function pointer
1043pub type JitTracingFnPointer<T> = jit_fn!(
1044    unsafe fn(
1045        *const T, // vars
1046        *mut u8,  // choices
1047        *mut u8,  // simplify (single boolean)
1048        *mut T,   // output (array)
1049    )
1050);
1051
1052/// Handle to an owned function pointer for tracing evaluation
1053#[derive(Clone)]
1054pub struct JitTracingFn<T> {
1055    mmap: Arc<Mmap>,
1056    choice_count: usize,
1057    output_count: usize,
1058    vars: Arc<VarMap>,
1059    fn_trace: JitTracingFnPointer<T>,
1060}
1061
1062impl<T: Clone> Tape for JitTracingFn<T> {
1063    type Storage = Mmap;
1064    fn recycle(self) -> Option<Self::Storage> {
1065        Arc::into_inner(self.mmap)
1066    }
1067
1068    fn vars(&self) -> &VarMap {
1069        &self.vars
1070    }
1071
1072    fn output_count(&self) -> usize {
1073        self.output_count
1074    }
1075}
1076
1077// SAFETY: there is no mutable state in a `JitTracingFn`, and the pointer
1078// inside of it points to its own `Mmap`, which is owned by an `Arc`
1079unsafe impl<T> Send for JitTracingFn<T> {}
1080unsafe impl<T> Sync for JitTracingFn<T> {}
1081
1082impl<T: From<f32> + Clone> JitTracingEval<T> {
1083    /// Evaluates a single point, capturing an evaluation trace
1084    fn eval(
1085        &mut self,
1086        tape: &JitTracingFn<T>,
1087        vars: &[T],
1088    ) -> (&[T], Option<&VmTrace>) {
1089        let mut simplify = 0;
1090        self.choices.resize(tape.choice_count, Choice::Unknown);
1091        self.choices.fill(Choice::Unknown);
1092        self.out.resize(tape.output_count, f32::NAN.into());
1093        self.out.fill(f32::NAN.into());
1094        unsafe {
1095            (tape.fn_trace)(
1096                vars.as_ptr(),
1097                self.choices.as_mut_ptr() as *mut u8,
1098                &mut simplify,
1099                self.out.as_mut_ptr(),
1100            )
1101        };
1102
1103        (
1104            &self.out,
1105            if simplify != 0 {
1106                Some(&self.choices)
1107            } else {
1108                None
1109            },
1110        )
1111    }
1112}
1113
1114/// JIT-based tracing evaluator for interval values
1115#[derive(Default)]
1116pub struct JitIntervalEval(JitTracingEval<Interval>);
1117impl TracingEvaluator for JitIntervalEval {
1118    type Data = Interval;
1119    type Tape = JitTracingFn<Interval>;
1120    type Trace = VmTrace;
1121    type TapeStorage = Mmap;
1122
1123    #[inline]
1124    fn eval(
1125        &mut self,
1126        tape: &Self::Tape,
1127        vars: &[Self::Data],
1128    ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1129        tape.vars().check_tracing_arguments(vars)?;
1130        Ok(self.0.eval(tape, vars))
1131    }
1132}
1133
1134/// JIT-based tracing evaluator for point values
1135#[derive(Default)]
1136pub struct JitPointEval(JitTracingEval<f32>);
1137impl TracingEvaluator for JitPointEval {
1138    type Data = f32;
1139    type Tape = JitTracingFn<f32>;
1140    type Trace = VmTrace;
1141    type TapeStorage = Mmap;
1142
1143    #[inline]
1144    fn eval(
1145        &mut self,
1146        tape: &Self::Tape,
1147        vars: &[Self::Data],
1148    ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1149        tape.vars().check_tracing_arguments(vars)?;
1150        Ok(self.0.eval(tape, vars))
1151    }
1152}
1153
1154////////////////////////////////////////////////////////////////////////////////
1155
1156/// Typedef for a bulk function pointer
1157pub type JitBulkFnPointer<T> = jit_fn!(
1158    unsafe fn(
1159        *const *const T, // vars
1160        *const *mut T,   // out
1161        u64,             // size
1162    )
1163);
1164
1165/// Handle to an owned function pointer for bulk evaluation
1166#[derive(Clone)]
1167pub struct JitBulkFn<T> {
1168    mmap: Arc<Mmap>,
1169    vars: Arc<VarMap>,
1170    output_count: usize,
1171    fn_bulk: JitBulkFnPointer<T>,
1172}
1173
1174impl<T: Clone> Tape for JitBulkFn<T> {
1175    type Storage = Mmap;
1176    fn recycle(self) -> Option<Self::Storage> {
1177        Arc::into_inner(self.mmap)
1178    }
1179
1180    fn vars(&self) -> &VarMap {
1181        &self.vars
1182    }
1183
1184    fn output_count(&self) -> usize {
1185        self.output_count
1186    }
1187}
1188
1189/// Maximum SIMD width for any type, checked at runtime (alas)
1190///
1191/// We can't use `T::SIMD_SIZE` directly here due to Rust limitations. Instead we
1192/// hard-code a maximum SIMD size along with an assertion that should be
1193/// optimized out; we can't use a constant assertion here due to the same
1194/// compiler limitations.
1195const MAX_SIMD_WIDTH: usize = 8;
1196
1197/// Bulk evaluator for JIT functions
1198struct JitBulkEval<T> {
1199    /// Array of pointers used when calling into the JIT function
1200    input_ptrs: Vec<*const T>,
1201
1202    /// Array of pointers used when calling into the JIT function
1203    output_ptrs: Vec<*mut T>,
1204
1205    /// Scratch array for evaluation of less-than-SIMD-size slices
1206    scratch: Vec<[T; MAX_SIMD_WIDTH]>,
1207
1208    /// Output arrays, written to during evaluation
1209    out: Vec<Vec<T>>,
1210}
1211
1212// SAFETY: the pointers in `JitBulkEval` are transient and only scoped to a
1213// single evaluation.
1214unsafe impl<T> Sync for JitBulkEval<T> {}
1215unsafe impl<T> Send for JitBulkEval<T> {}
1216
1217impl<T> Default for JitBulkEval<T> {
1218    fn default() -> Self {
1219        Self {
1220            out: vec![],
1221            scratch: vec![],
1222            input_ptrs: vec![],
1223            output_ptrs: vec![],
1224        }
1225    }
1226}
1227
1228// SAFETY: there is no mutable state in a `JitBulkFn`, and the pointer
1229// inside of it points to its own `Mmap`, which is owned by an `Arc`
1230unsafe impl<T> Send for JitBulkFn<T> {}
1231unsafe impl<T> Sync for JitBulkFn<T> {}
1232
1233impl<T: From<f32> + Copy + SimdSize> JitBulkEval<T> {
1234    /// Evaluate multiple points
1235    fn eval<V: std::ops::Deref<Target = [T]>>(
1236        &mut self,
1237        tape: &JitBulkFn<T>,
1238        vars: &[V],
1239    ) -> BulkOutput<'_, T> {
1240        let n = vars.first().map(|v| v.deref().len()).unwrap_or(0);
1241
1242        self.out.resize_with(tape.output_count(), Vec::new);
1243        for o in &mut self.out {
1244            o.resize(n.max(T::SIMD_SIZE), f32::NAN.into());
1245            o.fill(f32::NAN.into());
1246        }
1247
1248        // Special case for when we have fewer items than the native SIMD size,
1249        // in which case the input slices can't be used as workspace (because
1250        // they are not valid for the entire range of values read in assembly)
1251        if n < T::SIMD_SIZE {
1252            assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH);
1253
1254            self.scratch
1255                .resize(vars.len(), [f32::NAN.into(); MAX_SIMD_WIDTH]);
1256            for (v, t) in vars.iter().zip(self.scratch.iter_mut()) {
1257                t[0..n].copy_from_slice(v);
1258            }
1259
1260            self.input_ptrs.clear();
1261            self.input_ptrs
1262                .extend(self.scratch[..vars.len()].iter().map(|t| t.as_ptr()));
1263
1264            self.output_ptrs.clear();
1265            self.output_ptrs
1266                .extend(self.out.iter_mut().map(|t| t.as_mut_ptr()));
1267
1268            unsafe {
1269                (tape.fn_bulk)(
1270                    self.input_ptrs.as_ptr(),
1271                    self.output_ptrs.as_ptr(),
1272                    T::SIMD_SIZE as u64,
1273                );
1274            }
1275        } else {
1276            // Our vectorized function only accepts sets of a particular width,
1277            // so we'll find the biggest multiple, then do an extra operation to
1278            // process any remainders.
1279            let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; // Round down
1280            self.input_ptrs.clear();
1281            self.input_ptrs.extend(vars.iter().map(|v| v.as_ptr()));
1282
1283            self.output_ptrs.clear();
1284            self.output_ptrs
1285                .extend(self.out.iter_mut().map(|v| v.as_mut_ptr()));
1286            unsafe {
1287                (tape.fn_bulk)(
1288                    self.input_ptrs.as_ptr(),
1289                    self.output_ptrs.as_ptr(),
1290                    m as u64,
1291                );
1292            }
1293            // If we weren't given an even multiple of vector width, then we'll
1294            // handle the remaining items by simply evaluating the *last* full
1295            // vector in the array again.
1296            if n != m {
1297                self.input_ptrs.clear();
1298                self.output_ptrs.clear();
1299                unsafe {
1300                    self.input_ptrs.extend(
1301                        vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)),
1302                    );
1303                    self.output_ptrs.extend(
1304                        self.out
1305                            .iter_mut()
1306                            .map(|v| v.as_mut_ptr().add(n - T::SIMD_SIZE)),
1307                    );
1308                    (tape.fn_bulk)(
1309                        self.input_ptrs.as_ptr(),
1310                        self.output_ptrs.as_ptr(),
1311                        T::SIMD_SIZE as u64,
1312                    );
1313                }
1314            }
1315        }
1316        BulkOutput::new(&self.out, n)
1317    }
1318}
1319
1320/// JIT-based bulk evaluator for arrays of points, yielding point values
1321#[derive(Default)]
1322pub struct JitFloatSliceEval(JitBulkEval<f32>);
1323impl BulkEvaluator for JitFloatSliceEval {
1324    type Data = f32;
1325    type Tape = JitBulkFn<Self::Data>;
1326    type TapeStorage = Mmap;
1327
1328    #[inline]
1329    fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1330        &mut self,
1331        tape: &Self::Tape,
1332        vars: &[V],
1333    ) -> Result<BulkOutput<'_, f32>, Error> {
1334        tape.vars().check_bulk_arguments(vars)?;
1335        Ok(self.0.eval(tape, vars))
1336    }
1337}
1338
1339/// JIT-based bulk evaluator for arrays of points, yielding gradient values
1340#[derive(Default)]
1341pub struct JitGradSliceEval(JitBulkEval<Grad>);
1342impl BulkEvaluator for JitGradSliceEval {
1343    type Data = Grad;
1344    type Tape = JitBulkFn<Self::Data>;
1345    type TapeStorage = Mmap;
1346
1347    #[inline]
1348    fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1349        &mut self,
1350        tape: &Self::Tape,
1351        vars: &[V],
1352    ) -> Result<BulkOutput<'_, Grad>, Error> {
1353        tape.vars().check_bulk_arguments(vars)?;
1354        Ok(self.0.eval(tape, vars))
1355    }
1356}
1357
1358/// A [`Shape`](fidget_core::shape::Shape) which uses the JIT evaluator
1359pub type JitShape = fidget_core::shape::Shape<JitFunction>;
1360
1361////////////////////////////////////////////////////////////////////////////////
1362
1363#[cfg(test)]
1364mod test {
1365    use super::*;
1366    fidget_core::grad_slice_tests!(JitFunction);
1367    fidget_core::interval_tests!(JitFunction);
1368    fidget_core::float_slice_tests!(JitFunction);
1369    fidget_core::point_tests!(JitFunction);
1370
1371    #[test]
1372    fn test_mmap_expansion() {
1373        let mmap = Mmap::new(0).unwrap();
1374
1375        let mut asm = MmapAssembler::from(mmap);
1376        const COUNT: u32 = 23456; // larger than 1 page (4 KiB)
1377
1378        for i in 0..COUNT {
1379            asm.push_u32(i);
1380        }
1381        let mmap = asm.finalize().unwrap();
1382        let ptr = mmap.as_ptr() as *const u32;
1383        for i in 0..COUNT {
1384            let v = unsafe { *ptr.add(i as usize) };
1385            assert_eq!(v, i);
1386        }
1387    }
1388}