1use crate::{
3 Context, Error,
4 compiler::RegOp,
5 context::Node,
6 eval::{
7 BulkEvaluator, BulkOutput, Function, MathFunction, Tape, Trace,
8 TracingEvaluator,
9 },
10 render::{RenderHints, TileSizes},
11 shape::Shape,
12 types::{Grad, Interval},
13 var::VarMap,
14};
15use std::sync::Arc;
16
17mod choice;
18mod data;
19
20pub use choice::Choice;
21pub use data::{VmData, VmWorkspace};
22
23pub type VmFunction = GenericVmFunction<{ u8::MAX as usize }>;
33
34pub type VmShape = Shape<VmFunction>;
36
37#[derive(Default)]
39pub struct EmptyTapeStorage;
40
41#[derive(Clone)]
46pub struct GenericVmTape<const N: usize>(Arc<VmData<N>>);
47
48impl<const N: usize> GenericVmTape<N> {
49 pub fn data(&self) -> &VmData<N> {
51 &self.0
52 }
53}
54
55impl<const N: usize> Tape for GenericVmTape<N> {
56 type Storage = EmptyTapeStorage;
57 fn recycle(self) -> Option<Self::Storage> {
58 Some(EmptyTapeStorage)
59 }
60
61 fn vars(&self) -> &VarMap {
62 &self.0.vars
63 }
64
65 fn output_count(&self) -> usize {
66 self.0.output_count()
67 }
68}
69
70#[derive(Clone, Default, Eq, PartialEq)]
74pub struct VmTrace(Vec<Choice>);
75
76impl VmTrace {
77 pub fn fill(&mut self, v: Choice) {
79 self.0.fill(v);
80 }
81 pub fn resize(&mut self, n: usize, v: Choice) {
83 self.0.resize(n, v);
84 }
85 pub fn as_slice(&self) -> &[Choice] {
87 self.0.as_slice()
88 }
89 pub fn as_mut_slice(&mut self) -> &mut [Choice] {
91 self.0.as_mut_slice()
92 }
93 pub fn as_mut_ptr(&mut self) -> *mut Choice {
95 self.0.as_mut_ptr()
96 }
97}
98
99impl Trace for VmTrace {
100 fn copy_from(&mut self, other: &VmTrace) {
101 self.0.resize(other.0.len(), Choice::Unknown);
102 self.0.copy_from_slice(&other.0);
103 }
104}
105
106#[cfg(any(test, feature = "eval-tests"))]
107impl From<Vec<Choice>> for VmTrace {
108 fn from(v: Vec<Choice>) -> Self {
109 Self(v)
110 }
111}
112
113#[cfg(any(test, feature = "eval-tests"))]
114impl AsRef<[Choice]> for VmTrace {
115 fn as_ref(&self) -> &[Choice] {
116 &self.0
117 }
118}
119
120#[derive(Clone)]
125pub struct GenericVmFunction<const N: usize>(Arc<VmData<N>>);
126
127impl<const N: usize> From<VmData<N>> for GenericVmFunction<N> {
128 fn from(d: VmData<N>) -> Self {
129 Self(d.into())
130 }
131}
132
133impl<const N: usize> GenericVmFunction<N> {
134 pub fn size(&self) -> usize {
136 self.0.len()
137 }
138
139 pub fn recycle(self) -> Option<VmData<N>> {
141 Arc::try_unwrap(self.0).ok()
142 }
143
144 pub fn data(&self) -> &VmData<N> {
146 self.0.as_ref()
147 }
148
149 pub fn tape(&self) -> GenericVmTape<N> {
151 GenericVmTape(self.0.clone())
152 }
153
154 pub fn choice_count(&self) -> usize {
156 self.0.choice_count()
157 }
158
159 pub fn output_count(&self) -> usize {
161 self.0.output_count()
162 }
163
164 pub fn simplify_with<const M: usize>(
166 &self,
167 trace: &VmTrace,
168 storage: VmData<M>,
169 workspace: &mut VmWorkspace<M>,
170 ) -> Result<GenericVmFunction<M>, Error> {
171 let d = self.0.simplify::<M>(trace.as_slice(), workspace, storage)?;
172 Ok(GenericVmFunction(Arc::new(d)))
173 }
174}
175
176impl<const N: usize> Function for GenericVmFunction<N> {
177 type Storage = VmData<N>;
178 type Workspace = VmWorkspace<N>;
179
180 type TapeStorage = EmptyTapeStorage;
181
182 type FloatSliceEval = VmFloatSliceEval<N>;
183 type GradSliceEval = VmGradSliceEval<N>;
184 type PointEval = VmPointEval<N>;
185 type IntervalEval = VmIntervalEval<N>;
186 type Trace = VmTrace;
187
188 #[inline]
189 fn float_slice_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
190 self.tape()
191 }
192
193 #[inline]
194 fn grad_slice_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
195 self.tape()
196 }
197
198 #[inline]
199 fn point_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
200 self.tape()
201 }
202
203 #[inline]
204 fn interval_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
205 self.tape()
206 }
207
208 #[inline]
209 fn simplify(
210 &self,
211 trace: &Self::Trace,
212 storage: Self::Storage,
213 workspace: &mut Self::Workspace,
214 ) -> Result<Self, Error> {
215 self.simplify_with(trace, storage, workspace)
216 }
217
218 #[inline]
219 fn recycle(self) -> Option<Self::Storage> {
220 GenericVmFunction::recycle(self)
221 }
222
223 #[inline]
224 fn size(&self) -> usize {
225 GenericVmFunction::size(self)
226 }
227
228 #[inline]
229 fn vars(&self) -> &VarMap {
230 &self.0.vars
231 }
232
233 #[inline]
234 fn can_simplify(&self) -> bool {
235 self.0.choice_count() > 0
236 }
237}
238
239impl<const N: usize> RenderHints for GenericVmFunction<N> {
240 fn tile_sizes_3d() -> TileSizes {
241 TileSizes::new(&[128, 64, 32, 16, 8]).unwrap()
242 }
243
244 fn tile_sizes_2d() -> TileSizes {
245 TileSizes::new(&[128, 32, 8]).unwrap()
246 }
247}
248
249impl<const N: usize> MathFunction for GenericVmFunction<N> {
250 fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
251 let d = VmData::new(ctx, nodes)?;
252 Ok(Self(d.into()))
253 }
254}
255
256struct SlotArray<'a, T>(&'a mut [T]);
260impl<T> std::ops::Index<u8> for SlotArray<'_, T> {
261 type Output = T;
262 fn index(&self, i: u8) -> &Self::Output {
263 &self.0[i as usize]
264 }
265}
266impl<T> std::ops::IndexMut<u8> for SlotArray<'_, T> {
267 fn index_mut(&mut self, i: u8) -> &mut T {
268 &mut self.0[i as usize]
269 }
270}
271impl<T> std::ops::Index<u32> for SlotArray<'_, T> {
272 type Output = T;
273 fn index(&self, i: u32) -> &Self::Output {
274 &self.0[i as usize]
275 }
276}
277impl<T> std::ops::IndexMut<u32> for SlotArray<'_, T> {
278 fn index_mut(&mut self, i: u32) -> &mut T {
279 &mut self.0[i as usize]
280 }
281}
282
283struct TracingVmEval<T> {
287 slots: Vec<T>,
288 out: Vec<T>,
289 choices: VmTrace,
290}
291
292impl<T> Default for TracingVmEval<T> {
293 fn default() -> Self {
294 Self {
295 slots: Vec::default(),
296 out: Vec::default(),
297 choices: VmTrace::default(),
298 }
299 }
300}
301
302impl<T: From<f32> + Clone> TracingVmEval<T> {
303 fn resize_slots<const N: usize>(&mut self, tape: &VmData<N>) {
304 self.slots.resize(tape.slot_count(), f32::NAN.into());
305 self.choices.resize(tape.choice_count(), Choice::Unknown);
306 self.out.resize(tape.output_count(), f32::NAN.into());
307 self.choices.fill(Choice::Unknown);
308 }
309}
310
311#[derive(Default)]
313pub struct VmIntervalEval<const N: usize>(TracingVmEval<Interval>);
314impl<const N: usize> TracingEvaluator for VmIntervalEval<N> {
315 type Data = Interval;
316 type Tape = GenericVmTape<N>;
317 type Trace = VmTrace;
318 type TapeStorage = EmptyTapeStorage;
319
320 #[inline]
321 fn eval(
322 &mut self,
323 tape: &Self::Tape,
324 vars: &[Interval],
325 ) -> Result<(&[Interval], Option<&VmTrace>), Error> {
326 tape.vars().check_tracing_arguments(vars)?;
327 let tape = tape.data();
328 self.0.resize_slots(tape);
329
330 let mut simplify = false;
331 let mut v = SlotArray(&mut self.0.slots);
332 let mut choices = self.0.choices.as_mut_slice().iter_mut();
333 for op in tape.iter_asm() {
334 match op {
335 RegOp::Output(arg, i) => {
336 self.0.out[i as usize] = v[arg];
337 }
338 RegOp::Input(out, i) => {
339 v[out] = vars[i as usize];
340 }
341 RegOp::NegReg(out, arg) => {
342 v[out] = -v[arg];
343 }
344 RegOp::AbsReg(out, arg) => {
345 v[out] = v[arg].abs();
346 }
347 RegOp::RecipReg(out, arg) => {
348 v[out] = v[arg].recip();
349 }
350 RegOp::SqrtReg(out, arg) => {
351 v[out] = v[arg].sqrt();
352 }
353 RegOp::SquareReg(out, arg) => {
354 v[out] = v[arg].square();
355 }
356 RegOp::FloorReg(out, arg) => {
357 v[out] = v[arg].floor();
358 }
359 RegOp::CeilReg(out, arg) => {
360 v[out] = v[arg].ceil();
361 }
362 RegOp::RoundReg(out, arg) => {
363 v[out] = v[arg].round();
364 }
365 RegOp::SinReg(out, arg) => {
366 v[out] = v[arg].sin();
367 }
368 RegOp::CosReg(out, arg) => {
369 v[out] = v[arg].cos();
370 }
371 RegOp::TanReg(out, arg) => {
372 v[out] = v[arg].tan();
373 }
374 RegOp::AsinReg(out, arg) => {
375 v[out] = v[arg].asin();
376 }
377 RegOp::AcosReg(out, arg) => {
378 v[out] = v[arg].acos();
379 }
380 RegOp::AtanReg(out, arg) => {
381 v[out] = v[arg].atan();
382 }
383 RegOp::ExpReg(out, arg) => {
384 v[out] = v[arg].exp();
385 }
386 RegOp::LnReg(out, arg) => {
387 v[out] = v[arg].ln();
388 }
389 RegOp::NotReg(out, arg) => {
390 v[out] = if !v[arg].contains(0.0) && !v[arg].has_nan() {
391 Interval::new(0.0, 0.0)
392 } else if v[arg].lower() == 0.0 && v[arg].upper() == 0.0 {
393 Interval::new(1.0, 1.0)
394 } else {
395 Interval::new(0.0, 1.0)
396 };
397 }
398 RegOp::CopyReg(out, arg) => v[out] = v[arg],
399 RegOp::AddRegImm(out, arg, imm) => {
400 v[out] = v[arg] + imm.into();
401 }
402 RegOp::MulRegImm(out, arg, imm) => {
403 v[out] = v[arg] * imm;
404 }
405 RegOp::DivRegImm(out, arg, imm) => {
406 v[out] = v[arg] / imm.into();
407 }
408 RegOp::DivImmReg(out, arg, imm) => {
409 let imm: Interval = imm.into();
410 v[out] = imm / v[arg];
411 }
412 RegOp::AtanRegImm(out, arg, imm) => {
413 v[out] = v[arg].atan2(imm.into());
414 }
415 RegOp::AtanImmReg(out, arg, imm) => {
416 let imm: Interval = imm.into();
417 v[out] = imm.atan2(v[arg]);
418 }
419 RegOp::AtanRegReg(out, lhs, rhs) => {
420 v[out] = v[lhs].atan2(v[rhs]);
421 }
422 RegOp::SubImmReg(out, arg, imm) => {
423 v[out] = Interval::from(imm) - v[arg];
424 }
425 RegOp::SubRegImm(out, arg, imm) => {
426 v[out] = v[arg] - imm.into();
427 }
428 RegOp::MinRegImm(out, arg, imm) => {
429 let (value, choice) = v[arg].min_choice(imm.into());
430 v[out] = value;
431 *choices.next().unwrap() |= choice;
432 simplify |= choice != Choice::Both;
433 }
434 RegOp::MaxRegImm(out, arg, imm) => {
435 let (value, choice) = v[arg].max_choice(imm.into());
436 v[out] = value;
437 *choices.next().unwrap() |= choice;
438 simplify |= choice != Choice::Both;
439 }
440 RegOp::AndRegReg(out, lhs, rhs) => {
441 let (value, choice) = v[lhs].and_choice(v[rhs]);
442 v[out] = value;
443 *choices.next().unwrap() |= choice;
444 simplify |= choice != Choice::Both;
445 }
446 RegOp::AndRegImm(out, arg, imm) => {
447 let (value, choice) = v[arg].and_choice(imm.into());
448 v[out] = value;
449 *choices.next().unwrap() |= choice;
450 simplify |= choice != Choice::Both;
451 }
452 RegOp::OrRegReg(out, lhs, rhs) => {
453 let (value, choice) = v[lhs].or_choice(v[rhs]);
454 v[out] = value;
455 *choices.next().unwrap() |= choice;
456 simplify |= choice != Choice::Both;
457 }
458 RegOp::OrRegImm(out, arg, imm) => {
459 let (value, choice) = v[arg].or_choice(imm.into());
460 v[out] = value;
461 *choices.next().unwrap() |= choice;
462 simplify |= choice != Choice::Both;
463 }
464 RegOp::ModRegReg(out, lhs, rhs) => {
465 v[out] = v[lhs].rem_euclid(v[rhs]);
466 }
467 RegOp::ModRegImm(out, arg, imm) => {
468 v[out] = v[arg].rem_euclid(imm.into());
469 }
470 RegOp::ModImmReg(out, arg, imm) => {
471 v[out] = Interval::from(imm).rem_euclid(v[arg]);
472 }
473 RegOp::AddRegReg(out, lhs, rhs) => v[out] = v[lhs] + v[rhs],
474 RegOp::MulRegReg(out, lhs, rhs) => v[out] = v[lhs] * v[rhs],
475 RegOp::DivRegReg(out, lhs, rhs) => v[out] = v[lhs] / v[rhs],
476 RegOp::SubRegReg(out, lhs, rhs) => v[out] = v[lhs] - v[rhs],
477 RegOp::CompareRegReg(out, lhs, rhs) => {
478 v[out] = if v[lhs].has_nan() || v[rhs].has_nan() {
479 f32::NAN.into()
480 } else if v[lhs].upper() < v[rhs].lower() {
481 Interval::from(-1.0)
482 } else if v[lhs].lower() > v[rhs].upper() {
483 Interval::from(1.0)
484 } else {
485 Interval::new(-1.0, 1.0)
486 };
487 }
488 RegOp::CompareRegImm(out, arg, imm) => {
489 v[out] = if v[arg].has_nan() || imm.is_nan() {
490 f32::NAN.into()
491 } else if v[arg].upper() < imm {
492 Interval::from(-1.0)
493 } else if v[arg].lower() > imm {
494 Interval::from(1.0)
495 } else {
496 Interval::new(-1.0, 1.0)
497 };
498 }
499 RegOp::CompareImmReg(out, arg, imm) => {
500 v[out] = if v[arg].has_nan() || imm.is_nan() {
501 f32::NAN.into()
502 } else if imm < v[arg].lower() {
503 Interval::from(-1.0)
504 } else if imm > v[arg].upper() {
505 Interval::from(1.0)
506 } else {
507 Interval::new(-1.0, 1.0)
508 };
509 }
510 RegOp::MinRegReg(out, lhs, rhs) => {
511 let (value, choice) = v[lhs].min_choice(v[rhs]);
512 v[out] = value;
513 *choices.next().unwrap() |= choice;
514 simplify |= choice != Choice::Both;
515 }
516 RegOp::MaxRegReg(out, lhs, rhs) => {
517 let (value, choice) = v[lhs].max_choice(v[rhs]);
518 v[out] = value;
519 *choices.next().unwrap() |= choice;
520 simplify |= choice != Choice::Both;
521 }
522 RegOp::CopyImm(out, imm) => {
523 v[out] = imm.into();
524 }
525 RegOp::Load(out, mem) => {
526 v[out] = v[mem];
527 }
528 RegOp::Store(out, mem) => {
529 v[mem] = v[out];
530 }
531 }
532 }
533 Ok((
534 &self.0.out,
535 if simplify {
536 Some(&self.0.choices)
537 } else {
538 None
539 },
540 ))
541 }
542}
543
544#[derive(Default)]
546pub struct VmPointEval<const N: usize>(TracingVmEval<f32>);
547impl<const N: usize> TracingEvaluator for VmPointEval<N> {
548 type Data = f32;
549 type Tape = GenericVmTape<N>;
550 type Trace = VmTrace;
551 type TapeStorage = EmptyTapeStorage;
552
553 #[inline]
554 fn eval(
555 &mut self,
556 tape: &Self::Tape,
557 vars: &[f32],
558 ) -> Result<(&[f32], Option<&VmTrace>), Error> {
559 tape.vars().check_tracing_arguments(vars)?;
560 let tape = tape.data();
561 self.0.resize_slots(tape);
562
563 let mut choices = self.0.choices.as_mut_slice().iter_mut();
564 let mut simplify = false;
565 let mut v = SlotArray(&mut self.0.slots);
566 for op in tape.iter_asm() {
567 match op {
568 RegOp::Output(arg, i) => {
569 self.0.out[i as usize] = v[arg];
570 }
571 RegOp::Input(out, i) => {
572 v[out] = vars[i as usize];
573 }
574 RegOp::NegReg(out, arg) => {
575 v[out] = -v[arg];
576 }
577 RegOp::AbsReg(out, arg) => {
578 v[out] = v[arg].abs();
579 }
580 RegOp::RecipReg(out, arg) => {
581 v[out] = 1.0 / v[arg];
582 }
583 RegOp::SqrtReg(out, arg) => {
584 v[out] = v[arg].sqrt();
585 }
586 RegOp::SquareReg(out, arg) => {
587 let s = v[arg];
588 v[out] = s * s;
589 }
590 RegOp::FloorReg(out, arg) => {
591 v[out] = v[arg].floor();
592 }
593 RegOp::CeilReg(out, arg) => {
594 v[out] = v[arg].ceil();
595 }
596 RegOp::RoundReg(out, arg) => {
597 v[out] = v[arg].round();
598 }
599 RegOp::SinReg(out, arg) => {
600 v[out] = v[arg].sin();
601 }
602 RegOp::CosReg(out, arg) => {
603 v[out] = v[arg].cos();
604 }
605 RegOp::TanReg(out, arg) => {
606 v[out] = v[arg].tan();
607 }
608 RegOp::AsinReg(out, arg) => {
609 v[out] = v[arg].asin();
610 }
611 RegOp::AcosReg(out, arg) => {
612 v[out] = v[arg].acos();
613 }
614 RegOp::AtanReg(out, arg) => {
615 v[out] = v[arg].atan();
616 }
617 RegOp::ExpReg(out, arg) => {
618 v[out] = v[arg].exp();
619 }
620 RegOp::LnReg(out, arg) => {
621 v[out] = v[arg].ln();
622 }
623 RegOp::NotReg(out, arg) => v[out] = (v[arg] == 0.0).into(),
624 RegOp::CopyReg(out, arg) => {
625 v[out] = v[arg];
626 }
627 RegOp::AddRegImm(out, arg, imm) => {
628 v[out] = v[arg] + imm;
629 }
630 RegOp::MulRegImm(out, arg, imm) => {
631 v[out] = v[arg] * imm;
632 }
633 RegOp::DivRegImm(out, arg, imm) => {
634 v[out] = v[arg] / imm;
635 }
636 RegOp::DivImmReg(out, arg, imm) => {
637 v[out] = imm / v[arg];
638 }
639 RegOp::AtanRegImm(out, arg, imm) => {
640 v[out] = v[arg].atan2(imm);
641 }
642 RegOp::AtanImmReg(out, arg, imm) => {
643 v[out] = imm.atan2(v[arg]);
644 }
645 RegOp::AtanRegReg(out, lhs, rhs) => {
646 v[out] = v[lhs].atan2(v[rhs]);
647 }
648 RegOp::SubImmReg(out, arg, imm) => {
649 v[out] = imm - v[arg];
650 }
651 RegOp::SubRegImm(out, arg, imm) => {
652 v[out] = v[arg] - imm;
653 }
654 RegOp::MinRegImm(out, arg, imm) => {
655 let a = v[arg];
656 let (choice, value) = if a < imm {
657 (Choice::Left, a)
658 } else if imm < a {
659 (Choice::Right, imm)
660 } else {
661 (
662 Choice::Both,
663 if a.is_nan() || imm.is_nan() {
664 f32::NAN
665 } else {
666 imm
667 },
668 )
669 };
670 v[out] = value;
671 *choices.next().unwrap() |= choice;
672 simplify |= choice != Choice::Both;
673 }
674 RegOp::MaxRegImm(out, arg, imm) => {
675 let a = v[arg];
676 let (choice, value) = if a > imm {
677 (Choice::Left, a)
678 } else if imm > a {
679 (Choice::Right, imm)
680 } else {
681 (
682 Choice::Both,
683 if a.is_nan() || imm.is_nan() {
684 f32::NAN
685 } else {
686 imm
687 },
688 )
689 };
690 v[out] = value;
691 *choices.next().unwrap() |= choice;
692 simplify |= choice != Choice::Both;
693 }
694 RegOp::AndRegImm(out, arg, imm) => {
695 let a = v[arg];
696 let (choice, value) = if a == 0.0 {
697 (Choice::Left, a)
698 } else {
699 (Choice::Right, imm)
700 };
701 v[out] = value;
702 *choices.next().unwrap() |= choice;
703 simplify |= choice != Choice::Both;
704 }
705 RegOp::OrRegImm(out, arg, imm) => {
706 let a = v[arg];
707 let (choice, value) = if a != 0.0 {
708 (Choice::Left, a)
709 } else {
710 (Choice::Right, imm)
711 };
712 v[out] = value;
713 *choices.next().unwrap() |= choice;
714 simplify |= choice != Choice::Both;
715 }
716 RegOp::ModRegReg(out, lhs, rhs) => {
717 v[out] = v[lhs].rem_euclid(v[rhs]);
718 }
719 RegOp::ModRegImm(out, arg, imm) => {
720 v[out] = v[arg].rem_euclid(imm);
721 }
722 RegOp::ModImmReg(out, arg, imm) => {
723 v[out] = imm.rem_euclid(v[arg]);
724 }
725 RegOp::AddRegReg(out, lhs, rhs) => {
726 v[out] = v[lhs] + v[rhs];
727 }
728 RegOp::MulRegReg(out, lhs, rhs) => {
729 v[out] = v[lhs] * v[rhs];
730 }
731 RegOp::DivRegReg(out, lhs, rhs) => {
732 v[out] = v[lhs] / v[rhs];
733 }
734 RegOp::CompareRegReg(out, lhs, rhs) => {
735 v[out] = v[lhs]
736 .partial_cmp(&v[rhs])
737 .map(|c| c as i8 as f32)
738 .unwrap_or(f32::NAN)
739 }
740 RegOp::CompareRegImm(out, arg, imm) => {
741 v[out] = v[arg]
742 .partial_cmp(&imm)
743 .map(|c| c as i8 as f32)
744 .unwrap_or(f32::NAN)
745 }
746 RegOp::CompareImmReg(out, arg, imm) => {
747 v[out] = imm
748 .partial_cmp(&v[arg])
749 .map(|c| c as i8 as f32)
750 .unwrap_or(f32::NAN)
751 }
752 RegOp::SubRegReg(out, lhs, rhs) => {
753 v[out] = v[lhs] - v[rhs];
754 }
755 RegOp::MinRegReg(out, lhs, rhs) => {
756 let a = v[lhs];
757 let b = v[rhs];
758 let (choice, value) = if a < b {
759 (Choice::Left, a)
760 } else if b < a {
761 (Choice::Right, b)
762 } else {
763 (
764 Choice::Both,
765 if a.is_nan() || b.is_nan() {
766 f32::NAN
767 } else {
768 b
769 },
770 )
771 };
772 v[out] = value;
773 *choices.next().unwrap() |= choice;
774 simplify |= choice != Choice::Both;
775 }
776 RegOp::MaxRegReg(out, lhs, rhs) => {
777 let a = v[lhs];
778 let b = v[rhs];
779 let (choice, value) = if a > b {
780 (Choice::Left, a)
781 } else if b > a {
782 (Choice::Right, b)
783 } else {
784 (
785 Choice::Both,
786 if a.is_nan() || b.is_nan() {
787 f32::NAN
788 } else {
789 b
790 },
791 )
792 };
793 v[out] = value;
794 *choices.next().unwrap() |= choice;
795 simplify |= choice != Choice::Both;
796 }
797 RegOp::AndRegReg(out, lhs, rhs) => {
798 let a = v[lhs];
799 let b = v[rhs];
800 let (choice, value) = if a == 0.0 {
801 (Choice::Left, a)
802 } else {
803 (Choice::Right, b)
804 };
805 v[out] = value;
806 *choices.next().unwrap() |= choice;
807 simplify |= choice != Choice::Both;
808 }
809 RegOp::OrRegReg(out, lhs, rhs) => {
810 let a = v[lhs];
811 let b = v[rhs];
812 let (choice, value) = if a != 0.0 {
813 (Choice::Left, a)
814 } else {
815 (Choice::Right, b)
816 };
817 v[out] = value;
818 *choices.next().unwrap() |= choice;
819 simplify |= choice != Choice::Both;
820 }
821 RegOp::CopyImm(out, imm) => {
822 v[out] = imm;
823 }
824 RegOp::Load(out, mem) => {
825 v[out] = v[mem];
826 }
827 RegOp::Store(out, mem) => {
828 v[mem] = v[out];
829 }
830 }
831 }
832 Ok((
833 &self.0.out,
834 if simplify {
835 Some(&self.0.choices)
836 } else {
837 None
838 },
839 ))
840 }
841}
842
843#[derive(Default)]
847struct BulkVmEval<T> {
848 slots: Vec<Vec<T>>,
850
851 out: Vec<Vec<T>>,
853}
854
855impl<T: From<f32> + Clone> BulkVmEval<T> {
856 fn resize_slots<const N: usize>(&mut self, tape: &VmData<N>, size: usize) {
858 self.slots
859 .resize_with(tape.slot_count(), || vec![f32::NAN.into(); size]);
860 for s in self.slots.iter_mut() {
861 s.resize(size, f32::NAN.into());
862 }
863
864 self.out
865 .resize_with(tape.output_count(), || vec![f32::NAN.into(); size]);
866 for o in self.out.iter_mut() {
867 o.resize(size, f32::NAN.into());
868 }
869 }
870}
871
872#[derive(Default)]
874pub struct VmFloatSliceEval<const N: usize>(BulkVmEval<f32>);
875impl<const N: usize> BulkEvaluator for VmFloatSliceEval<N> {
876 type Data = f32;
877 type Tape = GenericVmTape<N>;
878 type TapeStorage = EmptyTapeStorage;
879
880 #[inline]
881 fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
882 &mut self,
883 tape: &Self::Tape,
884 vars: &[V],
885 ) -> Result<BulkOutput<'_, f32>, Error> {
886 tape.vars().check_bulk_arguments(vars)?;
887 let tape = tape.data();
888
889 let size = vars.first().map(|v| v.len()).unwrap_or(0);
890 self.0.resize_slots(tape, size);
891
892 let mut v = SlotArray(&mut self.0.slots);
893 for op in tape.iter_asm() {
894 match op {
895 RegOp::Output(arg, i) => {
896 self.0.out[i as usize][0..size]
897 .copy_from_slice(&v[arg][0..size]);
898 }
899 RegOp::Input(out, i) => {
900 v[out][0..size].copy_from_slice(&vars[i as usize]);
901 }
902 RegOp::NegReg(out, arg) => {
903 for i in 0..size {
904 v[out][i] = -v[arg][i];
905 }
906 }
907 RegOp::AbsReg(out, arg) => {
908 for i in 0..size {
909 v[out][i] = v[arg][i].abs();
910 }
911 }
912 RegOp::RecipReg(out, arg) => {
913 for i in 0..size {
914 v[out][i] = 1.0 / v[arg][i];
915 }
916 }
917 RegOp::SqrtReg(out, arg) => {
918 for i in 0..size {
919 v[out][i] = v[arg][i].sqrt();
920 }
921 }
922 RegOp::SquareReg(out, arg) => {
923 for i in 0..size {
924 let s = v[arg][i];
925 v[out][i] = s * s;
926 }
927 }
928 RegOp::FloorReg(out, arg) => {
929 for i in 0..size {
930 v[out][i] = v[arg][i].floor();
931 }
932 }
933 RegOp::CeilReg(out, arg) => {
934 for i in 0..size {
935 v[out][i] = v[arg][i].ceil();
936 }
937 }
938 RegOp::RoundReg(out, arg) => {
939 for i in 0..size {
940 v[out][i] = v[arg][i].round();
941 }
942 }
943 RegOp::SinReg(out, arg) => {
944 for i in 0..size {
945 v[out][i] = v[arg][i].sin();
946 }
947 }
948 RegOp::CosReg(out, arg) => {
949 for i in 0..size {
950 v[out][i] = v[arg][i].cos();
951 }
952 }
953 RegOp::TanReg(out, arg) => {
954 for i in 0..size {
955 v[out][i] = v[arg][i].tan();
956 }
957 }
958 RegOp::AsinReg(out, arg) => {
959 for i in 0..size {
960 v[out][i] = v[arg][i].asin();
961 }
962 }
963 RegOp::AcosReg(out, arg) => {
964 for i in 0..size {
965 v[out][i] = v[arg][i].acos();
966 }
967 }
968 RegOp::AtanReg(out, arg) => {
969 for i in 0..size {
970 v[out][i] = v[arg][i].atan();
971 }
972 }
973 RegOp::ExpReg(out, arg) => {
974 for i in 0..size {
975 v[out][i] = v[arg][i].exp();
976 }
977 }
978 RegOp::LnReg(out, arg) => {
979 for i in 0..size {
980 v[out][i] = v[arg][i].ln();
981 }
982 }
983 RegOp::NotReg(out, arg) => {
984 for i in 0..size {
985 v[out][i] = (v[arg][i] == 0.0).into();
986 }
987 }
988 RegOp::CopyReg(out, arg) => {
989 for i in 0..size {
990 v[out][i] = v[arg][i];
991 }
992 }
993 RegOp::AddRegImm(out, arg, imm) => {
994 for i in 0..size {
995 v[out][i] = v[arg][i] + imm;
996 }
997 }
998 RegOp::MulRegImm(out, arg, imm) => {
999 for i in 0..size {
1000 v[out][i] = v[arg][i] * imm;
1001 }
1002 }
1003 RegOp::DivRegImm(out, arg, imm) => {
1004 for i in 0..size {
1005 v[out][i] = v[arg][i] / imm;
1006 }
1007 }
1008 RegOp::DivImmReg(out, arg, imm) => {
1009 for i in 0..size {
1010 v[out][i] = imm / v[arg][i];
1011 }
1012 }
1013 RegOp::AtanRegImm(out, arg, imm) => {
1014 for i in 0..size {
1015 v[out][i] = v[arg][i].atan2(imm);
1016 }
1017 }
1018 RegOp::AtanImmReg(out, arg, imm) => {
1019 for i in 0..size {
1020 v[out][i] = imm.atan2(v[arg][i]);
1021 }
1022 }
1023 RegOp::AtanRegReg(out, lhs, rhs) => {
1024 for i in 0..size {
1025 v[out][i] = v[lhs][i].atan2(v[rhs][i]);
1026 }
1027 }
1028 RegOp::SubImmReg(out, arg, imm) => {
1029 for i in 0..size {
1030 v[out][i] = imm - v[arg][i];
1031 }
1032 }
1033 RegOp::SubRegImm(out, arg, imm) => {
1034 for i in 0..size {
1035 v[out][i] = v[arg][i] - imm;
1036 }
1037 }
1038 RegOp::CompareImmReg(out, arg, imm) => {
1039 for i in 0..size {
1040 v[out][i] = imm
1041 .partial_cmp(&v[arg][i])
1042 .map(|c| c as i8 as f32)
1043 .unwrap_or(f32::NAN)
1044 }
1045 }
1046 RegOp::CompareRegImm(out, arg, imm) => {
1047 for i in 0..size {
1048 v[out][i] = v[arg][i]
1049 .partial_cmp(&imm)
1050 .map(|c| c as i8 as f32)
1051 .unwrap_or(f32::NAN)
1052 }
1053 }
1054 RegOp::MinRegImm(out, arg, imm) => {
1055 for i in 0..size {
1056 v[out][i] = if v[arg][i].is_nan() || imm.is_nan() {
1057 f32::NAN
1058 } else {
1059 v[arg][i].min(imm)
1060 };
1061 }
1062 }
1063 RegOp::MaxRegImm(out, arg, imm) => {
1064 for i in 0..size {
1065 v[out][i] = if v[arg][i].is_nan() || imm.is_nan() {
1066 f32::NAN
1067 } else {
1068 v[arg][i].max(imm)
1069 };
1070 }
1071 }
1072 RegOp::AndRegImm(out, arg, imm) => {
1073 for i in 0..size {
1074 v[out][i] =
1075 if v[arg][i] == 0.0 { v[arg][i] } else { imm };
1076 }
1077 }
1078 RegOp::OrRegImm(out, arg, imm) => {
1079 for i in 0..size {
1080 v[out][i] =
1081 if v[arg][i] != 0.0 { v[arg][i] } else { imm };
1082 }
1083 }
1084 RegOp::ModRegReg(out, lhs, rhs) => {
1085 for i in 0..size {
1086 v[out][i] = v[lhs][i].rem_euclid(v[rhs][i]);
1087 }
1088 }
1089 RegOp::ModRegImm(out, arg, imm) => {
1090 for i in 0..size {
1091 v[out][i] = v[arg][i].rem_euclid(imm);
1092 }
1093 }
1094 RegOp::ModImmReg(out, arg, imm) => {
1095 for i in 0..size {
1096 v[out][i] = imm.rem_euclid(v[arg][i]);
1097 }
1098 }
1099 RegOp::AddRegReg(out, lhs, rhs) => {
1100 for i in 0..size {
1101 v[out][i] = v[lhs][i] + v[rhs][i];
1102 }
1103 }
1104 RegOp::MulRegReg(out, lhs, rhs) => {
1105 for i in 0..size {
1106 v[out][i] = v[lhs][i] * v[rhs][i];
1107 }
1108 }
1109 RegOp::DivRegReg(out, lhs, rhs) => {
1110 for i in 0..size {
1111 v[out][i] = v[lhs][i] / v[rhs][i];
1112 }
1113 }
1114 RegOp::SubRegReg(out, lhs, rhs) => {
1115 for i in 0..size {
1116 v[out][i] = v[lhs][i] - v[rhs][i];
1117 }
1118 }
1119 RegOp::CompareRegReg(out, lhs, rhs) => {
1120 for i in 0..size {
1121 v[out][i] = v[lhs][i]
1122 .partial_cmp(&v[rhs][i])
1123 .map(|c| c as i8 as f32)
1124 .unwrap_or(f32::NAN)
1125 }
1126 }
1127 RegOp::MinRegReg(out, lhs, rhs) => {
1128 for i in 0..size {
1129 v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan()
1130 {
1131 f32::NAN
1132 } else {
1133 v[lhs][i].min(v[rhs][i])
1134 };
1135 }
1136 }
1137 RegOp::MaxRegReg(out, lhs, rhs) => {
1138 for i in 0..size {
1139 v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan()
1140 {
1141 f32::NAN
1142 } else {
1143 v[lhs][i].max(v[rhs][i])
1144 };
1145 }
1146 }
1147 RegOp::AndRegReg(out, lhs, rhs) => {
1148 for i in 0..size {
1149 v[out][i] = if v[lhs][i] == 0.0 {
1150 v[lhs][i]
1151 } else {
1152 v[rhs][i]
1153 };
1154 }
1155 }
1156 RegOp::OrRegReg(out, lhs, rhs) => {
1157 for i in 0..size {
1158 v[out][i] = if v[lhs][i] != 0.0 {
1159 v[lhs][i]
1160 } else {
1161 v[rhs][i]
1162 };
1163 }
1164 }
1165 RegOp::CopyImm(out, imm) => {
1166 for i in 0..size {
1167 v[out][i] = imm;
1168 }
1169 }
1170 RegOp::Load(out, mem) => {
1171 for i in 0..size {
1172 v[out][i] = v[mem][i];
1173 }
1174 }
1175 RegOp::Store(out, mem) => {
1176 for i in 0..size {
1177 v[mem][i] = v[out][i];
1178 }
1179 }
1180 }
1181 }
1182 Ok(BulkOutput::new(&self.0.out, size))
1183 }
1184}
1185
1186#[derive(Default)]
1188pub struct VmGradSliceEval<const N: usize>(BulkVmEval<Grad>);
1189impl<const N: usize> BulkEvaluator for VmGradSliceEval<N> {
1190 type Data = Grad;
1191 type Tape = GenericVmTape<N>;
1192 type TapeStorage = EmptyTapeStorage;
1193
1194 #[inline]
1195 fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1196 &mut self,
1197 tape: &Self::Tape,
1198 vars: &[V],
1199 ) -> Result<BulkOutput<'_, Grad>, Error> {
1200 tape.vars().check_bulk_arguments(vars)?;
1201 let tape = tape.data();
1202 let size = vars.first().map(|v| v.len()).unwrap_or(0);
1203 self.0.resize_slots(tape, size);
1204
1205 let mut v = SlotArray(&mut self.0.slots);
1206 for op in tape.iter_asm() {
1207 match op {
1208 RegOp::Output(arg, i) => {
1209 self.0.out[i as usize][0..size]
1210 .copy_from_slice(&v[arg][0..size]);
1211 }
1212 RegOp::Input(out, i) => {
1213 v[out][0..size].copy_from_slice(&vars[i as usize]);
1214 }
1215 RegOp::NegReg(out, arg) => {
1216 for i in 0..size {
1217 v[out][i] = -v[arg][i];
1218 }
1219 }
1220 RegOp::AbsReg(out, arg) => {
1221 for i in 0..size {
1222 v[out][i] = v[arg][i].abs();
1223 }
1224 }
1225 RegOp::RecipReg(out, arg) => {
1226 let one: Grad = 1.0.into();
1227 for i in 0..size {
1228 v[out][i] = one / v[arg][i];
1229 }
1230 }
1231 RegOp::SqrtReg(out, arg) => {
1232 for i in 0..size {
1233 v[out][i] = v[arg][i].sqrt();
1234 }
1235 }
1236 RegOp::SquareReg(out, arg) => {
1237 for i in 0..size {
1238 let s = v[arg][i];
1239 v[out][i] = s * s;
1240 }
1241 }
1242 RegOp::FloorReg(out, arg) => {
1243 for i in 0..size {
1244 v[out][i] = v[arg][i].floor();
1245 }
1246 }
1247 RegOp::CeilReg(out, arg) => {
1248 for i in 0..size {
1249 v[out][i] = v[arg][i].ceil();
1250 }
1251 }
1252 RegOp::RoundReg(out, arg) => {
1253 for i in 0..size {
1254 v[out][i] = v[arg][i].round();
1255 }
1256 }
1257 RegOp::SinReg(out, arg) => {
1258 for i in 0..size {
1259 v[out][i] = v[arg][i].sin();
1260 }
1261 }
1262 RegOp::CosReg(out, arg) => {
1263 for i in 0..size {
1264 v[out][i] = v[arg][i].cos();
1265 }
1266 }
1267 RegOp::TanReg(out, arg) => {
1268 for i in 0..size {
1269 v[out][i] = v[arg][i].tan();
1270 }
1271 }
1272 RegOp::AsinReg(out, arg) => {
1273 for i in 0..size {
1274 v[out][i] = v[arg][i].asin();
1275 }
1276 }
1277 RegOp::AcosReg(out, arg) => {
1278 for i in 0..size {
1279 v[out][i] = v[arg][i].acos();
1280 }
1281 }
1282 RegOp::AtanReg(out, arg) => {
1283 for i in 0..size {
1284 v[out][i] = v[arg][i].atan();
1285 }
1286 }
1287 RegOp::ExpReg(out, arg) => {
1288 for i in 0..size {
1289 v[out][i] = v[arg][i].exp();
1290 }
1291 }
1292 RegOp::LnReg(out, arg) => {
1293 for i in 0..size {
1294 v[out][i] = v[arg][i].ln();
1295 }
1296 }
1297 RegOp::NotReg(out, arg) => {
1298 for i in 0..size {
1299 v[out][i] = f32::from(v[arg][i].v == 0.0).into();
1300 }
1301 }
1302 RegOp::CopyReg(out, arg) => {
1303 for i in 0..size {
1304 v[out][i] = v[arg][i];
1305 }
1306 }
1307 RegOp::AddRegImm(out, arg, imm) => {
1308 for i in 0..size {
1309 v[out][i] = v[arg][i] + imm.into();
1310 }
1311 }
1312 RegOp::MulRegImm(out, arg, imm) => {
1313 for i in 0..size {
1314 v[out][i] = v[arg][i] * imm;
1315 }
1316 }
1317 RegOp::DivRegImm(out, arg, imm) => {
1318 for i in 0..size {
1319 v[out][i] = v[arg][i] / imm.into();
1320 }
1321 }
1322 RegOp::DivImmReg(out, arg, imm) => {
1323 let imm = Grad::from(imm);
1324 for i in 0..size {
1325 v[out][i] = imm / v[arg][i];
1326 }
1327 }
1328 RegOp::AtanRegImm(out, arg, imm) => {
1329 let imm = Grad::from(imm);
1330 for i in 0..size {
1331 v[out][i] = v[arg][i].atan2(imm);
1332 }
1333 }
1334 RegOp::AtanImmReg(out, arg, imm) => {
1335 let imm = Grad::from(imm);
1336 for i in 0..size {
1337 v[out][i] = imm.atan2(v[arg][i]);
1338 }
1339 }
1340 RegOp::AtanRegReg(out, lhs, rhs) => {
1341 for i in 0..size {
1342 v[out][i] = v[lhs][i].atan2(v[rhs][i]);
1343 }
1344 }
1345 RegOp::SubImmReg(out, arg, imm) => {
1346 let imm: Grad = imm.into();
1347 for i in 0..size {
1348 v[out][i] = imm - v[arg][i];
1349 }
1350 }
1351 RegOp::SubRegImm(out, arg, imm) => {
1352 let imm: Grad = imm.into();
1353 for i in 0..size {
1354 v[out][i] = v[arg][i] - imm;
1355 }
1356 }
1357 RegOp::CompareImmReg(out, arg, imm) => {
1358 for i in 0..size {
1359 let p = imm
1360 .partial_cmp(&v[arg][i].v)
1361 .map(|c| c as i8 as f32)
1362 .unwrap_or(f32::NAN);
1363 v[out][i] = Grad::new(p, 0.0, 0.0, 0.0);
1364 }
1365 }
1366 RegOp::CompareRegImm(out, arg, imm) => {
1367 for i in 0..size {
1368 let p = v[arg][i]
1369 .v
1370 .partial_cmp(&imm)
1371 .map(|c| c as i8 as f32)
1372 .unwrap_or(f32::NAN);
1373 v[out][i] = Grad::new(p, 0.0, 0.0, 0.0);
1374 }
1375 }
1376 RegOp::MinRegImm(out, arg, imm) => {
1377 let imm: Grad = imm.into();
1378 for i in 0..size {
1379 v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() {
1380 f32::NAN.into()
1381 } else {
1382 v[arg][i].min(imm)
1383 };
1384 }
1385 }
1386 RegOp::MaxRegImm(out, arg, imm) => {
1387 let imm: Grad = imm.into();
1388 for i in 0..size {
1389 v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() {
1390 f32::NAN.into()
1391 } else {
1392 v[arg][i].max(imm)
1393 };
1394 }
1395 }
1396 RegOp::ModRegReg(out, lhs, rhs) => {
1397 for i in 0..size {
1398 v[out][i] = v[lhs][i].rem_euclid(v[rhs][i]);
1399 }
1400 }
1401 RegOp::ModRegImm(out, arg, imm) => {
1402 for i in 0..size {
1403 v[out][i] = v[arg][i].rem_euclid(imm.into());
1404 }
1405 }
1406 RegOp::ModImmReg(out, arg, imm) => {
1407 for i in 0..size {
1408 v[out][i] = Grad::from(imm).rem_euclid(v[arg][i]);
1409 }
1410 }
1411 RegOp::AddRegReg(out, lhs, rhs) => {
1412 for i in 0..size {
1413 v[out][i] = v[lhs][i] + v[rhs][i];
1414 }
1415 }
1416 RegOp::MulRegReg(out, lhs, rhs) => {
1417 for i in 0..size {
1418 v[out][i] = v[lhs][i] * v[rhs][i];
1419 }
1420 }
1421 RegOp::AndRegReg(out, lhs, rhs) => {
1422 for i in 0..size {
1423 v[out][i] = if v[lhs][i].v == 0.0 {
1424 v[lhs][i]
1425 } else {
1426 v[rhs][i]
1427 };
1428 }
1429 }
1430 RegOp::AndRegImm(out, arg, imm) => {
1431 for i in 0..size {
1432 v[out][i] = if v[arg][i].v == 0.0 {
1433 v[arg][i]
1434 } else {
1435 imm.into()
1436 };
1437 }
1438 }
1439 RegOp::OrRegReg(out, lhs, rhs) => {
1440 for i in 0..size {
1441 v[out][i] = if v[lhs][i].v != 0.0 {
1442 v[lhs][i]
1443 } else {
1444 v[rhs][i]
1445 };
1446 }
1447 }
1448 RegOp::OrRegImm(out, arg, imm) => {
1449 for i in 0..size {
1450 v[out][i] = if v[arg][i].v != 0.0 {
1451 v[arg][i]
1452 } else {
1453 imm.into()
1454 };
1455 }
1456 }
1457 RegOp::DivRegReg(out, lhs, rhs) => {
1458 for i in 0..size {
1459 v[out][i] = v[lhs][i] / v[rhs][i];
1460 }
1461 }
1462 RegOp::SubRegReg(out, lhs, rhs) => {
1463 for i in 0..size {
1464 v[out][i] = v[lhs][i] - v[rhs][i];
1465 }
1466 }
1467 RegOp::CompareRegReg(out, lhs, rhs) => {
1468 for i in 0..size {
1469 let p = v[lhs][i]
1470 .v
1471 .partial_cmp(&v[rhs][i].v)
1472 .map(|c| c as i8 as f32)
1473 .unwrap_or(f32::NAN);
1474 v[out][i] = Grad::new(p, 0.0, 0.0, 0.0);
1475 }
1476 }
1477 RegOp::MinRegReg(out, lhs, rhs) => {
1478 for i in 0..size {
1479 v[out][i] =
1480 if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() {
1481 f32::NAN.into()
1482 } else {
1483 v[lhs][i].min(v[rhs][i])
1484 };
1485 }
1486 }
1487 RegOp::MaxRegReg(out, lhs, rhs) => {
1488 for i in 0..size {
1489 v[out][i] =
1490 if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() {
1491 f32::NAN.into()
1492 } else {
1493 v[lhs][i].max(v[rhs][i])
1494 };
1495 }
1496 }
1497 RegOp::CopyImm(out, imm) => {
1498 let imm: Grad = imm.into();
1499 for i in 0..size {
1500 v[out][i] = imm;
1501 }
1502 }
1503 RegOp::Load(out, mem) => {
1504 for i in 0..size {
1505 v[out][i] = v[mem][i];
1506 }
1507 }
1508 RegOp::Store(out, mem) => {
1509 for i in 0..size {
1510 v[mem][i] = v[out][i];
1511 }
1512 }
1513 }
1514 }
1515 Ok(BulkOutput::new(&self.0.out, size))
1516 }
1517}
1518
1519#[cfg(test)]
1520mod test {
1521 use super::*;
1522 crate::grad_slice_tests!(VmFunction);
1523 crate::interval_tests!(VmFunction);
1524 crate::float_slice_tests!(VmFunction);
1525 crate::point_tests!(VmFunction);
1526}