1use crate::mmap::{Mmap, MmapWriter};
27use fidget_core::{
28 Error,
29 compiler::RegOp,
30 context::{Context, Node},
31 eval::{
32 BulkEvaluator, BulkOutput, Function, MathFunction, Tape,
33 TracingEvaluator,
34 },
35 render::{RenderHints, TileSizes},
36 types::{Grad, Interval},
37 var::VarMap,
38 vm::{Choice, GenericVmFunction, VmData, VmTrace, VmWorkspace},
39};
40
41use dynasmrt::{
42 AssemblyOffset, DynamicLabel, DynasmApi, DynasmError, DynasmLabelApi,
43 TargetKind, components::PatchLoc, dynasm,
44};
45use std::sync::Arc;
46
47mod mmap;
48mod permit;
49pub(crate) use permit::WritePermit;
50
51mod float_slice;
53mod grad_slice;
54mod interval;
55mod point;
56
57#[cfg(not(any(
58 target_os = "linux",
59 target_os = "macos",
60 target_os = "windows"
61)))]
62compile_error!(
63 "The `jit` module only builds on Linux, macOS, and Windows; \
64 please disable the `jit` feature"
65);
66
67#[cfg(target_arch = "aarch64")]
68mod aarch64;
69#[cfg(target_arch = "aarch64")]
70use aarch64 as arch;
71
72#[cfg(target_arch = "x86_64")]
73mod x86_64;
74#[cfg(target_arch = "x86_64")]
75use x86_64 as arch;
76
77const REGISTER_LIMIT: usize = arch::REGISTER_LIMIT;
79
80const OFFSET: u8 = arch::OFFSET;
82
83const IMM_REG: u8 = arch::IMM_REG;
89
90#[cfg(target_arch = "aarch64")]
92type RegIndex = u32;
93
94#[cfg(target_arch = "x86_64")]
96type RegIndex = u8;
97
98fn reg(r: u8) -> RegIndex {
107 let out = r.wrapping_add(OFFSET) as RegIndex;
108 assert!(out < 32);
109 out
110}
111
112const CHOICE_LEFT: u32 = Choice::Left as u32;
113const CHOICE_RIGHT: u32 = Choice::Right as u32;
114const CHOICE_BOTH: u32 = Choice::Both as u32;
115
116trait Assembler {
118 type Data;
122
123 fn init(m: Mmap, slot_count: usize) -> Self;
128
129 fn bytes_per_clause() -> usize {
131 8 }
133
134 fn build_load(&mut self, dst_reg: u8, src_mem: u32);
136
137 fn build_store(&mut self, dst_mem: u32, src_reg: u8);
139
140 fn build_input(&mut self, out_reg: u8, src_arg: u32);
142
143 fn build_output(&mut self, arg_reg: u8, out_index: u32);
145
146 fn build_copy(&mut self, out_reg: u8, lhs_reg: u8);
148
149 fn build_neg(&mut self, out_reg: u8, lhs_reg: u8);
151
152 fn build_abs(&mut self, out_reg: u8, lhs_reg: u8);
154
155 fn build_recip(&mut self, out_reg: u8, lhs_reg: u8);
157
158 fn build_sqrt(&mut self, out_reg: u8, lhs_reg: u8);
160
161 fn build_sin(&mut self, out_reg: u8, lhs_reg: u8);
163
164 fn build_cos(&mut self, out_reg: u8, lhs_reg: u8);
166
167 fn build_tan(&mut self, out_reg: u8, lhs_reg: u8);
169
170 fn build_asin(&mut self, out_reg: u8, lhs_reg: u8);
172
173 fn build_acos(&mut self, out_reg: u8, lhs_reg: u8);
175
176 fn build_atan(&mut self, out_reg: u8, lhs_reg: u8);
178
179 fn build_exp(&mut self, out_reg: u8, lhs_reg: u8);
181
182 fn build_ln(&mut self, out_reg: u8, lhs_reg: u8);
184
185 fn build_compare(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
187
188 fn build_square(&mut self, out_reg: u8, lhs_reg: u8) {
194 self.build_mul(out_reg, lhs_reg, lhs_reg)
195 }
196
197 fn build_floor(&mut self, out_reg: u8, lhs_reg: u8);
199
200 fn build_ceil(&mut self, out_reg: u8, lhs_reg: u8);
202
203 fn build_round(&mut self, out_reg: u8, lhs_reg: u8);
205
206 fn build_not(&mut self, out_reg: u8, lhs_reg: u8);
208
209 fn build_and(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
211
212 fn build_or(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
214
215 fn build_add(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
217
218 fn build_sub(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
220
221 fn build_mul(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
223
224 fn build_div(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
226
227 fn build_atan2(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
229
230 fn build_max(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
235
236 fn build_min(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
241
242 fn build_mod(&mut self, out_reg: u8, lhs_reg: u8, rhs_reg: u8);
244
245 fn build_add_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
253 let imm = self.load_imm(imm);
254 self.build_add(out_reg, lhs_reg, imm);
255 }
256 fn build_sub_imm_reg(&mut self, out_reg: u8, arg: u8, imm: f32) {
260 let imm = self.load_imm(imm);
261 self.build_sub(out_reg, imm, arg);
262 }
263 fn build_sub_reg_imm(&mut self, out_reg: u8, arg: u8, imm: f32) {
267 let imm = self.load_imm(imm);
268 self.build_sub(out_reg, arg, imm);
269 }
270 fn build_mul_imm(&mut self, out_reg: u8, lhs_reg: u8, imm: f32) {
274 let imm = self.load_imm(imm);
275 self.build_mul(out_reg, lhs_reg, imm);
276 }
277
278 fn load_imm(&mut self, imm: f32) -> u8;
280
281 fn finalize(self) -> Result<Mmap, DynasmError>;
283}
284
285pub trait SimdSize {
287 const SIMD_SIZE: usize;
292}
293
294pub(crate) struct AssemblerData<T> {
297 ops: MmapAssembler,
298
299 mem_offset: usize,
301
302 saved_callee_regs: bool,
307
308 _p: std::marker::PhantomData<*const T>,
309}
310
311impl<T> AssemblerData<T> {
312 fn new(mmap: Mmap) -> Self {
313 Self {
314 ops: MmapAssembler::from(mmap),
315 mem_offset: 0,
316 saved_callee_regs: false,
317 _p: std::marker::PhantomData,
318 }
319 }
320
321 fn prepare_stack(&mut self, slot_count: usize, stack_size: usize) {
322 let mem = slot_count.saturating_sub(REGISTER_LIMIT)
324 * std::mem::size_of::<T>()
325 + stack_size;
326
327 self.mem_offset = mem.next_multiple_of(16);
329 self.push_stack();
330 }
331
332 fn stack_pos(&self, slot: u32) -> u32 {
333 assert!(slot >= REGISTER_LIMIT as u32);
334 (slot - REGISTER_LIMIT as u32) * std::mem::size_of::<T>() as u32
335 }
336}
337
338#[cfg(target_arch = "x86_64")]
339impl<T> AssemblerData<T> {
340 fn push_stack(&mut self) {
341 dynasm!(self.ops
342 ; sub rsp, self.mem_offset as i32
343 );
344 }
345
346 fn finalize(mut self) -> Result<Mmap, DynasmError> {
347 dynasm!(self.ops
348 ; add rsp, self.mem_offset as i32
349 ; pop rbp
350 ; emms
351 ; vzeroall
352 ; ret
353 );
354 self.ops.finalize()
355 }
356}
357
358#[cfg(target_arch = "aarch64")]
359#[allow(clippy::unnecessary_cast)] impl<T> AssemblerData<T> {
361 fn push_stack(&mut self) {
362 if self.mem_offset < 4096 {
363 dynasm!(self.ops
364 ; sub sp, sp, self.mem_offset as u32
365 );
366 } else if self.mem_offset < 65536 {
367 dynasm!(self.ops
368 ; mov w28, self.mem_offset as u32
369 ; sub sp, sp, w28
370 );
371 } else {
372 panic!("invalid mem offset: {} is too large", self.mem_offset);
373 }
374 }
375
376 fn finalize(mut self) -> Result<Mmap, DynasmError> {
377 if self.mem_offset < 4096 {
379 dynasm!(self.ops
380 ; add sp, sp, self.mem_offset as u32
381 );
382 } else if self.mem_offset < 65536 {
383 dynasm!(self.ops
384 ; mov w9, self.mem_offset as u32
385 ; add sp, sp, w9
386 );
387 } else {
388 panic!("invalid mem offset: {}", self.mem_offset);
389 }
390
391 dynasm!(self.ops
392 ; ret
393 );
394 self.ops.finalize()
395 }
396}
397
398#[cfg(target_arch = "x86_64")]
401type Relocation = dynasmrt::x64::X64Relocation;
402
403#[cfg(target_arch = "aarch64")]
404type Relocation = dynasmrt::aarch64::Aarch64Relocation;
405
406struct MmapAssembler {
407 mmap: MmapWriter,
408
409 global_labels: [Option<AssemblyOffset>; 26],
410 local_labels: [Option<AssemblyOffset>; 26],
411
412 global_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 2>,
413 local_relocs: arrayvec::ArrayVec<(PatchLoc<Relocation>, u8), 8>,
414}
415
416impl Extend<u8> for MmapAssembler {
417 fn extend<T>(&mut self, iter: T)
418 where
419 T: IntoIterator<Item = u8>,
420 {
421 for c in iter.into_iter() {
422 self.push(c);
423 }
424 }
425}
426
427impl<'a> Extend<&'a u8> for MmapAssembler {
428 fn extend<T>(&mut self, iter: T)
429 where
430 T: IntoIterator<Item = &'a u8>,
431 {
432 for c in iter.into_iter() {
433 self.push(*c);
434 }
435 }
436}
437
438impl DynasmApi for MmapAssembler {
439 #[inline(always)]
440 fn offset(&self) -> AssemblyOffset {
441 AssemblyOffset(self.mmap.len())
442 }
443
444 #[inline(always)]
445 fn push(&mut self, byte: u8) {
446 self.mmap.push(byte);
447 }
448
449 #[inline(always)]
450 fn align(&mut self, alignment: usize, with: u8) {
451 let offset = self.offset().0 % alignment;
452 if offset != 0 {
453 for _ in offset..alignment {
454 self.push(with);
455 }
456 }
457 }
458
459 #[inline(always)]
460 fn push_u32(&mut self, value: u32) {
461 for b in value.to_le_bytes() {
462 self.mmap.push(b);
463 }
464 }
465}
466
467impl DynasmLabelApi for MmapAssembler {
482 type Relocation = Relocation;
483
484 fn local_label(&mut self, name: &'static str) {
485 if name.len() != 1 {
486 panic!("local label must be a single character");
487 }
488 let c = name.as_bytes()[0].wrapping_sub(b'A');
489 if c >= 26 {
490 panic!("Invalid label {name}, must be A-Z");
491 }
492 if self.local_labels[c as usize].is_some() {
493 panic!("duplicate local label {name}");
494 }
495
496 self.local_labels[c as usize] = Some(self.offset());
497 }
498 fn global_label(&mut self, name: &'static str) {
499 if name.len() != 1 {
500 panic!("local label must be a single character");
501 }
502 let c = name.as_bytes()[0].wrapping_sub(b'A');
503 if c >= 26 {
504 panic!("Invalid label {name}, must be A-Z");
505 }
506 if self.global_labels[c as usize].is_some() {
507 panic!("duplicate global label {name}");
508 }
509
510 self.global_labels[c as usize] = Some(self.offset());
511 }
512 fn dynamic_label(&mut self, _id: DynamicLabel) {
513 panic!("dynamic labels are not supported");
514 }
515 fn global_relocation(
516 &mut self,
517 name: &'static str,
518 target_offset: isize,
519 field_offset: u8,
520 ref_offset: u8,
521 kind: Relocation,
522 ) {
523 let location = self.offset();
524 if name.len() != 1 {
525 panic!("local label must be a single character");
526 }
527 let c = name.as_bytes()[0].wrapping_sub(b'A');
528 if c >= 26 {
529 panic!("Invalid label {name}, must be A-Z");
530 }
531 self.global_relocs.push((
532 PatchLoc::new(
533 location,
534 target_offset,
535 field_offset,
536 ref_offset,
537 kind,
538 ),
539 c,
540 ));
541 }
542 fn dynamic_relocation(
543 &mut self,
544 _id: DynamicLabel,
545 _target_offset: isize,
546 _field_offset: u8,
547 _ref_offset: u8,
548 _kind: Relocation,
549 ) {
550 panic!("dynamic relocations are not supported");
551 }
552 fn forward_relocation(
553 &mut self,
554 name: &'static str,
555 target_offset: isize,
556 field_offset: u8,
557 ref_offset: u8,
558 kind: Relocation,
559 ) {
560 if name.len() != 1 {
561 panic!("local label must be a single character");
562 }
563 let c = name.as_bytes()[0].wrapping_sub(b'A');
564 if c >= 26 {
565 panic!("Invalid label {name}, must be A-Z");
566 }
567 if self.local_labels[c as usize].is_some() {
568 panic!("invalid forward relocation: {name} already exists!");
569 }
570 let location = self.offset();
571 self.local_relocs.push((
572 PatchLoc::new(
573 location,
574 target_offset,
575 field_offset,
576 ref_offset,
577 kind,
578 ),
579 c,
580 ));
581 }
582 fn backward_relocation(
583 &mut self,
584 name: &'static str,
585 target_offset: isize,
586 field_offset: u8,
587 ref_offset: u8,
588 kind: Relocation,
589 ) {
590 if name.len() != 1 {
591 panic!("local label must be a single character");
592 }
593 let c = name.as_bytes()[0].wrapping_sub(b'A');
594 if c >= 26 {
595 panic!("Invalid label {name}, must be A-Z");
596 }
597 if self.local_labels[c as usize].is_none() {
598 panic!("invalid backward relocation: {name} does not exist");
599 }
600 let location = self.offset();
601 self.local_relocs.push((
602 PatchLoc::new(
603 location,
604 target_offset,
605 field_offset,
606 ref_offset,
607 kind,
608 ),
609 c,
610 ));
611 }
612 fn bare_relocation(
613 &mut self,
614 _target: usize,
615 _field_offset: u8,
616 _ref_offset: u8,
617 _kind: Relocation,
618 ) {
619 panic!("bare relocations not implemented");
620 }
621}
622
623impl MmapAssembler {
624 fn commit_local(&mut self) -> Result<(), DynasmError> {
628 let baseaddr = self.mmap.as_ptr() as usize;
629
630 for (loc, label) in self.local_relocs.take() {
631 let target =
632 self.local_labels[label as usize].expect("invalid local label");
633 let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
634 if loc.patch(buf, baseaddr, target.0).is_err() {
635 return Err(DynasmError::ImpossibleRelocation(
636 TargetKind::Local("oh no"),
637 ));
638 }
639 }
640 self.local_labels = [None; 26];
641 Ok(())
642 }
643
644 fn finalize(mut self) -> Result<Mmap, DynasmError> {
645 self.commit_local()?;
646
647 let baseaddr = self.mmap.as_ptr() as usize;
648 for (loc, label) in self.global_relocs.take() {
649 let target =
650 self.global_labels.get(label as usize).unwrap().unwrap();
651 let buf = &mut self.mmap.as_mut_slice()[loc.range(0)];
652 if loc.patch(buf, baseaddr, target.0).is_err() {
653 return Err(DynasmError::ImpossibleRelocation(
654 TargetKind::Global("oh no"),
655 ));
656 }
657 }
658
659 Ok(self.mmap.finalize())
660 }
661}
662
663impl From<Mmap> for MmapAssembler {
664 fn from(mmap: Mmap) -> Self {
665 Self {
666 mmap: MmapWriter::from(mmap),
667 global_labels: [None; 26],
668 local_labels: [None; 26],
669 global_relocs: Default::default(),
670 local_relocs: Default::default(),
671 }
672 }
673}
674
675fn build_asm_fn_with_storage<A: Assembler>(
678 t: &VmData<REGISTER_LIMIT>,
679 mut s: Mmap,
680) -> Mmap {
681 let size_estimate = t.len() * A::bytes_per_clause();
682 if size_estimate > 2 * s.capacity() {
683 s = Mmap::new(size_estimate).expect("failed to build mmap")
684 }
685
686 let mut asm = A::init(s, t.slot_count());
687
688 for op in t.iter_asm() {
689 match op {
690 RegOp::Load(reg, mem) => {
691 asm.build_load(reg, mem);
692 }
693 RegOp::Store(reg, mem) => {
694 asm.build_store(mem, reg);
695 }
696 RegOp::Input(out, i) => {
697 asm.build_input(out, i);
698 }
699 RegOp::Output(arg, i) => {
700 asm.build_output(arg, i);
701 }
702 RegOp::NegReg(out, arg) => {
703 asm.build_neg(out, arg);
704 }
705 RegOp::AbsReg(out, arg) => {
706 asm.build_abs(out, arg);
707 }
708 RegOp::RecipReg(out, arg) => {
709 asm.build_recip(out, arg);
710 }
711 RegOp::SqrtReg(out, arg) => {
712 asm.build_sqrt(out, arg);
713 }
714 RegOp::SinReg(out, arg) => {
715 asm.build_sin(out, arg);
716 }
717 RegOp::CosReg(out, arg) => {
718 asm.build_cos(out, arg);
719 }
720 RegOp::TanReg(out, arg) => {
721 asm.build_tan(out, arg);
722 }
723 RegOp::AsinReg(out, arg) => {
724 asm.build_asin(out, arg);
725 }
726 RegOp::AcosReg(out, arg) => {
727 asm.build_acos(out, arg);
728 }
729 RegOp::AtanReg(out, arg) => {
730 asm.build_atan(out, arg);
731 }
732 RegOp::ExpReg(out, arg) => {
733 asm.build_exp(out, arg);
734 }
735 RegOp::LnReg(out, arg) => {
736 asm.build_ln(out, arg);
737 }
738 RegOp::CopyReg(out, arg) => {
739 asm.build_copy(out, arg);
740 }
741 RegOp::SquareReg(out, arg) => {
742 asm.build_square(out, arg);
743 }
744 RegOp::FloorReg(out, arg) => {
745 asm.build_floor(out, arg);
746 }
747 RegOp::CeilReg(out, arg) => {
748 asm.build_ceil(out, arg);
749 }
750 RegOp::RoundReg(out, arg) => {
751 asm.build_round(out, arg);
752 }
753 RegOp::NotReg(out, arg) => {
754 asm.build_not(out, arg);
755 }
756 RegOp::AddRegReg(out, lhs, rhs) => {
757 asm.build_add(out, lhs, rhs);
758 }
759 RegOp::MulRegReg(out, lhs, rhs) => {
760 asm.build_mul(out, lhs, rhs);
761 }
762 RegOp::DivRegReg(out, lhs, rhs) => {
763 asm.build_div(out, lhs, rhs);
764 }
765 RegOp::AtanRegReg(out, lhs, rhs) => {
766 asm.build_atan2(out, lhs, rhs);
767 }
768 RegOp::SubRegReg(out, lhs, rhs) => {
769 asm.build_sub(out, lhs, rhs);
770 }
771 RegOp::MinRegReg(out, lhs, rhs) => {
772 asm.build_min(out, lhs, rhs);
773 }
774 RegOp::MaxRegReg(out, lhs, rhs) => {
775 asm.build_max(out, lhs, rhs);
776 }
777 RegOp::AddRegImm(out, arg, imm) => {
778 asm.build_add_imm(out, arg, imm);
779 }
780 RegOp::MulRegImm(out, arg, imm) => {
781 asm.build_mul_imm(out, arg, imm);
782 }
783 RegOp::DivRegImm(out, arg, imm) => {
784 let reg = asm.load_imm(imm);
785 asm.build_div(out, arg, reg);
786 }
787 RegOp::DivImmReg(out, arg, imm) => {
788 let reg = asm.load_imm(imm);
789 asm.build_div(out, reg, arg);
790 }
791 RegOp::AtanRegImm(out, arg, imm) => {
792 let reg = asm.load_imm(imm);
793 asm.build_atan2(out, arg, reg);
794 }
795 RegOp::AtanImmReg(out, arg, imm) => {
796 let reg = asm.load_imm(imm);
797 asm.build_atan2(out, reg, arg);
798 }
799 RegOp::SubImmReg(out, arg, imm) => {
800 asm.build_sub_imm_reg(out, arg, imm);
801 }
802 RegOp::SubRegImm(out, arg, imm) => {
803 asm.build_sub_reg_imm(out, arg, imm);
804 }
805 RegOp::MinRegImm(out, arg, imm) => {
806 let reg = asm.load_imm(imm);
807 asm.build_min(out, arg, reg);
808 }
809 RegOp::MaxRegImm(out, arg, imm) => {
810 let reg = asm.load_imm(imm);
811 asm.build_max(out, arg, reg);
812 }
813 RegOp::ModRegReg(out, lhs, rhs) => {
814 asm.build_mod(out, lhs, rhs);
815 }
816 RegOp::ModRegImm(out, arg, imm) => {
817 let reg = asm.load_imm(imm);
818 asm.build_mod(out, arg, reg);
819 }
820 RegOp::ModImmReg(out, arg, imm) => {
821 let reg = asm.load_imm(imm);
822 asm.build_mod(out, reg, arg);
823 }
824 RegOp::AndRegReg(out, lhs, rhs) => {
825 asm.build_and(out, lhs, rhs);
826 }
827 RegOp::AndRegImm(out, arg, imm) => {
828 let reg = asm.load_imm(imm);
829 asm.build_and(out, arg, reg);
830 }
831 RegOp::OrRegReg(out, lhs, rhs) => {
832 asm.build_or(out, lhs, rhs);
833 }
834 RegOp::OrRegImm(out, arg, imm) => {
835 let reg = asm.load_imm(imm);
836 asm.build_or(out, arg, reg);
837 }
838 RegOp::CopyImm(out, imm) => {
839 let reg = asm.load_imm(imm);
840 asm.build_copy(out, reg);
841 }
842 RegOp::CompareRegReg(out, lhs, rhs) => {
843 asm.build_compare(out, lhs, rhs);
844 }
845 RegOp::CompareRegImm(out, arg, imm) => {
846 let reg = asm.load_imm(imm);
847 asm.build_compare(out, arg, reg);
848 }
849 RegOp::CompareImmReg(out, arg, imm) => {
850 let reg = asm.load_imm(imm);
851 asm.build_compare(out, reg, arg);
852 }
853 }
854 }
855
856 asm.finalize().expect("failed to build JIT function")
857 }
859
860#[derive(Clone)]
862pub struct JitFunction(GenericVmFunction<REGISTER_LIMIT>);
863
864impl JitFunction {
865 fn tracing_tape<A: Assembler>(
866 &self,
867 storage: Mmap,
868 ) -> JitTracingFn<A::Data> {
869 let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
870 let ptr = f.as_ptr();
871 JitTracingFn {
872 mmap: f.into(),
873 vars: self.0.data().vars.clone(),
874 choice_count: self.0.choice_count(),
875 output_count: self.0.output_count(),
876 fn_trace: unsafe {
877 std::mem::transmute::<
878 *const std::ffi::c_void,
879 JitTracingFnPointer<A::Data>,
880 >(ptr)
881 },
882 }
883 }
884 fn bulk_tape<A: Assembler>(&self, storage: Mmap) -> JitBulkFn<A::Data> {
885 let f = build_asm_fn_with_storage::<A>(self.0.data(), storage);
886 let ptr = f.as_ptr();
887 JitBulkFn {
888 mmap: f.into(),
889 output_count: self.0.output_count(),
890 vars: self.0.data().vars.clone(),
891 fn_bulk: unsafe {
892 std::mem::transmute::<
893 *const std::ffi::c_void,
894 JitBulkFnPointer<A::Data>,
895 >(ptr)
896 },
897 }
898 }
899}
900
901impl Function for JitFunction {
902 type Trace = VmTrace;
903 type Storage = VmData<REGISTER_LIMIT>;
904 type Workspace = VmWorkspace<REGISTER_LIMIT>;
905
906 type TapeStorage = Mmap;
907
908 type IntervalEval = JitIntervalEval;
909 type PointEval = JitPointEval;
910 type FloatSliceEval = JitFloatSliceEval;
911 type GradSliceEval = JitGradSliceEval;
912
913 #[inline]
914 fn point_tape(&self, storage: Mmap) -> JitTracingFn<f32> {
915 self.tracing_tape::<point::PointAssembler>(storage)
916 }
917
918 #[inline]
919 fn interval_tape(&self, storage: Mmap) -> JitTracingFn<Interval> {
920 self.tracing_tape::<interval::IntervalAssembler>(storage)
921 }
922
923 #[inline]
924 fn float_slice_tape(&self, storage: Mmap) -> JitBulkFn<f32> {
925 self.bulk_tape::<float_slice::FloatSliceAssembler>(storage)
926 }
927
928 #[inline]
929 fn grad_slice_tape(&self, storage: Mmap) -> JitBulkFn<Grad> {
930 self.bulk_tape::<grad_slice::GradSliceAssembler>(storage)
931 }
932
933 #[inline]
934 fn simplify(
935 &self,
936 trace: &Self::Trace,
937 storage: Self::Storage,
938 workspace: &mut Self::Workspace,
939 ) -> Result<Self, Error> {
940 self.0.simplify(trace, storage, workspace).map(JitFunction)
941 }
942
943 #[inline]
944 fn recycle(self) -> Option<Self::Storage> {
945 self.0.recycle()
946 }
947
948 #[inline]
949 fn size(&self) -> usize {
950 self.0.size()
951 }
952
953 #[inline]
954 fn vars(&self) -> &VarMap {
955 self.0.vars()
956 }
957
958 #[inline]
959 fn can_simplify(&self) -> bool {
960 self.0.choice_count() > 0
961 }
962}
963
964impl RenderHints for JitFunction {
965 fn tile_sizes_3d() -> TileSizes {
966 TileSizes::new(&[64, 16, 8]).unwrap()
967 }
968
969 fn tile_sizes_2d() -> TileSizes {
970 TileSizes::new(&[128, 16]).unwrap()
971 }
972
973 fn simplify_tree_during_meshing(d: usize) -> bool {
974 d % 8 == 4
976 }
977}
978
979impl MathFunction for JitFunction {
980 fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
981 GenericVmFunction::new(ctx, nodes).map(JitFunction)
982 }
983}
984
985impl From<GenericVmFunction<REGISTER_LIMIT>> for JitFunction {
986 fn from(v: GenericVmFunction<REGISTER_LIMIT>) -> Self {
987 Self(v)
988 }
989}
990
991impl<'a> From<&'a JitFunction> for &'a GenericVmFunction<REGISTER_LIMIT> {
992 fn from(v: &'a JitFunction) -> Self {
993 &v.0
994 }
995}
996
997#[cfg(target_arch = "x86_64")]
1006macro_rules! jit_fn {
1007 (unsafe fn($($args:tt)*)) => {
1008 unsafe extern "sysv64" fn($($args)*)
1009 };
1010}
1011
1012#[cfg(target_arch = "aarch64")]
1016macro_rules! jit_fn {
1017 (unsafe fn($($args:tt)*)) => {
1018 unsafe extern "C" fn($($args)*)
1019 };
1020}
1021
1022struct JitTracingEval<T> {
1029 choices: VmTrace,
1030 out: Vec<T>,
1031}
1032
1033impl<T> Default for JitTracingEval<T> {
1034 fn default() -> Self {
1035 Self {
1036 choices: VmTrace::default(),
1037 out: Vec::default(),
1038 }
1039 }
1040}
1041
1042pub type JitTracingFnPointer<T> = jit_fn!(
1044 unsafe fn(
1045 *const T, *mut u8, *mut u8, *mut T, )
1050);
1051
1052#[derive(Clone)]
1054pub struct JitTracingFn<T> {
1055 mmap: Arc<Mmap>,
1056 choice_count: usize,
1057 output_count: usize,
1058 vars: Arc<VarMap>,
1059 fn_trace: JitTracingFnPointer<T>,
1060}
1061
1062impl<T: Clone> Tape for JitTracingFn<T> {
1063 type Storage = Mmap;
1064 fn recycle(self) -> Option<Self::Storage> {
1065 Arc::into_inner(self.mmap)
1066 }
1067
1068 fn vars(&self) -> &VarMap {
1069 &self.vars
1070 }
1071
1072 fn output_count(&self) -> usize {
1073 self.output_count
1074 }
1075}
1076
1077unsafe impl<T> Send for JitTracingFn<T> {}
1080unsafe impl<T> Sync for JitTracingFn<T> {}
1081
1082impl<T: From<f32> + Clone> JitTracingEval<T> {
1083 fn eval(
1085 &mut self,
1086 tape: &JitTracingFn<T>,
1087 vars: &[T],
1088 ) -> (&[T], Option<&VmTrace>) {
1089 let mut simplify = 0;
1090 self.choices.resize(tape.choice_count, Choice::Unknown);
1091 self.choices.fill(Choice::Unknown);
1092 self.out.resize(tape.output_count, f32::NAN.into());
1093 self.out.fill(f32::NAN.into());
1094 unsafe {
1095 (tape.fn_trace)(
1096 vars.as_ptr(),
1097 self.choices.as_mut_ptr() as *mut u8,
1098 &mut simplify,
1099 self.out.as_mut_ptr(),
1100 )
1101 };
1102
1103 (
1104 &self.out,
1105 if simplify != 0 {
1106 Some(&self.choices)
1107 } else {
1108 None
1109 },
1110 )
1111 }
1112}
1113
1114#[derive(Default)]
1116pub struct JitIntervalEval(JitTracingEval<Interval>);
1117impl TracingEvaluator for JitIntervalEval {
1118 type Data = Interval;
1119 type Tape = JitTracingFn<Interval>;
1120 type Trace = VmTrace;
1121 type TapeStorage = Mmap;
1122
1123 #[inline]
1124 fn eval(
1125 &mut self,
1126 tape: &Self::Tape,
1127 vars: &[Self::Data],
1128 ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1129 tape.vars().check_tracing_arguments(vars)?;
1130 Ok(self.0.eval(tape, vars))
1131 }
1132}
1133
1134#[derive(Default)]
1136pub struct JitPointEval(JitTracingEval<f32>);
1137impl TracingEvaluator for JitPointEval {
1138 type Data = f32;
1139 type Tape = JitTracingFn<f32>;
1140 type Trace = VmTrace;
1141 type TapeStorage = Mmap;
1142
1143 #[inline]
1144 fn eval(
1145 &mut self,
1146 tape: &Self::Tape,
1147 vars: &[Self::Data],
1148 ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> {
1149 tape.vars().check_tracing_arguments(vars)?;
1150 Ok(self.0.eval(tape, vars))
1151 }
1152}
1153
1154pub type JitBulkFnPointer<T> = jit_fn!(
1158 unsafe fn(
1159 *const *const T, *const *mut T, u64, )
1163);
1164
1165#[derive(Clone)]
1167pub struct JitBulkFn<T> {
1168 mmap: Arc<Mmap>,
1169 vars: Arc<VarMap>,
1170 output_count: usize,
1171 fn_bulk: JitBulkFnPointer<T>,
1172}
1173
1174impl<T: Clone> Tape for JitBulkFn<T> {
1175 type Storage = Mmap;
1176 fn recycle(self) -> Option<Self::Storage> {
1177 Arc::into_inner(self.mmap)
1178 }
1179
1180 fn vars(&self) -> &VarMap {
1181 &self.vars
1182 }
1183
1184 fn output_count(&self) -> usize {
1185 self.output_count
1186 }
1187}
1188
1189const MAX_SIMD_WIDTH: usize = 8;
1196
1197struct JitBulkEval<T> {
1199 input_ptrs: Vec<*const T>,
1201
1202 output_ptrs: Vec<*mut T>,
1204
1205 scratch: Vec<[T; MAX_SIMD_WIDTH]>,
1207
1208 out: Vec<Vec<T>>,
1210}
1211
1212unsafe impl<T> Sync for JitBulkEval<T> {}
1215unsafe impl<T> Send for JitBulkEval<T> {}
1216
1217impl<T> Default for JitBulkEval<T> {
1218 fn default() -> Self {
1219 Self {
1220 out: vec![],
1221 scratch: vec![],
1222 input_ptrs: vec![],
1223 output_ptrs: vec![],
1224 }
1225 }
1226}
1227
1228unsafe impl<T> Send for JitBulkFn<T> {}
1231unsafe impl<T> Sync for JitBulkFn<T> {}
1232
1233impl<T: From<f32> + Copy + SimdSize> JitBulkEval<T> {
1234 fn eval<V: std::ops::Deref<Target = [T]>>(
1236 &mut self,
1237 tape: &JitBulkFn<T>,
1238 vars: &[V],
1239 ) -> BulkOutput<'_, T> {
1240 let n = vars.first().map(|v| v.deref().len()).unwrap_or(0);
1241
1242 self.out.resize_with(tape.output_count(), Vec::new);
1243 for o in &mut self.out {
1244 o.resize(n.max(T::SIMD_SIZE), f32::NAN.into());
1245 o.fill(f32::NAN.into());
1246 }
1247
1248 if n < T::SIMD_SIZE {
1252 assert!(T::SIMD_SIZE <= MAX_SIMD_WIDTH);
1253
1254 self.scratch
1255 .resize(vars.len(), [f32::NAN.into(); MAX_SIMD_WIDTH]);
1256 for (v, t) in vars.iter().zip(self.scratch.iter_mut()) {
1257 t[0..n].copy_from_slice(v);
1258 }
1259
1260 self.input_ptrs.clear();
1261 self.input_ptrs
1262 .extend(self.scratch[..vars.len()].iter().map(|t| t.as_ptr()));
1263
1264 self.output_ptrs.clear();
1265 self.output_ptrs
1266 .extend(self.out.iter_mut().map(|t| t.as_mut_ptr()));
1267
1268 unsafe {
1269 (tape.fn_bulk)(
1270 self.input_ptrs.as_ptr(),
1271 self.output_ptrs.as_ptr(),
1272 T::SIMD_SIZE as u64,
1273 );
1274 }
1275 } else {
1276 let m = (n / T::SIMD_SIZE) * T::SIMD_SIZE; self.input_ptrs.clear();
1281 self.input_ptrs.extend(vars.iter().map(|v| v.as_ptr()));
1282
1283 self.output_ptrs.clear();
1284 self.output_ptrs
1285 .extend(self.out.iter_mut().map(|v| v.as_mut_ptr()));
1286 unsafe {
1287 (tape.fn_bulk)(
1288 self.input_ptrs.as_ptr(),
1289 self.output_ptrs.as_ptr(),
1290 m as u64,
1291 );
1292 }
1293 if n != m {
1297 self.input_ptrs.clear();
1298 self.output_ptrs.clear();
1299 unsafe {
1300 self.input_ptrs.extend(
1301 vars.iter().map(|v| v.as_ptr().add(n - T::SIMD_SIZE)),
1302 );
1303 self.output_ptrs.extend(
1304 self.out
1305 .iter_mut()
1306 .map(|v| v.as_mut_ptr().add(n - T::SIMD_SIZE)),
1307 );
1308 (tape.fn_bulk)(
1309 self.input_ptrs.as_ptr(),
1310 self.output_ptrs.as_ptr(),
1311 T::SIMD_SIZE as u64,
1312 );
1313 }
1314 }
1315 }
1316 BulkOutput::new(&self.out, n)
1317 }
1318}
1319
1320#[derive(Default)]
1322pub struct JitFloatSliceEval(JitBulkEval<f32>);
1323impl BulkEvaluator for JitFloatSliceEval {
1324 type Data = f32;
1325 type Tape = JitBulkFn<Self::Data>;
1326 type TapeStorage = Mmap;
1327
1328 #[inline]
1329 fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1330 &mut self,
1331 tape: &Self::Tape,
1332 vars: &[V],
1333 ) -> Result<BulkOutput<'_, f32>, Error> {
1334 tape.vars().check_bulk_arguments(vars)?;
1335 Ok(self.0.eval(tape, vars))
1336 }
1337}
1338
1339#[derive(Default)]
1341pub struct JitGradSliceEval(JitBulkEval<Grad>);
1342impl BulkEvaluator for JitGradSliceEval {
1343 type Data = Grad;
1344 type Tape = JitBulkFn<Self::Data>;
1345 type TapeStorage = Mmap;
1346
1347 #[inline]
1348 fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1349 &mut self,
1350 tape: &Self::Tape,
1351 vars: &[V],
1352 ) -> Result<BulkOutput<'_, Grad>, Error> {
1353 tape.vars().check_bulk_arguments(vars)?;
1354 Ok(self.0.eval(tape, vars))
1355 }
1356}
1357
1358pub type JitShape = fidget_core::shape::Shape<JitFunction>;
1360
1361#[cfg(test)]
1364mod test {
1365 use super::*;
1366 fidget_core::grad_slice_tests!(JitFunction);
1367 fidget_core::interval_tests!(JitFunction);
1368 fidget_core::float_slice_tests!(JitFunction);
1369 fidget_core::point_tests!(JitFunction);
1370
1371 #[test]
1372 fn test_mmap_expansion() {
1373 let mmap = Mmap::new(0).unwrap();
1374
1375 let mut asm = MmapAssembler::from(mmap);
1376 const COUNT: u32 = 23456; for i in 0..COUNT {
1379 asm.push_u32(i);
1380 }
1381 let mmap = asm.finalize().unwrap();
1382 let ptr = mmap.as_ptr() as *const u32;
1383 for i in 0..COUNT {
1384 let v = unsafe { *ptr.add(i as usize) };
1385 assert_eq!(v, i);
1386 }
1387 }
1388}