fidget_core/shape/
mod.rs

1//! Data structures for shape evaluation
2//!
3//! Types in this module are typically thin (generic) wrappers around objects
4//! that implement traits in [`fidget_core::eval`](crate::eval).  The wrapper types
5//! are specialized to operate on `x, y, z` arguments, rather than taking
6//! arbitrary numbers of variables.
7//!
8//! For example, a [`Shape`] is a wrapper which makes it easier to treat a
9//! [`Function`] as an implicit surface (with X, Y, Z axes and an optional
10//! transform matrix).
11//!
12//! ```rust
13//! use fidget_core::vm::VmShape;
14//! use fidget_core::context::Context;
15//! use fidget_core::shape::EzShape;
16//!
17//! let mut ctx = Context::new();
18//! let x = ctx.x();
19//! let shape = VmShape::new(&ctx, x)?;
20//!
21//! // Let's build a single point evaluator:
22//! let mut eval = VmShape::new_point_eval();
23//! let tape = shape.ez_point_tape();
24//! let (value, _trace) = eval.eval(&tape, 0.25, 0.0, 0.0)?;
25//! assert_eq!(value, 0.25);
26//! # Ok::<(), fidget_core::Error>(())
27//! ```
28
29use crate::{
30    Error,
31    context::{Context, Node, Tree},
32    eval::{BulkEvaluator, Function, MathFunction, Tape, TracingEvaluator},
33    types::{Grad, Interval},
34    var::{Var, VarIndex, VarMap},
35};
36use nalgebra::{Matrix4, Point3};
37use std::collections::HashMap;
38
39/// A shape represents an implicit surface
40///
41/// It is mostly agnostic to _how_ that surface is represented, wrapping a
42/// [`Function`] and a set of axes.
43///
44/// Shapes are shared between threads, so they should be cheap to clone.  In
45/// most cases, they're a thin wrapper around an `Arc<..>`.
46///
47/// At construction, a shape has no associated transformation.  A transformation
48/// matrix can be applied by calling [`Shape::with_transform`].
49///
50/// The shape's transformation matrix is propagated into tapes (constructed by
51/// `*_tape` functions), which use the matrix to transform incoming coordinates
52/// during evaluation.
53///
54/// Note that `with_transform` returns a `Shape` with [`Transformed`] as the
55/// second template parameter; to preserve immutability, the marker prevents
56/// further mutation of the transform.
57pub struct Shape<F, T = ()> {
58    /// Wrapped function
59    f: F,
60
61    /// Variables representing x, y, z axes
62    axes: [Var; 3],
63
64    /// Optional transform to apply to the shape
65    ///
66    /// This may only be `Some(..)` if `T` is `Transformed` (enforced at
67    /// compilation time)
68    transform: Option<Matrix4<f32>>,
69
70    _marker: std::marker::PhantomData<T>,
71}
72
73impl<F: Clone, T> Clone for Shape<F, T> {
74    fn clone(&self) -> Self {
75        Self {
76            f: self.f.clone(),
77            axes: self.axes,
78            transform: self.transform,
79            _marker: std::marker::PhantomData,
80        }
81    }
82}
83
84impl<F: Function + Clone, T> Shape<F, T> {
85    /// Builds a new point evaluator
86    pub fn new_point_eval() -> ShapeTracingEval<F::PointEval> {
87        ShapeTracingEval {
88            eval: F::PointEval::default(),
89            scratch: vec![],
90        }
91    }
92
93    /// Builds a new interval evaluator
94    pub fn new_interval_eval() -> ShapeTracingEval<F::IntervalEval> {
95        ShapeTracingEval {
96            eval: F::IntervalEval::default(),
97            scratch: vec![],
98        }
99    }
100
101    /// Builds a new float slice evaluator
102    pub fn new_float_slice_eval() -> ShapeBulkEval<F::FloatSliceEval> {
103        ShapeBulkEval {
104            eval: F::FloatSliceEval::default(),
105            scratch: vec![],
106        }
107    }
108
109    /// Builds a new gradient slice evaluator
110    pub fn new_grad_slice_eval() -> ShapeBulkEval<F::GradSliceEval> {
111        ShapeBulkEval {
112            eval: F::GradSliceEval::default(),
113            scratch: vec![],
114        }
115    }
116
117    /// Returns an evaluation tape for a point evaluator
118    #[inline]
119    pub fn point_tape(
120        &self,
121        storage: F::TapeStorage,
122    ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
123        let tape = self.f.point_tape(storage);
124        let vars = tape.vars();
125        let axes = self.axes.map(|v| vars.get(&v));
126        ShapeTape {
127            tape,
128            axes,
129            transform: self.transform,
130        }
131    }
132
133    /// Returns an evaluation tape for a interval evaluator
134    #[inline]
135    pub fn interval_tape(
136        &self,
137        storage: F::TapeStorage,
138    ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
139        let tape = self.f.interval_tape(storage);
140        let vars = tape.vars();
141        let axes = self.axes.map(|v| vars.get(&v));
142        ShapeTape {
143            tape,
144            axes,
145            transform: self.transform,
146        }
147    }
148
149    /// Returns an evaluation tape for a float slice evaluator
150    #[inline]
151    pub fn float_slice_tape(
152        &self,
153        storage: F::TapeStorage,
154    ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
155        let tape = self.f.float_slice_tape(storage);
156        let vars = tape.vars();
157        let axes = self.axes.map(|v| vars.get(&v));
158        ShapeTape {
159            tape,
160            axes,
161            transform: self.transform,
162        }
163    }
164
165    /// Returns an evaluation tape for a gradient slice evaluator
166    #[inline]
167    pub fn grad_slice_tape(
168        &self,
169        storage: F::TapeStorage,
170    ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
171        let tape = self.f.grad_slice_tape(storage);
172        let vars = tape.vars();
173        let axes = self.axes.map(|v| vars.get(&v));
174        ShapeTape {
175            tape,
176            axes,
177            transform: self.transform,
178        }
179    }
180
181    /// Computes a simplified tape using the given trace, and reusing storage
182    #[inline]
183    pub fn simplify(
184        &self,
185        trace: &F::Trace,
186        storage: F::Storage,
187        workspace: &mut F::Workspace,
188    ) -> Result<Self, Error>
189    where
190        Self: Sized,
191    {
192        let f = self.f.simplify(trace, storage, workspace)?;
193        Ok(Self {
194            f,
195            axes: self.axes,
196            transform: self.transform,
197            _marker: std::marker::PhantomData,
198        })
199    }
200
201    /// Attempt to reclaim storage from this shape
202    ///
203    /// This may fail, because shapes are `Clone` and are often implemented
204    /// using an `Arc` around a heavier data structure.
205    #[inline]
206    pub fn recycle(self) -> Option<F::Storage> {
207        self.f.recycle()
208    }
209
210    /// Returns a size associated with this shape
211    ///
212    /// This is underspecified and only used for unit testing; for tape-based
213    /// shapes, it's typically the length of the tape,
214    #[inline]
215    pub fn size(&self) -> usize {
216        self.f.size()
217    }
218}
219
220impl<F, T> Shape<F, T> {
221    /// Borrows the inner [`Function`] object
222    pub fn inner(&self) -> &F {
223        &self.f
224    }
225
226    /// Borrows the inner axis mapping
227    pub fn axes(&self) -> &[Var; 3] {
228        &self.axes
229    }
230
231    /// Raw constructor
232    pub fn new_raw(f: F, axes: [Var; 3]) -> Self {
233        Self {
234            f,
235            axes,
236            transform: None,
237            _marker: std::marker::PhantomData,
238        }
239    }
240}
241
242/// Marker struct indicating that a shape has a transform applied
243pub struct Transformed;
244
245impl<F: Clone> Shape<F, ()> {
246    /// Returns a shape with the given transform applied
247    pub fn with_transform(&self, mat: Matrix4<f32>) -> Shape<F, Transformed> {
248        Shape {
249            f: self.f.clone(),
250            axes: self.axes,
251            transform: Some(mat),
252            _marker: std::marker::PhantomData,
253        }
254    }
255}
256
257/// Variables bound to values for shape evaluation
258///
259/// Note that this cannot store `X`, `Y`, `Z` variables (which are passed in as
260/// first-class arguments); it only stores [`Var::V`] values (identified by
261/// their inner [`VarIndex`]).
262pub struct ShapeVars<F>(HashMap<VarIndex, F>);
263
264impl<F> Default for ShapeVars<F> {
265    fn default() -> Self {
266        Self(HashMap::default())
267    }
268}
269
270impl<F> ShapeVars<F> {
271    /// Builds a new, empty variable set
272    pub fn new() -> Self {
273        Self(HashMap::default())
274    }
275    /// Returns the number of variables stored in the set
276    pub fn len(&self) -> usize {
277        self.0.len()
278    }
279    /// Checks whether the variable set is empty
280    pub fn is_empty(&self) -> bool {
281        self.0.is_empty()
282    }
283    /// Inserts a new variable
284    ///
285    /// Returns the previous value (if present)
286    pub fn insert(&mut self, v: VarIndex, f: F) -> Option<F> {
287        self.0.insert(v, f)
288    }
289
290    /// Iterates over values
291    pub fn values(&self) -> impl Iterator<Item = &F> {
292        self.0.values()
293    }
294}
295
296impl<'a, F> IntoIterator for &'a ShapeVars<F> {
297    type Item = (&'a VarIndex, &'a F);
298    type IntoIter = std::collections::hash_map::Iter<'a, VarIndex, F>;
299    fn into_iter(self) -> Self::IntoIter {
300        self.0.iter()
301    }
302}
303
304/// Extension trait for working with a shape without thinking much about memory
305///
306/// All of the [`Shape`] functions that use significant amounts of memory
307/// pedantically require you to pass in storage for reuse.  This trait allows
308/// you to ignore that, at the cost of performance; we require that all storage
309/// types implement [`Default`], so these functions do the boilerplate for you.
310///
311/// This trait is automatically implemented for every [`Shape`], but must be
312/// imported separately as a speed-bump to using it everywhere.
313pub trait EzShape<F: Function> {
314    /// Returns an evaluation tape for a point evaluator
315    fn ez_point_tape(
316        &self,
317    ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape>;
318
319    /// Returns an evaluation tape for an interval evaluator
320    fn ez_interval_tape(
321        &self,
322    ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape>;
323
324    /// Returns an evaluation tape for a float slice evaluator
325    fn ez_float_slice_tape(
326        &self,
327    ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape>;
328
329    /// Returns an evaluation tape for a float slice evaluator
330    fn ez_grad_slice_tape(
331        &self,
332    ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape>;
333
334    /// Computes a simplified tape using the given trace
335    fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error>
336    where
337        Self: Sized;
338}
339
340impl<F: Function, T> EzShape<F> for Shape<F, T> {
341    fn ez_point_tape(
342        &self,
343    ) -> ShapeTape<<F::PointEval as TracingEvaluator>::Tape> {
344        self.point_tape(Default::default())
345    }
346
347    fn ez_interval_tape(
348        &self,
349    ) -> ShapeTape<<F::IntervalEval as TracingEvaluator>::Tape> {
350        self.interval_tape(Default::default())
351    }
352
353    fn ez_float_slice_tape(
354        &self,
355    ) -> ShapeTape<<F::FloatSliceEval as BulkEvaluator>::Tape> {
356        self.float_slice_tape(Default::default())
357    }
358
359    fn ez_grad_slice_tape(
360        &self,
361    ) -> ShapeTape<<F::GradSliceEval as BulkEvaluator>::Tape> {
362        self.grad_slice_tape(Default::default())
363    }
364
365    fn ez_simplify(&self, trace: &F::Trace) -> Result<Self, Error> {
366        let mut workspace = Default::default();
367        self.simplify(trace, Default::default(), &mut workspace)
368    }
369}
370
371impl<F: MathFunction> Shape<F> {
372    /// Builds a new shape from a math expression with the given axes
373    pub fn new_with_axes(
374        ctx: &Context,
375        node: Node,
376        axes: [Var; 3],
377    ) -> Result<Self, Error> {
378        let f = F::new(ctx, &[node])?;
379        Ok(Self {
380            f,
381            axes,
382            transform: None,
383            _marker: std::marker::PhantomData,
384        })
385    }
386
387    /// Builds a new shape from the given node with default (X, Y, Z) axes
388    pub fn new(ctx: &Context, node: Node) -> Result<Self, Error>
389    where
390        Self: Sized,
391    {
392        Self::new_with_axes(ctx, node, [Var::X, Var::Y, Var::Z])
393    }
394}
395
396/// Converts a [`Tree`] to a [`Shape`] with the default axes
397impl<F: MathFunction> From<Tree> for Shape<F> {
398    fn from(t: Tree) -> Self {
399        let mut ctx = Context::new();
400        let node = ctx.import(&t);
401        Self::new(&ctx, node).unwrap()
402    }
403}
404
405/// Wrapper around a function tape, with axes and an optional transform matrix
406#[derive(Clone)]
407pub struct ShapeTape<T> {
408    tape: T,
409
410    /// Index of the X, Y, Z axes in the variables array
411    axes: [Option<usize>; 3],
412
413    /// Optional transform
414    transform: Option<Matrix4<f32>>,
415}
416
417impl<T: Tape> ShapeTape<T> {
418    /// Recycles the inner tape's storage for reuse
419    pub fn recycle(self) -> Option<T::Storage> {
420        self.tape.recycle()
421    }
422
423    /// Returns a mapping from [`Var`] to evaluation index
424    pub fn vars(&self) -> &VarMap {
425        self.tape.vars()
426    }
427}
428
429/// Wrapper around a [`TracingEvaluator`]
430///
431/// Unlike the raw tracing evaluator, a [`ShapeTracingEval`] knows about the
432/// tape's X, Y, Z axes and optional transform matrix.
433#[derive(Debug)]
434pub struct ShapeTracingEval<E: TracingEvaluator> {
435    eval: E,
436    scratch: Vec<E::Data>,
437}
438
439impl<E: TracingEvaluator> Default for ShapeTracingEval<E> {
440    fn default() -> Self {
441        Self {
442            eval: E::default(),
443            scratch: vec![],
444        }
445    }
446}
447
448impl<E: TracingEvaluator> ShapeTracingEval<E>
449where
450    <E as TracingEvaluator>::Data: Transformable,
451{
452    /// Tracing evaluation of the given tape with X, Y, Z input arguments
453    ///
454    /// Before evaluation, the tape's transform matrix is applied (if present).
455    ///
456    /// If the tape has other variables, [`eval_v`](Self::eval_v) should be
457    /// called instead (and this function will return an error.
458    #[inline]
459    pub fn eval<F: Into<E::Data> + Copy>(
460        &mut self,
461        tape: &ShapeTape<E::Tape>,
462        x: F,
463        y: F,
464        z: F,
465    ) -> Result<(E::Data, Option<&E::Trace>), Error> {
466        let h = ShapeVars::<f32>::new();
467        self.eval_v(tape, x, y, z, &h)
468    }
469
470    /// Tracing evaluation of a single sample
471    ///
472    /// Before evaluation, the tape's transform matrix is applied (if present).
473    #[inline]
474    pub fn eval_v<F: Into<E::Data> + Copy, V: Into<E::Data> + Copy>(
475        &mut self,
476        tape: &ShapeTape<E::Tape>,
477        x: F,
478        y: F,
479        z: F,
480        vars: &ShapeVars<V>,
481    ) -> Result<(E::Data, Option<&E::Trace>), Error> {
482        assert_eq!(
483            tape.tape.output_count(),
484            1,
485            "ShapeTape has multiple outputs"
486        );
487
488        let x = x.into();
489        let y = y.into();
490        let z = z.into();
491        let (x, y, z) = if let Some(mat) = tape.transform {
492            Transformable::transform(x, y, z, mat)
493        } else {
494            (x, y, z)
495        };
496
497        let vs = tape.vars();
498        let expected_vars = vs.len()
499            - vs.get(&Var::X).is_some() as usize
500            - vs.get(&Var::Y).is_some() as usize
501            - vs.get(&Var::Z).is_some() as usize;
502        if expected_vars != vars.len() {
503            return Err(Error::BadVarSlice(vars.len(), expected_vars));
504        }
505
506        self.scratch.resize(tape.vars().len(), 0f32.into());
507        if let Some(a) = tape.axes[0] {
508            self.scratch[a] = x;
509        }
510        if let Some(b) = tape.axes[1] {
511            self.scratch[b] = y;
512        }
513        if let Some(c) = tape.axes[2] {
514            self.scratch[c] = z;
515        }
516        for (var, value) in vars {
517            if let Some(i) = vs.get(&Var::V(*var)) {
518                if i < self.scratch.len() {
519                    self.scratch[i] = (*value).into();
520                } else {
521                    return Err(Error::BadVarIndex(i, self.scratch.len()));
522                }
523            } else {
524                // Passing in Bonus Variables is allowed (for now)
525            }
526        }
527
528        let (out, trace) = self.eval.eval(&tape.tape, &self.scratch)?;
529        Ok((out[0], trace))
530    }
531}
532
533/// Wrapper around a [`BulkEvaluator`]
534///
535/// Unlike the raw bulk evaluator, a [`ShapeBulkEval`] knows about the
536/// tape's X, Y, Z axes and optional transform matrix.
537#[derive(Debug, Default)]
538pub struct ShapeBulkEval<E: BulkEvaluator> {
539    eval: E,
540    scratch: Vec<Vec<E::Data>>,
541}
542
543impl<E: BulkEvaluator> ShapeBulkEval<E>
544where
545    E::Data: From<f32> + Transformable,
546{
547    /// Bulk evaluation of many samples, without any variables
548    ///
549    /// If the shape includes variables other than `X`, `Y`, `Z`,
550    /// [`eval_v`](Self::eval_v) or [`eval_vs`](Self::eval_vs) should be used
551    /// instead (and this function will return an error).
552    ///
553    /// Before evaluation, the tape's transform matrix is applied (if present).
554    #[inline]
555    pub fn eval(
556        &mut self,
557        tape: &ShapeTape<E::Tape>,
558        x: &[E::Data],
559        y: &[E::Data],
560        z: &[E::Data],
561    ) -> Result<&[E::Data], Error> {
562        let h: ShapeVars<&[E::Data]> = ShapeVars::new();
563        self.eval_vs(tape, x, y, z, &h)
564    }
565
566    /// Helper function to do common setup
567    #[inline]
568    fn setup<V>(
569        &mut self,
570        tape: &ShapeTape<E::Tape>,
571        x: &[E::Data],
572        y: &[E::Data],
573        z: &[E::Data],
574        vars: &ShapeVars<V>,
575    ) -> Result<usize, Error> {
576        assert_eq!(
577            tape.tape.output_count(),
578            1,
579            "ShapeTape has multiple outputs"
580        );
581
582        // Make sure our scratch arrays are big enough for this evaluation
583        if x.len() != y.len() || x.len() != z.len() {
584            return Err(Error::MismatchedSlices);
585        }
586        let n = x.len();
587
588        let vs = tape.vars();
589        let expected_vars = vs.len()
590            - vs.get(&Var::X).is_some() as usize
591            - vs.get(&Var::Y).is_some() as usize
592            - vs.get(&Var::Z).is_some() as usize;
593        if expected_vars != vars.len() {
594            return Err(Error::BadVarSlice(vars.len(), expected_vars));
595        }
596
597        // We need at least one item in the scratch array to set evaluation
598        // size; otherwise, evaluating a single constant will return []
599        self.scratch.resize_with(vs.len().max(1), Vec::new);
600        for s in &mut self.scratch {
601            s.resize(n, 0.0.into());
602        }
603
604        if let Some(mat) = tape.transform {
605            for i in 0..n {
606                let (x, y, z) = Transformable::transform(x[i], y[i], z[i], mat);
607                if let Some(a) = tape.axes[0] {
608                    self.scratch[a][i] = x;
609                }
610                if let Some(b) = tape.axes[1] {
611                    self.scratch[b][i] = y;
612                }
613                if let Some(c) = tape.axes[2] {
614                    self.scratch[c][i] = z;
615                }
616            }
617        } else {
618            if let Some(a) = tape.axes[0] {
619                self.scratch[a].copy_from_slice(x);
620            }
621            if let Some(b) = tape.axes[1] {
622                self.scratch[b].copy_from_slice(y);
623            }
624            if let Some(c) = tape.axes[2] {
625                self.scratch[c].copy_from_slice(z);
626            }
627            // TODO fast path if there are no extra vars, reusing slices
628        };
629
630        Ok(n)
631    }
632    /// Bulk evaluation of many samples, with slices of variables
633    ///
634    /// Each variable is a slice (or `Vec`) of values, which must be the same
635    /// length as the `x`, `y`, `z` slices.  This is in contrast with
636    /// [`eval_vs`](Self::eval_v), where variables have a single value used for
637    /// every position in the `x`, `y,` `z` slices.
638    ///
639    ///
640    /// Before evaluation, the tape's transform matrix is applied (if present).
641    #[inline]
642    pub fn eval_vs<
643        V: std::ops::Deref<Target = [G]>,
644        G: Into<E::Data> + Copy,
645    >(
646        &mut self,
647        tape: &ShapeTape<E::Tape>,
648        x: &[E::Data],
649        y: &[E::Data],
650        z: &[E::Data],
651        vars: &ShapeVars<V>,
652    ) -> Result<&[E::Data], Error> {
653        let n = self.setup(tape, x, y, z, vars)?;
654
655        if vars.values().any(|vs| vs.len() != n) {
656            return Err(Error::MismatchedSlices);
657        }
658
659        let vs = tape.vars();
660        for (var, value) in vars {
661            if let Some(i) = vs.get(&Var::V(*var)) {
662                if i < self.scratch.len() {
663                    for (a, b) in
664                        self.scratch[i].iter_mut().zip(value.deref().iter())
665                    {
666                        *a = (*b).into();
667                    }
668                    // TODO fast path if we can use the slices directly?
669                } else {
670                    return Err(Error::BadVarIndex(i, self.scratch.len()));
671                }
672            } else {
673                // Passing in Bonus Variables is allowed (for now)
674            }
675        }
676
677        let out = self.eval.eval(&tape.tape, &self.scratch)?;
678        Ok(out.borrow(0))
679    }
680
681    /// Bulk evaluation of many samples, with fixed variables
682    ///
683    /// Each variable has a single value, which is used for every position in
684    /// the `x`, `y`, `z` slices.  This is in contrast with
685    /// [`eval_vs`](Self::eval_vs), where variables can be different for every
686    /// position in the `x`, `y,` `z` slices.
687    ///
688    /// Before evaluation, the tape's transform matrix is applied (if present).
689    #[inline]
690    pub fn eval_v<G: Into<E::Data> + Copy>(
691        &mut self,
692        tape: &ShapeTape<E::Tape>,
693        x: &[E::Data],
694        y: &[E::Data],
695        z: &[E::Data],
696        vars: &ShapeVars<G>,
697    ) -> Result<&[E::Data], Error> {
698        self.setup(tape, x, y, z, vars)?;
699        let vs = tape.vars();
700        for (var, value) in vars {
701            if let Some(i) = vs.get(&Var::V(*var)) {
702                if i < self.scratch.len() {
703                    self.scratch[i].fill((*value).into());
704                } else {
705                    return Err(Error::BadVarIndex(i, self.scratch.len()));
706                }
707            } else {
708                // Passing in Bonus Variables is allowed (for now)
709            }
710        }
711
712        let out = self.eval.eval(&tape.tape, &self.scratch)?;
713        Ok(out.borrow(0))
714    }
715}
716
717/// Trait for types that can be transformed by a 4x4 homogeneous transform matrix
718pub trait Transformable {
719    /// Apply the given transform to an `(x, y, z)` position
720    fn transform(
721        x: Self,
722        y: Self,
723        z: Self,
724        mat: Matrix4<f32>,
725    ) -> (Self, Self, Self)
726    where
727        Self: Sized;
728}
729
730impl Transformable for f32 {
731    fn transform(x: f32, y: f32, z: f32, mat: Matrix4<f32>) -> (f32, f32, f32) {
732        let out = mat.transform_point(&Point3::new(x, y, z));
733        (out.x, out.y, out.z)
734    }
735}
736
737impl Transformable for Interval {
738    fn transform(
739        x: Interval,
740        y: Interval,
741        z: Interval,
742        mat: Matrix4<f32>,
743    ) -> (Interval, Interval, Interval) {
744        let out = [0, 1, 2, 3].map(|i| {
745            let row = mat.row(i);
746            x * row[0] + y * row[1] + z * row[2] + Interval::from(row[3])
747        });
748
749        (out[0] / out[3], out[1] / out[3], out[2] / out[3])
750    }
751}
752
753impl Transformable for Grad {
754    fn transform(
755        x: Grad,
756        y: Grad,
757        z: Grad,
758        mat: Matrix4<f32>,
759    ) -> (Grad, Grad, Grad) {
760        let out = [0, 1, 2, 3].map(|i| {
761            let row = mat.row(i);
762            x * row[0] + y * row[1] + z * row[2] + Grad::from(row[3])
763        });
764
765        (out[0] / out[3], out[1] / out[3], out[2] / out[3])
766    }
767}
768
769#[cfg(test)]
770mod test {
771    use super::*;
772    use crate::vm::VmShape;
773
774    #[test]
775    fn shape_vars() {
776        let v = Var::new();
777        let s = Tree::x() + Tree::y() + v;
778
779        let mut ctx = Context::new();
780        let s = ctx.import(&s);
781
782        let s = VmShape::new(&ctx, s).unwrap();
783        let vs = s.inner().vars();
784        assert_eq!(vs.len(), 3);
785
786        assert!(vs.get(&Var::X).is_some());
787        assert!(vs.get(&Var::Y).is_some());
788        assert!(vs.get(&Var::Z).is_none());
789        assert!(vs.get(&v).is_some());
790
791        let mut seen = [false; 3];
792        for v in [Var::X, Var::Y, v] {
793            seen[vs[&v]] = true;
794        }
795        assert!(seen.iter().all(|i| *i));
796    }
797
798    #[test]
799    fn shape_eval_bulk_size() {
800        let s = Tree::constant(1.0);
801        let mut ctx = Context::new();
802        let s = ctx.import(&s);
803
804        let s = VmShape::new(&ctx, s).unwrap();
805        let tape = s.ez_float_slice_tape();
806        let mut eval = VmShape::new_float_slice_eval();
807        let out = eval
808            .eval_v::<f32>(
809                &tape,
810                &[1.0, 2.0, 3.0],
811                &[4.0, 5.0, 6.0],
812                &[7.0, 8.0, 9.0],
813                &ShapeVars::default(),
814            )
815            .unwrap();
816        assert_eq!(out, [1.0, 1.0, 1.0]);
817    }
818}