fidget_core/context/
tree.rs

1//! Context-free math trees
2use super::op::{BinaryOpcode, UnaryOpcode};
3use crate::{Error, var::Var};
4use std::{cmp::Ordering, sync::Arc};
5
6/// Opcode type for trees
7///
8/// This is equivalent to [`Op`](crate::context::Op), but also includes the
9/// [`RemapAxes`](TreeOp::RemapAxes) and [`TreeOp::RemapAffine`] operations for
10/// lazy remapping.
11#[derive(Debug)]
12#[allow(missing_docs)]
13pub enum TreeOp {
14    /// Input (an arbitrary [`Var`])
15    Input(Var),
16    Const(f64),
17    Binary(BinaryOpcode, Arc<TreeOp>, Arc<TreeOp>),
18    Unary(UnaryOpcode, Arc<TreeOp>),
19    /// Lazy remapping of trees
20    ///
21    /// When imported into a `Context`, all `x/y/z` clauses within `target` will
22    /// be replaced with the provided `x/y/z` trees.
23    ///
24    /// If the transform is affine, then `RemapAffine` should be preferred,
25    /// because it flattens sequences of affine transformations.
26    RemapAxes {
27        target: Arc<TreeOp>,
28        x: Arc<TreeOp>,
29        y: Arc<TreeOp>,
30        z: Arc<TreeOp>,
31    },
32    /// Lazy affine transforms
33    ///
34    /// When imported into a `Context`, the `x/y/z` clauses within `target` will
35    /// be transformed with the provided affine matrix.
36    RemapAffine {
37        target: Arc<TreeOp>,
38        mat: nalgebra::Affine3<f64>,
39    },
40}
41
42impl Drop for TreeOp {
43    fn drop(&mut self) {
44        // Early exit for TreeOps which have limited recursion
45        if self.eligible_for_fast_drop() {
46            return;
47        }
48
49        let mut todo = vec![std::mem::replace(self, TreeOp::Const(0.0))];
50        let empty = Arc::new(TreeOp::Const(0.0));
51        while let Some(mut t) = todo.pop() {
52            for t in t.iter_children_mut() {
53                let arg = std::mem::replace(t, empty.clone());
54                todo.extend(Arc::into_inner(arg));
55            }
56            drop(t);
57        }
58    }
59}
60
61impl TreeOp {
62    /// Checks whether the given tree is eligible for fast dropping
63    ///
64    /// Fast dropping uses the normal `Drop` implementation, which recurses on
65    /// the stack and can overflow for deep trees.  A recursive tree is only
66    /// eligible for fast dropping if all of its children are non-recursive.
67    fn eligible_for_fast_drop(&self) -> bool {
68        self.iter_children().all(|c| c.does_not_recurse())
69    }
70
71    /// Returns `true` if the given child does not recurse
72    fn does_not_recurse(&self) -> bool {
73        matches!(self, TreeOp::Const(..) | TreeOp::Input(..))
74    }
75
76    fn iter_children(&self) -> impl Iterator<Item = &Arc<TreeOp>> {
77        match self {
78            TreeOp::Const(..) | TreeOp::Input(..) => [None, None, None, None],
79            TreeOp::Unary(_op, arg) => [Some(arg), None, None, None],
80            TreeOp::Binary(_op, lhs, rhs) => [Some(lhs), Some(rhs), None, None],
81            TreeOp::RemapAxes { target, x, y, z } => {
82                [Some(target), Some(x), Some(y), Some(z)]
83            }
84            TreeOp::RemapAffine { target, .. } => {
85                [Some(target), None, None, None]
86            }
87        }
88        .into_iter()
89        .flatten()
90    }
91
92    fn iter_children_mut(&mut self) -> impl Iterator<Item = &mut Arc<TreeOp>> {
93        match self {
94            TreeOp::Const(..) | TreeOp::Input(..) => [None, None, None, None],
95            TreeOp::Unary(_op, arg) => [Some(arg), None, None, None],
96            TreeOp::Binary(_op, lhs, rhs) => [Some(lhs), Some(rhs), None, None],
97            TreeOp::RemapAxes { target, x, y, z } => {
98                [Some(target), Some(x), Some(y), Some(z)]
99            }
100            TreeOp::RemapAffine { target, .. } => {
101                [Some(target), None, None, None]
102            }
103        }
104        .into_iter()
105        .flatten()
106    }
107}
108
109impl From<f64> for Tree {
110    fn from(v: f64) -> Tree {
111        Tree::constant(v)
112    }
113}
114
115impl From<f32> for Tree {
116    fn from(v: f32) -> Tree {
117        Tree::constant(v as f64)
118    }
119}
120
121impl From<i32> for Tree {
122    fn from(v: i32) -> Tree {
123        Tree::constant(v as f64)
124    }
125}
126
127impl From<Var> for Tree {
128    fn from(v: Var) -> Tree {
129        Tree(Arc::new(TreeOp::Input(v)))
130    }
131}
132
133impl From<TreeOp> for Tree {
134    fn from(t: TreeOp) -> Tree {
135        Tree(Arc::new(t))
136    }
137}
138
139/// Owned handle for a standalone math tree
140#[derive(Clone, Debug, facet::Facet)]
141pub struct Tree(#[facet(opaque)] Arc<TreeOp>);
142
143impl std::ops::Deref for Tree {
144    type Target = TreeOp;
145    fn deref(&self) -> &Self::Target {
146        &self.0
147    }
148}
149
150impl PartialEq for Tree {
151    fn eq(&self, other: &Self) -> bool {
152        if self.ptr_eq(other) {
153            return true;
154        }
155        // Heap recursion using a `Vec`, to avoid blowing up the stack
156        let mut todo = vec![(&self.0, &other.0)];
157        while let Some((a, b)) = todo.pop() {
158            // Pointer equality lets us short-circuit deep checks
159            if Arc::as_ptr(a) == Arc::as_ptr(b) {
160                continue;
161            }
162            // Otherwise, we check opcodes then recurse
163            match (a.as_ref(), b.as_ref()) {
164                (TreeOp::Input(a), TreeOp::Input(b)) => {
165                    if *a != *b {
166                        return false;
167                    }
168                }
169                (TreeOp::Const(a), TreeOp::Const(b)) => {
170                    if *a != *b {
171                        return false;
172                    }
173                }
174                (TreeOp::Unary(op_a, arg_a), TreeOp::Unary(op_b, arg_b)) => {
175                    if *op_a != *op_b {
176                        return false;
177                    }
178                    todo.push((arg_a, arg_b));
179                }
180                (
181                    TreeOp::Binary(op_a, lhs_a, rhs_a),
182                    TreeOp::Binary(op_b, lhs_b, rhs_b),
183                ) => {
184                    if *op_a != *op_b {
185                        return false;
186                    }
187                    todo.push((lhs_a, lhs_b));
188                    todo.push((rhs_a, rhs_b));
189                }
190                (
191                    TreeOp::RemapAxes {
192                        target: t_a,
193                        x: x_a,
194                        y: y_a,
195                        z: z_a,
196                    },
197                    TreeOp::RemapAxes {
198                        target: t_b,
199                        x: x_b,
200                        y: y_b,
201                        z: z_b,
202                    },
203                ) => {
204                    todo.push((t_a, t_b));
205                    todo.push((x_a, x_b));
206                    todo.push((y_a, y_b));
207                    todo.push((z_a, z_b));
208                }
209                (
210                    TreeOp::RemapAffine {
211                        target: t_a,
212                        mat: mat_a,
213                    },
214                    TreeOp::RemapAffine {
215                        target: t_b,
216                        mat: mat_b,
217                    },
218                ) => {
219                    if *mat_a != *mat_b {
220                        return false;
221                    }
222                    todo.push((t_a, t_b));
223                }
224                _ => return false,
225            }
226        }
227        true
228    }
229}
230impl Eq for Tree {}
231
232impl Tree {
233    /// Returns an `(x, y, z)` tuple
234    pub fn axes() -> (Self, Self, Self) {
235        (Self::x(), Self::y(), Self::z())
236    }
237
238    /// Returns a pointer to the inner [`TreeOp`]
239    ///
240    /// This can be used as a strong (but not unique) identity.
241    pub fn as_ptr(&self) -> *const TreeOp {
242        Arc::as_ptr(&self.0)
243    }
244
245    /// Shallow (pointer) equality check
246    pub fn ptr_eq(&self, other: &Self) -> bool {
247        std::ptr::eq(self.as_ptr(), other.as_ptr())
248    }
249
250    /// Borrow the inner `Arc<TreeOp>`
251    pub(crate) fn arc(&self) -> &Arc<TreeOp> {
252        &self.0
253    }
254
255    /// Remaps the axes of the given tree
256    ///
257    /// If the mapping is affine, then [`remap_affine`](Self::remap_affine)
258    /// should be preferred.
259    ///
260    /// The remapping is lazy; it is not evaluated until the tree is imported
261    /// into a `Context`.
262    pub fn remap_xyz(&self, x: Tree, y: Tree, z: Tree) -> Tree {
263        Self(Arc::new(TreeOp::RemapAxes {
264            target: self.0.clone(),
265            x: x.0,
266            y: y.0,
267            z: z.0,
268        }))
269    }
270
271    /// Performs an affine remapping of the given tree
272    ///
273    /// The remapping is lazy; it is not evaluated until the tree is imported
274    /// into a `Context`.
275    pub fn remap_affine(&self, mat: nalgebra::Affine3<f64>) -> Tree {
276        // Flatten affine trees
277        let out = match &*self.0 {
278            TreeOp::RemapAffine { target, mat: next } => TreeOp::RemapAffine {
279                target: target.clone(),
280                mat: next * mat,
281            },
282            _ => TreeOp::RemapAffine {
283                target: self.0.clone(),
284                mat,
285            },
286        };
287        Self(out.into())
288    }
289
290    /// Returns the inner [`Var`] if this is an input tree, or `None`
291    pub fn var(&self) -> Option<Var> {
292        if let TreeOp::Input(v) = &*self.0 {
293            Some(*v)
294        } else {
295            None
296        }
297    }
298
299    /// Performs symbolic differentiation with respect to the given variable
300    pub fn deriv(&self, v: Var) -> Tree {
301        let mut ctx = crate::Context::new();
302        let node = ctx.import(self);
303        ctx.deriv(node, v).and_then(|d| ctx.export(d)).unwrap()
304    }
305
306    /// Raises this tree to the power of an integer using exponentiation by squaring
307    pub fn pow(&self, mut n: i64) -> Self {
308        // TODO should this also be in `Context`?
309        let mut x = match n.cmp(&0) {
310            Ordering::Less => {
311                n = -n;
312                self.recip()
313            }
314            Ordering::Equal => {
315                return Tree::from(1.0);
316            }
317            Ordering::Greater => self.clone(),
318        };
319        let mut y: Option<Tree> = None;
320        while n > 1 {
321            if n % 2 == 1 {
322                y = match y {
323                    Some(y) => Some(x.clone() * y),
324                    None => Some(x.clone()),
325                };
326                n -= 1;
327            }
328            x = x.square();
329            n /= 2;
330        }
331        if let Some(y) = y {
332            x *= y;
333        }
334        x
335    }
336}
337
338impl TryFrom<Tree> for Var {
339    type Error = Error;
340    fn try_from(t: Tree) -> Result<Var, Error> {
341        t.var().ok_or(Error::NotAVar)
342    }
343}
344
345/// See [`Context`](crate::Context) for documentation of these functions
346#[allow(missing_docs)]
347impl Tree {
348    pub fn x() -> Self {
349        Tree(Arc::new(TreeOp::Input(Var::X)))
350    }
351    pub fn y() -> Self {
352        Tree(Arc::new(TreeOp::Input(Var::Y)))
353    }
354    pub fn z() -> Self {
355        Tree(Arc::new(TreeOp::Input(Var::Z)))
356    }
357    pub fn constant(f: f64) -> Self {
358        Tree(Arc::new(TreeOp::Const(f)))
359    }
360    fn op_unary(a: Tree, op: UnaryOpcode) -> Self {
361        Tree(Arc::new(TreeOp::Unary(op, a.0)))
362    }
363    fn op_binary(a: Tree, b: Tree, op: BinaryOpcode) -> Self {
364        Tree(Arc::new(TreeOp::Binary(op, a.0, b.0)))
365    }
366    pub fn square(&self) -> Self {
367        Self::op_unary(self.clone(), UnaryOpcode::Square)
368    }
369    pub fn floor(&self) -> Self {
370        Self::op_unary(self.clone(), UnaryOpcode::Floor)
371    }
372    pub fn ceil(&self) -> Self {
373        Self::op_unary(self.clone(), UnaryOpcode::Ceil)
374    }
375    pub fn round(&self) -> Self {
376        Self::op_unary(self.clone(), UnaryOpcode::Round)
377    }
378    pub fn sqrt(&self) -> Self {
379        Self::op_unary(self.clone(), UnaryOpcode::Sqrt)
380    }
381    pub fn max<T: Into<Tree>>(&self, other: T) -> Self {
382        Self::op_binary(self.clone(), other.into(), BinaryOpcode::Max)
383    }
384    pub fn min<T: Into<Tree>>(&self, other: T) -> Self {
385        Self::op_binary(self.clone(), other.into(), BinaryOpcode::Min)
386    }
387    pub fn compare<T: Into<Tree>>(&self, other: T) -> Self {
388        Self::op_binary(self.clone(), other.into(), BinaryOpcode::Compare)
389    }
390    pub fn modulo<T: Into<Tree>>(&self, other: T) -> Self {
391        Self::op_binary(self.clone(), other.into(), BinaryOpcode::Mod)
392    }
393    pub fn and<T: Into<Tree>>(&self, other: T) -> Self {
394        Self::op_binary(self.clone(), other.into(), BinaryOpcode::And)
395    }
396    pub fn or<T: Into<Tree>>(&self, other: T) -> Self {
397        Self::op_binary(self.clone(), other.into(), BinaryOpcode::Or)
398    }
399    pub fn atan2<T: Into<Tree>>(&self, other: T) -> Self {
400        Self::op_binary(self.clone(), other.into(), BinaryOpcode::Atan)
401    }
402    pub fn neg(&self) -> Self {
403        Self::op_unary(self.clone(), UnaryOpcode::Neg)
404    }
405    pub fn recip(&self) -> Self {
406        Self::op_unary(self.clone(), UnaryOpcode::Recip)
407    }
408    pub fn sin(&self) -> Self {
409        Self::op_unary(self.clone(), UnaryOpcode::Sin)
410    }
411    pub fn cos(&self) -> Self {
412        Self::op_unary(self.clone(), UnaryOpcode::Cos)
413    }
414    pub fn tan(&self) -> Self {
415        Self::op_unary(self.clone(), UnaryOpcode::Tan)
416    }
417    pub fn asin(&self) -> Self {
418        Self::op_unary(self.clone(), UnaryOpcode::Asin)
419    }
420    pub fn acos(&self) -> Self {
421        Self::op_unary(self.clone(), UnaryOpcode::Acos)
422    }
423    pub fn atan(&self) -> Self {
424        Self::op_unary(self.clone(), UnaryOpcode::Atan)
425    }
426    pub fn exp(&self) -> Self {
427        Self::op_unary(self.clone(), UnaryOpcode::Exp)
428    }
429    pub fn ln(&self) -> Self {
430        Self::op_unary(self.clone(), UnaryOpcode::Ln)
431    }
432    pub fn not(&self) -> Self {
433        Self::op_unary(self.clone(), UnaryOpcode::Not)
434    }
435    pub fn abs(&self) -> Self {
436        Self::op_unary(self.clone(), UnaryOpcode::Abs)
437    }
438}
439
440macro_rules! impl_binary {
441    ($op:ident, $op_assign:ident, $base_fn:ident, $assign_fn:ident) => {
442        impl<A: Into<Tree>> std::ops::$op<A> for Tree {
443            type Output = Self;
444
445            fn $base_fn(self, other: A) -> Self {
446                Self::op_binary(self, other.into(), BinaryOpcode::$op)
447            }
448        }
449        impl<A: Into<Tree>> std::ops::$op_assign<A> for Tree {
450            fn $assign_fn(&mut self, other: A) {
451                use std::ops::$op;
452                let mut next = self.clone().$base_fn(other.into());
453                std::mem::swap(self, &mut next);
454            }
455        }
456        impl std::ops::$op<Tree> for f32 {
457            type Output = Tree;
458            fn $base_fn(self, other: Tree) -> Tree {
459                Tree::op_binary(self.into(), other, BinaryOpcode::$op)
460            }
461        }
462        impl std::ops::$op<Tree> for f64 {
463            type Output = Tree;
464            fn $base_fn(self, other: Tree) -> Tree {
465                Tree::op_binary(self.into(), other, BinaryOpcode::$op)
466            }
467        }
468    };
469}
470
471impl_binary!(Add, AddAssign, add, add_assign);
472impl_binary!(Sub, SubAssign, sub, sub_assign);
473impl_binary!(Mul, MulAssign, mul, mul_assign);
474impl_binary!(Div, DivAssign, div, div_assign);
475
476impl std::ops::Neg for Tree {
477    type Output = Tree;
478    fn neg(self) -> Self::Output {
479        Tree::op_unary(self, UnaryOpcode::Neg)
480    }
481}
482
483#[cfg(test)]
484mod test {
485    use super::*;
486    use crate::Context;
487
488    #[test]
489    fn tree_x() {
490        let x1 = Tree::x();
491        let x2 = Tree::x();
492        assert!(!x1.ptr_eq(&x2)); // shallow equality
493        assert_eq!(x1, x2); // deep equality
494
495        let mut ctx = Context::new();
496        let x1 = ctx.import(&x1);
497        let x2 = ctx.import(&x2);
498        assert_eq!(x1, x2);
499    }
500
501    #[test]
502    fn test_remap_xyz() {
503        // Remapping X
504        let s = Tree::x() + 1.0;
505
506        let v = s.remap_xyz(Tree::y(), Tree::z(), Tree::x());
507        let mut ctx = Context::new();
508        let v_ = ctx.import(&v);
509        assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 2.0);
510
511        let v = s.remap_xyz(Tree::z(), Tree::x(), Tree::y());
512        let mut ctx = Context::new();
513        let v_ = ctx.import(&v);
514        assert_eq!(ctx.eval_xyz(v_, 0.0, 0.0, 1.0).unwrap(), 2.0);
515
516        let v = s.remap_xyz(Tree::x(), Tree::y(), Tree::z());
517        let mut ctx = Context::new();
518        let v_ = ctx.import(&v);
519        assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);
520
521        // Remapping Y
522        let s = Tree::y() + 1.0;
523
524        let v = s.remap_xyz(Tree::y(), Tree::z(), Tree::x());
525        let mut ctx = Context::new();
526        let v_ = ctx.import(&v);
527        assert_eq!(ctx.eval_xyz(v_, 0.0, 0.0, 1.0).unwrap(), 2.0);
528
529        let v = s.remap_xyz(Tree::z(), Tree::x(), Tree::y());
530        let mut ctx = Context::new();
531        let v_ = ctx.import(&v);
532        assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);
533
534        let v = s.remap_xyz(Tree::x(), Tree::y(), Tree::z());
535        let mut ctx = Context::new();
536        let v_ = ctx.import(&v);
537        assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 2.0);
538
539        // Remapping Z
540        let s = Tree::z() + 1.0;
541
542        let v = s.remap_xyz(Tree::y(), Tree::z(), Tree::x());
543        let mut ctx = Context::new();
544        let v_ = ctx.import(&v);
545        assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);
546
547        let v = s.remap_xyz(Tree::z(), Tree::x(), Tree::y());
548        let mut ctx = Context::new();
549        let v_ = ctx.import(&v);
550        assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 2.0);
551
552        let v = s.remap_xyz(Tree::x(), Tree::y(), Tree::z());
553        let mut ctx = Context::new();
554        let v_ = ctx.import(&v);
555        assert_eq!(ctx.eval_xyz(v_, 0.0, 0.0, 1.0).unwrap(), 2.0);
556
557        // Test remapping to a constant
558        let s = Tree::x() + 1.0;
559        let one = Tree::constant(3.0);
560        let v = s.remap_xyz(one, Tree::y(), Tree::z());
561        let v_ = ctx.import(&v);
562        assert_eq!(ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap(), 4.0);
563    }
564
565    #[test]
566    fn test_remap_affine() {
567        let s = Tree::x();
568        // Two rotations by 45° -> 90°
569        let t = nalgebra::convert(nalgebra::Rotation3::<f64>::from_axis_angle(
570            &nalgebra::Vector3::<f64>::z_axis(),
571            -std::f64::consts::FRAC_PI_4,
572        ));
573        let s = s.remap_affine(t);
574        let s = s.remap_affine(t);
575
576        let TreeOp::RemapAffine { target, .. } = &*s else {
577            panic!("invalid shape");
578        };
579        assert!(matches!(&**target, TreeOp::Input(Var::X)));
580
581        let mut ctx = Context::new();
582        let v_ = ctx.import(&s);
583
584        assert!((ctx.eval_xyz(v_, 0.0, 1.0, 0.0).unwrap() - 1.0).abs() < 1e-6);
585        assert!(
586            (ctx.eval_xyz(v_, 0.0, -2.0, 0.0).unwrap() - -2.0).abs() < 1e-6
587        );
588    }
589
590    #[test]
591    fn test_remap_order() {
592        let translate = nalgebra::convert(nalgebra::Translation3::<f64>::new(
593            3.0, 0.0, 0.0,
594        ));
595        let scale =
596            nalgebra::convert(nalgebra::Scale3::<f64>::new(0.5, 0.5, 0.5));
597
598        let s = Tree::x();
599        let s = s.remap_affine(translate);
600        let s = s.remap_affine(scale);
601
602        // Confirm that we didn't stack up RemapAffine nodes
603        let TreeOp::RemapAffine { target, .. } = &*s else {
604            panic!("invalid shape");
605        };
606        assert!(matches!(&**target, TreeOp::Input(Var::X)));
607
608        // Basic evaluation testing
609        let mut ctx = Context::new();
610        let v_ = ctx.import(&s);
611        assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 3.5);
612        assert_eq!(ctx.eval_xyz(v_, 2.0, 0.0, 0.0).unwrap(), 4.0);
613
614        // Do the same thing but testing collapsing in `Context::import`
615        let manual = TreeOp::RemapAffine {
616            target: Arc::new(TreeOp::RemapAffine {
617                target: TreeOp::Input(Var::X).into(),
618                mat: scale,
619            }),
620            mat: translate,
621        }
622        .into();
623        let mut ctx = Context::new();
624        let v_ = ctx.import(&manual);
625        assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 3.5);
626        assert_eq!(ctx.eval_xyz(v_, 2.0, 0.0, 0.0).unwrap(), 4.0);
627
628        // Swap the order and make sure it still works
629        let s = Tree::x();
630        let s = s.remap_affine(scale);
631        let s = s.remap_affine(translate);
632
633        let mut ctx = Context::new();
634        let v_ = ctx.import(&s);
635        assert_eq!(ctx.eval_xyz(v_, 1.0, 0.0, 0.0).unwrap(), 2.0);
636        assert_eq!(ctx.eval_xyz(v_, 2.0, 0.0, 0.0).unwrap(), 2.5);
637    }
638
639    #[test]
640    fn deep_recursion_drop() {
641        let mut x = Tree::x();
642        for _ in 0..1_000_000 {
643            x += 1.0;
644        }
645        drop(x);
646        // we should not panic here!
647    }
648
649    #[test]
650    fn deep_recursion_eq() {
651        let mut x1 = Tree::x();
652        for _ in 0..1_000_000 {
653            x1 += 1.0;
654        }
655        let mut x2 = Tree::x();
656        for _ in 0..1_000_000 {
657            x2 += 1.0;
658        }
659        assert_eq!(x1, x2);
660    }
661
662    #[test]
663    fn deep_recursion_import() {
664        let mut x = Tree::x();
665        for _ in 0..1_000_000 {
666            x += 1.0;
667        }
668        let mut ctx = Context::new();
669        ctx.import(&x);
670        // we should not panic here!
671    }
672
673    #[test]
674    fn tree_remap_multi() {
675        let mut ctx = Context::new();
676
677        let out = Tree::x() + Tree::y() + Tree::z();
678        let out =
679            out.remap_xyz(Tree::x() * 2.0, Tree::y() * 3.0, Tree::z() * 5.0);
680
681        let v_ = ctx.import(&out);
682        assert_eq!(ctx.eval_xyz(v_, 1.0, 1.0, 1.0).unwrap(), 10.0);
683        assert_eq!(ctx.eval_xyz(v_, 2.0, 1.0, 1.0).unwrap(), 12.0);
684        assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 1.0).unwrap(), 15.0);
685        assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 2.0).unwrap(), 20.0);
686
687        let out = out.remap_xyz(Tree::y(), Tree::z(), Tree::x());
688        let v_ = ctx.import(&out);
689        assert_eq!(ctx.eval_xyz(v_, 1.0, 1.0, 1.0).unwrap(), 10.0);
690        assert_eq!(ctx.eval_xyz(v_, 2.0, 1.0, 1.0).unwrap(), 15.0);
691        assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 1.0).unwrap(), 17.0);
692        assert_eq!(ctx.eval_xyz(v_, 2.0, 2.0, 2.0).unwrap(), 20.0);
693    }
694
695    #[test]
696    fn tree_import_cache() {
697        let mut x = Tree::x();
698        for _ in 0..100_000 {
699            x += 1.0;
700        }
701        let mut ctx = Context::new();
702        let start = std::time::Instant::now();
703        ctx.import(&x);
704        let small = start.elapsed();
705
706        // Build a new tree with 4 copies of the original
707        let x = x.clone() * x.clone() * x.clone() * x;
708        let mut ctx = Context::new();
709        let start = std::time::Instant::now();
710        ctx.import(&x);
711        let large = start.elapsed();
712
713        assert!(
714            large.as_millis() < small.as_millis() * 2,
715            "tree import cache failed: {large:?} is much larger than {small:?}"
716        );
717    }
718
719    #[test]
720    fn tree_import_nocache() {
721        let mut x = Tree::x();
722        for _ in 0..100_000 {
723            x += 1.0;
724        }
725        let mut ctx = Context::new();
726        let start = std::time::Instant::now();
727        ctx.import(&x);
728        let small = start.elapsed();
729
730        // Build a new tree with 4 remapped versions of the original
731        let x = x.remap_xyz(Tree::y(), Tree::z(), Tree::x())
732            * x.remap_xyz(Tree::z(), Tree::x(), Tree::y())
733            * x.remap_xyz(Tree::y(), Tree::x(), Tree::z())
734            * x;
735        let mut ctx = Context::new();
736        let start = std::time::Instant::now();
737        ctx.import(&x);
738        let large = start.elapsed();
739
740        assert!(
741            large.as_millis() > small.as_millis() * 2,
742            "tree import cache failed:
743             {large:?} is not much larger than {small:?}"
744        );
745    }
746
747    #[test]
748    fn tree_from_int() {
749        let a = Tree::from(3);
750        let b = a * 5;
751
752        let mut ctx = Context::new();
753        let root = ctx.import(&b);
754        assert_eq!(ctx.get_const(root).unwrap(), 15.0);
755    }
756
757    #[test]
758    fn tree_deriv() {
759        // dx/dx = 1
760        let x = Tree::x();
761        let vx = x.var().unwrap();
762        let d = x.deriv(vx);
763        let TreeOp::Const(v) = *d else {
764            panic!("invalid deriv {d:?}")
765        };
766        assert_eq!(v, 1.0);
767
768        // dx/dv = 0
769        let d = x.deriv(Var::new());
770        let TreeOp::Const(v) = *d else {
771            panic!("invalid deriv {d:?}")
772        };
773        assert_eq!(v, 0.0);
774    }
775
776    #[test]
777    fn tree_pow() {
778        let a = Tree::from(3);
779        let b = a.pow(3);
780        let c = a.pow(-3);
781        let d = a.pow(0);
782
783        let mut ctx = Context::new();
784        let root = ctx.import(&b);
785        assert_eq!(ctx.get_const(root).unwrap(), 27.0);
786        ctx.clear();
787        let root = ctx.import(&c);
788        assert_eq!(ctx.get_const(root).unwrap(), 1.0 / 27.0);
789        ctx.clear();
790        let root = ctx.import(&d);
791        assert_eq!(ctx.get_const(root).unwrap(), 1.0);
792    }
793
794    #[test]
795    fn tree_poke() {
796        use facet::Facet;
797        #[derive(facet::Facet)]
798        struct Transform {
799            tree: Tree,
800            x: f64,
801        }
802
803        let mut builder =
804            facet::Partial::alloc_shape(Transform::SHAPE).unwrap();
805        builder
806            .set_field("tree", Tree::x() + 2.0 * Tree::y())
807            .unwrap()
808            .set_field("x", 1.0)
809            .unwrap();
810        let t: Transform = builder.build().unwrap().materialize().unwrap();
811        assert_eq!(t.x, 1.0);
812        let mut ctx = Context::new();
813        let node = ctx.import(&t.tree);
814        assert_eq!(ctx.eval_xyz(node, 1.0, 2.0, 3.0).unwrap(), 5.0);
815    }
816}