1use crate::mmap::{Mmap, MmapWriter};
27use fidget_core::{
28 Error,
29 compiler::RegOp,
30 context::{Context, Node},
31 eval::{
32 BulkEvaluator, BulkOutput, Function, MathFunction, Tape,
33 TracingEvaluator,
34 },
35 render::{RenderHints, TileSizes},
36 types::{Grad, Interval},
37 var::VarMap,
38 vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace},
39};
40
41use dynasmrt::{
42 AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi,
43 TargetKind, components::PatchLoc, dynasm,
44};
45use std::sync::Arc;
46
47mod mmap;
48mod permit;
49pub(crate) use permit::WritePermit;
50
51mod float_slice;
53mod grad_slice;
54mod interval;
55mod point;
56
57#[cfg(not(any(
58 target_os = "linux",
59 target_os = "macos",
60 target_os = "windows"
61)))]
62compile_error!(
63 "The `jit` module only builds on Linux, macOS, and Windows; \
64 please disable the `jit` feature"
65);
66
67#[cfg(target_arch = "aarch64")]
68mod aarch64;
69#[cfg(target_arch = "aarch64")]
70use aarch64 as arch;
71
72#[cfg(target_arch = "x86_64")]
73mod x86_64;
74#[cfg(target_arch = "x86_64")]
75use x86_64 as arch;
76
77const REGISTER_LIMIT: usize = arch::REGISTER_LIMIT;
79
80const OFFSET: u8 = arch::OFFSET;
82
83const IMM_REG: u8 = arch::IMM_REG;
89
90fn reg(r: u8) -> u8 {
99 let out = r.wrapping_add(OFFSET);
100 assert!(out < 32);
101 out
102}
103
104const CHOICE_LEFT: u32 = Choice::Left as u32;
105const CHOICE_RIGHT: u32 = Choice::Right as u32;
106const CHOICE_BOTH: u32 = Choice::Both as u32;
107
108trait Assembler {
110 type Data;
114
115 fn init(m: Mmap, slot_count: usize) -> Self;
120
121 fn bytes_per_clause() -> usize {
123 8 }
125
126 fn build_load(&mut self, dst_reg: u8, src_mem: u32);
128
129 fn build_store(&mut self, dst_mem: u32, src_reg: u8);
131
132 fn build_input(&mut self, out_reg: u8, src_arg: u32);
134
135 fn build_output(&mut self, arg_reg: u8, out_index: u32);
137
138 fn build_copy(&mut self, out_reg: u8, lhs_reg: u8);
140
141 fn build_neg(&mut self, out_reg: u8, lhs_reg: u8);
143
144 fn build_abs(&mut self, out_reg: u8, lhs_reg: u8);
146
147 fn build_recip(&mut self, out_reg: u8, lhs_reg: u8);
149
150 fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8);
152
153 fn build_sin(&mut self, out_reg: u8, lhs_reg: u8);
155
156 fn build_cos(&mut self, out_reg: u8, lhs_reg: u8);
158
159 fn build_tan(&mut self, out_reg: u8, lhs_reg: u8);
161
162 fn build_asin(&mut self, out_reg: u8, lhs_reg: u8);
164
165 fn build_acos(&mut self, out_reg: u8, lhs_reg: u8);
167
168 fn build_atan(&mut self, out_reg: u8, lhs_reg: u8);
170
171 fn build_exp(&mut self, out_reg: u8, lhs_reg: u8);
173
174 fn build_ln(&mut self, out_reg: u8, lhs_reg: u8);
176
177 fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
179
180 fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
186 self.build_mul(out_reg, lhs_reg, lhs_reg)
187 }
188
189 fn build_floor(&mut self, out_reg: u8, lhs_reg: u8);
191
192 fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8);
194
195 fn build_round(&mut self, out_reg: u8, lhs_reg: u8);
197
198 fn build_not(&mut self, out_reg: u8, lhs_reg: u8);
200
201 fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
203
204 fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
206
207 fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
209
210 fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
212
213 fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
215
216 fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
218
219 fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
221
222 fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
227
228 fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
233
234 fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
236
237 fn build_add_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
245 let imm = self.load_imm(imm);
246 self.build_add(out_reg, lhs_reg, imm);
247 }
248 fn build_sub_imm_reg(&mut self, out_reg: u8, arg: u8, imm: f32) {
252 let imm = self.load_imm(imm);
253 self.build_sub(out_reg, imm, arg);
254 }
255 fn build_sub_reg_imm(&mut self, out_reg: u8, arg: u8, imm: f32) {
259 let imm = self.load_imm(imm);
260 self.build_sub(out_reg, arg, imm);
261 }
262 fn build_mul_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
266 let imm = self.load_imm(imm);
267 self.build_mul(out_reg, lhs_reg, imm);
268 }
269
270 fn load_imm(&mut self, imm: f32) -> u8;
272
273 fn finalize(self) -> Result<Mmap, DynasmError>;
275}
276
277pub trait SimdSize {
279 const SIMD_SIZE: usize;
284}
285
286pub(crate) struct AssemblerData<T> {
289 ops: MmapAssembler,
290
291 mem_offset: usize,
293
294 saved_callee_regs: bool,
299
300 _p: std::marker::PhantomData<*const T>,
301}
302
303impl<T> AssemblerData<T> {
304 fn new(mmap: Mmap) -> Self {
305 Self {
306 ops: MmapAssembler::from(mmap),
307 mem_offset: 0,
308 saved_callee_regs: false,
309 _p: std::marker::PhantomData,
310 }
311 }
312
313 fn prepare_stack(&mut self, slot_count: usize, stack_size: usize) {
314 let mem = slot_count.saturating_sub(REGISTER_LIMIT)
316 * std::mem::size_of::<T>()
317 + stack_size;
318
319 self.mem_offset = mem.next_multiple_of(16);
321 self.push_stack();
322 }
323
324 fn stack_pos(&self, slot: u32) -> u32 {
325 assert!(slot >= REGISTER_LIMIT as u32);
326 (slot - REGISTER_LIMIT as u32) * std::mem::size_of::<T>() as u32
327 }
328}
329
330#[cfg(target_arch = "x86_64")]
331impl<T> AssemblerData<T> {
332 fn push_stack(&mut self) {
333 dynasm!(self.ops
334 ; sub rsp, self.mem_offset as i32
335 );
336 }
337
338 fn finalize(mut self) -> Result<Mmap, DynasmError> {
339 dynasm!(self.ops
340 ; add rsp, self.mem_offset as i32
341 ; pop rbp
342 ; vzeroupper
343 ; ret
344 );
345 self.ops.finalize()
346 }
347}
348
349#[cfg(target_arch = "aarch64")]
350#[allow(clippy::unnecessary_cast)] impl<T> AssemblerData<T> {
352 fn push_stack(&mut self) {
353 if self.mem_offset < 4096 {
354 dynasm!(self.ops
355 ; sub sp, sp, self.mem_offset as u32
356 );
357 } else if self.mem_offset < 65536 {
358 dynasm!(self.ops
359 ; mov w28, self.mem_offset as u32
360 ; sub sp, sp, w28
361 );
362 } else {
363 panic!("invalid mem offset: {} is too large", self.mem_offset);
364 }
365 }
366
367 fn finalize(mut self) -> Result<Mmap, DynasmError> {
368 if self.mem_offset < 4096 {
370 dynasm!(self.ops
371 ; add sp, sp, self.mem_offset as u32
372 );
373 } else if self.mem_offset < 65536 {
374 dynasm!(self.ops
375 ; mov w9, self.mem_offset as u32
376 ; add sp, sp, w9
377 );
378 } else {
379 panic!("invalid mem offset: {}", self.mem_offset);
380 }
381
382 dynasm!(self.ops
383 ; ret
384 );
385 self.ops.finalize()
386 }
387}
388
389#[cfg(target_arch = "x86_64")]
392type Relocation = dynasmrt::x64::X64Relocation;
393
394#[cfg(target_arch = "aarch64")]
395type Relocation = dynasmrt::aarch64::Aarch64Relocation;
396
397struct MmapAssembler {
398 mmap: MmapWriter,
399
400 global_labels: [Option<AssemblyOffset>; 26],
401 local_labels: [Option<AssemblyOffset>; 26],
402
403 global_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 2>,
404 local_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 8>,
405}
406
407impl Extend<u8> for MmapAssembler {
408 fn extend<T>(&mut self, iter: T)
409 where
410 T: IntoIterator<Item = u8>,
411 {
412 for c in iter.into_iter() {
413 self.push(c);
414 }
415 }
416}
417
418impl<'a> Extend<&'a u8> for MmapAssembler {
419 fn extend<T>(&mut self, iter: T)
420 where
421 T: IntoIterator<Item = &'a u8>,
422 {
423 for c in iter.into_iter() {
424 self.push(*c);
425 }
426 }
427}
428
429impl DynasmApi for MmapAssembler {
430 #[inline(always)]
431 fn offset(&self) -> AssemblyOffset {
432 AssemblyOffset(self.mmap.len())
433 }
434
435 #[inline(always)]
436 fn push(&mut self, byte: u8) {
437 self.mmap.push(byte);
438 }
439
440 #[inline(always)]
441 fn align(&mut self, alignment: usize, with: u8) {
442 let offset = self.offset().0 % alignment;
443 if offset != 0 {
444 for _ in offset..alignment {
445 self.push(with);
446 }
447 }
448 }
449
450 #[inline(always)]
451 fn push_u32(&mut self, value: u32) {
452 for b in value.to_le_bytes() {
453 self.mmap.push(b);
454 }
455 }
456}
457
458impl DynasmLabelApi for MmapAssembler {
473 type Relocation = Relocation;
474
475 fn local_label(&mut self, name: &'static str) {
476 if name.len() != 1 {
477 panic!("local label must be a single character");
478 }
479 let c = name.as_bytes()[0].wrapping_sub(b'A');
480 if c >= 26 {
481 panic!("Invalid label {name}, must be A-Z");
482 }
483 if self.local_labels[c as usize].is_some() {
484 panic!("duplicate local label {name}");
485 }
486
487 self.local_labels[c as usize] = Some(self.offset());
488 }
489 fn global_label(&mut self, name: &'static str) {
490 if name.len() != 1 {
491 panic!("local label must be a single character");
492 }
493 let c = name.as_bytes()[0].wrapping_sub(b'A');
494 if c >= 26 {
495 panic!("Invalid label {name}, must be A-Z");
496 }
497 if self.global_labels[c as usize].is_some() {
498 panic!("duplicate global label {name}");
499 }
500
501 self.global_labels[c as usize] = Some(self.offset());
502 }
503 fn dynamic_label(&mut self, _id: DynamicLabel) {
504 panic!("dynamic labels are not supported");
505 }
506 fn global_relocation(
507 &mut self,
508 name: &'static str,
509 target_offset: isize,
510 field_offset: u8,
511 ref_offset: u8,
512 kind: Relocation,
513 ) {
514 let location = self.offset();
515 if name.len() != 1 {
516 panic!("local label must be a single character");
517 }
518 let c = name.as_bytes()[0].wrapping_sub(b'A');
519 if c >= 26 {
520 panic!("Invalid label {name}, must be A-Z");
521 }
522 self.global_relocs.push((
523 PatchLoc::new(
524 location,
525 target_offset,
526 field_offset,
527 ref_offset,
528 kind,
529 ),
530 c,
531 ));
532 }
533 fn dynamic_relocation(
534 &mut self,
535 _id: DynamicLabel,
536 _target_offset: isize,
537 _field_offset: u8,
538 _ref_offset: u8,
539 _kind: Relocation,
540 ) {
541 panic!("dynamic relocations are not supported");
542 }
543 fn forward_relocation(
544 &mut self,
545 name: &'static str,
546 target_offset: isize,
547 field_offset: u8,
548 ref_offset: u8,
549 kind: Relocation,
550 ) {
551 if name.len() != 1 {
552 panic!("local label must be a single character");
553 }
554 let c = name.as_bytes()[0].wrapping_sub(b'A');
555 if c >= 26 {
556 panic!("Invalid label {name}, must be A-Z");
557 }
558 if self.local_labels[c as usize].is_some() {
559 panic!("invalid forward relocation: {name} already exists!");
560 }
561 let location = self.offset();
562 self.local_relocs.push((
563 PatchLoc::new(
564 location,
565 target_offset,
566 field_offset,
567 ref_offset,
568 kind,
569 ),
570 c,
571 ));
572 }
573 fn backward_relocation(
574 &mut self,
575 name: &'static str,
576 target_offset: isize,
577 field_offset: u8,
578 ref_offset: u8,
579 kind: Relocation,
580 ) {
581 if name.len() != 1 {
582 panic!("local label must be a single character");
583 }
584 let c = name.as_bytes()[0].wrapping_sub(b'A');
585 if c >= 26 {
586 panic!("Invalid label {name}, must be A-Z");
587 }
588 if self.local_labels[c as usize].is_none() {
589 panic!("invalid backward relocation: {name} does not exist");
590 }
591 let location = self.offset();
592 self.local_relocs.push((
593 PatchLoc::new(
594 location,
595 target_offset,
596 field_offset,
597 ref_offset,
598 kind,
599 ),
600 c,
601 ));
602 }
603 fn bare_relocation(
604 &mut self,
605 _target: usize,
606 _field_offset: u8,
607 _ref_offset: u8,
608 _kind: Relocation,
609 ) {
610 panic!("bare relocations not implemented");
611 }
612}
613
614impl MmapAssembler {
615 fn commit_local(&mut self) -> Result<(), DynasmError> {
619 let baseaddr = self.mmap.as_ptr() as usize;
620
621 for (loc, label) in self.local_relocs.take() {
622 let target =
623 self.local_labels[label as usize].expect("invalid local label");
624 let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
625 if loc.patch(buf, baseaddr, target.0).is_err() {
626 return Err(DynasmError::ImpossibleRelocation(
627 TargetKind::Local("oh no"),
628 ));
629 }
630 }
631 self.local_labels = [None; 26];
632 Ok(())
633 }
634
635 fn finalize(mut self) -> Result<Mmap, DynasmError> {
636 self.commit_local()?;
637
638 let baseaddr = self.mmap.as_ptr() as usize;
639 for (loc, label) in self.global_relocs.take() {
640 let target =
641 self.global_labels.get(label as usize).unwrap().unwrap();
642 let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
643 if loc.patch(buf, baseaddr, target.0).is_err() {
644 return Err(DynasmError::ImpossibleRelocation(
645 TargetKind::Global("oh no"),
646 ));
647 }
648 }
649
650 Ok(self.mmap.finalize())
651 }
652}
653
654impl From<Mmap> for MmapAssembler {
655 fn from(mmap: Mmap) -> Self {
656 Self {
657 mmap: MmapWriter::from(mmap),
658 global_labels: [None; 26],
659 local_labels: [None; 26],
660 global_relocs: Default::default(),
661 local_relocs: Default::default(),
662 }
663 }
664}
665
666fn build_asm_fn_with_storage<A: Assembler>(
669 t: &VmData<REGISTER_LIMIT>,
670 mut s: Mmap,
671) -> Mmap {
672 let size_estimate = t.len() * A::bytes_per_clause();
673 if size_estimate > 2 * s.capacity() {
674 s = Mmap::new(size_estimate).expect("failed to build mmap")
675 }
676
677 let mut asm = A::init(s, t.slot_count());
678
679 for op in t.iter_asm() {
680 match op {
681 RegOp::Load(reg, mem) => {
682 asm.build_load(reg, mem);
683 }
684 RegOp::Store(reg, mem) => {
685 asm.build_store(mem, reg);
686 }
687 RegOp::Input(out, i) => {
688 asm.build_input(out, i);
689 }
690 RegOp::Output(arg, i) => {
691 asm.build_output(arg, i);
692 }
693 RegOp::NegReg(out, arg) => {
694 asm.build_neg(out, arg);
695 }
696 RegOp::AbsReg(out, arg) => {
697 asm.build_abs(out, arg);
698 }
699 RegOp::RecipReg(out, arg) => {
700 asm.build_recip(out, arg);
701 }
702 RegOp::SqrtReg(out, arg) => {
703 asm.build_sqrt(out, arg);
704 }
705 RegOp::SinReg(out, arg) => {
706 asm.build_sin(out, arg);
707 }
708 RegOp::CosReg(out, arg) => {
709 asm.build_cos(out, arg);
710 }
711 RegOp::TanReg(out, arg) => {
712 asm.build_tan(out, arg);
713 }
714 RegOp::AsinReg(out, arg) => {
715 asm.build_asin(out, arg);
716 }
717 RegOp::AcosReg(out, arg) => {
718 asm.build_acos(out, arg);
719 }
720 RegOp::AtanReg(out, arg) => {
721 asm.build_atan(out, arg);
722 }
723 RegOp::ExpReg(out, arg) => {
724 asm.build_exp(out, arg);
725 }
726 RegOp::LnReg(out, arg) => {
727 asm.build_ln(out, arg);
728 }
729 RegOp::CopyReg(out, arg) => {
730 asm.build_copy(out, arg);
731 }
732 RegOp::SquareReg(out, arg) => {
733 asm.build_square(out, arg);
734 }
735 RegOp::FloorReg(out, arg) => {
736 asm.build_floor(out, arg);
737 }
738 RegOp::CeilReg(out, arg) => {
739 asm.build_ceil(out, arg);
740 }
741 RegOp::RoundReg(out, arg) => {
742 asm.build_round(out, arg);
743 }
744 RegOp::NotReg(out, arg) => {
745 asm.build_not(out, arg);
746 }
747 RegOp::AddRegReg(out, lhs, rhs) => {
748 asm.build_add(out, lhs, rhs);
749 }
750 RegOp::MulRegReg(out, lhs, rhs) => {
751 asm.build_mul(out, lhs, rhs);
752 }
753 RegOp::DivRegReg(out, lhs, rhs) => {
754 asm.build_div(out, lhs, rhs);
755 }
756 RegOp::AtanRegReg(out, lhs, rhs) => {
757 asm.build_atan2(out, lhs, rhs);
758 }
759 RegOp::SubRegReg(out, lhs, rhs) => {
760 asm.build_sub(out, lhs, rhs);
761 }
762 RegOp::MinRegReg(out, lhs, rhs) => {
763 asm.build_min(out, lhs, rhs);
764 }
765 RegOp::MaxRegReg(out, lhs, rhs) => {
766 asm.build_max(out, lhs, rhs);
767 }
768 RegOp::AddRegImm(out, arg, imm) => {
769 asm.build_add_imm(out, arg, imm);
770 }
771 RegOp::MulRegImm(out, arg, imm) => {
772 asm.build_mul_imm(out, arg, imm);
773 }
774 RegOp::DivRegImm(out, arg, imm) => {
775 let reg = asm.load_imm(imm);
776 asm.build_div(out, arg, reg);
777 }
778 RegOp::DivImmReg(out, arg, imm) => {
779 let reg = asm.load_imm(imm);
780 asm.build_div(out, reg, arg);
781 }
782 RegOp::AtanRegImm(out, arg, imm) => {
783 let reg = asm.load_imm(imm);
784 asm.build_atan2(out, arg, reg);
785 }
786 RegOp::AtanImmReg(out, arg, imm) => {
787 let reg = asm.load_imm(imm);
788 asm.build_atan2(out, reg, arg);
789 }
790 RegOp::SubImmReg(out, arg, imm) => {
791 asm.build_sub_imm_reg(out, arg, imm);
792 }
793 RegOp::SubRegImm(out, arg, imm) => {
794 asm.build_sub_reg_imm(out, arg, imm);
795 }
796 RegOp::MinRegImm(out, arg, imm) => {
797 let reg = asm.load_imm(imm);
798 asm.build_min(out, arg, reg);
799 }
800 RegOp::MaxRegImm(out, arg, imm) => {
801 let reg = asm.load_imm(imm);
802 asm.build_max(out, arg, reg);
803 }
804 RegOp::ModRegReg(out, lhs, rhs) => {
805 asm.build_mod(out, lhs, rhs);
806 }
807 RegOp::ModRegImm(out, arg, imm) => {
808 let reg = asm.load_imm(imm);
809 asm.build_mod(out, arg, reg);
810 }
811 RegOp::ModImmReg(out, arg, imm) => {
812 let reg = asm.load_imm(imm);
813 asm.build_mod(out, reg, arg);
814 }
815 RegOp::AndRegReg(out, lhs, rhs) => {
816 asm.build_and(out, lhs, rhs);
817 }
818 RegOp::AndRegImm(out, arg, imm) => {
819 let reg = asm.load_imm(imm);
820 asm.build_and(out, arg, reg);
821 }
822 RegOp::OrRegReg(out, lhs, rhs) => {
823 asm.build_or(out, lhs, rhs);
824 }
825 RegOp::OrRegImm(out, arg, imm) => {
826 let reg = asm.load_imm(imm);
827 asm.build_or(out, arg, reg);
828 }
829 RegOp::CopyImm(out, imm) => {
830 let reg = asm.load_imm(imm);
831 asm.build_copy(out, reg);
832 }
833 RegOp::CompareRegReg(out, lhs, rhs) => {
834 asm.build_compare(out, lhs, rhs);
835 }
836 RegOp::CompareRegImm(out, arg, imm) => {
837 let reg = asm.load_imm(imm);
838 asm.build_compare(out, arg, reg);
839 }
840 RegOp::CompareImmReg(out, arg, imm) => {
841 let reg = asm.load_imm(imm);
842 asm.build_compare(out, reg, arg);
843 }
844 }
845 }
846
847 asm.finalize().expect("failed to build JIT function")
848 }
850
851#[derive(Clone)]
853pub struct JitFunction(GenericVmFunction<REGISTER_LIMIT>);
854
855impl JitFunction {
856 fn tracing_tape<A: Assembler>(
857 &self,
858 storage: Mmap,
859 ) -> JitTracingFn<A::Data> {
860 let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
861 let ptr = f.as_ptr();
862 JitTracingFn {
863 mmap: f.into(),
864 vars: self.0.data().vars.clone(),
865 choice_count: self.0.choice_count(),
866 output_count: self.0.output_count(),
867 fn_trace: unsafe {
868 std::mem::transmute::<
869 *const std::ffi::c_void,
870 JitTracingFnPointer<A::Data>,
871 >(ptr)
872 },
873 }
874 }
875 fn bulk_tape<A: Assembler>(&self, storage: Mmap) -> JitBulkFn<A::Data> {
876 let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
877 let ptr = f.as_ptr();
878 JitBulkFn {
879 mmap: f.into(),
880 output_count: self.0.output_count(),
881 vars: self.0.data().vars.clone(),
882 fn_bulk: unsafe {
883 std::mem::transmute::<
884 *const std::ffi::c_void,
885 JitBulkFnPointer<A::Data>,
886 >(ptr)
887 },
888 }
889 }
890}
891
892impl Function for JitFunction {
893 type Trace = VmTrace;
894 type Storage = VmData<REGISTER_LIMIT>;
895 type Workspace = VmWorkspace<REGISTER_LIMIT>;
896
897 type TapeStorage = Mmap;
898
899 type IntervalEval = JitIntervalEval;
900 type PointEval = JitPointEval;
901 type FloatSliceEval = JitFloatSliceEval;
902 type GradSliceEval = JitGradSliceEval;
903
904 #[inline]
905 fn point_tape(&self, storage: Mmap) -> JitTracingFn<f32> {
906 self.tracing_tape::<point::PointAssembler>(storage)
907 }
908
909 #[inline]
910 fn interval_tape(&self, storage: Mmap) -> JitTracingFn<Interval> {
911 self.tracing_tape::<interval::IntervalAssembler>(storage)
912 }
913
914 #[inline]
915 fn float_slice_tape(&self, storage: Mmap) -> JitBulkFn<f32> {
916 self.bulk_tape::<float_slice::FloatSliceAssembler>(storage)
917 }
918
919 #[inline]
920 fn grad_slice_tape(&self, storage: Mmap) -> JitBulkFn<Grad> {
921 self.bulk_tape::<grad_slice::GradSliceAssembler>(storage)
922 }
923
924 #[inline]
925 fn simplify(
926 &self,
927 trace: &Self::Trace,
928 storage: Self::Storage,
929 workspace: &mut Self::Workspace,
930 ) -> Result<Self, Error> {
931 self.0.simplify(trace, storage, workspace).map(JitFunction)
932 }
933
934 #[inline]
935 fn recycle(self) -> Option<Self::Storage> {
936 self.0.recycle()
937 }
938
939 #[inline]
940 fn size(&self) -> usize {
941 self.0.size()
942 }
943
944 #[inline]
945 fn vars(&self) -> &VarMap {
946 self.0.vars()
947 }
948
949 #[inline]
950 fn can_simplify(&self) -> bool {
951 self.0.choice_count() > 0
952 }
953}
954
955impl RenderHints for JitFunction {
956 fn tile_sizes_3d() -> TileSizes {
957 TileSizes::new(&[64, 16, 8]).unwrap()
958 }
959
960 fn tile_sizes_2d() -> TileSizes {
961 TileSizes::new(&[128, 16]).unwrap()
962 }
963
964 fn simplify_tree_during_meshing(d: usize) -> bool {
965 d % 8 == 4
967 }
968}
969
970impl MathFunction for JitFunction {
971 fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
972 GenericVmFunction::new(ctx, nodes).map(JitFunction)
973 }
974}
975
976impl From<GenericVmFunction<REGISTER_LIMIT>> for JitFunction {
977 fn from(v: GenericVmFunction<REGISTER_LIMIT>) -> Self {
978 Self(v)
979 }
980}
981
982impl<'a> From<&'a JitFunction> for &'a GenericVmFunction<REGISTER_LIMIT> {
983 fn from(v: &'a JitFunction) -> Self {
984 &v.0
985 }
986}
987
988#[cfg(target_arch = "x86_64")]
997macro_rules! jit_fn {
998 (unsafe fn($($args:tt)*)) => {
999 unsafe extern "sysv64" fn($($args)*)
1000 };
1001}
1002
1003#[cfg(target_arch = "aarch64")]
1007macro_rules! jit_fn {
1008 (unsafe fn($($args:tt)*)) => {
1009 unsafe extern "C" fn($($args)*)
1010 };
1011}
1012
1013struct JitTracingEval<T> {
1020 choices: VmTrace,
1021 out: Vec<T>,
1022}
1023
1024impl<T> Default for JitTracingEval<T> {
1025 fn default() -> Self {
1026 Self {
1027 choices: VmTrace::default(),
1028 out: Vec::default(),
1029 }
1030 }
1031}
1032
1033pub type JitTracingFnPointer<T> = jit_fn!(
1035 unsafe fn(
1036 *const T, *mut u8, *mut u8, *mut T, )
1041);
1042
1043#[derive(Clone)]
1045pub struct JitTracingFn<T> {
1046 mmap: Arc<Mmap>,
1047 choice_count: usize,
1048 output_count: usize,
1049 vars: Arc<VarMap>,
1050 fn_trace: JitTracingFnPointer<T>,
1051}
1052
1053impl<T: Clone> Tape for JitTracingFn<T> {
1054 type Storage = Mmap;
1055 fn recycle(self) -> Option<Self::Storage> {
1056 Arc::into_inner(self.mmap)
1057 }
1058
1059 fn vars(&self) -> &VarMap {
1060 &self.vars
1061 }
1062
1063 fn output_count(&self) -> usize {
1064 self.output_count
1065 }
1066}
1067
1068unsafe impl<T> Send for JitTracingFn<T> {}
1071unsafe impl<T> Sync for JitTracingFn<T> {}
1072
1073impl<T: From<f32> + Clone> JitTracingEval<T> {
1074 fn eval(
1076 &mut self,
1077 tape: &JitTracingFn<T>,
1078 vars: &[T],
1079 ) -> (&[T], Option<&VmTrace>) {
1080 let mut simplify = 0;
1081 self.choices.resize(tape.choice_count, Choice::Unknown);
1082 self.choices.fill(Choice::Unknown);
1083 self.out.resize(tape.output_count, f32::NAN.into());
1084 self.out.fill(f32::NAN.into());
1085 unsafe {
1086 (tape.fn_trace)(
1087 vars.as_ptr(),
1088 self.choices.as_mut_ptr() as *mut u8,
1089 &mut simplify,
1090 self.out.as_mut_ptr(),
1091 )
1092 };
1093
1094 (
1095 &self.out,
1096 if simplify != 0 {
1097 Some(&self.choices)
1098 } else {
1099 None
1100 },
1101 )
1102 }
1103}
1104
1105#[derive(Default)]
1107pub struct JitIntervalEval(JitTracingEval<Interval>);
1108impl TracingEvaluator for JitIntervalEval {
1109 type Data = Interval;
1110 type Tape = JitTracingFn<Interval>;
1111 type Trace = VmTrace;
1112 type TapeStorage = Mmap;
1113
1114 #[inline]
1115 fn eval(
1116 &mut self,
1117 tape: &Self::Tape,
1118 vars: &[Self::Data],
1119 ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1120 tape.vars().check_tracing_arguments(vars)?;
1121 Ok(self.0.eval(tape, vars))
1122 }
1123}
1124
1125#[derive(Default)]
1127pub struct JitPointEval(JitTracingEval<f32>);
1128impl TracingEvaluator for JitPointEval {
1129 type Data = f32;
1130 type Tape = JitTracingFn<f32>;
1131 type Trace = VmTrace;
1132 type TapeStorage = Mmap;
1133
1134 #[inline]
1135 fn eval(
1136 &mut self,
1137 tape: &Self::Tape,
1138 vars: &[Self::Data],
1139 ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1140 tape.vars().check_tracing_arguments(vars)?;
1141 Ok(self.0.eval(tape, vars))
1142 }
1143}
1144
1145pub type JitBulkFnPointer<T> = jit_fn!(
1149 unsafe fn(
1150 *const *const T, *const *mut T, u64, )
1154);
1155
1156#[derive(Clone)]
1158pub struct JitBulkFn<T> {
1159 mmap: Arc<Mmap>,
1160 vars: Arc<VarMap>,
1161 output_count: usize,
1162 fn_bulk: JitBulkFnPointer<T>,
1163}
1164
1165impl<T: Clone> Tape for JitBulkFn<T> {
1166 type Storage = Mmap;
1167 fn recycle(self) -> Option<Self::Storage> {
1168 Arc::into_inner(self.mmap)
1169 }
1170
1171 fn vars(&self) -> &VarMap {
1172 &self.vars
1173 }
1174
1175 fn output_count(&self) -> usize {
1176 self.output_count
1177 }
1178}
1179
1180const MAX_SIMD_WIDTH: usize = 8;
1187
1188struct JitBulkEval<T> {
1190 input_ptrs: Vec<*const T>,
1192
1193 output_ptrs: Vec<*mut T>,
1195
1196 scratch: Vec<[T; MAX_SIMD_WIDTH]>,
1198
1199 out: Vec<Vec<T>>,
1201}
1202
1203unsafe impl<T> Sync for JitBulkEval<T> {}
1206unsafe impl<T> Send for JitBulkEval<T> {}
1207
1208impl<T> Default for JitBulkEval<T> {
1209 fn default() -> Self {
1210 Self {
1211 out: vec![],
1212 scratch: vec![],
1213 input_ptrs: vec![],
1214 output_ptrs: vec![],
1215 }
1216 }
1217}
1218
1219unsafe impl<T> Send for JitBulkFn<T> {}
1222unsafe impl<T> Sync for JitBulkFn<T> {}
1223
1224impl<T: From<f32> + Copy + SimdSize> JitBulkEval<T> {
1225 fn eval<V: std::ops::Deref<Target = [T]>>(
1227 &mut self,
1228 tape: &JitBulkFn<T>,
1229 vars: &[V],
1230 ) -> BulkOutput<'_, T> {
1231 let n = vars.first().map(|v| v.deref().len()).unwrap_or(0);
1232
1233 self.out.resize_with(tape.output_count(), Vec::new);
1234 for o in &mut self.out {
1235 o.resize(n.max(T::SIMD_SIZE), f32::NAN.into());
1236 o.fill(f32::NAN.into());
1237 }
1238
1239 if n < T::SIMD_SIZE {
1243 assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH);
1244
1245 self.scratch
1246 .resize(vars.len(), [f32::NAN.into(); MAX_SIMD_WIDTH]);
1247 for (v, t) in vars.iter().zip(self.scratch.iter_mut()) {
1248 t[0..n].copy_from_slice(v);
1249 }
1250
1251 self.input_ptrs.clear();
1252 self.input_ptrs
1253 .extend(self.scratch[..vars.len()].iter().map(|t| t.as_ptr()));
1254
1255 self.output_ptrs.clear();
1256 self.output_ptrs
1257 .extend(self.out.iter_mut().map(|t| t.as_mut_ptr()));
1258
1259 unsafe {
1260 (tape.fn_bulk)(
1261 self.input_ptrs.as_ptr(),
1262 self.output_ptrs.as_ptr(),
1263 T::SIMD_SIZE as u64,
1264 );
1265 }
1266 } else {
1267 let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; self.input_ptrs.clear();
1272 self.input_ptrs.extend(vars.iter().map(|v| v.as_ptr()));
1273
1274 self.output_ptrs.clear();
1275 self.output_ptrs
1276 .extend(self.out.iter_mut().map(|v| v.as_mut_ptr()));
1277 unsafe {
1278 (tape.fn_bulk)(
1279 self.input_ptrs.as_ptr(),
1280 self.output_ptrs.as_ptr(),
1281 m as u64,
1282 );
1283 }
1284 if n != m {
1288 self.input_ptrs.clear();
1289 self.output_ptrs.clear();
1290 unsafe {
1291 self.input_ptrs.extend(
1292 vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)),
1293 );
1294 self.output_ptrs.extend(
1295 self.out
1296 .iter_mut()
1297 .map(|v| v.as_mut_ptr().add(n - T::SIMD_SIZE)),
1298 );
1299 (tape.fn_bulk)(
1300 self.input_ptrs.as_ptr(),
1301 self.output_ptrs.as_ptr(),
1302 T::SIMD_SIZE as u64,
1303 );
1304 }
1305 }
1306 }
1307 BulkOutput::new(&self.out, n)
1308 }
1309}
1310
1311#[derive(Default)]
1313pub struct JitFloatSliceEval(JitBulkEval<f32>);
1314impl BulkEvaluator for JitFloatSliceEval {
1315 type Data = f32;
1316 type Tape = JitBulkFn<Self::Data>;
1317 type TapeStorage = Mmap;
1318
1319 #[inline]
1320 fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1321 &mut self,
1322 tape: &Self::Tape,
1323 vars: &[V],
1324 ) -> Result<BulkOutput<'_, f32>, Error> {
1325 tape.vars().check_bulk_arguments(vars)?;
1326 Ok(self.0.eval(tape, vars))
1327 }
1328}
1329
1330#[derive(Default)]
1332pub struct JitGradSliceEval(JitBulkEval<Grad>);
1333impl BulkEvaluator for JitGradSliceEval {
1334 type Data = Grad;
1335 type Tape = JitBulkFn<Self::Data>;
1336 type TapeStorage = Mmap;
1337
1338 #[inline]
1339 fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1340 &mut self,
1341 tape: &Self::Tape,
1342 vars: &[V],
1343 ) -> Result<BulkOutput<'_, Grad>, Error> {
1344 tape.vars().check_bulk_arguments(vars)?;
1345 Ok(self.0.eval(tape, vars))
1346 }
1347}
1348
1349pub type JitShape = fidget_core::shape::Shape<JitFunction>;
1351
1352#[cfg(test)]
1355mod test {
1356 use super::*;
1357 fidget_core::grad_slice_tests!(JitFunction);
1358 fidget_core::interval_tests!(JitFunction);
1359 fidget_core::float_slice_tests!(JitFunction);
1360 fidget_core::point_tests!(JitFunction);
1361
1362 #[test]
1363 fn test_mmap_expansion() {
1364 let mmap = Mmap::new(0).unwrap();
1365
1366 let mut asm = MmapAssembler::from(mmap);
1367 const COUNT: u32 = 23456; for i in 0..COUNT {
1370 asm.push_u32(i);
1371 }
1372 let mmap = asm.finalize().unwrap();
1373 let ptr = mmap.as_ptr() as *const u32;
1374 for i in 0..COUNT {
1375 let v = unsafe { *ptr.add(i as usize) };
1376 assert_eq!(v, i);
1377 }
1378 }
1379}