Skip to main content

fidget_core/context/
mod.rs

1//! Infrastructure for representing math expressions as trees and graphs
2//!
3//! There are two families of representations in this module:
4//!
5//! - A [`Tree`] is a free-floating math expression, which can be cloned
6//!   and has overloaded operators for ease of use.  It is **not** deduplicated;
7//!   two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two
8//!   different objects.
9//!   `Tree` objects are typically used when building up expressions; they
10//!   should be converted to `Node` objects (in a particular `Context`) after
11//!   they have been constructed.
12//! - A [`Context`] is an arena for unique (deduplicated) math expressions,
13//!   which are represented as [`Node`] handles.  Each `Node` is specific to a
14//!   particular context.  Only `Node` objects can be converted into
15//!   [`Function`](crate::eval::Function) objects for evaluation.
16//!
17//! In other words, the typical workflow is `Tree → (Context, Node) → Function`.
18mod indexed;
19mod op;
20mod tree;
21
22use indexed::{Index, IndexMap, IndexVec, define_index};
23pub use op::{BinaryOpcode, Op, UnaryOpcode};
24pub use tree::{Tree, TreeOp};
25
26use crate::{Error, var::Var};
27
28use std::collections::{BTreeMap, HashMap};
29use std::fmt::Write;
30use std::io::{BufRead, BufReader, Read};
31use std::sync::Arc;
32
33use nalgebra::Matrix4;
34use ordered_float::OrderedFloat;
35
36define_index!(Node, "An index in the `Context::ops` map");
37
38/// A `Context` holds a set of deduplicated constants, variables, and
39/// operations.
40///
41/// It should be used like an arena allocator: it grows over time, then frees
42/// all of its contents when dropped.  There is no reference counting within the
43/// context.
44///
45/// Items in the context are accessed with [`Node`] keys, which are simple
46/// handles into an internal map.  Inside the context, operations are
47/// represented with the [`Op`] type.
48#[derive(Debug, Default)]
49pub struct Context {
50    ops: IndexMap<Op, Node>,
51}
52
53impl Context {
54    /// Build a new empty context
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Clears the context
60    ///
61    /// All [`Node`] handles from this context are invalidated.
62    ///
63    /// ```
64    /// # use fidget_core::context::Context;
65    /// let mut ctx = Context::new();
66    /// let x = ctx.x();
67    /// ctx.clear();
68    /// assert!(ctx.eval_xyz(x, 1.0, 0.0, 0.0).is_err());
69    /// ```
70    pub fn clear(&mut self) {
71        self.ops.clear();
72    }
73
74    /// Returns the number of [`Op`] nodes in the context
75    ///
76    /// ```
77    /// # use fidget_core::context::Context;
78    /// let mut ctx = Context::new();
79    /// let x = ctx.x();
80    /// assert_eq!(ctx.len(), 1);
81    /// let y = ctx.y();
82    /// assert_eq!(ctx.len(), 2);
83    /// ctx.clear();
84    /// assert_eq!(ctx.len(), 0);
85    /// ```
86    pub fn len(&self) -> usize {
87        self.ops.len()
88    }
89
90    /// Checks whether the context is empty
91    pub fn is_empty(&self) -> bool {
92        self.ops.is_empty()
93    }
94
95    /// Checks whether the given [`Node`] is valid in this context
96    fn check_node(&self, node: Node) -> Result<(), Error> {
97        self.get_op(node).ok_or(Error::BadNode).map(|_| ())
98    }
99
100    /// Erases the most recently added node from the tree.
101    ///
102    /// A few caveats apply, so this must be used with caution:
103    /// - Existing handles to the node will be invalidated
104    /// - The most recently added node must be unique
105    ///
106    /// In practice, this is only used to delete temporary operation nodes
107    /// during constant folding.  Such nodes which have no handles (because
108    /// they are never returned) and are guaranteed to be unique (because we
109    /// never store them persistently).
110    fn pop(&mut self) -> Result<(), Error> {
111        self.ops.pop().map(|_| ())
112    }
113
114    /// Looks up the constant associated with the given node.
115    ///
116    /// If the node is invalid for this tree, returns an error; if the node is
117    /// not a constant, returns `Ok(None)`.
118    pub fn get_const(&self, n: Node) -> Result<f64, Error> {
119        match self.get_op(n) {
120            Some(Op::Const(c)) => Ok(c.0),
121            Some(_) => Err(Error::NotAConst),
122            _ => Err(Error::BadNode),
123        }
124    }
125
126    /// Looks up the [`Var`] associated with the given node.
127    ///
128    /// If the node is invalid for this tree or not an `Op::Input`, returns an
129    /// error.
130    pub fn get_var(&self, n: Node) -> Result<Var, Error> {
131        match self.get_op(n) {
132            Some(Op::Input(v)) => Ok(*v),
133            Some(..) => Err(Error::NotAVar),
134            _ => Err(Error::BadNode),
135        }
136    }
137
138    ////////////////////////////////////////////////////////////////////////////
139    // Primitives
140    /// Constructs or finds a [`Var::X`] node
141    /// ```
142    /// # use fidget_core::context::Context;
143    /// let mut ctx = Context::new();
144    /// let x = ctx.x();
145    /// let v = ctx.eval_xyz(x, 1.0, 0.0, 0.0).unwrap();
146    /// assert_eq!(v, 1.0);
147    /// ```
148    pub fn x(&mut self) -> Node {
149        self.var(Var::X)
150    }
151
152    /// Constructs or finds a [`Var::Y`] node
153    pub fn y(&mut self) -> Node {
154        self.var(Var::Y)
155    }
156
157    /// Constructs or finds a [`Var::Z`] node
158    pub fn z(&mut self) -> Node {
159        self.var(Var::Z)
160    }
161
162    /// Constructs or finds a variable input node
163    ///
164    /// To make an anonymous variable, call this function with [`Var::new()`]:
165    ///
166    /// ```
167    /// # use fidget_core::{context::Context, var::Var};
168    /// # use std::collections::HashMap;
169    /// let mut ctx = Context::new();
170    /// let v1 = ctx.var(Var::new());
171    /// let v2 = ctx.var(Var::new());
172    /// assert_ne!(v1, v2);
173    ///
174    /// let mut vars = HashMap::new();
175    /// vars.insert(ctx.get_var(v1).unwrap(), 3.0);
176    /// assert_eq!(ctx.eval(v1, &vars).unwrap(), 3.0);
177    /// assert!(ctx.eval(v2, &vars).is_err()); // v2 isn't in the map
178    /// ```
179    pub fn var(&mut self, v: Var) -> Node {
180        self.ops.insert(Op::Input(v))
181    }
182
183    /// Returns a 3-element array of `X`, `Y`, `Z` nodes
184    pub fn axes(&mut self) -> [Node; 3] {
185        [self.x(), self.y(), self.z()]
186    }
187
188    /// Returns a node representing the given constant value.
189    /// ```
190    /// # let mut ctx = fidget_core::context::Context::new();
191    /// let v = ctx.constant(3.0);
192    /// assert_eq!(ctx.eval_xyz(v, 0.0, 0.0, 0.0).unwrap(), 3.0);
193    /// ```
194    pub fn constant(&mut self, f: f64) -> Node {
195        self.ops.insert(Op::Const(OrderedFloat(f)))
196    }
197
198    ////////////////////////////////////////////////////////////////////////////
199    // Helper functions to create nodes with constant folding
200    /// Find or create a [Node] for the given unary operation, with constant
201    /// folding.
202    fn op_unary(&mut self, a: Node, op: UnaryOpcode) -> Result<Node, Error> {
203        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
204        let n = self.ops.insert(Op::Unary(op, a));
205        let out = if matches!(op_a, Op::Const(_)) {
206            let v = self.eval(n, &Default::default())?;
207            self.pop().unwrap(); // removes `n`
208            self.constant(v)
209        } else {
210            n
211        };
212        Ok(out)
213    }
214    /// Find or create a [Node] for the given binary operation, with constant
215    /// folding.
216    fn op_binary(
217        &mut self,
218        a: Node,
219        b: Node,
220        op: BinaryOpcode,
221    ) -> Result<Node, Error> {
222        self.op_binary_f(a, b, |lhs, rhs| Op::Binary(op, lhs, rhs))
223    }
224
225    /// Find or create a [Node] for a generic binary operation (represented by a
226    /// thunk), with constant folding.
227    fn op_binary_f<F>(&mut self, a: Node, b: Node, f: F) -> Result<Node, Error>
228    where
229        F: Fn(Node, Node) -> Op,
230    {
231        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
232        let op_b = *self.get_op(b).ok_or(Error::BadNode)?;
233
234        // This call to `insert` should always insert the node, because we
235        // don't permanently store operations in the tree that could be
236        // constant-folded (indeed, we pop the node right afterwards)
237        let n = self.ops.insert(f(a, b));
238        let out = if matches!((op_a, op_b), (Op::Const(_), Op::Const(_))) {
239            let v = self.eval(n, &Default::default())?;
240            self.pop().unwrap(); // removes `n`
241            self.constant(v)
242        } else {
243            n
244        };
245        Ok(out)
246    }
247
248    /// Find or create a [Node] for the given commutative operation, with
249    /// constant folding; deduplication is encouraged by sorting `a` and `b`.
250    fn op_binary_commutative(
251        &mut self,
252        a: Node,
253        b: Node,
254        op: BinaryOpcode,
255    ) -> Result<Node, Error> {
256        self.op_binary(a.min(b), a.max(b), op)
257    }
258
259    /// Builds an addition node
260    /// ```
261    /// # let mut ctx = fidget_core::context::Context::new();
262    /// let x = ctx.x();
263    /// let op = ctx.add(x, 1.0).unwrap();
264    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
265    /// assert_eq!(v, 2.0);
266    /// ```
267    pub fn add<A: IntoNode, B: IntoNode>(
268        &mut self,
269        a: A,
270        b: B,
271    ) -> Result<Node, Error> {
272        let a: Node = a.into_node(self)?;
273        let b: Node = b.into_node(self)?;
274        if a == b {
275            let two = self.constant(2.0);
276            self.mul(a, two)
277        } else {
278            match (self.get_const(a), self.get_const(b)) {
279                (Ok(0.0), _) => Ok(b),
280                (_, Ok(0.0)) => Ok(a),
281                _ => self.op_binary_commutative(a, b, BinaryOpcode::Add),
282            }
283        }
284    }
285
286    /// Builds an multiplication node
287    /// ```
288    /// # let mut ctx = fidget_core::context::Context::new();
289    /// let x = ctx.x();
290    /// let op = ctx.mul(x, 5.0).unwrap();
291    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
292    /// assert_eq!(v, 10.0);
293    /// ```
294    pub fn mul<A: IntoNode, B: IntoNode>(
295        &mut self,
296        a: A,
297        b: B,
298    ) -> Result<Node, Error> {
299        let a = a.into_node(self)?;
300        let b = b.into_node(self)?;
301        if a == b {
302            self.square(a)
303        } else {
304            match (self.get_const(a), self.get_const(b)) {
305                (Ok(1.0), _) => Ok(b),
306                (_, Ok(1.0)) => Ok(a),
307                (Ok(0.0), _) => Ok(a),
308                (_, Ok(0.0)) => Ok(b),
309                _ => self.op_binary_commutative(a, b, BinaryOpcode::Mul),
310            }
311        }
312    }
313
314    /// Builds an `min` node
315    /// ```
316    /// # let mut ctx = fidget_core::context::Context::new();
317    /// let x = ctx.x();
318    /// let op = ctx.min(x, 5.0).unwrap();
319    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
320    /// assert_eq!(v, 2.0);
321    /// ```
322    pub fn min<A: IntoNode, B: IntoNode>(
323        &mut self,
324        a: A,
325        b: B,
326    ) -> Result<Node, Error> {
327        let a = a.into_node(self)?;
328        let b = b.into_node(self)?;
329        if a == b {
330            Ok(a)
331        } else {
332            self.op_binary_commutative(a, b, BinaryOpcode::Min)
333        }
334    }
335    /// Builds an `max` node
336    /// ```
337    /// # let mut ctx = fidget_core::context::Context::new();
338    /// let x = ctx.x();
339    /// let op = ctx.max(x, 5.0).unwrap();
340    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
341    /// assert_eq!(v, 5.0);
342    /// ```
343    pub fn max<A: IntoNode, B: IntoNode>(
344        &mut self,
345        a: A,
346        b: B,
347    ) -> Result<Node, Error> {
348        let a = a.into_node(self)?;
349        let b = b.into_node(self)?;
350        if a == b {
351            Ok(a)
352        } else {
353            self.op_binary_commutative(a, b, BinaryOpcode::Max)
354        }
355    }
356
357    /// Builds an `and` node
358    ///
359    /// If both arguments are non-zero, returns the right-hand argument.
360    /// Otherwise, returns zero.
361    ///
362    /// This node can be simplified using a tracing evaluator:
363    /// - If the left-hand argument is zero, simplify to just that argument
364    /// - If the left-hand argument is non-zero, simplify to the other argument
365    /// ```
366    /// # let mut ctx = fidget_core::context::Context::new();
367    /// let x = ctx.x();
368    /// let y = ctx.y();
369    /// let op = ctx.and(x, y).unwrap();
370    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
371    /// assert_eq!(v, 0.0);
372    /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
373    /// assert_eq!(v, 1.0);
374    /// let v = ctx.eval_xyz(op, 1.0, 2.0, 0.0).unwrap();
375    /// assert_eq!(v, 2.0);
376    /// ```
377    pub fn and<A: IntoNode, B: IntoNode>(
378        &mut self,
379        a: A,
380        b: B,
381    ) -> Result<Node, Error> {
382        let a = a.into_node(self)?;
383        let b = b.into_node(self)?;
384
385        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
386        if let Op::Const(v) = op_a {
387            if v.0 == 0.0 { Ok(a) } else { Ok(b) }
388        } else {
389            self.op_binary(a, b, BinaryOpcode::And)
390        }
391    }
392
393    /// Builds an `or` node
394    ///
395    /// If the left-hand argument is non-zero, it is returned.  Otherwise, the
396    /// right-hand argument is returned.
397    ///
398    /// This node can be simplified using a tracing evaluator.
399    /// ```
400    /// # let mut ctx = fidget_core::context::Context::new();
401    /// let x = ctx.x();
402    /// let y = ctx.y();
403    /// let op = ctx.or(x, y).unwrap();
404    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
405    /// assert_eq!(v, 1.0);
406    /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
407    /// assert_eq!(v, 0.0);
408    /// let v = ctx.eval_xyz(op, 0.0, 3.0, 0.0).unwrap();
409    /// assert_eq!(v, 3.0);
410    /// ```
411    pub fn or<A: IntoNode, B: IntoNode>(
412        &mut self,
413        a: A,
414        b: B,
415    ) -> Result<Node, Error> {
416        let a = a.into_node(self)?;
417        let b = b.into_node(self)?;
418
419        let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
420        let op_b = *self.get_op(b).ok_or(Error::BadNode)?;
421        if let Op::Const(v) = op_a {
422            if v.0 != 0.0 {
423                return Ok(a);
424            } else {
425                return Ok(b);
426            }
427        } else if let Op::Const(v) = op_b
428            && v.0 == 0.0
429        {
430            return Ok(a);
431        }
432        self.op_binary(a, b, BinaryOpcode::Or)
433    }
434
435    /// Builds a logical negation node
436    ///
437    /// The output is 1 if the argument is 0, and 0 otherwise.
438    pub fn not<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
439        let a = a.into_node(self)?;
440        self.op_unary(a, UnaryOpcode::Not)
441    }
442
443    /// Builds a unary negation node
444    /// ```
445    /// # let mut ctx = fidget_core::context::Context::new();
446    /// let x = ctx.x();
447    /// let op = ctx.neg(x).unwrap();
448    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
449    /// assert_eq!(v, -2.0);
450    /// ```
451    pub fn neg<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
452        let a = a.into_node(self)?;
453        self.op_unary(a, UnaryOpcode::Neg)
454    }
455
456    /// Builds a reciprocal node
457    /// ```
458    /// # let mut ctx = fidget_core::context::Context::new();
459    /// let x = ctx.x();
460    /// let op = ctx.recip(x).unwrap();
461    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
462    /// assert_eq!(v, 0.5);
463    /// ```
464    pub fn recip<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
465        let a = a.into_node(self)?;
466        self.op_unary(a, UnaryOpcode::Recip)
467    }
468
469    /// Builds a node which calculates the absolute value of its input
470    /// ```
471    /// # let mut ctx = fidget_core::context::Context::new();
472    /// let x = ctx.x();
473    /// let op = ctx.abs(x).unwrap();
474    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
475    /// assert_eq!(v, 2.0);
476    /// let v = ctx.eval_xyz(op, -2.0, 0.0, 0.0).unwrap();
477    /// assert_eq!(v, 2.0);
478    /// ```
479    pub fn abs<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
480        let a = a.into_node(self)?;
481        self.op_unary(a, UnaryOpcode::Abs)
482    }
483
484    /// Builds a node which calculates the square root of its input
485    /// ```
486    /// # let mut ctx = fidget_core::context::Context::new();
487    /// let x = ctx.x();
488    /// let op = ctx.sqrt(x).unwrap();
489    /// let v = ctx.eval_xyz(op, 4.0, 0.0, 0.0).unwrap();
490    /// assert_eq!(v, 2.0);
491    /// ```
492    pub fn sqrt<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
493        let a = a.into_node(self)?;
494        self.op_unary(a, UnaryOpcode::Sqrt)
495    }
496
497    /// Builds a node which calculates the sine of its input (in radians)
498    /// ```
499    /// # let mut ctx = fidget_core::context::Context::new();
500    /// let x = ctx.x();
501    /// let op = ctx.sin(x).unwrap();
502    /// let v = ctx.eval_xyz(op, std::f64::consts::PI / 2.0, 0.0, 0.0).unwrap();
503    /// assert_eq!(v, 1.0);
504    /// ```
505    pub fn sin<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
506        let a = a.into_node(self)?;
507        self.op_unary(a, UnaryOpcode::Sin)
508    }
509
510    /// Builds a node which calculates the cosine of its input (in radians)
511    pub fn cos<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
512        let a = a.into_node(self)?;
513        self.op_unary(a, UnaryOpcode::Cos)
514    }
515
516    /// Builds a node which calculates the tangent of its input (in radians)
517    pub fn tan<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
518        let a = a.into_node(self)?;
519        self.op_unary(a, UnaryOpcode::Tan)
520    }
521
522    /// Builds a node which calculates the arcsine of its input (in radians)
523    pub fn asin<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
524        let a = a.into_node(self)?;
525        self.op_unary(a, UnaryOpcode::Asin)
526    }
527
528    /// Builds a node which calculates the arccosine of its input (in radians)
529    pub fn acos<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
530        let a = a.into_node(self)?;
531        self.op_unary(a, UnaryOpcode::Acos)
532    }
533
534    /// Builds a node which calculates the arctangent of its input (in radians)
535    pub fn atan<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
536        let a = a.into_node(self)?;
537        self.op_unary(a, UnaryOpcode::Atan)
538    }
539
540    /// Builds a node which calculates the exponent of its input
541    pub fn exp<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
542        let a = a.into_node(self)?;
543        self.op_unary(a, UnaryOpcode::Exp)
544    }
545
546    /// Builds a node which calculates the natural log of its input
547    pub fn ln<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
548        let a = a.into_node(self)?;
549        self.op_unary(a, UnaryOpcode::Ln)
550    }
551
552    ////////////////////////////////////////////////////////////////////////////
553    // Derived functions
554    /// Builds a node which squares its input
555    /// ```
556    /// # let mut ctx = fidget_core::context::Context::new();
557    /// let x = ctx.x();
558    /// let op = ctx.square(x).unwrap();
559    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
560    /// assert_eq!(v, 4.0);
561    /// ```
562    pub fn square<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
563        let a = a.into_node(self)?;
564        self.op_unary(a, UnaryOpcode::Square)
565    }
566
567    /// Builds a node which takes the floor of its input
568    /// ```
569    /// # let mut ctx = fidget_core::context::Context::new();
570    /// let x = ctx.x();
571    /// let op = ctx.floor(x).unwrap();
572    /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
573    /// assert_eq!(v, 1.0);
574    /// ```
575    pub fn floor<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
576        let a = a.into_node(self)?;
577        self.op_unary(a, UnaryOpcode::Floor)
578    }
579
580    /// Builds a node which takes the ceiling of its input
581    /// ```
582    /// # let mut ctx = fidget_core::context::Context::new();
583    /// let x = ctx.x();
584    /// let op = ctx.ceil(x).unwrap();
585    /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
586    /// assert_eq!(v, 2.0);
587    /// ```
588    pub fn ceil<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
589        let a = a.into_node(self)?;
590        self.op_unary(a, UnaryOpcode::Ceil)
591    }
592
593    /// Builds a node which rounds its input to the nearest integer
594    /// ```
595    /// # let mut ctx = fidget_core::context::Context::new();
596    /// let x = ctx.x();
597    /// let op = ctx.round(x).unwrap();
598    /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
599    /// assert_eq!(v, 1.0);
600    /// let v = ctx.eval_xyz(op, 1.6, 0.0, 0.0).unwrap();
601    /// assert_eq!(v, 2.0);
602    /// let v = ctx.eval_xyz(op, 1.5, 0.0, 0.0).unwrap();
603    /// assert_eq!(v, 2.0); // rounds away from 0.0 if ambiguous
604    /// ```
605    pub fn round<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
606        let a = a.into_node(self)?;
607        self.op_unary(a, UnaryOpcode::Round)
608    }
609
610    /// Builds a node which performs subtraction.
611    /// ```
612    /// # let mut ctx = fidget_core::context::Context::new();
613    /// let x = ctx.x();
614    /// let y = ctx.y();
615    /// let op = ctx.sub(x, y).unwrap();
616    /// let v = ctx.eval_xyz(op, 3.0, 2.0, 0.0).unwrap();
617    /// assert_eq!(v, 1.0);
618    /// ```
619    pub fn sub<A: IntoNode, B: IntoNode>(
620        &mut self,
621        a: A,
622        b: B,
623    ) -> Result<Node, Error> {
624        let a = a.into_node(self)?;
625        let b = b.into_node(self)?;
626
627        match (self.get_const(a), self.get_const(b)) {
628            (Ok(0.0), _) => self.neg(b),
629            (_, Ok(0.0)) => Ok(a),
630            _ => self.op_binary(a, b, BinaryOpcode::Sub),
631        }
632    }
633
634    /// Builds a node which performs division.
635    /// ```
636    /// # let mut ctx = fidget_core::context::Context::new();
637    /// let x = ctx.x();
638    /// let y = ctx.y();
639    /// let op = ctx.div(x, y).unwrap();
640    /// let v = ctx.eval_xyz(op, 3.0, 2.0, 0.0).unwrap();
641    /// assert_eq!(v, 1.5);
642    /// ```
643    pub fn div<A: IntoNode, B: IntoNode>(
644        &mut self,
645        a: A,
646        b: B,
647    ) -> Result<Node, Error> {
648        let a = a.into_node(self)?;
649        let b = b.into_node(self)?;
650
651        match (self.get_const(a), self.get_const(b)) {
652            (Ok(0.0), _) => Ok(a),
653            (_, Ok(1.0)) => Ok(a),
654            _ => self.op_binary(a, b, BinaryOpcode::Div),
655        }
656    }
657
658    /// Builds a node which computes `atan2(y, x)`
659    /// ```
660    /// # let mut ctx = fidget_core::context::Context::new();
661    /// let x = ctx.x();
662    /// let y = ctx.y();
663    /// let op = ctx.atan2(y, x).unwrap();
664    /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
665    /// assert_eq!(v, std::f64::consts::FRAC_PI_2);
666    /// ```
667    pub fn atan2<A: IntoNode, B: IntoNode>(
668        &mut self,
669        y: A,
670        x: B,
671    ) -> Result<Node, Error> {
672        let y = y.into_node(self)?;
673        let x = x.into_node(self)?;
674
675        self.op_binary(y, x, BinaryOpcode::Atan)
676    }
677
678    /// Builds a node that compares two values
679    ///
680    /// The result is -1 if `a < b`, +1 if `a > b`, 0 if `a == b`, and `NaN` if
681    /// either side is `NaN`.
682    /// ```
683    /// # let mut ctx = fidget_core::context::Context::new();
684    /// let x = ctx.x();
685    /// let op = ctx.compare(x, 1.0).unwrap();
686    /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
687    /// assert_eq!(v, -1.0);
688    /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
689    /// assert_eq!(v, 1.0);
690    /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
691    /// assert_eq!(v, 0.0);
692    /// ```
693    pub fn compare<A: IntoNode, B: IntoNode>(
694        &mut self,
695        a: A,
696        b: B,
697    ) -> Result<Node, Error> {
698        let a = a.into_node(self)?;
699        let b = b.into_node(self)?;
700        self.op_binary(a, b, BinaryOpcode::Compare)
701    }
702
703    /// Builds a node that is 1 if `lhs < rhs` and 0 otherwise
704    ///
705    /// ```
706    /// # let mut ctx = fidget_core::context::Context::new();
707    /// let x = ctx.x();
708    /// let y = ctx.y();
709    /// let op = ctx.less_than(x, y).unwrap();
710    /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
711    /// assert_eq!(v, 1.0);
712    /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
713    /// assert_eq!(v, 0.0);
714    /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
715    /// assert_eq!(v, 0.0);
716    /// ```
717    pub fn less_than<A: IntoNode, B: IntoNode>(
718        &mut self,
719        lhs: A,
720        rhs: B,
721    ) -> Result<Node, Error> {
722        let lhs = lhs.into_node(self)?;
723        let rhs = rhs.into_node(self)?;
724        let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
725        self.max(cmp, 0.0)
726    }
727
728    /// Builds a node that is 1 if `lhs <= rhs` and 0 otherwise
729    ///
730    /// ```
731    /// # let mut ctx = fidget_core::context::Context::new();
732    /// let x = ctx.x();
733    /// let y = ctx.y();
734    /// let op = ctx.less_than_or_equal(x, y).unwrap();
735    /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
736    /// assert_eq!(v, 1.0);
737    /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
738    /// assert_eq!(v, 1.0);
739    /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
740    /// assert_eq!(v, 0.0);
741    /// ```
742    pub fn less_than_or_equal<A: IntoNode, B: IntoNode>(
743        &mut self,
744        lhs: A,
745        rhs: B,
746    ) -> Result<Node, Error> {
747        let lhs = lhs.into_node(self)?;
748        let rhs = rhs.into_node(self)?;
749        let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
750        let shift = self.add(cmp, 1.0)?;
751        self.min(shift, 1.0)
752    }
753
754    /// Builds a node that takes the modulo (least non-negative remainder)
755    pub fn modulo<A: IntoNode, B: IntoNode>(
756        &mut self,
757        a: A,
758        b: B,
759    ) -> Result<Node, Error> {
760        let a = a.into_node(self)?;
761        let b = b.into_node(self)?;
762        self.op_binary(a, b, BinaryOpcode::Mod)
763    }
764
765    /// Builds a node that returns the first node if the condition is not
766    /// equal to zero, else returns the other node
767    ///
768    /// The result is `a` if `condition != 0`, else the result is `b`.
769    /// ```
770    /// # let mut ctx = fidget_core::context::Context::new();
771    /// let x = ctx.x();
772    /// let y = ctx.y();
773    /// let z = ctx.z();
774    ///
775    /// let if_else = ctx.if_nonzero_else(x, y, z).unwrap();
776    ///
777    /// assert_eq!(ctx.eval_xyz(if_else, 0.0, 2.0, 3.0).unwrap(), 3.0);
778    /// assert_eq!(ctx.eval_xyz(if_else, 1.0, 2.0, 3.0).unwrap(), 2.0);
779    /// assert_eq!(ctx.eval_xyz(if_else, 0.0, f64::NAN, 3.0).unwrap(), 3.0);
780    /// assert_eq!(ctx.eval_xyz(if_else, 1.0, 2.0, f64::NAN).unwrap(), 2.0);
781    /// ```
782    pub fn if_nonzero_else<Condition: IntoNode, A: IntoNode, B: IntoNode>(
783        &mut self,
784        condition: Condition,
785        a: A,
786        b: B,
787    ) -> Result<Node, Error> {
788        let condition = condition.into_node(self)?;
789        let a = a.into_node(self)?;
790        let b = b.into_node(self)?;
791
792        let lhs = self.and(condition, a)?;
793        let n_condition = self.not(condition)?;
794        let rhs = self.and(n_condition, b)?;
795        self.or(lhs, rhs)
796    }
797
798    ////////////////////////////////////////////////////////////////////////////
799    /// Evaluates the given node with the provided values for X, Y, and Z.
800    ///
801    /// This is extremely inefficient; consider converting the node into a
802    /// [`Shape`](crate::shape::Shape) and using its evaluators instead.
803    ///
804    /// ```
805    /// # let mut ctx = fidget_core::context::Context::new();
806    /// let x = ctx.x();
807    /// let y = ctx.y();
808    /// let z = ctx.z();
809    /// let op = ctx.mul(x, y).unwrap();
810    /// let op = ctx.div(op, z).unwrap();
811    /// let v = ctx.eval_xyz(op, 3.0, 5.0, 2.0).unwrap();
812    /// assert_eq!(v, 7.5); // (3.0 * 5.0) / 2.0
813    /// ```
814    pub fn eval_xyz(
815        &self,
816        root: Node,
817        x: f64,
818        y: f64,
819        z: f64,
820    ) -> Result<f64, Error> {
821        let vars = [(Var::X, x), (Var::Y, y), (Var::Z, z)]
822            .into_iter()
823            .collect();
824        self.eval(root, &vars)
825    }
826
827    /// Evaluates the given node with a generic set of variables
828    ///
829    /// This is extremely inefficient; consider converting the node into a
830    /// [`Shape`](crate::shape::Shape) and using its evaluators instead.
831    pub fn eval(
832        &self,
833        root: Node,
834        vars: &HashMap<Var, f64>,
835    ) -> Result<f64, Error> {
836        let mut cache = vec![None; self.ops.len()].into();
837        self.eval_inner(root, vars, &mut cache)
838    }
839
840    fn eval_inner(
841        &self,
842        node: Node,
843        vars: &HashMap<Var, f64>,
844        cache: &mut IndexVec<Option<f64>, Node>,
845    ) -> Result<f64, Error> {
846        if node.0 >= cache.len() {
847            return Err(Error::BadNode);
848        }
849        if let Some(v) = cache[node] {
850            return Ok(v);
851        }
852        let mut get = |n: Node| self.eval_inner(n, vars, cache);
853        let v = match self.get_op(node).ok_or(Error::BadNode)? {
854            Op::Input(v) => *vars.get(v).ok_or(Error::MissingVar(*v))?,
855            Op::Const(c) => c.0,
856
857            Op::Binary(op, a, b) => {
858                let a = get(*a)?;
859                let b = get(*b)?;
860                match op {
861                    BinaryOpcode::Add => a + b,
862                    BinaryOpcode::Sub => a - b,
863                    BinaryOpcode::Mul => a * b,
864                    BinaryOpcode::Div => a / b,
865                    BinaryOpcode::Atan => a.atan2(b),
866                    BinaryOpcode::Min => a.min(b),
867                    BinaryOpcode::Max => a.max(b),
868                    BinaryOpcode::Compare => a
869                        .partial_cmp(&b)
870                        .map(|i| i as i8 as f64)
871                        .unwrap_or(f64::NAN),
872                    BinaryOpcode::Mod => a.rem_euclid(b),
873                    BinaryOpcode::And => {
874                        if a == 0.0 {
875                            a
876                        } else {
877                            b
878                        }
879                    }
880                    BinaryOpcode::Or => {
881                        if a != 0.0 {
882                            a
883                        } else {
884                            b
885                        }
886                    }
887                }
888            }
889
890            // Unary operations
891            Op::Unary(op, a) => {
892                let a = get(*a)?;
893                match op {
894                    UnaryOpcode::Neg => -a,
895                    UnaryOpcode::Abs => a.abs(),
896                    UnaryOpcode::Recip => 1.0 / a,
897                    UnaryOpcode::Sqrt => a.sqrt(),
898                    UnaryOpcode::Square => a * a,
899                    UnaryOpcode::Floor => a.floor(),
900                    UnaryOpcode::Ceil => a.ceil(),
901                    UnaryOpcode::Round => a.round(),
902                    UnaryOpcode::Sin => a.sin(),
903                    UnaryOpcode::Cos => a.cos(),
904                    UnaryOpcode::Tan => a.tan(),
905                    UnaryOpcode::Asin => a.asin(),
906                    UnaryOpcode::Acos => a.acos(),
907                    UnaryOpcode::Atan => a.atan(),
908                    UnaryOpcode::Exp => a.exp(),
909                    UnaryOpcode::Ln => a.ln(),
910                    UnaryOpcode::Not => (a == 0.0).into(),
911                }
912            }
913        };
914
915        cache[node] = Some(v);
916        Ok(v)
917    }
918
919    /// Parses a flat text representation of a math tree. For example, the
920    /// circle `(- (+ (square x) (square y)) 1)` can be parsed from
921    /// ```
922    /// # use fidget_core::context::Context;
923    /// let txt = "
924    /// ## This is a comment!
925    /// 0x600000b90000 var-x
926    /// 0x600000b900a0 square 0x600000b90000
927    /// 0x600000b90050 var-y
928    /// 0x600000b900f0 square 0x600000b90050
929    /// 0x600000b90140 add 0x600000b900a0 0x600000b900f0
930    /// 0x600000b90190 sqrt 0x600000b90140
931    /// 0x600000b901e0 const 1
932    /// ";
933    /// let (ctx, _node) = Context::from_text(&mut txt.as_bytes()).unwrap();
934    /// assert_eq!(ctx.len(), 7);
935    /// ```
936    ///
937    /// This representation is loosely defined and only intended for use in
938    /// quick experiments.
939    pub fn from_text<R: Read>(r: R) -> Result<(Self, Node), Error> {
940        let reader = BufReader::new(r);
941        let mut ctx = Self::new();
942        let mut seen = BTreeMap::new();
943        let mut last = None;
944
945        for line in reader.lines().map(|line| line.unwrap()) {
946            if line.is_empty() || line.starts_with('#') {
947                continue;
948            }
949            let mut iter = line.split_whitespace();
950            let i: String = iter.next().unwrap().to_owned();
951            let opcode = iter.next().unwrap();
952
953            let mut pop = || {
954                let txt = iter.next().unwrap();
955                seen.get(txt)
956                    .cloned()
957                    .ok_or_else(|| Error::UnknownVariable(txt.to_string()))
958            };
959            let node = match opcode {
960                "const" => ctx.constant(iter.next().unwrap().parse().unwrap()),
961                "var-x" => ctx.x(),
962                "var-y" => ctx.y(),
963                "var-z" => ctx.z(),
964                "abs" => ctx.abs(pop()?)?,
965                "neg" => ctx.neg(pop()?)?,
966                "sqrt" => ctx.sqrt(pop()?)?,
967                "square" => ctx.square(pop()?)?,
968                "floor" => ctx.floor(pop()?)?,
969                "ceil" => ctx.ceil(pop()?)?,
970                "round" => ctx.round(pop()?)?,
971                "sin" => ctx.sin(pop()?)?,
972                "cos" => ctx.cos(pop()?)?,
973                "tan" => ctx.tan(pop()?)?,
974                "asin" => ctx.asin(pop()?)?,
975                "acos" => ctx.acos(pop()?)?,
976                "atan" => ctx.atan(pop()?)?,
977                "ln" => ctx.ln(pop()?)?,
978                "not" => ctx.not(pop()?)?,
979                "exp" => ctx.exp(pop()?)?,
980                "add" => ctx.add(pop()?, pop()?)?,
981                "mul" => ctx.mul(pop()?, pop()?)?,
982                "min" => ctx.min(pop()?, pop()?)?,
983                "max" => ctx.max(pop()?, pop()?)?,
984                "div" => ctx.div(pop()?, pop()?)?,
985                "atan2" => ctx.atan2(pop()?, pop()?)?,
986                "sub" => ctx.sub(pop()?, pop()?)?,
987                "compare" => ctx.compare(pop()?, pop()?)?,
988                "mod" => ctx.modulo(pop()?, pop()?)?,
989                "and" => ctx.and(pop()?, pop()?)?,
990                "or" => ctx.or(pop()?, pop()?)?,
991                op => return Err(Error::UnknownOpcode(op.to_owned())),
992            };
993            seen.insert(i, node);
994            last = Some(node);
995        }
996        match last {
997            Some(node) => Ok((ctx, node)),
998            None => Err(Error::EmptyFile),
999        }
1000    }
1001
1002    /// Converts the entire context into a GraphViz drawing
1003    pub fn dot(&self) -> String {
1004        let mut out = "digraph mygraph{\n".to_owned();
1005        for node in self.ops.keys() {
1006            let op = self.get_op(node).unwrap();
1007            out += &self.dot_node(node);
1008            out += &op.dot_edges(node);
1009        }
1010        out += "}\n";
1011        out
1012    }
1013
1014    /// Converts the given node into a GraphViz node
1015    ///
1016    /// (this is a local function instead of a function on `Op` because it
1017    ///  requires looking up variables by name)
1018    fn dot_node(&self, i: Node) -> String {
1019        let mut out = format!(r#"n{} [label = ""#, i.get());
1020        let op = self.get_op(i).unwrap();
1021        match op {
1022            Op::Const(c) => write!(out, "{c}").unwrap(),
1023            Op::Input(v) => {
1024                out += &v.to_string();
1025            }
1026            Op::Binary(op, ..) => match op {
1027                BinaryOpcode::Add => out += "add",
1028                BinaryOpcode::Sub => out += "sub",
1029                BinaryOpcode::Mul => out += "mul",
1030                BinaryOpcode::Div => out += "div",
1031                BinaryOpcode::Atan => out += "atan2",
1032                BinaryOpcode::Min => out += "min",
1033                BinaryOpcode::Max => out += "max",
1034                BinaryOpcode::Compare => out += "compare",
1035                BinaryOpcode::Mod => out += "mod",
1036                BinaryOpcode::And => out += "and",
1037                BinaryOpcode::Or => out += "or",
1038            },
1039            Op::Unary(op, ..) => match op {
1040                UnaryOpcode::Neg => out += "neg",
1041                UnaryOpcode::Abs => out += "abs",
1042                UnaryOpcode::Recip => out += "recip",
1043                UnaryOpcode::Sqrt => out += "sqrt",
1044                UnaryOpcode::Square => out += "square",
1045                UnaryOpcode::Floor => out += "floor",
1046                UnaryOpcode::Ceil => out += "ceil",
1047                UnaryOpcode::Round => out += "round",
1048                UnaryOpcode::Sin => out += "sin",
1049                UnaryOpcode::Cos => out += "cos",
1050                UnaryOpcode::Tan => out += "tan",
1051                UnaryOpcode::Asin => out += "asin",
1052                UnaryOpcode::Acos => out += "acos",
1053                UnaryOpcode::Atan => out += "atan",
1054                UnaryOpcode::Exp => out += "exp",
1055                UnaryOpcode::Ln => out += "ln",
1056                UnaryOpcode::Not => out += "not",
1057            },
1058        };
1059        write!(
1060            out,
1061            r#"" color="{0}1" shape="{1}" fontcolor="{0}4"]"#,
1062            op.dot_node_color(),
1063            op.dot_node_shape()
1064        )
1065        .unwrap();
1066        out
1067    }
1068
1069    /// Looks up an operation by `Node` handle
1070    pub fn get_op(&self, node: Node) -> Option<&Op> {
1071        self.ops.get_by_index(node)
1072    }
1073
1074    /// Imports the given tree, deduplicating and returning the root
1075    pub fn import(&mut self, tree: &Tree) -> Node {
1076        // A naive remapping implementation would use recursion.  A naive
1077        // remapping implementation would blow up the stack given any
1078        // significant tree size.
1079        //
1080        // Instead, we maintain our own pseudo-stack here in a pair of Vecs (one
1081        // stack for actions, and a second stack for return values).
1082        enum Action<'a> {
1083            /// Pushes `Up(op)` followed by `Down(c)` for each child
1084            Down(&'a Arc<TreeOp>),
1085            /// Consumes imported trees from the stack and pushes a new tree
1086            Up(&'a Arc<TreeOp>),
1087            /// Pops the latest axis frame
1088            Pop,
1089            /// Pops the latest affine frame
1090            PopAffine,
1091        }
1092        let mut axes = vec![(self.x(), self.y(), self.z())];
1093        let mut todo = vec![Action::Down(tree.arc())];
1094        let mut stack = vec![];
1095        let mut affine: Vec<Matrix4<f64>> = vec![];
1096
1097        // Cache of TreeOp -> Node mapping under a particular frame (axes)
1098        //
1099        // This isn't required for correctness, but can be a speed optimization
1100        // (because it means we don't have to walk the same tree twice).
1101        let mut seen = HashMap::new();
1102
1103        while let Some(t) = todo.pop() {
1104            match t {
1105                Action::Down(t) => {
1106                    // If we've already seen this TreeOp with these axes, then
1107                    // we can return the previous Node.
1108                    if matches!(
1109                        t.as_ref(),
1110                        TreeOp::Unary(..) | TreeOp::Binary(..)
1111                    ) && let Some(p) =
1112                        seen.get(&(*axes.last().unwrap(), Arc::as_ptr(t)))
1113                    {
1114                        stack.push(*p);
1115                        continue;
1116                    }
1117                    match t.as_ref() {
1118                        TreeOp::Const(c) => {
1119                            stack.push(self.constant(*c));
1120                        }
1121                        TreeOp::Input(s) => {
1122                            let axes = axes.last().unwrap();
1123                            stack.push(match *s {
1124                                Var::X => axes.0,
1125                                Var::Y => axes.1,
1126                                Var::Z => axes.2,
1127                                v @ Var::V(..) => self.var(v),
1128                            });
1129                        }
1130                        TreeOp::Unary(_op, arg) => {
1131                            todo.push(Action::Up(t));
1132                            todo.push(Action::Down(arg));
1133                        }
1134                        TreeOp::Binary(_op, lhs, rhs) => {
1135                            todo.push(Action::Up(t));
1136                            todo.push(Action::Down(lhs));
1137                            todo.push(Action::Down(rhs));
1138                        }
1139                        TreeOp::RemapAxes { target: _, x, y, z } => {
1140                            // Action::Up(t) does the remapping and target eval
1141                            todo.push(Action::Up(t));
1142                            todo.push(Action::Down(x));
1143                            todo.push(Action::Down(y));
1144                            todo.push(Action::Down(z));
1145                        }
1146                        TreeOp::RemapAffine { target, mat } => {
1147                            let prev = affine
1148                                .last()
1149                                .cloned()
1150                                .unwrap_or(Matrix4::identity());
1151                            let mat = prev * mat.to_homogeneous();
1152
1153                            // Push either an affine frame or an axis frame,
1154                            // depending on whether the target is also affine
1155                            if matches!(&**target, TreeOp::RemapAffine { .. }) {
1156                                affine.push(mat);
1157                                todo.push(Action::PopAffine);
1158                            } else {
1159                                let (x, y, z) = axes.last().unwrap();
1160                                let mut out = [None; 3];
1161                                for i in 0..3 {
1162                                    let a = self.mul(mat[(i, 0)], *x).unwrap();
1163                                    let b = self.mul(mat[(i, 1)], *y).unwrap();
1164                                    let c = self.mul(mat[(i, 2)], *z).unwrap();
1165                                    let d = self.constant(mat[(i, 3)]);
1166                                    let ab = self.add(a, b).unwrap();
1167                                    let cd = self.add(c, d).unwrap();
1168                                    out[i] = Some(self.add(ab, cd).unwrap());
1169                                }
1170                                let [x, y, z] = out.map(Option::unwrap);
1171                                axes.push((x, y, z));
1172                                todo.push(Action::Pop);
1173                            }
1174                            todo.push(Action::Down(target));
1175                        }
1176                    }
1177                }
1178                Action::Up(t) => {
1179                    match t.as_ref() {
1180                        TreeOp::Const(..)
1181                        | TreeOp::Input(..)
1182                        | TreeOp::RemapAffine { .. } => unreachable!(),
1183                        TreeOp::Unary(op, ..) => {
1184                            let arg = stack.pop().unwrap();
1185                            let out = self.op_unary(arg, *op).unwrap();
1186                            stack.push(out);
1187                        }
1188                        TreeOp::Binary(op, ..) => {
1189                            let lhs = stack.pop().unwrap();
1190                            let rhs = stack.pop().unwrap();
1191                            // Call individual builders to apply optimizations
1192                            let out = match op {
1193                                BinaryOpcode::Add => self.add(lhs, rhs),
1194                                BinaryOpcode::Sub => self.sub(lhs, rhs),
1195                                BinaryOpcode::Mul => self.mul(lhs, rhs),
1196                                BinaryOpcode::Div => self.div(lhs, rhs),
1197                                BinaryOpcode::Atan => self.atan2(lhs, rhs),
1198                                BinaryOpcode::Min => self.min(lhs, rhs),
1199                                BinaryOpcode::Max => self.max(lhs, rhs),
1200                                BinaryOpcode::Compare => self.compare(lhs, rhs),
1201                                BinaryOpcode::Mod => self.modulo(lhs, rhs),
1202                                BinaryOpcode::And => self.and(lhs, rhs),
1203                                BinaryOpcode::Or => self.or(lhs, rhs),
1204                            }
1205                            .unwrap();
1206                            if Arc::strong_count(t) > 1 {
1207                                seen.insert(
1208                                    (*axes.last().unwrap(), Arc::as_ptr(t)),
1209                                    out,
1210                                );
1211                            }
1212                            stack.push(out);
1213                        }
1214                        TreeOp::RemapAxes { target, .. } => {
1215                            let x = stack.pop().unwrap();
1216                            let y = stack.pop().unwrap();
1217                            let z = stack.pop().unwrap();
1218                            axes.push((x, y, z));
1219                            todo.push(Action::Pop);
1220                            todo.push(Action::Down(target));
1221                        }
1222                    }
1223                    // Update the cache with the new tree, if relevant
1224                    //
1225                    // The `strong_count` check is a rough heuristic to avoid
1226                    // caching if there's only a single owner of the tree.  This
1227                    // isn't perfect, but it doesn't need to be for correctness.
1228                    if matches!(
1229                        t.as_ref(),
1230                        TreeOp::Unary(..) | TreeOp::Binary(..)
1231                    ) && Arc::strong_count(t) > 1
1232                    {
1233                        seen.insert(
1234                            (*axes.last().unwrap(), Arc::as_ptr(t)),
1235                            *stack.last().unwrap(),
1236                        );
1237                    }
1238                }
1239                Action::Pop => {
1240                    axes.pop().unwrap();
1241                }
1242                Action::PopAffine => {
1243                    affine.pop().unwrap();
1244                }
1245            }
1246        }
1247        assert_eq!(stack.len(), 1);
1248        stack.pop().unwrap()
1249    }
1250
1251    /// Converts from a context-specific node into a standalone [`Tree`]
1252    pub fn export(&self, n: Node) -> Result<Tree, Error> {
1253        if self.get_op(n).is_none() {
1254            return Err(Error::BadNode);
1255        }
1256
1257        // Do recursion on the heap to avoid stack overflows for deep trees
1258        enum Action {
1259            /// Pushes `Up(n)` followed by `Down(n)` for each child
1260            Down(Node),
1261            /// Consumes trees from the stack and pushes a new tree
1262            Up(Node, Op),
1263        }
1264        let mut todo = vec![Action::Down(n)];
1265        let mut stack = vec![];
1266
1267        // Cache of Node -> Tree mapping, for Tree deduplication
1268        let mut seen: HashMap<Node, Tree> = HashMap::new();
1269
1270        while let Some(t) = todo.pop() {
1271            match t {
1272                Action::Down(n) => {
1273                    // If we've already seen this TreeOp with these axes, then
1274                    // we can return the previous Node.
1275                    if let Some(p) = seen.get(&n) {
1276                        stack.push(p.clone());
1277                        continue;
1278                    }
1279                    let op = self.get_op(n).unwrap();
1280                    match op {
1281                        Op::Const(c) => {
1282                            let t = Tree::from(c.0);
1283                            seen.insert(n, t.clone());
1284                            stack.push(t);
1285                        }
1286                        Op::Input(v) => {
1287                            let t = Tree::from(*v);
1288                            seen.insert(n, t.clone());
1289                            stack.push(t);
1290                        }
1291                        Op::Unary(_op, arg) => {
1292                            todo.push(Action::Up(n, *op));
1293                            todo.push(Action::Down(*arg));
1294                        }
1295                        Op::Binary(_op, lhs, rhs) => {
1296                            todo.push(Action::Up(n, *op));
1297                            todo.push(Action::Down(*lhs));
1298                            todo.push(Action::Down(*rhs));
1299                        }
1300                    }
1301                }
1302                Action::Up(n, op) => match op {
1303                    Op::Const(..) | Op::Input(..) => unreachable!(),
1304                    Op::Unary(op, ..) => {
1305                        let arg = stack.pop().unwrap();
1306                        let out =
1307                            Tree::from(TreeOp::Unary(op, arg.arc().clone()));
1308                        seen.insert(n, out.clone());
1309                        stack.push(out);
1310                    }
1311                    Op::Binary(op, ..) => {
1312                        let lhs = stack.pop().unwrap();
1313                        let rhs = stack.pop().unwrap();
1314                        let out = Tree::from(TreeOp::Binary(
1315                            op,
1316                            lhs.arc().clone(),
1317                            rhs.arc().clone(),
1318                        ));
1319                        seen.insert(n, out.clone());
1320                        stack.push(out);
1321                    }
1322                },
1323            }
1324        }
1325        assert_eq!(stack.len(), 1);
1326        Ok(stack.pop().unwrap())
1327    }
1328
1329    /// Takes the symbolic derivative of a node with respect to a variable
1330    pub fn deriv(&mut self, n: Node, v: Var) -> Result<Node, Error> {
1331        if self.get_op(n).is_none() {
1332            return Err(Error::BadNode);
1333        }
1334
1335        // Do recursion on the heap to avoid stack overflows for deep trees
1336        enum Action {
1337            /// Pushes `Up(n)` followed by `Down(n)` for each child
1338            Down(Node),
1339            /// Consumes trees from the stack and pushes a new tree
1340            Up(Node, Op),
1341        }
1342        let mut todo = vec![Action::Down(n)];
1343        let mut stack = vec![];
1344        let zero = self.constant(0.0);
1345
1346        // Cache of Node -> Node mapping, for deduplication
1347        let mut seen: HashMap<Node, Node> = HashMap::new();
1348
1349        while let Some(t) = todo.pop() {
1350            match t {
1351                Action::Down(n) => {
1352                    // If we've already seen this TreeOp with these axes, then
1353                    // we can return the previous Node.
1354                    if let Some(p) = seen.get(&n) {
1355                        stack.push(*p);
1356                        continue;
1357                    }
1358                    let op = *self.get_op(n).unwrap();
1359                    match op {
1360                        Op::Const(_c) => {
1361                            seen.insert(n, zero);
1362                            stack.push(zero);
1363                        }
1364                        Op::Input(u) => {
1365                            let z =
1366                                if v == u { self.constant(1.0) } else { zero };
1367                            seen.insert(n, z);
1368                            stack.push(z);
1369                        }
1370                        Op::Unary(_op, arg) => {
1371                            todo.push(Action::Up(n, op));
1372                            todo.push(Action::Down(arg));
1373                        }
1374                        Op::Binary(_op, lhs, rhs) => {
1375                            todo.push(Action::Up(n, op));
1376                            todo.push(Action::Down(lhs));
1377                            todo.push(Action::Down(rhs));
1378                        }
1379                    }
1380                }
1381                Action::Up(n, op) => match op {
1382                    Op::Const(..) | Op::Input(..) => unreachable!(),
1383                    Op::Unary(op, v_arg) => {
1384                        let d_arg = stack.pop().unwrap();
1385                        let out = match op {
1386                            UnaryOpcode::Neg => self.neg(d_arg),
1387                            UnaryOpcode::Abs => {
1388                                let cond = self.less_than(v_arg, zero).unwrap();
1389                                let pos = d_arg;
1390                                let neg = self.neg(d_arg).unwrap();
1391                                self.if_nonzero_else(cond, neg, pos)
1392                            }
1393                            UnaryOpcode::Recip => {
1394                                let a = self.square(v_arg).unwrap();
1395                                let b = self.neg(d_arg).unwrap();
1396                                self.div(b, a)
1397                            }
1398                            UnaryOpcode::Sqrt => {
1399                                let v = self.mul(n, 2.0).unwrap();
1400                                self.div(d_arg, v)
1401                            }
1402                            UnaryOpcode::Square => {
1403                                let v = self.mul(d_arg, v_arg).unwrap();
1404                                self.mul(2.0, v)
1405                            }
1406                            // Discontinuous constants don't have Dirac deltas
1407                            UnaryOpcode::Floor
1408                            | UnaryOpcode::Ceil
1409                            | UnaryOpcode::Round => Ok(zero),
1410
1411                            UnaryOpcode::Sin => {
1412                                let c = self.cos(v_arg).unwrap();
1413                                self.mul(c, d_arg)
1414                            }
1415
1416                            UnaryOpcode::Cos => {
1417                                let s = self.sin(v_arg).unwrap();
1418                                let s = self.neg(s).unwrap();
1419                                self.mul(s, d_arg)
1420                            }
1421
1422                            UnaryOpcode::Tan => {
1423                                let c = self.cos(v_arg).unwrap();
1424                                let c = self.square(c).unwrap();
1425                                self.div(d_arg, c)
1426                            }
1427
1428                            UnaryOpcode::Asin => {
1429                                let v = self.square(v_arg).unwrap();
1430                                let v = self.sub(1.0, v).unwrap();
1431                                let v = self.sqrt(v).unwrap();
1432                                self.div(d_arg, v)
1433                            }
1434                            UnaryOpcode::Acos => {
1435                                let v = self.square(v_arg).unwrap();
1436                                let v = self.sub(1.0, v).unwrap();
1437                                let v = self.sqrt(v).unwrap();
1438                                let v = self.neg(v).unwrap();
1439                                self.div(d_arg, v)
1440                            }
1441                            UnaryOpcode::Atan => {
1442                                let v = self.square(v_arg).unwrap();
1443                                let v = self.add(1.0, v).unwrap();
1444                                self.div(d_arg, v)
1445                            }
1446                            UnaryOpcode::Exp => self.mul(n, d_arg),
1447                            UnaryOpcode::Ln => self.div(d_arg, v_arg),
1448                            UnaryOpcode::Not => Ok(zero),
1449                        }
1450                        .unwrap();
1451                        seen.insert(n, out);
1452                        stack.push(out);
1453                    }
1454                    Op::Binary(op, v_lhs, v_rhs) => {
1455                        let d_lhs = stack.pop().unwrap();
1456                        let d_rhs = stack.pop().unwrap();
1457                        let out = match op {
1458                            BinaryOpcode::Add => self.add(d_lhs, d_rhs),
1459                            BinaryOpcode::Sub => self.sub(d_lhs, d_rhs),
1460                            BinaryOpcode::Mul => {
1461                                let a = self.mul(d_lhs, v_rhs).unwrap();
1462                                let b = self.mul(v_lhs, d_rhs).unwrap();
1463                                self.add(a, b)
1464                            }
1465                            BinaryOpcode::Div => {
1466                                let v = self.square(v_rhs).unwrap();
1467                                let a = self.mul(v_rhs, d_lhs).unwrap();
1468                                let b = self.mul(v_lhs, d_rhs).unwrap();
1469                                let c = self.sub(a, b).unwrap();
1470                                self.div(c, v)
1471                            }
1472                            BinaryOpcode::Atan => {
1473                                let a = self.square(v_lhs).unwrap();
1474                                let b = self.square(v_rhs).unwrap();
1475                                let d = self.add(a, b).unwrap();
1476
1477                                let a = self.mul(v_rhs, d_lhs).unwrap();
1478                                let b = self.mul(v_lhs, d_rhs).unwrap();
1479                                let v = self.sub(a, b).unwrap();
1480                                self.div(v, d)
1481                            }
1482                            BinaryOpcode::Min => {
1483                                let cond =
1484                                    self.less_than(v_lhs, v_rhs).unwrap();
1485                                self.if_nonzero_else(cond, d_lhs, d_rhs)
1486                            }
1487                            BinaryOpcode::Max => {
1488                                let cond =
1489                                    self.less_than(v_rhs, v_lhs).unwrap();
1490                                self.if_nonzero_else(cond, d_lhs, d_rhs)
1491                            }
1492                            BinaryOpcode::Compare => Ok(zero),
1493                            BinaryOpcode::Mod => {
1494                                let e = self.div(v_lhs, v_rhs).unwrap();
1495                                let q = self.floor(e).unwrap();
1496
1497                                // XXX
1498                                // (we don't actually have %, so hack it from
1499                                // `modulo`, which is actually `rem_euclid`)
1500                                // ???
1501                                let m = self.modulo(q, v_rhs).unwrap();
1502                                let cond = self.less_than(q, zero).unwrap();
1503                                let offset = self
1504                                    .if_nonzero_else(cond, v_rhs, zero)
1505                                    .unwrap();
1506                                let m = self.sub(m, offset).unwrap();
1507
1508                                // Torn from the div_euclid implementation
1509                                let outer = self.less_than(m, zero).unwrap();
1510                                let inner =
1511                                    self.less_than(zero, v_rhs).unwrap();
1512                                let qa = self.sub(q, 1.0).unwrap();
1513                                let qb = self.add(q, 1.0).unwrap();
1514                                let inner = self
1515                                    .if_nonzero_else(inner, qa, qb)
1516                                    .unwrap();
1517                                let e = self
1518                                    .if_nonzero_else(outer, inner, q)
1519                                    .unwrap();
1520
1521                                let v = self.mul(d_rhs, e).unwrap();
1522                                self.sub(d_lhs, v)
1523                            }
1524                            BinaryOpcode::And => {
1525                                let cond = self.compare(v_lhs, zero).unwrap();
1526                                self.if_nonzero_else(cond, d_rhs, d_lhs)
1527                            }
1528                            BinaryOpcode::Or => {
1529                                let cond = self.compare(v_lhs, zero).unwrap();
1530                                self.if_nonzero_else(cond, d_lhs, d_rhs)
1531                            }
1532                        }
1533                        .unwrap();
1534                        seen.insert(n, out);
1535                        stack.push(out);
1536                    }
1537                },
1538            }
1539        }
1540        assert_eq!(stack.len(), 1);
1541        Ok(stack.pop().unwrap())
1542    }
1543}
1544
1545////////////////////////////////////////////////////////////////////////////////
1546/// Helper trait for things that can be converted into a [`Node`] given a
1547/// [`Context`].
1548///
1549/// This trait allows you to write
1550/// ```
1551/// # let mut ctx = fidget_core::context::Context::new();
1552/// let x = ctx.x();
1553/// let sum = ctx.add(x, 1.0).unwrap();
1554/// ```
1555/// instead of the more verbose
1556/// ```
1557/// # let mut ctx = fidget_core::context::Context::new();
1558/// let x = ctx.x();
1559/// let num = ctx.constant(1.0);
1560/// let sum = ctx.add(x, num).unwrap();
1561/// ```
1562pub trait IntoNode {
1563    /// Converts the given values into a node
1564    fn into_node(self, ctx: &mut Context) -> Result<Node, Error>;
1565}
1566
1567impl IntoNode for Node {
1568    fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1569        ctx.check_node(self)?;
1570        Ok(self)
1571    }
1572}
1573
1574impl IntoNode for f32 {
1575    fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1576        Ok(ctx.constant(self as f64))
1577    }
1578}
1579
1580impl IntoNode for f64 {
1581    fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1582        Ok(ctx.constant(self))
1583    }
1584}
1585
1586////////////////////////////////////////////////////////////////////////////////
1587
1588#[cfg(test)]
1589mod test {
1590    use super::*;
1591    use crate::vm::VmData;
1592
1593    // This can't be in a doctest, because it uses a private function
1594    #[test]
1595    fn test_get_op() {
1596        let mut ctx = Context::new();
1597        let x = ctx.x();
1598        let op_x = ctx.get_op(x).unwrap();
1599        assert!(matches!(op_x, Op::Input(_)));
1600    }
1601
1602    #[test]
1603    fn test_ring() {
1604        let mut ctx = Context::new();
1605        let c0 = ctx.constant(0.5);
1606        let x = ctx.x();
1607        let y = ctx.y();
1608        let x2 = ctx.square(x).unwrap();
1609        let y2 = ctx.square(y).unwrap();
1610        let r = ctx.add(x2, y2).unwrap();
1611        let c6 = ctx.sub(r, c0).unwrap();
1612        let c7 = ctx.constant(0.25);
1613        let c8 = ctx.sub(c7, r).unwrap();
1614        let c9 = ctx.max(c8, c6).unwrap();
1615
1616        let tape = VmData::<255>::new(&ctx, &[c9]).unwrap();
1617        assert_eq!(tape.len(), 9);
1618        assert_eq!(tape.vars.len(), 2);
1619    }
1620
1621    #[test]
1622    fn test_dupe() {
1623        let mut ctx = Context::new();
1624        let x = ctx.x();
1625        let x_squared = ctx.mul(x, x).unwrap();
1626
1627        let tape = VmData::<255>::new(&ctx, &[x_squared]).unwrap();
1628        assert_eq!(tape.len(), 3); // x, square, output
1629        assert_eq!(tape.vars.len(), 1);
1630    }
1631
1632    #[test]
1633    fn test_export() {
1634        let mut ctx = Context::new();
1635        let x = ctx.x();
1636        let s = ctx.sin(x).unwrap();
1637        let c = ctx.cos(x).unwrap();
1638        let sum = ctx.add(s, c).unwrap();
1639        let t = ctx.export(sum).unwrap();
1640        if let TreeOp::Binary(BinaryOpcode::Add, lhs, rhs) = &*t {
1641            match (&**lhs, &**rhs) {
1642                (
1643                    TreeOp::Unary(UnaryOpcode::Sin, x1),
1644                    TreeOp::Unary(UnaryOpcode::Cos, x2),
1645                ) => {
1646                    assert_eq!(Arc::as_ptr(x1), Arc::as_ptr(x2));
1647                    let TreeOp::Input(Var::X) = &**x1 else {
1648                        panic!("invalid X: {x1:?}");
1649                    };
1650                }
1651                _ => panic!("invalid lhs / rhs: {lhs:?} {rhs:?}"),
1652            }
1653        } else {
1654            panic!("unexpected opcode {t:?}");
1655        }
1656    }
1657
1658    #[test]
1659    fn import_optimization() {
1660        let t = Tree::x() + 0;
1661        let mut ctx = Context::new();
1662        let root = ctx.import(&t);
1663        assert_eq!(ctx.get_op(root).unwrap(), &Op::Input(Var::X));
1664    }
1665}