fidget_core/vm/
mod.rs

1//! Simple virtual machine for shape evaluation
2use crate::{
3    Context, Error,
4    compiler::RegOp,
5    context::Node,
6    eval::{
7        BulkEvaluator, BulkOutput, Function, MathFunction, Tape, Trace,
8        TracingEvaluator,
9    },
10    render::{RenderHints, TileSizes},
11    shape::Shape,
12    types::{Grad, Interval},
13    var::VarMap,
14};
15use std::sync::Arc;
16
17mod choice;
18mod data;
19
20pub use choice::Choice;
21pub use data::{VmData, VmWorkspace};
22
23////////////////////////////////////////////////////////////////////////////////
24
25/// Function which uses the VM backend for evaluation
26///
27/// Internally, the [`VmFunction`] stores an [`Arc<VmData>`](VmData), and
28/// iterates over a [`Vec<RegOp>`](RegOp) to perform evaluation.
29///
30/// All of the associated [`Tape`] types simply clone the internal `Arc`;
31/// there's no separate planning required to generate a tape.
32pub type VmFunction = GenericVmFunction<{ u8::MAX as usize }>;
33
34/// Shape that uses the [`VmFunction`] backend for evaluation
35pub type VmShape = Shape<VmFunction>;
36
37/// Tape storage type which indicates that there's no actual backing storage
38#[derive(Default)]
39pub struct EmptyTapeStorage;
40
41/// Tape which uses the VM backend for evaluation
42///
43/// This tape type is equivalent to a [`GenericVmFunction`], but implements
44/// different traits ([`Tape`] instead of [`Function`]).
45#[derive(Clone)]
46pub struct GenericVmTape<const N: usize>(Arc<VmData<N>>);
47
48impl<const N: usize> GenericVmTape<N> {
49    /// Returns a handle to the inner [`VmData`] used by the tape
50    pub fn data(&self) -> &VmData<N> {
51        &self.0
52    }
53}
54
55impl<const N: usize> Tape for GenericVmTape<N> {
56    type Storage = EmptyTapeStorage;
57    fn recycle(self) -> Option<Self::Storage> {
58        Some(EmptyTapeStorage)
59    }
60
61    fn vars(&self) -> &VarMap {
62        &self.0.vars
63    }
64
65    fn output_count(&self) -> usize {
66        self.0.output_count()
67    }
68}
69
70/// A trace captured by a VM evaluation
71///
72/// This is a thin wrapper around a [`Vec<Choice>`](Choice).
73#[derive(Clone, Default, Eq, PartialEq)]
74pub struct VmTrace(Vec<Choice>);
75
76impl VmTrace {
77    /// Fills the trace with the given value
78    pub fn fill(&mut self, v: Choice) {
79        self.0.fill(v);
80    }
81    /// Resizes the trace, using the new value if it needs to be extended
82    pub fn resize(&mut self, n: usize, v: Choice) {
83        self.0.resize(n, v);
84    }
85    /// Returns the inner choice slice
86    pub fn as_slice(&self) -> &[Choice] {
87        self.0.as_slice()
88    }
89    /// Returns the inner choice slice as a mutable reference
90    pub fn as_mut_slice(&mut self) -> &mut [Choice] {
91        self.0.as_mut_slice()
92    }
93    /// Returns a pointer to the allocated choice array
94    pub fn as_mut_ptr(&mut self) -> *mut Choice {
95        self.0.as_mut_ptr()
96    }
97}
98
99impl Trace for VmTrace {
100    fn copy_from(&mut self, other: &VmTrace) {
101        self.0.resize(other.0.len(), Choice::Unknown);
102        self.0.copy_from_slice(&other.0);
103    }
104}
105
106#[cfg(any(test, feature = "eval-tests"))]
107impl From<Vec<Choice>> for VmTrace {
108    fn from(v: Vec<Choice>) -> Self {
109        Self(v)
110    }
111}
112
113#[cfg(any(test, feature = "eval-tests"))]
114impl AsRef<[Choice]> for VmTrace {
115    fn as_ref(&self) -> &[Choice] {
116        &self.0
117    }
118}
119
120/// VM-backed shape with a configurable number of registers
121///
122/// You are unlikely to use this directly; [`VmShape`] should be used for
123/// VM-based evaluation.
124#[derive(Clone)]
125pub struct GenericVmFunction<const N: usize>(Arc<VmData<N>>);
126
127impl<const N: usize> From<VmData<N>> for GenericVmFunction<N> {
128    fn from(d: VmData<N>) -> Self {
129        Self(d.into())
130    }
131}
132
133impl<const N: usize> GenericVmFunction<N> {
134    /// Returns a characteristic size (the length of the inner assembly tape)
135    pub fn size(&self) -> usize {
136        self.0.len()
137    }
138
139    /// Reclaim the inner `VmData` if there's only a single reference
140    pub fn recycle(self) -> Option<VmData<N>> {
141        Arc::try_unwrap(self.0).ok()
142    }
143
144    /// Borrows the inner [`VmData`]
145    pub fn data(&self) -> &VmData<N> {
146        self.0.as_ref()
147    }
148
149    /// Returns a [`GenericVmTape`] for the given function
150    pub fn tape(&self) -> GenericVmTape<N> {
151        GenericVmTape(self.0.clone())
152    }
153
154    /// Returns the number of choices (i.e. `min` and `max` nodes) in the tape
155    pub fn choice_count(&self) -> usize {
156        self.0.choice_count()
157    }
158
159    /// Returns the number of outputs in the tape
160    pub fn output_count(&self) -> usize {
161        self.0.output_count()
162    }
163
164    /// Simplifies the function with the given trace and a new register count
165    pub fn simplify_with<const M: usize>(
166        &self,
167        trace: &VmTrace,
168        storage: VmData<M>,
169        workspace: &mut VmWorkspace<M>,
170    ) -> Result<GenericVmFunction<M>, Error> {
171        let d = self.0.simplify::<M>(trace.as_slice(), workspace, storage)?;
172        Ok(GenericVmFunction(Arc::new(d)))
173    }
174}
175
176impl<const N: usize> Function for GenericVmFunction<N> {
177    type Storage = VmData<N>;
178    type Workspace = VmWorkspace<N>;
179
180    type TapeStorage = EmptyTapeStorage;
181
182    type FloatSliceEval = VmFloatSliceEval<N>;
183    type GradSliceEval = VmGradSliceEval<N>;
184    type PointEval = VmPointEval<N>;
185    type IntervalEval = VmIntervalEval<N>;
186    type Trace = VmTrace;
187
188    #[inline]
189    fn float_slice_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
190        self.tape()
191    }
192
193    #[inline]
194    fn grad_slice_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
195        self.tape()
196    }
197
198    #[inline]
199    fn point_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
200        self.tape()
201    }
202
203    #[inline]
204    fn interval_tape(&self, _storage: EmptyTapeStorage) -> GenericVmTape<N> {
205        self.tape()
206    }
207
208    #[inline]
209    fn simplify(
210        &self,
211        trace: &Self::Trace,
212        storage: Self::Storage,
213        workspace: &mut Self::Workspace,
214    ) -> Result<Self, Error> {
215        self.simplify_with(trace, storage, workspace)
216    }
217
218    #[inline]
219    fn recycle(self) -> Option<Self::Storage> {
220        GenericVmFunction::recycle(self)
221    }
222
223    #[inline]
224    fn size(&self) -> usize {
225        GenericVmFunction::size(self)
226    }
227
228    #[inline]
229    fn vars(&self) -> &VarMap {
230        &self.0.vars
231    }
232
233    #[inline]
234    fn can_simplify(&self) -> bool {
235        self.0.choice_count() > 0
236    }
237}
238
239impl<const N: usize> RenderHints for GenericVmFunction<N> {
240    fn tile_sizes_3d() -> TileSizes {
241        TileSizes::new(&[128, 64, 32, 16, 8]).unwrap()
242    }
243
244    fn tile_sizes_2d() -> TileSizes {
245        TileSizes::new(&[128, 32, 8]).unwrap()
246    }
247}
248
249impl<const N: usize> MathFunction for GenericVmFunction<N> {
250    fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error> {
251        let d = VmData::new(ctx, nodes)?;
252        Ok(Self(d.into()))
253    }
254}
255
256////////////////////////////////////////////////////////////////////////////////
257
258/// Helper struct to reduce boilerplate conversions
259struct SlotArray<'a, T>(&'a mut [T]);
260impl<T> std::ops::Index<u8> for SlotArray<'_, T> {
261    type Output = T;
262    fn index(&self, i: u8) -> &Self::Output {
263        &self.0[i as usize]
264    }
265}
266impl<T> std::ops::IndexMut<u8> for SlotArray<'_, T> {
267    fn index_mut(&mut self, i: u8) -> &mut T {
268        &mut self.0[i as usize]
269    }
270}
271impl<T> std::ops::Index<u32> for SlotArray<'_, T> {
272    type Output = T;
273    fn index(&self, i: u32) -> &Self::Output {
274        &self.0[i as usize]
275    }
276}
277impl<T> std::ops::IndexMut<u32> for SlotArray<'_, T> {
278    fn index_mut(&mut self, i: u32) -> &mut T {
279        &mut self.0[i as usize]
280    }
281}
282
283////////////////////////////////////////////////////////////////////////////////
284
285/// Generic VM evaluator for tracing evaluation
286struct TracingVmEval<T> {
287    slots: Vec<T>,
288    out: Vec<T>,
289    choices: VmTrace,
290}
291
292impl<T> Default for TracingVmEval<T> {
293    fn default() -> Self {
294        Self {
295            slots: Vec::default(),
296            out: Vec::default(),
297            choices: VmTrace::default(),
298        }
299    }
300}
301
302impl<T: From<f32> + Clone> TracingVmEval<T> {
303    fn resize_slots<const N: usize>(&mut self, tape: &VmData<N>) {
304        self.slots.resize(tape.slot_count(), f32::NAN.into());
305        self.choices.resize(tape.choice_count(), Choice::Unknown);
306        self.out.resize(tape.output_count(), f32::NAN.into());
307        self.choices.fill(Choice::Unknown);
308    }
309}
310
311/// VM-based tracing evaluator for intervals
312#[derive(Default)]
313pub struct VmIntervalEval<const N: usize>(TracingVmEval<Interval>);
314impl<const N: usize> TracingEvaluator for VmIntervalEval<N> {
315    type Data = Interval;
316    type Tape = GenericVmTape<N>;
317    type Trace = VmTrace;
318    type TapeStorage = EmptyTapeStorage;
319
320    #[inline]
321    fn eval(
322        &mut self,
323        tape: &Self::Tape,
324        vars: &[Interval],
325    ) -> Result<(&[Interval], Option<&VmTrace>), Error> {
326        tape.vars().check_tracing_arguments(vars)?;
327        let tape = tape.data();
328        self.0.resize_slots(tape);
329
330        let mut simplify = false;
331        let mut v = SlotArray(&mut self.0.slots);
332        let mut choices = self.0.choices.as_mut_slice().iter_mut();
333        for op in tape.iter_asm() {
334            match op {
335                RegOp::Output(arg, i) => {
336                    self.0.out[i as usize] = v[arg];
337                }
338                RegOp::Input(out, i) => {
339                    v[out] = vars[i as usize];
340                }
341                RegOp::NegReg(out, arg) => {
342                    v[out] = -v[arg];
343                }
344                RegOp::AbsReg(out, arg) => {
345                    v[out] = v[arg].abs();
346                }
347                RegOp::RecipReg(out, arg) => {
348                    v[out] = v[arg].recip();
349                }
350                RegOp::SqrtReg(out, arg) => {
351                    v[out] = v[arg].sqrt();
352                }
353                RegOp::SquareReg(out, arg) => {
354                    v[out] = v[arg].square();
355                }
356                RegOp::FloorReg(out, arg) => {
357                    v[out] = v[arg].floor();
358                }
359                RegOp::CeilReg(out, arg) => {
360                    v[out] = v[arg].ceil();
361                }
362                RegOp::RoundReg(out, arg) => {
363                    v[out] = v[arg].round();
364                }
365                RegOp::SinReg(out, arg) => {
366                    v[out] = v[arg].sin();
367                }
368                RegOp::CosReg(out, arg) => {
369                    v[out] = v[arg].cos();
370                }
371                RegOp::TanReg(out, arg) => {
372                    v[out] = v[arg].tan();
373                }
374                RegOp::AsinReg(out, arg) => {
375                    v[out] = v[arg].asin();
376                }
377                RegOp::AcosReg(out, arg) => {
378                    v[out] = v[arg].acos();
379                }
380                RegOp::AtanReg(out, arg) => {
381                    v[out] = v[arg].atan();
382                }
383                RegOp::ExpReg(out, arg) => {
384                    v[out] = v[arg].exp();
385                }
386                RegOp::LnReg(out, arg) => {
387                    v[out] = v[arg].ln();
388                }
389                RegOp::NotReg(out, arg) => {
390                    v[out] = if !v[arg].contains(0.0) && !v[arg].has_nan() {
391                        Interval::new(0.0, 0.0)
392                    } else if v[arg].lower() == 0.0 && v[arg].upper() == 0.0 {
393                        Interval::new(1.0, 1.0)
394                    } else {
395                        Interval::new(0.0, 1.0)
396                    };
397                }
398                RegOp::CopyReg(out, arg) => v[out] = v[arg],
399                RegOp::AddRegImm(out, arg, imm) => {
400                    v[out] = v[arg] + imm.into();
401                }
402                RegOp::MulRegImm(out, arg, imm) => {
403                    v[out] = v[arg] * imm;
404                }
405                RegOp::DivRegImm(out, arg, imm) => {
406                    v[out] = v[arg] / imm.into();
407                }
408                RegOp::DivImmReg(out, arg, imm) => {
409                    let imm: Interval = imm.into();
410                    v[out] = imm / v[arg];
411                }
412                RegOp::AtanRegImm(out, arg, imm) => {
413                    v[out] = v[arg].atan2(imm.into());
414                }
415                RegOp::AtanImmReg(out, arg, imm) => {
416                    let imm: Interval = imm.into();
417                    v[out] = imm.atan2(v[arg]);
418                }
419                RegOp::AtanRegReg(out, lhs, rhs) => {
420                    v[out] = v[lhs].atan2(v[rhs]);
421                }
422                RegOp::SubImmReg(out, arg, imm) => {
423                    v[out] = Interval::from(imm) - v[arg];
424                }
425                RegOp::SubRegImm(out, arg, imm) => {
426                    v[out] = v[arg] - imm.into();
427                }
428                RegOp::MinRegImm(out, arg, imm) => {
429                    let (value, choice) = v[arg].min_choice(imm.into());
430                    v[out] = value;
431                    *choices.next().unwrap() |= choice;
432                    simplify |= choice != Choice::Both;
433                }
434                RegOp::MaxRegImm(out, arg, imm) => {
435                    let (value, choice) = v[arg].max_choice(imm.into());
436                    v[out] = value;
437                    *choices.next().unwrap() |= choice;
438                    simplify |= choice != Choice::Both;
439                }
440                RegOp::AndRegReg(out, lhs, rhs) => {
441                    let (value, choice) = v[lhs].and_choice(v[rhs]);
442                    v[out] = value;
443                    *choices.next().unwrap() |= choice;
444                    simplify |= choice != Choice::Both;
445                }
446                RegOp::AndRegImm(out, arg, imm) => {
447                    let (value, choice) = v[arg].and_choice(imm.into());
448                    v[out] = value;
449                    *choices.next().unwrap() |= choice;
450                    simplify |= choice != Choice::Both;
451                }
452                RegOp::OrRegReg(out, lhs, rhs) => {
453                    let (value, choice) = v[lhs].or_choice(v[rhs]);
454                    v[out] = value;
455                    *choices.next().unwrap() |= choice;
456                    simplify |= choice != Choice::Both;
457                }
458                RegOp::OrRegImm(out, arg, imm) => {
459                    let (value, choice) = v[arg].or_choice(imm.into());
460                    v[out] = value;
461                    *choices.next().unwrap() |= choice;
462                    simplify |= choice != Choice::Both;
463                }
464                RegOp::ModRegReg(out, lhs, rhs) => {
465                    v[out] = v[lhs].rem_euclid(v[rhs]);
466                }
467                RegOp::ModRegImm(out, arg, imm) => {
468                    v[out] = v[arg].rem_euclid(imm.into());
469                }
470                RegOp::ModImmReg(out, arg, imm) => {
471                    v[out] = Interval::from(imm).rem_euclid(v[arg]);
472                }
473                RegOp::AddRegReg(out, lhs, rhs) => v[out] = v[lhs] + v[rhs],
474                RegOp::MulRegReg(out, lhs, rhs) => v[out] = v[lhs] * v[rhs],
475                RegOp::DivRegReg(out, lhs, rhs) => v[out] = v[lhs] / v[rhs],
476                RegOp::SubRegReg(out, lhs, rhs) => v[out] = v[lhs] - v[rhs],
477                RegOp::CompareRegReg(out, lhs, rhs) => {
478                    v[out] = if v[lhs].has_nan() || v[rhs].has_nan() {
479                        f32::NAN.into()
480                    } else if v[lhs].upper() < v[rhs].lower() {
481                        Interval::from(-1.0)
482                    } else if v[lhs].lower() > v[rhs].upper() {
483                        Interval::from(1.0)
484                    } else {
485                        Interval::new(-1.0, 1.0)
486                    };
487                }
488                RegOp::CompareRegImm(out, arg, imm) => {
489                    v[out] = if v[arg].has_nan() || imm.is_nan() {
490                        f32::NAN.into()
491                    } else if v[arg].upper() < imm {
492                        Interval::from(-1.0)
493                    } else if v[arg].lower() > imm {
494                        Interval::from(1.0)
495                    } else {
496                        Interval::new(-1.0, 1.0)
497                    };
498                }
499                RegOp::CompareImmReg(out, arg, imm) => {
500                    v[out] = if v[arg].has_nan() || imm.is_nan() {
501                        f32::NAN.into()
502                    } else if imm < v[arg].lower() {
503                        Interval::from(-1.0)
504                    } else if imm > v[arg].upper() {
505                        Interval::from(1.0)
506                    } else {
507                        Interval::new(-1.0, 1.0)
508                    };
509                }
510                RegOp::MinRegReg(out, lhs, rhs) => {
511                    let (value, choice) = v[lhs].min_choice(v[rhs]);
512                    v[out] = value;
513                    *choices.next().unwrap() |= choice;
514                    simplify |= choice != Choice::Both;
515                }
516                RegOp::MaxRegReg(out, lhs, rhs) => {
517                    let (value, choice) = v[lhs].max_choice(v[rhs]);
518                    v[out] = value;
519                    *choices.next().unwrap() |= choice;
520                    simplify |= choice != Choice::Both;
521                }
522                RegOp::CopyImm(out, imm) => {
523                    v[out] = imm.into();
524                }
525                RegOp::Load(out, mem) => {
526                    v[out] = v[mem];
527                }
528                RegOp::Store(out, mem) => {
529                    v[mem] = v[out];
530                }
531            }
532        }
533        Ok((
534            &self.0.out,
535            if simplify {
536                Some(&self.0.choices)
537            } else {
538                None
539            },
540        ))
541    }
542}
543
544/// VM-based tracing evaluator for single points
545#[derive(Default)]
546pub struct VmPointEval<const N: usize>(TracingVmEval<f32>);
547impl<const N: usize> TracingEvaluator for VmPointEval<N> {
548    type Data = f32;
549    type Tape = GenericVmTape<N>;
550    type Trace = VmTrace;
551    type TapeStorage = EmptyTapeStorage;
552
553    #[inline]
554    fn eval(
555        &mut self,
556        tape: &Self::Tape,
557        vars: &[f32],
558    ) -> Result<(&[f32], Option<&VmTrace>), Error> {
559        tape.vars().check_tracing_arguments(vars)?;
560        let tape = tape.data();
561        self.0.resize_slots(tape);
562
563        let mut choices = self.0.choices.as_mut_slice().iter_mut();
564        let mut simplify = false;
565        let mut v = SlotArray(&mut self.0.slots);
566        for op in tape.iter_asm() {
567            match op {
568                RegOp::Output(arg, i) => {
569                    self.0.out[i as usize] = v[arg];
570                }
571                RegOp::Input(out, i) => {
572                    v[out] = vars[i as usize];
573                }
574                RegOp::NegReg(out, arg) => {
575                    v[out] = -v[arg];
576                }
577                RegOp::AbsReg(out, arg) => {
578                    v[out] = v[arg].abs();
579                }
580                RegOp::RecipReg(out, arg) => {
581                    v[out] = 1.0 / v[arg];
582                }
583                RegOp::SqrtReg(out, arg) => {
584                    v[out] = v[arg].sqrt();
585                }
586                RegOp::SquareReg(out, arg) => {
587                    let s = v[arg];
588                    v[out] = s * s;
589                }
590                RegOp::FloorReg(out, arg) => {
591                    v[out] = v[arg].floor();
592                }
593                RegOp::CeilReg(out, arg) => {
594                    v[out] = v[arg].ceil();
595                }
596                RegOp::RoundReg(out, arg) => {
597                    v[out] = v[arg].round();
598                }
599                RegOp::SinReg(out, arg) => {
600                    v[out] = v[arg].sin();
601                }
602                RegOp::CosReg(out, arg) => {
603                    v[out] = v[arg].cos();
604                }
605                RegOp::TanReg(out, arg) => {
606                    v[out] = v[arg].tan();
607                }
608                RegOp::AsinReg(out, arg) => {
609                    v[out] = v[arg].asin();
610                }
611                RegOp::AcosReg(out, arg) => {
612                    v[out] = v[arg].acos();
613                }
614                RegOp::AtanReg(out, arg) => {
615                    v[out] = v[arg].atan();
616                }
617                RegOp::ExpReg(out, arg) => {
618                    v[out] = v[arg].exp();
619                }
620                RegOp::LnReg(out, arg) => {
621                    v[out] = v[arg].ln();
622                }
623                RegOp::NotReg(out, arg) => v[out] = (v[arg] == 0.0).into(),
624                RegOp::CopyReg(out, arg) => {
625                    v[out] = v[arg];
626                }
627                RegOp::AddRegImm(out, arg, imm) => {
628                    v[out] = v[arg] + imm;
629                }
630                RegOp::MulRegImm(out, arg, imm) => {
631                    v[out] = v[arg] * imm;
632                }
633                RegOp::DivRegImm(out, arg, imm) => {
634                    v[out] = v[arg] / imm;
635                }
636                RegOp::DivImmReg(out, arg, imm) => {
637                    v[out] = imm / v[arg];
638                }
639                RegOp::AtanRegImm(out, arg, imm) => {
640                    v[out] = v[arg].atan2(imm);
641                }
642                RegOp::AtanImmReg(out, arg, imm) => {
643                    v[out] = imm.atan2(v[arg]);
644                }
645                RegOp::AtanRegReg(out, lhs, rhs) => {
646                    v[out] = v[lhs].atan2(v[rhs]);
647                }
648                RegOp::SubImmReg(out, arg, imm) => {
649                    v[out] = imm - v[arg];
650                }
651                RegOp::SubRegImm(out, arg, imm) => {
652                    v[out] = v[arg] - imm;
653                }
654                RegOp::MinRegImm(out, arg, imm) => {
655                    let a = v[arg];
656                    let (choice, value) = if a < imm {
657                        (Choice::Left, a)
658                    } else if imm < a {
659                        (Choice::Right, imm)
660                    } else {
661                        (
662                            Choice::Both,
663                            if a.is_nan() || imm.is_nan() {
664                                f32::NAN
665                            } else {
666                                imm
667                            },
668                        )
669                    };
670                    v[out] = value;
671                    *choices.next().unwrap() |= choice;
672                    simplify |= choice != Choice::Both;
673                }
674                RegOp::MaxRegImm(out, arg, imm) => {
675                    let a = v[arg];
676                    let (choice, value) = if a > imm {
677                        (Choice::Left, a)
678                    } else if imm > a {
679                        (Choice::Right, imm)
680                    } else {
681                        (
682                            Choice::Both,
683                            if a.is_nan() || imm.is_nan() {
684                                f32::NAN
685                            } else {
686                                imm
687                            },
688                        )
689                    };
690                    v[out] = value;
691                    *choices.next().unwrap() |= choice;
692                    simplify |= choice != Choice::Both;
693                }
694                RegOp::AndRegImm(out, arg, imm) => {
695                    let a = v[arg];
696                    let (choice, value) = if a == 0.0 {
697                        (Choice::Left, a)
698                    } else {
699                        (Choice::Right, imm)
700                    };
701                    v[out] = value;
702                    *choices.next().unwrap() |= choice;
703                    simplify |= choice != Choice::Both;
704                }
705                RegOp::OrRegImm(out, arg, imm) => {
706                    let a = v[arg];
707                    let (choice, value) = if a != 0.0 {
708                        (Choice::Left, a)
709                    } else {
710                        (Choice::Right, imm)
711                    };
712                    v[out] = value;
713                    *choices.next().unwrap() |= choice;
714                    simplify |= choice != Choice::Both;
715                }
716                RegOp::ModRegReg(out, lhs, rhs) => {
717                    v[out] = v[lhs].rem_euclid(v[rhs]);
718                }
719                RegOp::ModRegImm(out, arg, imm) => {
720                    v[out] = v[arg].rem_euclid(imm);
721                }
722                RegOp::ModImmReg(out, arg, imm) => {
723                    v[out] = imm.rem_euclid(v[arg]);
724                }
725                RegOp::AddRegReg(out, lhs, rhs) => {
726                    v[out] = v[lhs] + v[rhs];
727                }
728                RegOp::MulRegReg(out, lhs, rhs) => {
729                    v[out] = v[lhs] * v[rhs];
730                }
731                RegOp::DivRegReg(out, lhs, rhs) => {
732                    v[out] = v[lhs] / v[rhs];
733                }
734                RegOp::CompareRegReg(out, lhs, rhs) => {
735                    v[out] = v[lhs]
736                        .partial_cmp(&v[rhs])
737                        .map(|c| c as i8 as f32)
738                        .unwrap_or(f32::NAN)
739                }
740                RegOp::CompareRegImm(out, arg, imm) => {
741                    v[out] = v[arg]
742                        .partial_cmp(&imm)
743                        .map(|c| c as i8 as f32)
744                        .unwrap_or(f32::NAN)
745                }
746                RegOp::CompareImmReg(out, arg, imm) => {
747                    v[out] = imm
748                        .partial_cmp(&v[arg])
749                        .map(|c| c as i8 as f32)
750                        .unwrap_or(f32::NAN)
751                }
752                RegOp::SubRegReg(out, lhs, rhs) => {
753                    v[out] = v[lhs] - v[rhs];
754                }
755                RegOp::MinRegReg(out, lhs, rhs) => {
756                    let a = v[lhs];
757                    let b = v[rhs];
758                    let (choice, value) = if a < b {
759                        (Choice::Left, a)
760                    } else if b < a {
761                        (Choice::Right, b)
762                    } else {
763                        (
764                            Choice::Both,
765                            if a.is_nan() || b.is_nan() {
766                                f32::NAN
767                            } else {
768                                b
769                            },
770                        )
771                    };
772                    v[out] = value;
773                    *choices.next().unwrap() |= choice;
774                    simplify |= choice != Choice::Both;
775                }
776                RegOp::MaxRegReg(out, lhs, rhs) => {
777                    let a = v[lhs];
778                    let b = v[rhs];
779                    let (choice, value) = if a > b {
780                        (Choice::Left, a)
781                    } else if b > a {
782                        (Choice::Right, b)
783                    } else {
784                        (
785                            Choice::Both,
786                            if a.is_nan() || b.is_nan() {
787                                f32::NAN
788                            } else {
789                                b
790                            },
791                        )
792                    };
793                    v[out] = value;
794                    *choices.next().unwrap() |= choice;
795                    simplify |= choice != Choice::Both;
796                }
797                RegOp::AndRegReg(out, lhs, rhs) => {
798                    let a = v[lhs];
799                    let b = v[rhs];
800                    let (choice, value) = if a == 0.0 {
801                        (Choice::Left, a)
802                    } else {
803                        (Choice::Right, b)
804                    };
805                    v[out] = value;
806                    *choices.next().unwrap() |= choice;
807                    simplify |= choice != Choice::Both;
808                }
809                RegOp::OrRegReg(out, lhs, rhs) => {
810                    let a = v[lhs];
811                    let b = v[rhs];
812                    let (choice, value) = if a != 0.0 {
813                        (Choice::Left, a)
814                    } else {
815                        (Choice::Right, b)
816                    };
817                    v[out] = value;
818                    *choices.next().unwrap() |= choice;
819                    simplify |= choice != Choice::Both;
820                }
821                RegOp::CopyImm(out, imm) => {
822                    v[out] = imm;
823                }
824                RegOp::Load(out, mem) => {
825                    v[out] = v[mem];
826                }
827                RegOp::Store(out, mem) => {
828                    v[mem] = v[out];
829                }
830            }
831        }
832        Ok((
833            &self.0.out,
834            if simplify {
835                Some(&self.0.choices)
836            } else {
837                None
838            },
839        ))
840    }
841}
842
843////////////////////////////////////////////////////////////////////////////////
844
845/// Bulk evaluator for VM tapes
846#[derive(Default)]
847struct BulkVmEval<T> {
848    /// Workspace for data
849    slots: Vec<Vec<T>>,
850
851    /// Output array
852    out: Vec<Vec<T>>,
853}
854
855impl<T: From<f32> + Clone> BulkVmEval<T> {
856    /// Reserves slots for the given tape and slice size
857    fn resize_slots<const N: usize>(&mut self, tape: &VmData<N>, size: usize) {
858        self.slots
859            .resize_with(tape.slot_count(), || vec![f32::NAN.into(); size]);
860        for s in self.slots.iter_mut() {
861            s.resize(size, f32::NAN.into());
862        }
863
864        self.out
865            .resize_with(tape.output_count(), || vec![f32::NAN.into(); size]);
866        for o in self.out.iter_mut() {
867            o.resize(size, f32::NAN.into());
868        }
869    }
870}
871
872/// VM-based bulk evaluator for arrays of points, yielding point values
873#[derive(Default)]
874pub struct VmFloatSliceEval<const N: usize>(BulkVmEval<f32>);
875impl<const N: usize> BulkEvaluator for VmFloatSliceEval<N> {
876    type Data = f32;
877    type Tape = GenericVmTape<N>;
878    type TapeStorage = EmptyTapeStorage;
879
880    #[inline]
881    fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
882        &mut self,
883        tape: &Self::Tape,
884        vars: &[V],
885    ) -> Result<BulkOutput<'_, f32>, Error> {
886        tape.vars().check_bulk_arguments(vars)?;
887        let tape = tape.data();
888
889        let size = vars.first().map(|v| v.len()).unwrap_or(0);
890        self.0.resize_slots(tape, size);
891
892        let mut v = SlotArray(&mut self.0.slots);
893        for op in tape.iter_asm() {
894            match op {
895                RegOp::Output(arg, i) => {
896                    self.0.out[i as usize][0..size]
897                        .copy_from_slice(&v[arg][0..size]);
898                }
899                RegOp::Input(out, i) => {
900                    v[out][0..size].copy_from_slice(&vars[i as usize]);
901                }
902                RegOp::NegReg(out, arg) => {
903                    for i in 0..size {
904                        v[out][i] = -v[arg][i];
905                    }
906                }
907                RegOp::AbsReg(out, arg) => {
908                    for i in 0..size {
909                        v[out][i] = v[arg][i].abs();
910                    }
911                }
912                RegOp::RecipReg(out, arg) => {
913                    for i in 0..size {
914                        v[out][i] = 1.0 / v[arg][i];
915                    }
916                }
917                RegOp::SqrtReg(out, arg) => {
918                    for i in 0..size {
919                        v[out][i] = v[arg][i].sqrt();
920                    }
921                }
922                RegOp::SquareReg(out, arg) => {
923                    for i in 0..size {
924                        let s = v[arg][i];
925                        v[out][i] = s * s;
926                    }
927                }
928                RegOp::FloorReg(out, arg) => {
929                    for i in 0..size {
930                        v[out][i] = v[arg][i].floor();
931                    }
932                }
933                RegOp::CeilReg(out, arg) => {
934                    for i in 0..size {
935                        v[out][i] = v[arg][i].ceil();
936                    }
937                }
938                RegOp::RoundReg(out, arg) => {
939                    for i in 0..size {
940                        v[out][i] = v[arg][i].round();
941                    }
942                }
943                RegOp::SinReg(out, arg) => {
944                    for i in 0..size {
945                        v[out][i] = v[arg][i].sin();
946                    }
947                }
948                RegOp::CosReg(out, arg) => {
949                    for i in 0..size {
950                        v[out][i] = v[arg][i].cos();
951                    }
952                }
953                RegOp::TanReg(out, arg) => {
954                    for i in 0..size {
955                        v[out][i] = v[arg][i].tan();
956                    }
957                }
958                RegOp::AsinReg(out, arg) => {
959                    for i in 0..size {
960                        v[out][i] = v[arg][i].asin();
961                    }
962                }
963                RegOp::AcosReg(out, arg) => {
964                    for i in 0..size {
965                        v[out][i] = v[arg][i].acos();
966                    }
967                }
968                RegOp::AtanReg(out, arg) => {
969                    for i in 0..size {
970                        v[out][i] = v[arg][i].atan();
971                    }
972                }
973                RegOp::ExpReg(out, arg) => {
974                    for i in 0..size {
975                        v[out][i] = v[arg][i].exp();
976                    }
977                }
978                RegOp::LnReg(out, arg) => {
979                    for i in 0..size {
980                        v[out][i] = v[arg][i].ln();
981                    }
982                }
983                RegOp::NotReg(out, arg) => {
984                    for i in 0..size {
985                        v[out][i] = (v[arg][i] == 0.0).into();
986                    }
987                }
988                RegOp::CopyReg(out, arg) => {
989                    for i in 0..size {
990                        v[out][i] = v[arg][i];
991                    }
992                }
993                RegOp::AddRegImm(out, arg, imm) => {
994                    for i in 0..size {
995                        v[out][i] = v[arg][i] + imm;
996                    }
997                }
998                RegOp::MulRegImm(out, arg, imm) => {
999                    for i in 0..size {
1000                        v[out][i] = v[arg][i] * imm;
1001                    }
1002                }
1003                RegOp::DivRegImm(out, arg, imm) => {
1004                    for i in 0..size {
1005                        v[out][i] = v[arg][i] / imm;
1006                    }
1007                }
1008                RegOp::DivImmReg(out, arg, imm) => {
1009                    for i in 0..size {
1010                        v[out][i] = imm / v[arg][i];
1011                    }
1012                }
1013                RegOp::AtanRegImm(out, arg, imm) => {
1014                    for i in 0..size {
1015                        v[out][i] = v[arg][i].atan2(imm);
1016                    }
1017                }
1018                RegOp::AtanImmReg(out, arg, imm) => {
1019                    for i in 0..size {
1020                        v[out][i] = imm.atan2(v[arg][i]);
1021                    }
1022                }
1023                RegOp::AtanRegReg(out, lhs, rhs) => {
1024                    for i in 0..size {
1025                        v[out][i] = v[lhs][i].atan2(v[rhs][i]);
1026                    }
1027                }
1028                RegOp::SubImmReg(out, arg, imm) => {
1029                    for i in 0..size {
1030                        v[out][i] = imm - v[arg][i];
1031                    }
1032                }
1033                RegOp::SubRegImm(out, arg, imm) => {
1034                    for i in 0..size {
1035                        v[out][i] = v[arg][i] - imm;
1036                    }
1037                }
1038                RegOp::CompareImmReg(out, arg, imm) => {
1039                    for i in 0..size {
1040                        v[out][i] = imm
1041                            .partial_cmp(&v[arg][i])
1042                            .map(|c| c as i8 as f32)
1043                            .unwrap_or(f32::NAN)
1044                    }
1045                }
1046                RegOp::CompareRegImm(out, arg, imm) => {
1047                    for i in 0..size {
1048                        v[out][i] = v[arg][i]
1049                            .partial_cmp(&imm)
1050                            .map(|c| c as i8 as f32)
1051                            .unwrap_or(f32::NAN)
1052                    }
1053                }
1054                RegOp::MinRegImm(out, arg, imm) => {
1055                    for i in 0..size {
1056                        v[out][i] = if v[arg][i].is_nan() || imm.is_nan() {
1057                            f32::NAN
1058                        } else {
1059                            v[arg][i].min(imm)
1060                        };
1061                    }
1062                }
1063                RegOp::MaxRegImm(out, arg, imm) => {
1064                    for i in 0..size {
1065                        v[out][i] = if v[arg][i].is_nan() || imm.is_nan() {
1066                            f32::NAN
1067                        } else {
1068                            v[arg][i].max(imm)
1069                        };
1070                    }
1071                }
1072                RegOp::AndRegImm(out, arg, imm) => {
1073                    for i in 0..size {
1074                        v[out][i] =
1075                            if v[arg][i] == 0.0 { v[arg][i] } else { imm };
1076                    }
1077                }
1078                RegOp::OrRegImm(out, arg, imm) => {
1079                    for i in 0..size {
1080                        v[out][i] =
1081                            if v[arg][i] != 0.0 { v[arg][i] } else { imm };
1082                    }
1083                }
1084                RegOp::ModRegReg(out, lhs, rhs) => {
1085                    for i in 0..size {
1086                        v[out][i] = v[lhs][i].rem_euclid(v[rhs][i]);
1087                    }
1088                }
1089                RegOp::ModRegImm(out, arg, imm) => {
1090                    for i in 0..size {
1091                        v[out][i] = v[arg][i].rem_euclid(imm);
1092                    }
1093                }
1094                RegOp::ModImmReg(out, arg, imm) => {
1095                    for i in 0..size {
1096                        v[out][i] = imm.rem_euclid(v[arg][i]);
1097                    }
1098                }
1099                RegOp::AddRegReg(out, lhs, rhs) => {
1100                    for i in 0..size {
1101                        v[out][i] = v[lhs][i] + v[rhs][i];
1102                    }
1103                }
1104                RegOp::MulRegReg(out, lhs, rhs) => {
1105                    for i in 0..size {
1106                        v[out][i] = v[lhs][i] * v[rhs][i];
1107                    }
1108                }
1109                RegOp::DivRegReg(out, lhs, rhs) => {
1110                    for i in 0..size {
1111                        v[out][i] = v[lhs][i] / v[rhs][i];
1112                    }
1113                }
1114                RegOp::SubRegReg(out, lhs, rhs) => {
1115                    for i in 0..size {
1116                        v[out][i] = v[lhs][i] - v[rhs][i];
1117                    }
1118                }
1119                RegOp::CompareRegReg(out, lhs, rhs) => {
1120                    for i in 0..size {
1121                        v[out][i] = v[lhs][i]
1122                            .partial_cmp(&v[rhs][i])
1123                            .map(|c| c as i8 as f32)
1124                            .unwrap_or(f32::NAN)
1125                    }
1126                }
1127                RegOp::MinRegReg(out, lhs, rhs) => {
1128                    for i in 0..size {
1129                        v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan()
1130                        {
1131                            f32::NAN
1132                        } else {
1133                            v[lhs][i].min(v[rhs][i])
1134                        };
1135                    }
1136                }
1137                RegOp::MaxRegReg(out, lhs, rhs) => {
1138                    for i in 0..size {
1139                        v[out][i] = if v[lhs][i].is_nan() || v[rhs][i].is_nan()
1140                        {
1141                            f32::NAN
1142                        } else {
1143                            v[lhs][i].max(v[rhs][i])
1144                        };
1145                    }
1146                }
1147                RegOp::AndRegReg(out, lhs, rhs) => {
1148                    for i in 0..size {
1149                        v[out][i] = if v[lhs][i] == 0.0 {
1150                            v[lhs][i]
1151                        } else {
1152                            v[rhs][i]
1153                        };
1154                    }
1155                }
1156                RegOp::OrRegReg(out, lhs, rhs) => {
1157                    for i in 0..size {
1158                        v[out][i] = if v[lhs][i] != 0.0 {
1159                            v[lhs][i]
1160                        } else {
1161                            v[rhs][i]
1162                        };
1163                    }
1164                }
1165                RegOp::CopyImm(out, imm) => {
1166                    for i in 0..size {
1167                        v[out][i] = imm;
1168                    }
1169                }
1170                RegOp::Load(out, mem) => {
1171                    for i in 0..size {
1172                        v[out][i] = v[mem][i];
1173                    }
1174                }
1175                RegOp::Store(out, mem) => {
1176                    for i in 0..size {
1177                        v[mem][i] = v[out][i];
1178                    }
1179                }
1180            }
1181        }
1182        Ok(BulkOutput::new(&self.0.out, size))
1183    }
1184}
1185
1186/// VM-based bulk evaluator for arrays of points, yielding gradient values
1187#[derive(Default)]
1188pub struct VmGradSliceEval<const N: usize>(BulkVmEval<Grad>);
1189impl<const N: usize> BulkEvaluator for VmGradSliceEval<N> {
1190    type Data = Grad;
1191    type Tape = GenericVmTape<N>;
1192    type TapeStorage = EmptyTapeStorage;
1193
1194    #[inline]
1195    fn eval<V: std::ops::Deref<Target = [Self::Data]>>(
1196        &mut self,
1197        tape: &Self::Tape,
1198        vars: &[V],
1199    ) -> Result<BulkOutput<'_, Grad>, Error> {
1200        tape.vars().check_bulk_arguments(vars)?;
1201        let tape = tape.data();
1202        let size = vars.first().map(|v| v.len()).unwrap_or(0);
1203        self.0.resize_slots(tape, size);
1204
1205        let mut v = SlotArray(&mut self.0.slots);
1206        for op in tape.iter_asm() {
1207            match op {
1208                RegOp::Output(arg, i) => {
1209                    self.0.out[i as usize][0..size]
1210                        .copy_from_slice(&v[arg][0..size]);
1211                }
1212                RegOp::Input(out, i) => {
1213                    v[out][0..size].copy_from_slice(&vars[i as usize]);
1214                }
1215                RegOp::NegReg(out, arg) => {
1216                    for i in 0..size {
1217                        v[out][i] = -v[arg][i];
1218                    }
1219                }
1220                RegOp::AbsReg(out, arg) => {
1221                    for i in 0..size {
1222                        v[out][i] = v[arg][i].abs();
1223                    }
1224                }
1225                RegOp::RecipReg(out, arg) => {
1226                    let one: Grad = 1.0.into();
1227                    for i in 0..size {
1228                        v[out][i] = one / v[arg][i];
1229                    }
1230                }
1231                RegOp::SqrtReg(out, arg) => {
1232                    for i in 0..size {
1233                        v[out][i] = v[arg][i].sqrt();
1234                    }
1235                }
1236                RegOp::SquareReg(out, arg) => {
1237                    for i in 0..size {
1238                        let s = v[arg][i];
1239                        v[out][i] = s * s;
1240                    }
1241                }
1242                RegOp::FloorReg(out, arg) => {
1243                    for i in 0..size {
1244                        v[out][i] = v[arg][i].floor();
1245                    }
1246                }
1247                RegOp::CeilReg(out, arg) => {
1248                    for i in 0..size {
1249                        v[out][i] = v[arg][i].ceil();
1250                    }
1251                }
1252                RegOp::RoundReg(out, arg) => {
1253                    for i in 0..size {
1254                        v[out][i] = v[arg][i].round();
1255                    }
1256                }
1257                RegOp::SinReg(out, arg) => {
1258                    for i in 0..size {
1259                        v[out][i] = v[arg][i].sin();
1260                    }
1261                }
1262                RegOp::CosReg(out, arg) => {
1263                    for i in 0..size {
1264                        v[out][i] = v[arg][i].cos();
1265                    }
1266                }
1267                RegOp::TanReg(out, arg) => {
1268                    for i in 0..size {
1269                        v[out][i] = v[arg][i].tan();
1270                    }
1271                }
1272                RegOp::AsinReg(out, arg) => {
1273                    for i in 0..size {
1274                        v[out][i] = v[arg][i].asin();
1275                    }
1276                }
1277                RegOp::AcosReg(out, arg) => {
1278                    for i in 0..size {
1279                        v[out][i] = v[arg][i].acos();
1280                    }
1281                }
1282                RegOp::AtanReg(out, arg) => {
1283                    for i in 0..size {
1284                        v[out][i] = v[arg][i].atan();
1285                    }
1286                }
1287                RegOp::ExpReg(out, arg) => {
1288                    for i in 0..size {
1289                        v[out][i] = v[arg][i].exp();
1290                    }
1291                }
1292                RegOp::LnReg(out, arg) => {
1293                    for i in 0..size {
1294                        v[out][i] = v[arg][i].ln();
1295                    }
1296                }
1297                RegOp::NotReg(out, arg) => {
1298                    for i in 0..size {
1299                        v[out][i] = f32::from(v[arg][i].v == 0.0).into();
1300                    }
1301                }
1302                RegOp::CopyReg(out, arg) => {
1303                    for i in 0..size {
1304                        v[out][i] = v[arg][i];
1305                    }
1306                }
1307                RegOp::AddRegImm(out, arg, imm) => {
1308                    for i in 0..size {
1309                        v[out][i] = v[arg][i] + imm.into();
1310                    }
1311                }
1312                RegOp::MulRegImm(out, arg, imm) => {
1313                    for i in 0..size {
1314                        v[out][i] = v[arg][i] * imm;
1315                    }
1316                }
1317                RegOp::DivRegImm(out, arg, imm) => {
1318                    for i in 0..size {
1319                        v[out][i] = v[arg][i] / imm.into();
1320                    }
1321                }
1322                RegOp::DivImmReg(out, arg, imm) => {
1323                    let imm = Grad::from(imm);
1324                    for i in 0..size {
1325                        v[out][i] = imm / v[arg][i];
1326                    }
1327                }
1328                RegOp::AtanRegImm(out, arg, imm) => {
1329                    let imm = Grad::from(imm);
1330                    for i in 0..size {
1331                        v[out][i] = v[arg][i].atan2(imm);
1332                    }
1333                }
1334                RegOp::AtanImmReg(out, arg, imm) => {
1335                    let imm = Grad::from(imm);
1336                    for i in 0..size {
1337                        v[out][i] = imm.atan2(v[arg][i]);
1338                    }
1339                }
1340                RegOp::AtanRegReg(out, lhs, rhs) => {
1341                    for i in 0..size {
1342                        v[out][i] = v[lhs][i].atan2(v[rhs][i]);
1343                    }
1344                }
1345                RegOp::SubImmReg(out, arg, imm) => {
1346                    let imm: Grad = imm.into();
1347                    for i in 0..size {
1348                        v[out][i] = imm - v[arg][i];
1349                    }
1350                }
1351                RegOp::SubRegImm(out, arg, imm) => {
1352                    let imm: Grad = imm.into();
1353                    for i in 0..size {
1354                        v[out][i] = v[arg][i] - imm;
1355                    }
1356                }
1357                RegOp::CompareImmReg(out, arg, imm) => {
1358                    for i in 0..size {
1359                        let p = imm
1360                            .partial_cmp(&v[arg][i].v)
1361                            .map(|c| c as i8 as f32)
1362                            .unwrap_or(f32::NAN);
1363                        v[out][i] = Grad::new(p, 0.0, 0.0, 0.0);
1364                    }
1365                }
1366                RegOp::CompareRegImm(out, arg, imm) => {
1367                    for i in 0..size {
1368                        let p = v[arg][i]
1369                            .v
1370                            .partial_cmp(&imm)
1371                            .map(|c| c as i8 as f32)
1372                            .unwrap_or(f32::NAN);
1373                        v[out][i] = Grad::new(p, 0.0, 0.0, 0.0);
1374                    }
1375                }
1376                RegOp::MinRegImm(out, arg, imm) => {
1377                    let imm: Grad = imm.into();
1378                    for i in 0..size {
1379                        v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() {
1380                            f32::NAN.into()
1381                        } else {
1382                            v[arg][i].min(imm)
1383                        };
1384                    }
1385                }
1386                RegOp::MaxRegImm(out, arg, imm) => {
1387                    let imm: Grad = imm.into();
1388                    for i in 0..size {
1389                        v[out][i] = if v[arg][i].v.is_nan() || imm.v.is_nan() {
1390                            f32::NAN.into()
1391                        } else {
1392                            v[arg][i].max(imm)
1393                        };
1394                    }
1395                }
1396                RegOp::ModRegReg(out, lhs, rhs) => {
1397                    for i in 0..size {
1398                        v[out][i] = v[lhs][i].rem_euclid(v[rhs][i]);
1399                    }
1400                }
1401                RegOp::ModRegImm(out, arg, imm) => {
1402                    for i in 0..size {
1403                        v[out][i] = v[arg][i].rem_euclid(imm.into());
1404                    }
1405                }
1406                RegOp::ModImmReg(out, arg, imm) => {
1407                    for i in 0..size {
1408                        v[out][i] = Grad::from(imm).rem_euclid(v[arg][i]);
1409                    }
1410                }
1411                RegOp::AddRegReg(out, lhs, rhs) => {
1412                    for i in 0..size {
1413                        v[out][i] = v[lhs][i] + v[rhs][i];
1414                    }
1415                }
1416                RegOp::MulRegReg(out, lhs, rhs) => {
1417                    for i in 0..size {
1418                        v[out][i] = v[lhs][i] * v[rhs][i];
1419                    }
1420                }
1421                RegOp::AndRegReg(out, lhs, rhs) => {
1422                    for i in 0..size {
1423                        v[out][i] = if v[lhs][i].v == 0.0 {
1424                            v[lhs][i]
1425                        } else {
1426                            v[rhs][i]
1427                        };
1428                    }
1429                }
1430                RegOp::AndRegImm(out, arg, imm) => {
1431                    for i in 0..size {
1432                        v[out][i] = if v[arg][i].v == 0.0 {
1433                            v[arg][i]
1434                        } else {
1435                            imm.into()
1436                        };
1437                    }
1438                }
1439                RegOp::OrRegReg(out, lhs, rhs) => {
1440                    for i in 0..size {
1441                        v[out][i] = if v[lhs][i].v != 0.0 {
1442                            v[lhs][i]
1443                        } else {
1444                            v[rhs][i]
1445                        };
1446                    }
1447                }
1448                RegOp::OrRegImm(out, arg, imm) => {
1449                    for i in 0..size {
1450                        v[out][i] = if v[arg][i].v != 0.0 {
1451                            v[arg][i]
1452                        } else {
1453                            imm.into()
1454                        };
1455                    }
1456                }
1457                RegOp::DivRegReg(out, lhs, rhs) => {
1458                    for i in 0..size {
1459                        v[out][i] = v[lhs][i] / v[rhs][i];
1460                    }
1461                }
1462                RegOp::SubRegReg(out, lhs, rhs) => {
1463                    for i in 0..size {
1464                        v[out][i] = v[lhs][i] - v[rhs][i];
1465                    }
1466                }
1467                RegOp::CompareRegReg(out, lhs, rhs) => {
1468                    for i in 0..size {
1469                        let p = v[lhs][i]
1470                            .v
1471                            .partial_cmp(&v[rhs][i].v)
1472                            .map(|c| c as i8 as f32)
1473                            .unwrap_or(f32::NAN);
1474                        v[out][i] = Grad::new(p, 0.0, 0.0, 0.0);
1475                    }
1476                }
1477                RegOp::MinRegReg(out, lhs, rhs) => {
1478                    for i in 0..size {
1479                        v[out][i] =
1480                            if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() {
1481                                f32::NAN.into()
1482                            } else {
1483                                v[lhs][i].min(v[rhs][i])
1484                            };
1485                    }
1486                }
1487                RegOp::MaxRegReg(out, lhs, rhs) => {
1488                    for i in 0..size {
1489                        v[out][i] =
1490                            if v[lhs][i].v.is_nan() || v[rhs][i].v.is_nan() {
1491                                f32::NAN.into()
1492                            } else {
1493                                v[lhs][i].max(v[rhs][i])
1494                            };
1495                    }
1496                }
1497                RegOp::CopyImm(out, imm) => {
1498                    let imm: Grad = imm.into();
1499                    for i in 0..size {
1500                        v[out][i] = imm;
1501                    }
1502                }
1503                RegOp::Load(out, mem) => {
1504                    for i in 0..size {
1505                        v[out][i] = v[mem][i];
1506                    }
1507                }
1508                RegOp::Store(out, mem) => {
1509                    for i in 0..size {
1510                        v[mem][i] = v[out][i];
1511                    }
1512                }
1513            }
1514        }
1515        Ok(BulkOutput::new(&self.0.out, size))
1516    }
1517}
1518
1519#[cfg(test)]
1520mod test {
1521    use super::*;
1522    crate::grad_slice_tests!(VmFunction);
1523    crate::interval_tests!(VmFunction);
1524    crate::float_slice_tests!(VmFunction);
1525    crate::point_tests!(VmFunction);
1526}