fidget_core/context/
mod.rs

1//! Infrastructure for representing math expressions as trees and graphs
2//!
3//! There are two families of representations in this module:
4//!
5//! - A [`Tree`] is a free-floating math expression, which can be cloned
6//!   and has overloaded operators for ease of use.  It is **not** deduplicated;
7//!   two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two
8//!   different objects.
9//!   `Tree` objects are typically used when building up expressions; they
10//!   should be converted to `Node` objects (in a particular `Context`) after
11//!   they have been constructed.
12//! - A [`Context`] is an arena for unique (deduplicated) math expressions,
13//!   which are represented as [`Node`] handles.  Each `Node` is specific to a
14//!   particular context.  Only `Node` objects can be converted into
15//!   [`Function`](crate::eval::Function) objects for evaluation.
16//!
17//! In other words, the typical workflow is `Tree → (Context, Node) → Function`.
18mod indexed;
19mod op;
20mod tree;
21
22use indexed::{Index, IndexMap, IndexVec, define_index};
23pub use op::{BinaryOpcode, Op, UnaryOpcode};
24pub use tree::{Tree, TreeOp};
25
26use crate::{Error, var::Var};
27
28use std::collections::{BTreeMap, HashMap};
29use std::fmt::Write;
30use std::io::{BufRead, BufReader, Read};
31use std::sync::Arc;
32
33use nalgebra::Matrix4;
34use ordered_float::OrderedFloat;
35
36define_index!(Node, "An index in the `Context::ops` map");
37
38/// A `Context` holds a set of deduplicated constants, variables, and
39/// operations.
40///
41/// It should be used like an arena allocator: it grows over time, then frees
42/// all of its contents when dropped.  There is no reference counting within the
43/// context.
44///
45/// Items in the context are accessed with [`Node`] keys, which are simple
46/// handles into an internal map.  Inside the context, operations are
47/// represented with the [`Op`] type.
48#[derive(Debug, Default)]
49pub struct Context {
50    ops: IndexMap<Op, Node>,
51}
52
53impl Context {
54    /// Build a new empty context
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Clears the context
60    ///
61    /// All [`Node`] handles from this context are invalidated.
62    ///
63    /// ```
64    /// # use fidget_core::context::Context;
65    /// let mut ctx = Context::new();
66    /// let x = ctx.x();
67    /// ctx.clear();
68    /// assert!(ctx.eval_xyz(x, 1.0, 0.0, 0.0).is_err());
69    /// ```
70    pub fn clear(&mut self) {
71        self.ops.clear();
72    }
73
74    /// Returns the number of [`Op`] nodes in the context
75    ///
76    /// ```
77    /// # use fidget_core::context::Context;
78    /// let mut ctx = Context::new();
79    /// let x = ctx.x();
80    /// assert_eq!(ctx.len(), 1);
81    /// let y = ctx.y();
82    /// assert_eq!(ctx.len(), 2);
83    /// ctx.clear();
84    /// assert_eq!(ctx.len(), 0);
85    /// ```
86    pub fn len(&self) -> usize {
87        self.ops.len()
88    }
89
90    /// Checks whether the context is empty
91    pub fn is_empty(&self) -> bool {
92        self.ops.is_empty()
93    }
94
95    /// Checks whether the given [`Node`] is valid in this context
96    fn check_node(&self, node: Node) -> Result<(), Error> {
97        self.get_op(node).ok_or(Error::BadNode).map(|_| ())
98    }
99
100    /// Erases the most recently added node from the tree.
101    ///
102    /// A few caveats apply, so this must be used with caution:
103    /// - Existing handles to the node will be invalidated
104    /// - The most recently added node must be unique
105    ///
106    /// In practice, this is only used to delete temporary operation nodes
107    /// during constant folding.  Such nodes which have no handles (because
108    /// they are never returned) and are guaranteed to be unique (because we
109    /// never store them persistently).
110    fn pop(&mut self) -> Result<(), Error> {
111        self.ops.pop().map(|_| ())
112    }
113
114    /// Looks up the constant associated with the given node.
115    ///
116    /// If the node is invalid for this tree, returns an error; if the node is
117    /// not a constant, returns `Ok(None)`.
118    pub fn get_const(&self, n: Node) -> Result<f64, Error> {
119        match self.get_op(n) {
120            Some(Op::Const(c)) => Ok(c.0),
121            Some(_) => Err(Error::NotAConst),
122            _ => Err(Error::BadNode),
123        }
124    }
125
126    /// Looks up the [`Var`] associated with the given node.
127    ///
128    /// If the node is invalid for this tree or not an `Op::Input`, returns an
129    /// error.
130    pub fn get_var(&self, n: Node) -> Result<Var, Error> {
131        match self.get_op(n) {
132            Some(Op::Input(v)) => Ok(*v),
133            Some(..) => Err(Error::NotAVar),
134            _ => Err(Error::BadNode),
135        }
136    }
137
138    ////////////////////////////////////////////////////////////////////////////
139    // Primitives
140    /// Constructs or finds a [`Var::X`] node
141    /// ```
142    /// # use fidget_core::context::Context;
143    /// let mut ctx = Context::new();
144    /// let x = ctx.x();
145    /// let v = ctx.eval_xyz(x, 1.0, 0.0, 0.0).unwrap();
146    /// assert_eq!(v, 1.0);
147    /// ```
148    pub fn x(&mut self) -> Node {
149        self.var(Var::X)
150    }
151
152    /// Constructs or finds a [`Var::Y`] node
153    pub fn y(&mut self) -> Node {
154        self.var(Var::Y)
155    }
156
157    /// Constructs or finds a [`Var::Z`] node
158    pub fn z(&mut self) -> Node {
159        self.var(Var::Z)
160    }
161
162    /// Constructs or finds a variable input node
163    ///
164    /// To make an anonymous variable, call this function with [`Var::new()`]:
165    ///
166    /// ```
167    /// # use fidget_core::{context::Context, var::Var};
168    /// # use std::collections::HashMap;
169    /// let mut ctx = Context::new();
170    /// let v1 = ctx.var(Var::new());
171    /// let v2 = ctx.var(Var::new());
172    /// assert_ne!(v1, v2);
173    ///
174    /// let mut vars = HashMap::new();
175    /// vars.insert(ctx.get_var(v1).unwrap(), 3.0);
176    /// assert_eq!(ctx.eval(v1, &vars).unwrap(), 3.0);
177    /// assert!(ctx.eval(v2, &vars).is_err()); // v2 isn't in the map
178    /// ```
179    pub fn var(&mut self, v: Var) -> Node {
180        self.ops.insert(Op::Input(v))
181    }
182
183    /// Returns a 3-element array of `X`, `Y`, `Z` nodes
184    pub fn axes(&mut self) -> [Node; 3] {
185        [self.x(), self.y(), self.z()]
186    }
187
188    /// Returns a node representing the given constant value.
189    /// ```
190    /// # let mut ctx = fidget_core::context::Context::new();
191    /// let v = ctx.constant(3.0);
192    /// assert_eq!(ctx.eval_xyz(v, 0.0, 0.0, 0.0).unwrap(), 3.0);
193    /// ```
194    pub fn constant(&mut self, f: f64) -> Node {
195        self.ops.insert(Op::Const(OrderedFloat(f)))
196    }
197
198    ////////////////////////////////////////////////////////////////////////////
199    // Helper functions to create nodes with constant folding
200    /// Find or create a [Node] for the given unary operation, with constant
201    /// folding.
202    fn op_unary(&mut self, a: Node, op: UnaryOpcode) -> Result<Node, Error> {
203        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
204        let n = self.ops.insert(Op::Unary(op, a));
205        let out = if matches!(op_a, Op::Const(_)) {
206            let v = self.eval(n, &Default::default())?;
207            self.pop().unwrap(); // removes `n`
208            self.constant(v)
209        } else {
210            n
211        };
212        Ok(out)
213    }
214    /// Find or create a [Node] for the given binary operation, with constant
215    /// folding.
216    fn op_binary(
217        &mut self,
218        a: Node,
219        b: Node,
220        op: BinaryOpcode,
221    ) -> Result<Node, Error> {
222        self.op_binary_f(a, b, |lhs, rhs| Op::Binary(op, lhs, rhs))
223    }
224
225    /// Find or create a [Node] for a generic binary operation (represented by a
226    /// thunk), with constant folding.
227    fn op_binary_f<F>(&mut self, a: Node, b: Node, f: F) -> Result<Node, Error>
228    where
229        F: Fn(Node, Node) -> Op,
230    {
231        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
232        let op_b = *self.get_op(b).ok_or(Error::BadNode)?;
233
234        // This call to `insert` should always insert the node, because we
235        // don't permanently store operations in the tree that could be
236        // constant-folded (indeed, we pop the node right afterwards)
237        let n = self.ops.insert(f(a, b));
238        let out = if matches!((op_a, op_b), (Op::Const(_), Op::Const(_))) {
239            let v = self.eval(n, &Default::default())?;
240            self.pop().unwrap(); // removes `n`
241            self.constant(v)
242        } else {
243            n
244        };
245        Ok(out)
246    }
247
248    /// Find or create a [Node] for the given commutative operation, with
249    /// constant folding; deduplication is encouraged by sorting `a` and `b`.
250    fn op_binary_commutative(
251        &mut self,
252        a: Node,
253        b: Node,
254        op: BinaryOpcode,
255    ) -> Result<Node, Error> {
256        self.op_binary(a.min(b), a.max(b), op)
257    }
258
259    /// Builds an addition node
260    /// ```
261    /// # let mut ctx = fidget_core::context::Context::new();
262    /// let x = ctx.x();
263    /// let op = ctx.add(x, 1.0).unwrap();
264    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
265    /// assert_eq!(v, 2.0);
266    /// ```
267    pub fn add<A: IntoNode, B: IntoNode>(
268        &mut self,
269        a: A,
270        b: B,
271    ) -> Result<Node, Error> {
272        let a: Node = a.into_node(self)?;
273        let b: Node = b.into_node(self)?;
274        if a == b {
275            let two = self.constant(2.0);
276            self.mul(a, two)
277        } else {
278            match (self.get_const(a), self.get_const(b)) {
279                (Ok(0.0), _) => Ok(b),
280                (_, Ok(0.0)) => Ok(a),
281                _ => self.op_binary_commutative(a, b, BinaryOpcode::Add),
282            }
283        }
284    }
285
286    /// Builds an multiplication node
287    /// ```
288    /// # let mut ctx = fidget_core::context::Context::new();
289    /// let x = ctx.x();
290    /// let op = ctx.mul(x, 5.0).unwrap();
291    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
292    /// assert_eq!(v, 10.0);
293    /// ```
294    pub fn mul<A: IntoNode, B: IntoNode>(
295        &mut self,
296        a: A,
297        b: B,
298    ) -> Result<Node, Error> {
299        let a = a.into_node(self)?;
300        let b = b.into_node(self)?;
301        if a == b {
302            self.square(a)
303        } else {
304            match (self.get_const(a), self.get_const(b)) {
305                (Ok(1.0), _) => Ok(b),
306                (_, Ok(1.0)) => Ok(a),
307                (Ok(0.0), _) => Ok(a),
308                (_, Ok(0.0)) => Ok(b),
309                _ => self.op_binary_commutative(a, b, BinaryOpcode::Mul),
310            }
311        }
312    }
313
314    /// Builds an `min` node
315    /// ```
316    /// # let mut ctx = fidget_core::context::Context::new();
317    /// let x = ctx.x();
318    /// let op = ctx.min(x, 5.0).unwrap();
319    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
320    /// assert_eq!(v, 2.0);
321    /// ```
322    pub fn min<A: IntoNode, B: IntoNode>(
323        &mut self,
324        a: A,
325        b: B,
326    ) -> Result<Node, Error> {
327        let a = a.into_node(self)?;
328        let b = b.into_node(self)?;
329        if a == b {
330            Ok(a)
331        } else {
332            self.op_binary_commutative(a, b, BinaryOpcode::Min)
333        }
334    }
335    /// Builds an `max` node
336    /// ```
337    /// # let mut ctx = fidget_core::context::Context::new();
338    /// let x = ctx.x();
339    /// let op = ctx.max(x, 5.0).unwrap();
340    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
341    /// assert_eq!(v, 5.0);
342    /// ```
343    pub fn max<A: IntoNode, B: IntoNode>(
344        &mut self,
345        a: A,
346        b: B,
347    ) -> Result<Node, Error> {
348        let a = a.into_node(self)?;
349        let b = b.into_node(self)?;
350        if a == b {
351            Ok(a)
352        } else {
353            self.op_binary_commutative(a, b, BinaryOpcode::Max)
354        }
355    }
356
357    /// Builds an `and` node
358    ///
359    /// If both arguments are non-zero, returns the right-hand argument.
360    /// Otherwise, returns zero.
361    ///
362    /// This node can be simplified using a tracing evaluator:
363    /// - If the left-hand argument is zero, simplify to just that argument
364    /// - If the left-hand argument is non-zero, simplify to the other argument
365    /// ```
366    /// # let mut ctx = fidget_core::context::Context::new();
367    /// let x = ctx.x();
368    /// let y = ctx.y();
369    /// let op = ctx.and(x, y).unwrap();
370    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
371    /// assert_eq!(v, 0.0);
372    /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
373    /// assert_eq!(v, 1.0);
374    /// let v = ctx.eval_xyz(op, 1.0, 2.0, 0.0).unwrap();
375    /// assert_eq!(v, 2.0);
376    /// ```
377    pub fn and<A: IntoNode, B: IntoNode>(
378        &mut self,
379        a: A,
380        b: B,
381    ) -> Result<Node, Error> {
382        let a = a.into_node(self)?;
383        let b = b.into_node(self)?;
384
385        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
386        if let Op::Const(v) = op_a {
387            if v.0 == 0.0 { Ok(a) } else { Ok(b) }
388        } else {
389            self.op_binary(a, b, BinaryOpcode::And)
390        }
391    }
392
393    /// Builds an `or` node
394    ///
395    /// If the left-hand argument is non-zero, it is returned.  Otherwise, the
396    /// right-hand argument is returned.
397    ///
398    /// This node can be simplified using a tracing evaluator.
399    /// ```
400    /// # let mut ctx = fidget_core::context::Context::new();
401    /// let x = ctx.x();
402    /// let y = ctx.y();
403    /// let op = ctx.or(x, y).unwrap();
404    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
405    /// assert_eq!(v, 1.0);
406    /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
407    /// assert_eq!(v, 0.0);
408    /// let v = ctx.eval_xyz(op, 0.0, 3.0, 0.0).unwrap();
409    /// assert_eq!(v, 3.0);
410    /// ```
411    pub fn or<A: IntoNode, B: IntoNode>(
412        &mut self,
413        a: A,
414        b: B,
415    ) -> Result<Node, Error> {
416        let a = a.into_node(self)?;
417        let b = b.into_node(self)?;
418
419        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
420        let op_b = *self.get_op(b).ok_or(Error::BadNode)?;
421        if let Op::Const(v) = op_a {
422            if v.0 != 0.0 {
423                return Ok(a);
424            } else {
425                return Ok(b);
426            }
427        } else if let Op::Const(v) = op_b {
428            if v.0 == 0.0 {
429                return Ok(a);
430            }
431        }
432        self.op_binary(a, b, BinaryOpcode::Or)
433    }
434
435    /// Builds a logical negation node
436    ///
437    /// The output is 1 if the argument is 0, and 0 otherwise.
438    pub fn not<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
439        let a = a.into_node(self)?;
440        self.op_unary(a, UnaryOpcode::Not)
441    }
442
443    /// Builds a unary negation node
444    /// ```
445    /// # let mut ctx = fidget_core::context::Context::new();
446    /// let x = ctx.x();
447    /// let op = ctx.neg(x).unwrap();
448    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
449    /// assert_eq!(v, -2.0);
450    /// ```
451    pub fn neg<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
452        let a = a.into_node(self)?;
453        self.op_unary(a, UnaryOpcode::Neg)
454    }
455
456    /// Builds a reciprocal node
457    /// ```
458    /// # let mut ctx = fidget_core::context::Context::new();
459    /// let x = ctx.x();
460    /// let op = ctx.recip(x).unwrap();
461    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
462    /// assert_eq!(v, 0.5);
463    /// ```
464    pub fn recip<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
465        let a = a.into_node(self)?;
466        self.op_unary(a, UnaryOpcode::Recip)
467    }
468
469    /// Builds a node which calculates the absolute value of its input
470    /// ```
471    /// # let mut ctx = fidget_core::context::Context::new();
472    /// let x = ctx.x();
473    /// let op = ctx.abs(x).unwrap();
474    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
475    /// assert_eq!(v, 2.0);
476    /// let v = ctx.eval_xyz(op, -2.0, 0.0, 0.0).unwrap();
477    /// assert_eq!(v, 2.0);
478    /// ```
479    pub fn abs<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
480        let a = a.into_node(self)?;
481        self.op_unary(a, UnaryOpcode::Abs)
482    }
483
484    /// Builds a node which calculates the square root of its input
485    /// ```
486    /// # let mut ctx = fidget_core::context::Context::new();
487    /// let x = ctx.x();
488    /// let op = ctx.sqrt(x).unwrap();
489    /// let v = ctx.eval_xyz(op, 4.0, 0.0, 0.0).unwrap();
490    /// assert_eq!(v, 2.0);
491    /// ```
492    pub fn sqrt<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
493        let a = a.into_node(self)?;
494        self.op_unary(a, UnaryOpcode::Sqrt)
495    }
496
497    /// Builds a node which calculates the sine of its input (in radians)
498    /// ```
499    /// # let mut ctx = fidget_core::context::Context::new();
500    /// let x = ctx.x();
501    /// let op = ctx.sin(x).unwrap();
502    /// let v = ctx.eval_xyz(op, std::f64::consts::PI / 2.0, 0.0, 0.0).unwrap();
503    /// assert_eq!(v, 1.0);
504    /// ```
505    pub fn sin<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
506        let a = a.into_node(self)?;
507        self.op_unary(a, UnaryOpcode::Sin)
508    }
509
510    /// Builds a node which calculates the cosine of its input (in radians)
511    pub fn cos<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
512        let a = a.into_node(self)?;
513        self.op_unary(a, UnaryOpcode::Cos)
514    }
515
516    /// Builds a node which calculates the tangent of its input (in radians)
517    pub fn tan<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
518        let a = a.into_node(self)?;
519        self.op_unary(a, UnaryOpcode::Tan)
520    }
521
522    /// Builds a node which calculates the arcsine of its input (in radians)
523    pub fn asin<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
524        let a = a.into_node(self)?;
525        self.op_unary(a, UnaryOpcode::Asin)
526    }
527
528    /// Builds a node which calculates the arccosine of its input (in radians)
529    pub fn acos<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
530        let a = a.into_node(self)?;
531        self.op_unary(a, UnaryOpcode::Acos)
532    }
533
534    /// Builds a node which calculates the arctangent of its input (in radians)
535    pub fn atan<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
536        let a = a.into_node(self)?;
537        self.op_unary(a, UnaryOpcode::Atan)
538    }
539
540    /// Builds a node which calculates the exponent of its input
541    pub fn exp<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
542        let a = a.into_node(self)?;
543        self.op_unary(a, UnaryOpcode::Exp)
544    }
545
546    /// Builds a node which calculates the natural log of its input
547    pub fn ln<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
548        let a = a.into_node(self)?;
549        self.op_unary(a, UnaryOpcode::Ln)
550    }
551
552    ////////////////////////////////////////////////////////////////////////////
553    // Derived functions
554    /// Builds a node which squares its input
555    /// ```
556    /// # let mut ctx = fidget_core::context::Context::new();
557    /// let x = ctx.x();
558    /// let op = ctx.square(x).unwrap();
559    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
560    /// assert_eq!(v, 4.0);
561    /// ```
562    pub fn square<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
563        let a = a.into_node(self)?;
564        self.op_unary(a, UnaryOpcode::Square)
565    }
566
567    /// Builds a node which takes the floor of its input
568    /// ```
569    /// # let mut ctx = fidget_core::context::Context::new();
570    /// let x = ctx.x();
571    /// let op = ctx.floor(x).unwrap();
572    /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
573    /// assert_eq!(v, 1.0);
574    /// ```
575    pub fn floor<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
576        let a = a.into_node(self)?;
577        self.op_unary(a, UnaryOpcode::Floor)
578    }
579
580    /// Builds a node which takes the ceiling of its input
581    /// ```
582    /// # let mut ctx = fidget_core::context::Context::new();
583    /// let x = ctx.x();
584    /// let op = ctx.ceil(x).unwrap();
585    /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
586    /// assert_eq!(v, 2.0);
587    /// ```
588    pub fn ceil<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
589        let a = a.into_node(self)?;
590        self.op_unary(a, UnaryOpcode::Ceil)
591    }
592
593    /// Builds a node which rounds its input to the nearest integer
594    /// ```
595    /// # let mut ctx = fidget_core::context::Context::new();
596    /// let x = ctx.x();
597    /// let op = ctx.round(x).unwrap();
598    /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
599    /// assert_eq!(v, 1.0);
600    /// let v = ctx.eval_xyz(op, 1.6, 0.0, 0.0).unwrap();
601    /// assert_eq!(v, 2.0);
602    /// let v = ctx.eval_xyz(op, 1.5, 0.0, 0.0).unwrap();
603    /// assert_eq!(v, 2.0); // rounds away from 0.0 if ambiguous
604    /// ```
605    pub fn round<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
606        let a = a.into_node(self)?;
607        self.op_unary(a, UnaryOpcode::Round)
608    }
609
610    /// Builds a node which performs subtraction.
611    /// ```
612    /// # let mut ctx = fidget_core::context::Context::new();
613    /// let x = ctx.x();
614    /// let y = ctx.y();
615    /// let op = ctx.sub(x, y).unwrap();
616    /// let v = ctx.eval_xyz(op, 3.0, 2.0, 0.0).unwrap();
617    /// assert_eq!(v, 1.0);
618    /// ```
619    pub fn sub<A: IntoNode, B: IntoNode>(
620        &mut self,
621        a: A,
622        b: B,
623    ) -> Result<Node, Error> {
624        let a = a.into_node(self)?;
625        let b = b.into_node(self)?;
626
627        match (self.get_const(a), self.get_const(b)) {
628            (Ok(0.0), _) => self.neg(b),
629            (_, Ok(0.0)) => Ok(a),
630            _ => self.op_binary(a, b, BinaryOpcode::Sub),
631        }
632    }
633
634    /// Builds a node which performs division.
635    /// ```
636    /// # let mut ctx = fidget_core::context::Context::new();
637    /// let x = ctx.x();
638    /// let y = ctx.y();
639    /// let op = ctx.div(x, y).unwrap();
640    /// let v = ctx.eval_xyz(op, 3.0, 2.0, 0.0).unwrap();
641    /// assert_eq!(v, 1.5);
642    /// ```
643    pub fn div<A: IntoNode, B: IntoNode>(
644        &mut self,
645        a: A,
646        b: B,
647    ) -> Result<Node, Error> {
648        let a = a.into_node(self)?;
649        let b = b.into_node(self)?;
650
651        match (self.get_const(a), self.get_const(b)) {
652            (Ok(0.0), _) => Ok(a),
653            (_, Ok(1.0)) => Ok(a),
654            _ => self.op_binary(a, b, BinaryOpcode::Div),
655        }
656    }
657
658    /// Builds a node which computes `atan2(y, x)`
659    /// ```
660    /// # let mut ctx = fidget_core::context::Context::new();
661    /// let x = ctx.x();
662    /// let y = ctx.y();
663    /// let op = ctx.atan2(y, x).unwrap();
664    /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
665    /// assert_eq!(v, std::f64::consts::FRAC_PI_2);
666    /// ```
667    pub fn atan2<A: IntoNode, B: IntoNode>(
668        &mut self,
669        y: A,
670        x: B,
671    ) -> Result<Node, Error> {
672        let y = y.into_node(self)?;
673        let x = x.into_node(self)?;
674
675        self.op_binary(y, x, BinaryOpcode::Atan)
676    }
677
678    /// Builds a node that compares two values
679    ///
680    /// The result is -1 if `a < b`, +1 if `a > b`, 0 if `a == b`, and `NaN` if
681    /// either side is `NaN`.
682    /// ```
683    /// # let mut ctx = fidget_core::context::Context::new();
684    /// let x = ctx.x();
685    /// let op = ctx.compare(x, 1.0).unwrap();
686    /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
687    /// assert_eq!(v, -1.0);
688    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
689    /// assert_eq!(v, 1.0);
690    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
691    /// assert_eq!(v, 0.0);
692    /// ```
693    pub fn compare<A: IntoNode, B: IntoNode>(
694        &mut self,
695        a: A,
696        b: B,
697    ) -> Result<Node, Error> {
698        let a = a.into_node(self)?;
699        let b = b.into_node(self)?;
700        self.op_binary(a, b, BinaryOpcode::Compare)
701    }
702
703    /// Builds a node that is 1 if `lhs < rhs` and 0 otherwise
704    ///
705    /// ```
706    /// # let mut ctx = fidget_core::context::Context::new();
707    /// let x = ctx.x();
708    /// let y = ctx.y();
709    /// let op = ctx.less_than(x, y).unwrap();
710    /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
711    /// assert_eq!(v, 1.0);
712    /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
713    /// assert_eq!(v, 0.0);
714    /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
715    /// assert_eq!(v, 0.0);
716    /// ```
717    pub fn less_than<A: IntoNode, B: IntoNode>(
718        &mut self,
719        lhs: A,
720        rhs: B,
721    ) -> Result<Node, Error> {
722        let lhs = lhs.into_node(self)?;
723        let rhs = rhs.into_node(self)?;
724        let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
725        self.max(cmp, 0.0)
726    }
727
728    /// Builds a node that is 1 if `lhs <= rhs` and 0 otherwise
729    ///
730    /// ```
731    /// # let mut ctx = fidget_core::context::Context::new();
732    /// let x = ctx.x();
733    /// let y = ctx.y();
734    /// let op = ctx.less_than_or_equal(x, y).unwrap();
735    /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
736    /// assert_eq!(v, 1.0);
737    /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
738    /// assert_eq!(v, 1.0);
739    /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
740    /// assert_eq!(v, 0.0);
741    /// ```
742    pub fn less_than_or_equal<A: IntoNode, B: IntoNode>(
743        &mut self,
744        lhs: A,
745        rhs: B,
746    ) -> Result<Node, Error> {
747        let lhs = lhs.into_node(self)?;
748        let rhs = rhs.into_node(self)?;
749        let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
750        let shift = self.add(cmp, 1.0)?;
751        self.min(shift, 1.0)
752    }
753
754    /// Builds a node that takes the modulo (least non-negative remainder)
755    pub fn modulo<A: IntoNode, B: IntoNode>(
756        &mut self,
757        a: A,
758        b: B,
759    ) -> Result<Node, Error> {
760        let a = a.into_node(self)?;
761        let b = b.into_node(self)?;
762        self.op_binary(a, b, BinaryOpcode::Mod)
763    }
764
765    /// Builds a node that returns the first node if the condition is not
766    /// equal to zero, else returns the other node
767    ///
768    /// The result is `a` if `condition != 0`, else the result is `b`.
769    /// ```
770    /// # let mut ctx = fidget_core::context::Context::new();
771    /// let x = ctx.x();
772    /// let y = ctx.y();
773    /// let z = ctx.z();
774    ///
775    /// let if_else = ctx.if_nonzero_else(x, y, z).unwrap();
776    ///
777    /// assert_eq!(ctx.eval_xyz(if_else, 0.0, 2.0, 3.0).unwrap(), 3.0);
778    /// assert_eq!(ctx.eval_xyz(if_else, 1.0, 2.0, 3.0).unwrap(), 2.0);
779    /// assert_eq!(ctx.eval_xyz(if_else, 0.0, f64::NAN, 3.0).unwrap(), 3.0);
780    /// assert_eq!(ctx.eval_xyz(if_else, 1.0, 2.0, f64::NAN).unwrap(), 2.0);
781    /// ```
782    pub fn if_nonzero_else<Condition: IntoNode, A: IntoNode, B: IntoNode>(
783        &mut self,
784        condition: Condition,
785        a: A,
786        b: B,
787    ) -> Result<Node, Error> {
788        let condition = condition.into_node(self)?;
789        let a = a.into_node(self)?;
790        let b = b.into_node(self)?;
791
792        let lhs = self.and(condition, a)?;
793        let n_condition = self.not(condition)?;
794        let rhs = self.and(n_condition, b)?;
795        self.or(lhs, rhs)
796    }
797
798    ////////////////////////////////////////////////////////////////////////////
799    /// Evaluates the given node with the provided values for X, Y, and Z.
800    ///
801    /// This is extremely inefficient; consider converting the node into a
802    /// [`Shape`](crate::shape::Shape) and using its evaluators instead.
803    ///
804    /// ```
805    /// # let mut ctx = fidget_core::context::Context::new();
806    /// let x = ctx.x();
807    /// let y = ctx.y();
808    /// let z = ctx.z();
809    /// let op = ctx.mul(x, y).unwrap();
810    /// let op = ctx.div(op, z).unwrap();
811    /// let v = ctx.eval_xyz(op, 3.0, 5.0, 2.0).unwrap();
812    /// assert_eq!(v, 7.5); // (3.0 * 5.0) / 2.0
813    /// ```
814    pub fn eval_xyz(
815        &self,
816        root: Node,
817        x: f64,
818        y: f64,
819        z: f64,
820    ) -> Result<f64, Error> {
821        let vars = [(Var::X, x), (Var::Y, y), (Var::Z, z)]
822            .into_iter()
823            .collect();
824        self.eval(root, &vars)
825    }
826
827    /// Evaluates the given node with a generic set of variables
828    ///
829    /// This is extremely inefficient; consider converting the node into a
830    /// [`Shape`](crate::shape::Shape) and using its evaluators instead.
831    pub fn eval(
832        &self,
833        root: Node,
834        vars: &HashMap<Var, f64>,
835    ) -> Result<f64, Error> {
836        let mut cache = vec![None; self.ops.len()].into();
837        self.eval_inner(root, vars, &mut cache)
838    }
839
840    fn eval_inner(
841        &self,
842        node: Node,
843        vars: &HashMap<Var, f64>,
844        cache: &mut IndexVec<Option<f64>, Node>,
845    ) -> Result<f64, Error> {
846        if node.0 >= cache.len() {
847            return Err(Error::BadNode);
848        }
849        if let Some(v) = cache[node] {
850            return Ok(v);
851        }
852        let mut get = |n: Node| self.eval_inner(n, vars, cache);
853        let v = match self.get_op(node).ok_or(Error::BadNode)? {
854            Op::Input(v) => *vars.get(v).ok_or(Error::MissingVar(*v))?,
855            Op::Const(c) => c.0,
856
857            Op::Binary(op, a, b) => {
858                let a = get(*a)?;
859                let b = get(*b)?;
860                match op {
861                    BinaryOpcode::Add => a + b,
862                    BinaryOpcode::Sub => a - b,
863                    BinaryOpcode::Mul => a * b,
864                    BinaryOpcode::Div => a / b,
865                    BinaryOpcode::Atan => a.atan2(b),
866                    BinaryOpcode::Min => a.min(b),
867                    BinaryOpcode::Max => a.max(b),
868                    BinaryOpcode::Compare => a
869                        .partial_cmp(&b)
870                        .map(|i| i as i8 as f64)
871                        .unwrap_or(f64::NAN),
872                    BinaryOpcode::Mod => a.rem_euclid(b),
873                    BinaryOpcode::And => {
874                        if a == 0.0 {
875                            a
876                        } else {
877                            b
878                        }
879                    }
880                    BinaryOpcode::Or => {
881                        if a != 0.0 {
882                            a
883                        } else {
884                            b
885                        }
886                    }
887                }
888            }
889
890            // Unary operations
891            Op::Unary(op, a) => {
892                let a = get(*a)?;
893                match op {
894                    UnaryOpcode::Neg => -a,
895                    UnaryOpcode::Abs => a.abs(),
896                    UnaryOpcode::Recip => 1.0 / a,
897                    UnaryOpcode::Sqrt => a.sqrt(),
898                    UnaryOpcode::Square => a * a,
899                    UnaryOpcode::Floor => a.floor(),
900                    UnaryOpcode::Ceil => a.ceil(),
901                    UnaryOpcode::Round => a.round(),
902                    UnaryOpcode::Sin => a.sin(),
903                    UnaryOpcode::Cos => a.cos(),
904                    UnaryOpcode::Tan => a.tan(),
905                    UnaryOpcode::Asin => a.asin(),
906                    UnaryOpcode::Acos => a.acos(),
907                    UnaryOpcode::Atan => a.atan(),
908                    UnaryOpcode::Exp => a.exp(),
909                    UnaryOpcode::Ln => a.ln(),
910                    UnaryOpcode::Not => (a == 0.0).into(),
911                }
912            }
913        };
914
915        cache[node] = Some(v);
916        Ok(v)
917    }
918
919    /// Parses a flat text representation of a math tree. For example, the
920    /// circle `(- (+ (square x) (square y)) 1)` can be parsed from
921    /// ```
922    /// # use fidget_core::context::Context;
923    /// let txt = "
924    /// ## This is a comment!
925    /// 0x600000b90000 var-x
926    /// 0x600000b900a0 square 0x600000b90000
927    /// 0x600000b90050 var-y
928    /// 0x600000b900f0 square 0x600000b90050
929    /// 0x600000b90140 add 0x600000b900a0 0x600000b900f0
930    /// 0x600000b90190 sqrt 0x600000b90140
931    /// 0x600000b901e0 const 1
932    /// ";
933    /// let (ctx, _node) = Context::from_text(&mut txt.as_bytes()).unwrap();
934    /// assert_eq!(ctx.len(), 7);
935    /// ```
936    ///
937    /// This representation is loosely defined and only intended for use in
938    /// quick experiments.
939    pub fn from_text<R: Read>(r: R) -> Result<(Self, Node), Error> {
940        let reader = BufReader::new(r);
941        let mut ctx = Self::new();
942        let mut seen = BTreeMap::new();
943        let mut last = None;
944
945        for line in reader.lines().map(|line| line.unwrap()) {
946            if line.is_empty() || line.starts_with('#') {
947                continue;
948            }
949            let mut iter = line.split_whitespace();
950            let i: String = iter.next().unwrap().to_owned();
951            let opcode = iter.next().unwrap();
952
953            let mut pop = || {
954                let txt = iter.next().unwrap();
955                seen.get(txt)
956                    .cloned()
957                    .ok_or_else(|| Error::UnknownVariable(txt.to_string()))
958            };
959            let node = match opcode {
960                "const" => ctx.constant(iter.next().unwrap().parse().unwrap()),
961                "var-x" => ctx.x(),
962                "var-y" => ctx.y(),
963                "var-z" => ctx.z(),
964                "abs" => ctx.abs(pop()?)?,
965                "neg" => ctx.neg(pop()?)?,
966                "sqrt" => ctx.sqrt(pop()?)?,
967                "square" => ctx.square(pop()?)?,
968                "floor" => ctx.floor(pop()?)?,
969                "ceil" => ctx.ceil(pop()?)?,
970                "round" => ctx.round(pop()?)?,
971                "sin" => ctx.sin(pop()?)?,
972                "cos" => ctx.cos(pop()?)?,
973                "tan" => ctx.tan(pop()?)?,
974                "asin" => ctx.asin(pop()?)?,
975                "acos" => ctx.acos(pop()?)?,
976                "atan" => ctx.atan(pop()?)?,
977                "ln" => ctx.ln(pop()?)?,
978                "not" => ctx.not(pop()?)?,
979                "exp" => ctx.exp(pop()?)?,
980                "add" => ctx.add(pop()?, pop()?)?,
981                "mul" => ctx.mul(pop()?, pop()?)?,
982                "min" => ctx.min(pop()?, pop()?)?,
983                "max" => ctx.max(pop()?, pop()?)?,
984                "div" => ctx.div(pop()?, pop()?)?,
985                "atan2" => ctx.atan2(pop()?, pop()?)?,
986                "sub" => ctx.sub(pop()?, pop()?)?,
987                "compare" => ctx.compare(pop()?, pop()?)?,
988                "mod" => ctx.modulo(pop()?, pop()?)?,
989                "and" => ctx.and(pop()?, pop()?)?,
990                "or" => ctx.or(pop()?, pop()?)?,
991                op => return Err(Error::UnknownOpcode(op.to_owned())),
992            };
993            seen.insert(i, node);
994            last = Some(node);
995        }
996        match last {
997            Some(node) => Ok((ctx, node)),
998            None => Err(Error::EmptyFile),
999        }
1000    }
1001
1002    /// Converts the entire context into a GraphViz drawing
1003    pub fn dot(&self) -> String {
1004        let mut out = "digraph mygraph{\n".to_owned();
1005        for node in self.ops.keys() {
1006            let op = self.get_op(node).unwrap();
1007            out += &self.dot_node(node);
1008            out += &op.dot_edges(node);
1009        }
1010        out += "}\n";
1011        out
1012    }
1013
1014    /// Converts the given node into a GraphViz node
1015    ///
1016    /// (this is a local function instead of a function on `Op` because it
1017    ///  requires looking up variables by name)
1018    fn dot_node(&self, i: Node) -> String {
1019        let mut out = format!(r#"n{} [label = ""#, i.get());
1020        let op = self.get_op(i).unwrap();
1021        match op {
1022            Op::Const(c) => write!(out, "{c}").unwrap(),
1023            Op::Input(v) => {
1024                out += &v.to_string();
1025            }
1026            Op::Binary(op, ..) => match op {
1027                BinaryOpcode::Add => out += "add",
1028                BinaryOpcode::Sub => out += "sub",
1029                BinaryOpcode::Mul => out += "mul",
1030                BinaryOpcode::Div => out += "div",
1031                BinaryOpcode::Atan => out += "atan2",
1032                BinaryOpcode::Min => out += "min",
1033                BinaryOpcode::Max => out += "max",
1034                BinaryOpcode::Compare => out += "compare",
1035                BinaryOpcode::Mod => out += "mod",
1036                BinaryOpcode::And => out += "and",
1037                BinaryOpcode::Or => out += "or",
1038            },
1039            Op::Unary(op, ..) => match op {
1040                UnaryOpcode::Neg => out += "neg",
1041                UnaryOpcode::Abs => out += "abs",
1042                UnaryOpcode::Recip => out += "recip",
1043                UnaryOpcode::Sqrt => out += "sqrt",
1044                UnaryOpcode::Square => out += "square",
1045                UnaryOpcode::Floor => out += "floor",
1046                UnaryOpcode::Ceil => out += "ceil",
1047                UnaryOpcode::Round => out += "round",
1048                UnaryOpcode::Sin => out += "sin",
1049                UnaryOpcode::Cos => out += "cos",
1050                UnaryOpcode::Tan => out += "tan",
1051                UnaryOpcode::Asin => out += "asin",
1052                UnaryOpcode::Acos => out += "acos",
1053                UnaryOpcode::Atan => out += "atan",
1054                UnaryOpcode::Exp => out += "exp",
1055                UnaryOpcode::Ln => out += "ln",
1056                UnaryOpcode::Not => out += "not",
1057            },
1058        };
1059        write!(
1060            out,
1061            r#"" color="{0}1" shape="{1}" fontcolor="{0}4"]"#,
1062            op.dot_node_color(),
1063            op.dot_node_shape()
1064        )
1065        .unwrap();
1066        out
1067    }
1068
1069    /// Looks up an operation by `Node` handle
1070    pub fn get_op(&self, node: Node) -> Option<&Op> {
1071        self.ops.get_by_index(node)
1072    }
1073
1074    /// Imports the given tree, deduplicating and returning the root
1075    pub fn import(&mut self, tree: &Tree) -> Node {
1076        // A naive remapping implementation would use recursion.  A naive
1077        // remapping implementation would blow up the stack given any
1078        // significant tree size.
1079        //
1080        // Instead, we maintain our own pseudo-stack here in a pair of Vecs (one
1081        // stack for actions, and a second stack for return values).
1082        enum Action<'a> {
1083            /// Pushes `Up(op)` followed by `Down(c)` for each child
1084            Down(&'a Arc<TreeOp>),
1085            /// Consumes imported trees from the stack and pushes a new tree
1086            Up(&'a Arc<TreeOp>),
1087            /// Pops the latest axis frame
1088            Pop,
1089            /// Pops the latest affine frame
1090            PopAffine,
1091        }
1092        let mut axes = vec![(self.x(), self.y(), self.z())];
1093        let mut todo = vec![Action::Down(tree.arc())];
1094        let mut stack = vec![];
1095        let mut affine: Vec<Matrix4<f64>> = vec![];
1096
1097        // Cache of TreeOp -> Node mapping under a particular frame (axes)
1098        //
1099        // This isn't required for correctness, but can be a speed optimization
1100        // (because it means we don't have to walk the same tree twice).
1101        let mut seen = HashMap::new();
1102
1103        while let Some(t) = todo.pop() {
1104            match t {
1105                Action::Down(t) => {
1106                    // If we've already seen this TreeOp with these axes, then
1107                    // we can return the previous Node.
1108                    if matches!(
1109                        t.as_ref(),
1110                        TreeOp::Unary(..) | TreeOp::Binary(..)
1111                    ) {
1112                        if let Some(p) =
1113                            seen.get(&(*axes.last().unwrap(), Arc::as_ptr(t)))
1114                        {
1115                            stack.push(*p);
1116                            continue;
1117                        }
1118                    }
1119                    match t.as_ref() {
1120                        TreeOp::Const(c) => {
1121                            stack.push(self.constant(*c));
1122                        }
1123                        TreeOp::Input(s) => {
1124                            let axes = axes.last().unwrap();
1125                            stack.push(match *s {
1126                                Var::X => axes.0,
1127                                Var::Y => axes.1,
1128                                Var::Z => axes.2,
1129                                v @ Var::V(..) => self.var(v),
1130                            });
1131                        }
1132                        TreeOp::Unary(_op, arg) => {
1133                            todo.push(Action::Up(t));
1134                            todo.push(Action::Down(arg));
1135                        }
1136                        TreeOp::Binary(_op, lhs, rhs) => {
1137                            todo.push(Action::Up(t));
1138                            todo.push(Action::Down(lhs));
1139                            todo.push(Action::Down(rhs));
1140                        }
1141                        TreeOp::RemapAxes { target: _, x, y, z } => {
1142                            // Action::Up(t) does the remapping and target eval
1143                            todo.push(Action::Up(t));
1144                            todo.push(Action::Down(x));
1145                            todo.push(Action::Down(y));
1146                            todo.push(Action::Down(z));
1147                        }
1148                        TreeOp::RemapAffine { target, mat } => {
1149                            let prev = affine
1150                                .last()
1151                                .cloned()
1152                                .unwrap_or(Matrix4::identity());
1153                            let mat = prev * mat.to_homogeneous();
1154
1155                            // Push either an affine frame or an axis frame,
1156                            // depending on whether the target is also affine
1157                            if matches!(&**target, TreeOp::RemapAffine { .. }) {
1158                                affine.push(mat);
1159                                todo.push(Action::PopAffine);
1160                            } else {
1161                                let (x, y, z) = axes.last().unwrap();
1162                                let mut out = [None; 3];
1163                                for i in 0..3 {
1164                                    let a = self.mul(mat[(i, 0)], *x).unwrap();
1165                                    let b = self.mul(mat[(i, 1)], *y).unwrap();
1166                                    let c = self.mul(mat[(i, 2)], *z).unwrap();
1167                                    let d = self.constant(mat[(i, 3)]);
1168                                    let ab = self.add(a, b).unwrap();
1169                                    let cd = self.add(c, d).unwrap();
1170                                    out[i] = Some(self.add(ab, cd).unwrap());
1171                                }
1172                                let [x, y, z] = out.map(Option::unwrap);
1173                                axes.push((x, y, z));
1174                                todo.push(Action::Pop);
1175                            }
1176                            todo.push(Action::Down(target));
1177                        }
1178                    }
1179                }
1180                Action::Up(t) => {
1181                    match t.as_ref() {
1182                        TreeOp::Const(..)
1183                        | TreeOp::Input(..)
1184                        | TreeOp::RemapAffine { .. } => unreachable!(),
1185                        TreeOp::Unary(op, ..) => {
1186                            let arg = stack.pop().unwrap();
1187                            let out = self.op_unary(arg, *op).unwrap();
1188                            stack.push(out);
1189                        }
1190                        TreeOp::Binary(op, ..) => {
1191                            let lhs = stack.pop().unwrap();
1192                            let rhs = stack.pop().unwrap();
1193                            // Call individual builders to apply optimizations
1194                            let out = match op {
1195                                BinaryOpcode::Add => self.add(lhs, rhs),
1196                                BinaryOpcode::Sub => self.sub(lhs, rhs),
1197                                BinaryOpcode::Mul => self.mul(lhs, rhs),
1198                                BinaryOpcode::Div => self.div(lhs, rhs),
1199                                BinaryOpcode::Atan => self.atan2(lhs, rhs),
1200                                BinaryOpcode::Min => self.min(lhs, rhs),
1201                                BinaryOpcode::Max => self.max(lhs, rhs),
1202                                BinaryOpcode::Compare => self.compare(lhs, rhs),
1203                                BinaryOpcode::Mod => self.modulo(lhs, rhs),
1204                                BinaryOpcode::And => self.and(lhs, rhs),
1205                                BinaryOpcode::Or => self.or(lhs, rhs),
1206                            }
1207                            .unwrap();
1208                            if Arc::strong_count(t) > 1 {
1209                                seen.insert(
1210                                    (*axes.last().unwrap(), Arc::as_ptr(t)),
1211                                    out,
1212                                );
1213                            }
1214                            stack.push(out);
1215                        }
1216                        TreeOp::RemapAxes { target, .. } => {
1217                            let x = stack.pop().unwrap();
1218                            let y = stack.pop().unwrap();
1219                            let z = stack.pop().unwrap();
1220                            axes.push((x, y, z));
1221                            todo.push(Action::Pop);
1222                            todo.push(Action::Down(target));
1223                        }
1224                    }
1225                    // Update the cache with the new tree, if relevant
1226                    //
1227                    // The `strong_count` check is a rough heuristic to avoid
1228                    // caching if there's only a single owner of the tree.  This
1229                    // isn't perfect, but it doesn't need to be for correctness.
1230                    if matches!(
1231                        t.as_ref(),
1232                        TreeOp::Unary(..) | TreeOp::Binary(..)
1233                    ) && Arc::strong_count(t) > 1
1234                    {
1235                        seen.insert(
1236                            (*axes.last().unwrap(), Arc::as_ptr(t)),
1237                            *stack.last().unwrap(),
1238                        );
1239                    }
1240                }
1241                Action::Pop => {
1242                    axes.pop().unwrap();
1243                }
1244                Action::PopAffine => {
1245                    affine.pop().unwrap();
1246                }
1247            }
1248        }
1249        assert_eq!(stack.len(), 1);
1250        stack.pop().unwrap()
1251    }
1252
1253    /// Converts from a context-specific node into a standalone [`Tree`]
1254    pub fn export(&self, n: Node) -> Result<Tree, Error> {
1255        if self.get_op(n).is_none() {
1256            return Err(Error::BadNode);
1257        }
1258
1259        // Do recursion on the heap to avoid stack overflows for deep trees
1260        enum Action {
1261            /// Pushes `Up(n)` followed by `Down(n)` for each child
1262            Down(Node),
1263            /// Consumes trees from the stack and pushes a new tree
1264            Up(Node, Op),
1265        }
1266        let mut todo = vec![Action::Down(n)];
1267        let mut stack = vec![];
1268
1269        // Cache of Node -> Tree mapping, for Tree deduplication
1270        let mut seen: HashMap<Node, Tree> = HashMap::new();
1271
1272        while let Some(t) = todo.pop() {
1273            match t {
1274                Action::Down(n) => {
1275                    // If we've already seen this TreeOp with these axes, then
1276                    // we can return the previous Node.
1277                    if let Some(p) = seen.get(&n) {
1278                        stack.push(p.clone());
1279                        continue;
1280                    }
1281                    let op = self.get_op(n).unwrap();
1282                    match op {
1283                        Op::Const(c) => {
1284                            let t = Tree::from(c.0);
1285                            seen.insert(n, t.clone());
1286                            stack.push(t);
1287                        }
1288                        Op::Input(v) => {
1289                            let t = Tree::from(*v);
1290                            seen.insert(n, t.clone());
1291                            stack.push(t);
1292                        }
1293                        Op::Unary(_op, arg) => {
1294                            todo.push(Action::Up(n, *op));
1295                            todo.push(Action::Down(*arg));
1296                        }
1297                        Op::Binary(_op, lhs, rhs) => {
1298                            todo.push(Action::Up(n, *op));
1299                            todo.push(Action::Down(*lhs));
1300                            todo.push(Action::Down(*rhs));
1301                        }
1302                    }
1303                }
1304                Action::Up(n, op) => match op {
1305                    Op::Const(..) | Op::Input(..) => unreachable!(),
1306                    Op::Unary(op, ..) => {
1307                        let arg = stack.pop().unwrap();
1308                        let out =
1309                            Tree::from(TreeOp::Unary(op, arg.arc().clone()));
1310                        seen.insert(n, out.clone());
1311                        stack.push(out);
1312                    }
1313                    Op::Binary(op, ..) => {
1314                        let lhs = stack.pop().unwrap();
1315                        let rhs = stack.pop().unwrap();
1316                        let out = Tree::from(TreeOp::Binary(
1317                            op,
1318                            lhs.arc().clone(),
1319                            rhs.arc().clone(),
1320                        ));
1321                        seen.insert(n, out.clone());
1322                        stack.push(out);
1323                    }
1324                },
1325            }
1326        }
1327        assert_eq!(stack.len(), 1);
1328        Ok(stack.pop().unwrap())
1329    }
1330
1331    /// Takes the symbolic derivative of a node with respect to a variable
1332    pub fn deriv(&mut self, n: Node, v: Var) -> Result<Node, Error> {
1333        if self.get_op(n).is_none() {
1334            return Err(Error::BadNode);
1335        }
1336
1337        // Do recursion on the heap to avoid stack overflows for deep trees
1338        enum Action {
1339            /// Pushes `Up(n)` followed by `Down(n)` for each child
1340            Down(Node),
1341            /// Consumes trees from the stack and pushes a new tree
1342            Up(Node, Op),
1343        }
1344        let mut todo = vec![Action::Down(n)];
1345        let mut stack = vec![];
1346        let zero = self.constant(0.0);
1347
1348        // Cache of Node -> Node mapping, for deduplication
1349        let mut seen: HashMap<Node, Node> = HashMap::new();
1350
1351        while let Some(t) = todo.pop() {
1352            match t {
1353                Action::Down(n) => {
1354                    // If we've already seen this TreeOp with these axes, then
1355                    // we can return the previous Node.
1356                    if let Some(p) = seen.get(&n) {
1357                        stack.push(*p);
1358                        continue;
1359                    }
1360                    let op = *self.get_op(n).unwrap();
1361                    match op {
1362                        Op::Const(_c) => {
1363                            seen.insert(n, zero);
1364                            stack.push(zero);
1365                        }
1366                        Op::Input(u) => {
1367                            let z =
1368                                if v == u { self.constant(1.0) } else { zero };
1369                            seen.insert(n, z);
1370                            stack.push(z);
1371                        }
1372                        Op::Unary(_op, arg) => {
1373                            todo.push(Action::Up(n, op));
1374                            todo.push(Action::Down(arg));
1375                        }
1376                        Op::Binary(_op, lhs, rhs) => {
1377                            todo.push(Action::Up(n, op));
1378                            todo.push(Action::Down(lhs));
1379                            todo.push(Action::Down(rhs));
1380                        }
1381                    }
1382                }
1383                Action::Up(n, op) => match op {
1384                    Op::Const(..) | Op::Input(..) => unreachable!(),
1385                    Op::Unary(op, v_arg) => {
1386                        let d_arg = stack.pop().unwrap();
1387                        let out = match op {
1388                            UnaryOpcode::Neg => self.neg(d_arg),
1389                            UnaryOpcode::Abs => {
1390                                let cond = self.less_than(v_arg, zero).unwrap();
1391                                let pos = d_arg;
1392                                let neg = self.neg(d_arg).unwrap();
1393                                self.if_nonzero_else(cond, neg, pos)
1394                            }
1395                            UnaryOpcode::Recip => {
1396                                let a = self.square(v_arg).unwrap();
1397                                let b = self.neg(d_arg).unwrap();
1398                                self.div(b, a)
1399                            }
1400                            UnaryOpcode::Sqrt => {
1401                                let v = self.mul(n, 2.0).unwrap();
1402                                self.div(d_arg, v)
1403                            }
1404                            UnaryOpcode::Square => {
1405                                let v = self.mul(d_arg, v_arg).unwrap();
1406                                self.mul(2.0, v)
1407                            }
1408                            // Discontinuous constants don't have Dirac deltas
1409                            UnaryOpcode::Floor
1410                            | UnaryOpcode::Ceil
1411                            | UnaryOpcode::Round => Ok(zero),
1412
1413                            UnaryOpcode::Sin => {
1414                                let c = self.cos(v_arg).unwrap();
1415                                self.mul(c, d_arg)
1416                            }
1417
1418                            UnaryOpcode::Cos => {
1419                                let s = self.sin(v_arg).unwrap();
1420                                let s = self.neg(s).unwrap();
1421                                self.mul(s, d_arg)
1422                            }
1423
1424                            UnaryOpcode::Tan => {
1425                                let c = self.cos(v_arg).unwrap();
1426                                let c = self.square(c).unwrap();
1427                                self.div(d_arg, c)
1428                            }
1429
1430                            UnaryOpcode::Asin => {
1431                                let v = self.square(v_arg).unwrap();
1432                                let v = self.sub(1.0, v).unwrap();
1433                                let v = self.sqrt(v).unwrap();
1434                                self.div(d_arg, v)
1435                            }
1436                            UnaryOpcode::Acos => {
1437                                let v = self.square(v_arg).unwrap();
1438                                let v = self.sub(1.0, v).unwrap();
1439                                let v = self.sqrt(v).unwrap();
1440                                let v = self.neg(v).unwrap();
1441                                self.div(d_arg, v)
1442                            }
1443                            UnaryOpcode::Atan => {
1444                                let v = self.square(v_arg).unwrap();
1445                                let v = self.add(1.0, v).unwrap();
1446                                self.div(d_arg, v)
1447                            }
1448                            UnaryOpcode::Exp => self.mul(n, d_arg),
1449                            UnaryOpcode::Ln => self.div(d_arg, v_arg),
1450                            UnaryOpcode::Not => Ok(zero),
1451                        }
1452                        .unwrap();
1453                        seen.insert(n, out);
1454                        stack.push(out);
1455                    }
1456                    Op::Binary(op, v_lhs, v_rhs) => {
1457                        let d_lhs = stack.pop().unwrap();
1458                        let d_rhs = stack.pop().unwrap();
1459                        let out = match op {
1460                            BinaryOpcode::Add => self.add(d_lhs, d_rhs),
1461                            BinaryOpcode::Sub => self.sub(d_lhs, d_rhs),
1462                            BinaryOpcode::Mul => {
1463                                let a = self.mul(d_lhs, v_rhs).unwrap();
1464                                let b = self.mul(v_lhs, d_rhs).unwrap();
1465                                self.add(a, b)
1466                            }
1467                            BinaryOpcode::Div => {
1468                                let v = self.square(v_rhs).unwrap();
1469                                let a = self.mul(v_rhs, d_lhs).unwrap();
1470                                let b = self.mul(v_lhs, d_rhs).unwrap();
1471                                let c = self.sub(a, b).unwrap();
1472                                self.div(c, v)
1473                            }
1474                            BinaryOpcode::Atan => {
1475                                let a = self.square(v_lhs).unwrap();
1476                                let b = self.square(v_rhs).unwrap();
1477                                let d = self.add(a, b).unwrap();
1478
1479                                let a = self.mul(v_rhs, d_lhs).unwrap();
1480                                let b = self.mul(v_lhs, d_rhs).unwrap();
1481                                let v = self.sub(a, b).unwrap();
1482                                self.div(v, d)
1483                            }
1484                            BinaryOpcode::Min => {
1485                                let cond =
1486                                    self.less_than(v_lhs, v_rhs).unwrap();
1487                                self.if_nonzero_else(cond, d_lhs, d_rhs)
1488                            }
1489                            BinaryOpcode::Max => {
1490                                let cond =
1491                                    self.less_than(v_rhs, v_lhs).unwrap();
1492                                self.if_nonzero_else(cond, d_lhs, d_rhs)
1493                            }
1494                            BinaryOpcode::Compare => Ok(zero),
1495                            BinaryOpcode::Mod => {
1496                                let e = self.div(v_lhs, v_rhs).unwrap();
1497                                let q = self.floor(e).unwrap();
1498
1499                                // XXX
1500                                // (we don't actually have %, so hack it from
1501                                // `modulo`, which is actually `rem_euclid`)
1502                                // ???
1503                                let m = self.modulo(q, v_rhs).unwrap();
1504                                let cond = self.less_than(q, zero).unwrap();
1505                                let offset = self
1506                                    .if_nonzero_else(cond, v_rhs, zero)
1507                                    .unwrap();
1508                                let m = self.sub(m, offset).unwrap();
1509
1510                                // Torn from the div_euclid implementation
1511                                let outer = self.less_than(m, zero).unwrap();
1512                                let inner =
1513                                    self.less_than(zero, v_rhs).unwrap();
1514                                let qa = self.sub(q, 1.0).unwrap();
1515                                let qb = self.add(q, 1.0).unwrap();
1516                                let inner = self
1517                                    .if_nonzero_else(inner, qa, qb)
1518                                    .unwrap();
1519                                let e = self
1520                                    .if_nonzero_else(outer, inner, q)
1521                                    .unwrap();
1522
1523                                let v = self.mul(d_rhs, e).unwrap();
1524                                self.sub(d_lhs, v)
1525                            }
1526                            BinaryOpcode::And => {
1527                                let cond = self.compare(v_lhs, zero).unwrap();
1528                                self.if_nonzero_else(cond, d_rhs, d_lhs)
1529                            }
1530                            BinaryOpcode::Or => {
1531                                let cond = self.compare(v_lhs, zero).unwrap();
1532                                self.if_nonzero_else(cond, d_lhs, d_rhs)
1533                            }
1534                        }
1535                        .unwrap();
1536                        seen.insert(n, out);
1537                        stack.push(out);
1538                    }
1539                },
1540            }
1541        }
1542        assert_eq!(stack.len(), 1);
1543        Ok(stack.pop().unwrap())
1544    }
1545}
1546
1547////////////////////////////////////////////////////////////////////////////////
1548/// Helper trait for things that can be converted into a [`Node`] given a
1549/// [`Context`].
1550///
1551/// This trait allows you to write
1552/// ```
1553/// # let mut ctx = fidget_core::context::Context::new();
1554/// let x = ctx.x();
1555/// let sum = ctx.add(x, 1.0).unwrap();
1556/// ```
1557/// instead of the more verbose
1558/// ```
1559/// # let mut ctx = fidget_core::context::Context::new();
1560/// let x = ctx.x();
1561/// let num = ctx.constant(1.0);
1562/// let sum = ctx.add(x, num).unwrap();
1563/// ```
1564pub trait IntoNode {
1565    /// Converts the given values into a node
1566    fn into_node(self, ctx: &mut Context) -> Result<Node, Error>;
1567}
1568
1569impl IntoNode for Node {
1570    fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1571        ctx.check_node(self)?;
1572        Ok(self)
1573    }
1574}
1575
1576impl IntoNode for f32 {
1577    fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1578        Ok(ctx.constant(self as f64))
1579    }
1580}
1581
1582impl IntoNode for f64 {
1583    fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1584        Ok(ctx.constant(self))
1585    }
1586}
1587
1588////////////////////////////////////////////////////////////////////////////////
1589
1590#[cfg(test)]
1591mod test {
1592    use super::*;
1593    use crate::vm::VmData;
1594
1595    // This can't be in a doctest, because it uses a private function
1596    #[test]
1597    fn test_get_op() {
1598        let mut ctx = Context::new();
1599        let x = ctx.x();
1600        let op_x = ctx.get_op(x).unwrap();
1601        assert!(matches!(op_x, Op::Input(_)));
1602    }
1603
1604    #[test]
1605    fn test_ring() {
1606        let mut ctx = Context::new();
1607        let c0 = ctx.constant(0.5);
1608        let x = ctx.x();
1609        let y = ctx.y();
1610        let x2 = ctx.square(x).unwrap();
1611        let y2 = ctx.square(y).unwrap();
1612        let r = ctx.add(x2, y2).unwrap();
1613        let c6 = ctx.sub(r, c0).unwrap();
1614        let c7 = ctx.constant(0.25);
1615        let c8 = ctx.sub(c7, r).unwrap();
1616        let c9 = ctx.max(c8, c6).unwrap();
1617
1618        let tape = VmData::<255>::new(&ctx, &[c9]).unwrap();
1619        assert_eq!(tape.len(), 9);
1620        assert_eq!(tape.vars.len(), 2);
1621    }
1622
1623    #[test]
1624    fn test_dupe() {
1625        let mut ctx = Context::new();
1626        let x = ctx.x();
1627        let x_squared = ctx.mul(x, x).unwrap();
1628
1629        let tape = VmData::<255>::new(&ctx, &[x_squared]).unwrap();
1630        assert_eq!(tape.len(), 3); // x, square, output
1631        assert_eq!(tape.vars.len(), 1);
1632    }
1633
1634    #[test]
1635    fn test_export() {
1636        let mut ctx = Context::new();
1637        let x = ctx.x();
1638        let s = ctx.sin(x).unwrap();
1639        let c = ctx.cos(x).unwrap();
1640        let sum = ctx.add(s, c).unwrap();
1641        let t = ctx.export(sum).unwrap();
1642        if let TreeOp::Binary(BinaryOpcode::Add, lhs, rhs) = &*t {
1643            match (&**lhs, &**rhs) {
1644                (
1645                    TreeOp::Unary(UnaryOpcode::Sin, x1),
1646                    TreeOp::Unary(UnaryOpcode::Cos, x2),
1647                ) => {
1648                    assert_eq!(Arc::as_ptr(x1), Arc::as_ptr(x2));
1649                    let TreeOp::Input(Var::X) = &**x1 else {
1650                        panic!("invalid X: {x1:?}");
1651                    };
1652                }
1653                _ => panic!("invalid lhs / rhs: {lhs:?} {rhs:?}"),
1654            }
1655        } else {
1656            panic!("unexpected opcode {t:?}");
1657        }
1658    }
1659
1660    #[test]
1661    fn import_optimization() {
1662        let t = Tree::x() + 0;
1663        let mut ctx = Context::new();
1664        let root = ctx.import(&t);
1665        assert_eq!(ctx.get_op(root).unwrap(), &Op::Input(Var::X));
1666    }
1667}