fidget_core/context/mod.rs
1//! Infrastructure for representing math expressions as trees and graphs
2//!
3//! There are two families of representations in this module:
4//!
5//! - A [`Tree`] is a free-floating math expression, which can be cloned
6//! and has overloaded operators for ease of use. It is **not** deduplicated;
7//! two calls to [`Tree::constant(1.0)`](Tree::constant) will allocate two
8//! different objects.
9//! `Tree` objects are typically used when building up expressions; they
10//! should be converted to `Node` objects (in a particular `Context`) after
11//! they have been constructed.
12//! - A [`Context`] is an arena for unique (deduplicated) math expressions,
13//! which are represented as [`Node`] handles. Each `Node` is specific to a
14//! particular context. Only `Node` objects can be converted into
15//! [`Function`](crate::eval::Function) objects for evaluation.
16//!
17//! In other words, the typical workflow is `Tree → (Context, Node) → Function`.
18mod indexed;
19mod op;
20mod tree;
21
22use indexed::{Index, IndexMap, IndexVec, define_index};
23pub use op::{BinaryOpcode, Op, UnaryOpcode};
24pub use tree::{Tree, TreeOp};
25
26use crate::{Error, var::Var};
27
28use std::collections::{BTreeMap, HashMap};
29use std::fmt::Write;
30use std::io::{BufRead, BufReader, Read};
31use std::sync::Arc;
32
33use nalgebra::Matrix4;
34use ordered_float::OrderedFloat;
35
36define_index!(Node, "An index in the `Context::ops` map");
37
38/// A `Context` holds a set of deduplicated constants, variables, and
39/// operations.
40///
41/// It should be used like an arena allocator: it grows over time, then frees
42/// all of its contents when dropped. There is no reference counting within the
43/// context.
44///
45/// Items in the context are accessed with [`Node`] keys, which are simple
46/// handles into an internal map. Inside the context, operations are
47/// represented with the [`Op`] type.
48#[derive(Debug, Default)]
49pub struct Context {
50 ops: IndexMap<Op, Node>,
51}
52
53impl Context {
54 /// Build a new empty context
55 pub fn new() -> Self {
56 Self::default()
57 }
58
59 /// Clears the context
60 ///
61 /// All [`Node`] handles from this context are invalidated.
62 ///
63 /// ```
64 /// # use fidget_core::context::Context;
65 /// let mut ctx = Context::new();
66 /// let x = ctx.x();
67 /// ctx.clear();
68 /// assert!(ctx.eval_xyz(x, 1.0, 0.0, 0.0).is_err());
69 /// ```
70 pub fn clear(&mut self) {
71 self.ops.clear();
72 }
73
74 /// Returns the number of [`Op`] nodes in the context
75 ///
76 /// ```
77 /// # use fidget_core::context::Context;
78 /// let mut ctx = Context::new();
79 /// let x = ctx.x();
80 /// assert_eq!(ctx.len(), 1);
81 /// let y = ctx.y();
82 /// assert_eq!(ctx.len(), 2);
83 /// ctx.clear();
84 /// assert_eq!(ctx.len(), 0);
85 /// ```
86 pub fn len(&self) -> usize {
87 self.ops.len()
88 }
89
90 /// Checks whether the context is empty
91 pub fn is_empty(&self) -> bool {
92 self.ops.is_empty()
93 }
94
95 /// Checks whether the given [`Node`] is valid in this context
96 fn check_node(&self, node: Node) -> Result<(), Error> {
97 self.get_op(node).ok_or(Error::BadNode).map(|_| ())
98 }
99
100 /// Erases the most recently added node from the tree.
101 ///
102 /// A few caveats apply, so this must be used with caution:
103 /// - Existing handles to the node will be invalidated
104 /// - The most recently added node must be unique
105 ///
106 /// In practice, this is only used to delete temporary operation nodes
107 /// during constant folding. Such nodes which have no handles (because
108 /// they are never returned) and are guaranteed to be unique (because we
109 /// never store them persistently).
110 fn pop(&mut self) -> Result<(), Error> {
111 self.ops.pop().map(|_| ())
112 }
113
114 /// Looks up the constant associated with the given node.
115 ///
116 /// If the node is invalid for this tree, returns an error; if the node is
117 /// not a constant, returns `Ok(None)`.
118 pub fn get_const(&self, n: Node) -> Result<f64, Error> {
119 match self.get_op(n) {
120 Some(Op::Const(c)) => Ok(c.0),
121 Some(_) => Err(Error::NotAConst),
122 _ => Err(Error::BadNode),
123 }
124 }
125
126 /// Looks up the [`Var`] associated with the given node.
127 ///
128 /// If the node is invalid for this tree or not an `Op::Input`, returns an
129 /// error.
130 pub fn get_var(&self, n: Node) -> Result<Var, Error> {
131 match self.get_op(n) {
132 Some(Op::Input(v)) => Ok(*v),
133 Some(..) => Err(Error::NotAVar),
134 _ => Err(Error::BadNode),
135 }
136 }
137
138 ////////////////////////////////////////////////////////////////////////////
139 // Primitives
140 /// Constructs or finds a [`Var::X`] node
141 /// ```
142 /// # use fidget_core::context::Context;
143 /// let mut ctx = Context::new();
144 /// let x = ctx.x();
145 /// let v = ctx.eval_xyz(x, 1.0, 0.0, 0.0).unwrap();
146 /// assert_eq!(v, 1.0);
147 /// ```
148 pub fn x(&mut self) -> Node {
149 self.var(Var::X)
150 }
151
152 /// Constructs or finds a [`Var::Y`] node
153 pub fn y(&mut self) -> Node {
154 self.var(Var::Y)
155 }
156
157 /// Constructs or finds a [`Var::Z`] node
158 pub fn z(&mut self) -> Node {
159 self.var(Var::Z)
160 }
161
162 /// Constructs or finds a variable input node
163 ///
164 /// To make an anonymous variable, call this function with [`Var::new()`]:
165 ///
166 /// ```
167 /// # use fidget_core::{context::Context, var::Var};
168 /// # use std::collections::HashMap;
169 /// let mut ctx = Context::new();
170 /// let v1 = ctx.var(Var::new());
171 /// let v2 = ctx.var(Var::new());
172 /// assert_ne!(v1, v2);
173 ///
174 /// let mut vars = HashMap::new();
175 /// vars.insert(ctx.get_var(v1).unwrap(), 3.0);
176 /// assert_eq!(ctx.eval(v1, &vars).unwrap(), 3.0);
177 /// assert!(ctx.eval(v2, &vars).is_err()); // v2 isn't in the map
178 /// ```
179 pub fn var(&mut self, v: Var) -> Node {
180 self.ops.insert(Op::Input(v))
181 }
182
183 /// Returns a 3-element array of `X`, `Y`, `Z` nodes
184 pub fn axes(&mut self) -> [Node; 3] {
185 [self.x(), self.y(), self.z()]
186 }
187
188 /// Returns a node representing the given constant value.
189 /// ```
190 /// # let mut ctx = fidget_core::context::Context::new();
191 /// let v = ctx.constant(3.0);
192 /// assert_eq!(ctx.eval_xyz(v, 0.0, 0.0, 0.0).unwrap(), 3.0);
193 /// ```
194 pub fn constant(&mut self, f: f64) -> Node {
195 self.ops.insert(Op::Const(OrderedFloat(f)))
196 }
197
198 ////////////////////////////////////////////////////////////////////////////
199 // Helper functions to create nodes with constant folding
200 /// Find or create a [Node] for the given unary operation, with constant
201 /// folding.
202 fn op_unary(&mut self, a: Node, op: UnaryOpcode) -> Result<Node, Error> {
203 let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
204 let n = self.ops.insert(Op::Unary(op, a));
205 let out = if matches!(op_a, Op::Const(_)) {
206 let v = self.eval(n, &Default::default())?;
207 self.pop().unwrap(); // removes `n`
208 self.constant(v)
209 } else {
210 n
211 };
212 Ok(out)
213 }
214 /// Find or create a [Node] for the given binary operation, with constant
215 /// folding.
216 fn op_binary(
217 &mut self,
218 a: Node,
219 b: Node,
220 op: BinaryOpcode,
221 ) -> Result<Node, Error> {
222 self.op_binary_f(a, b, |lhs, rhs| Op::Binary(op, lhs, rhs))
223 }
224
225 /// Find or create a [Node] for a generic binary operation (represented by a
226 /// thunk), with constant folding.
227 fn op_binary_f<F>(&mut self, a: Node, b: Node, f: F) -> Result<Node, Error>
228 where
229 F: Fn(Node, Node) -> Op,
230 {
231 let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
232 let op_b = *self.get_op(b).ok_or(Error::BadNode)?;
233
234 // This call to `insert` should always insert the node, because we
235 // don't permanently store operations in the tree that could be
236 // constant-folded (indeed, we pop the node right afterwards)
237 let n = self.ops.insert(f(a, b));
238 let out = if matches!((op_a, op_b), (Op::Const(_), Op::Const(_))) {
239 let v = self.eval(n, &Default::default())?;
240 self.pop().unwrap(); // removes `n`
241 self.constant(v)
242 } else {
243 n
244 };
245 Ok(out)
246 }
247
248 /// Find or create a [Node] for the given commutative operation, with
249 /// constant folding; deduplication is encouraged by sorting `a` and `b`.
250 fn op_binary_commutative(
251 &mut self,
252 a: Node,
253 b: Node,
254 op: BinaryOpcode,
255 ) -> Result<Node, Error> {
256 self.op_binary(a.min(b), a.max(b), op)
257 }
258
259 /// Builds an addition node
260 /// ```
261 /// # let mut ctx = fidget_core::context::Context::new();
262 /// let x = ctx.x();
263 /// let op = ctx.add(x, 1.0).unwrap();
264 /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
265 /// assert_eq!(v, 2.0);
266 /// ```
267 pub fn add<A: IntoNode, B: IntoNode>(
268 &mut self,
269 a: A,
270 b: B,
271 ) -> Result<Node, Error> {
272 let a: Node = a.into_node(self)?;
273 let b: Node = b.into_node(self)?;
274 if a == b {
275 let two = self.constant(2.0);
276 self.mul(a, two)
277 } else {
278 match (self.get_const(a), self.get_const(b)) {
279 (Ok(0.0), _) => Ok(b),
280 (_, Ok(0.0)) => Ok(a),
281 _ => self.op_binary_commutative(a, b, BinaryOpcode::Add),
282 }
283 }
284 }
285
286 /// Builds an multiplication node
287 /// ```
288 /// # let mut ctx = fidget_core::context::Context::new();
289 /// let x = ctx.x();
290 /// let op = ctx.mul(x, 5.0).unwrap();
291 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
292 /// assert_eq!(v, 10.0);
293 /// ```
294 pub fn mul<A: IntoNode, B: IntoNode>(
295 &mut self,
296 a: A,
297 b: B,
298 ) -> Result<Node, Error> {
299 let a = a.into_node(self)?;
300 let b = b.into_node(self)?;
301 if a == b {
302 self.square(a)
303 } else {
304 match (self.get_const(a), self.get_const(b)) {
305 (Ok(1.0), _) => Ok(b),
306 (_, Ok(1.0)) => Ok(a),
307 (Ok(0.0), _) => Ok(a),
308 (_, Ok(0.0)) => Ok(b),
309 _ => self.op_binary_commutative(a, b, BinaryOpcode::Mul),
310 }
311 }
312 }
313
314 /// Builds an `min` node
315 /// ```
316 /// # let mut ctx = fidget_core::context::Context::new();
317 /// let x = ctx.x();
318 /// let op = ctx.min(x, 5.0).unwrap();
319 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
320 /// assert_eq!(v, 2.0);
321 /// ```
322 pub fn min<A: IntoNode, B: IntoNode>(
323 &mut self,
324 a: A,
325 b: B,
326 ) -> Result<Node, Error> {
327 let a = a.into_node(self)?;
328 let b = b.into_node(self)?;
329 if a == b {
330 Ok(a)
331 } else {
332 self.op_binary_commutative(a, b, BinaryOpcode::Min)
333 }
334 }
335 /// Builds an `max` node
336 /// ```
337 /// # let mut ctx = fidget_core::context::Context::new();
338 /// let x = ctx.x();
339 /// let op = ctx.max(x, 5.0).unwrap();
340 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
341 /// assert_eq!(v, 5.0);
342 /// ```
343 pub fn max<A: IntoNode, B: IntoNode>(
344 &mut self,
345 a: A,
346 b: B,
347 ) -> Result<Node, Error> {
348 let a = a.into_node(self)?;
349 let b = b.into_node(self)?;
350 if a == b {
351 Ok(a)
352 } else {
353 self.op_binary_commutative(a, b, BinaryOpcode::Max)
354 }
355 }
356
357 /// Builds an `and` node
358 ///
359 /// If both arguments are non-zero, returns the right-hand argument.
360 /// Otherwise, returns zero.
361 ///
362 /// This node can be simplified using a tracing evaluator:
363 /// - If the left-hand argument is zero, simplify to just that argument
364 /// - If the left-hand argument is non-zero, simplify to the other argument
365 /// ```
366 /// # let mut ctx = fidget_core::context::Context::new();
367 /// let x = ctx.x();
368 /// let y = ctx.y();
369 /// let op = ctx.and(x, y).unwrap();
370 /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
371 /// assert_eq!(v, 0.0);
372 /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
373 /// assert_eq!(v, 1.0);
374 /// let v = ctx.eval_xyz(op, 1.0, 2.0, 0.0).unwrap();
375 /// assert_eq!(v, 2.0);
376 /// ```
377 pub fn and<A: IntoNode, B: IntoNode>(
378 &mut self,
379 a: A,
380 b: B,
381 ) -> Result<Node, Error> {
382 let a = a.into_node(self)?;
383 let b = b.into_node(self)?;
384
385 let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
386 if let Op::Const(v) = op_a {
387 if v.0 == 0.0 { Ok(a) } else { Ok(b) }
388 } else {
389 self.op_binary(a, b, BinaryOpcode::And)
390 }
391 }
392
393 /// Builds an `or` node
394 ///
395 /// If the left-hand argument is non-zero, it is returned. Otherwise, the
396 /// right-hand argument is returned.
397 ///
398 /// This node can be simplified using a tracing evaluator.
399 /// ```
400 /// # let mut ctx = fidget_core::context::Context::new();
401 /// let x = ctx.x();
402 /// let y = ctx.y();
403 /// let op = ctx.or(x, y).unwrap();
404 /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
405 /// assert_eq!(v, 1.0);
406 /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
407 /// assert_eq!(v, 0.0);
408 /// let v = ctx.eval_xyz(op, 0.0, 3.0, 0.0).unwrap();
409 /// assert_eq!(v, 3.0);
410 /// ```
411 pub fn or<A: IntoNode, B: IntoNode>(
412 &mut self,
413 a: A,
414 b: B,
415 ) -> Result<Node, Error> {
416 let a = a.into_node(self)?;
417 let b = b.into_node(self)?;
418
419 let op_a = *self.get_op(a).ok_or(Error::BadNode)?;
420 let op_b = *self.get_op(b).ok_or(Error::BadNode)?;
421 if let Op::Const(v) = op_a {
422 if v.0 != 0.0 {
423 return Ok(a);
424 } else {
425 return Ok(b);
426 }
427 } else if let Op::Const(v) = op_b {
428 if v.0 == 0.0 {
429 return Ok(a);
430 }
431 }
432 self.op_binary(a, b, BinaryOpcode::Or)
433 }
434
435 /// Builds a logical negation node
436 ///
437 /// The output is 1 if the argument is 0, and 0 otherwise.
438 pub fn not<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
439 let a = a.into_node(self)?;
440 self.op_unary(a, UnaryOpcode::Not)
441 }
442
443 /// Builds a unary negation node
444 /// ```
445 /// # let mut ctx = fidget_core::context::Context::new();
446 /// let x = ctx.x();
447 /// let op = ctx.neg(x).unwrap();
448 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
449 /// assert_eq!(v, -2.0);
450 /// ```
451 pub fn neg<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
452 let a = a.into_node(self)?;
453 self.op_unary(a, UnaryOpcode::Neg)
454 }
455
456 /// Builds a reciprocal node
457 /// ```
458 /// # let mut ctx = fidget_core::context::Context::new();
459 /// let x = ctx.x();
460 /// let op = ctx.recip(x).unwrap();
461 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
462 /// assert_eq!(v, 0.5);
463 /// ```
464 pub fn recip<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
465 let a = a.into_node(self)?;
466 self.op_unary(a, UnaryOpcode::Recip)
467 }
468
469 /// Builds a node which calculates the absolute value of its input
470 /// ```
471 /// # let mut ctx = fidget_core::context::Context::new();
472 /// let x = ctx.x();
473 /// let op = ctx.abs(x).unwrap();
474 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
475 /// assert_eq!(v, 2.0);
476 /// let v = ctx.eval_xyz(op, -2.0, 0.0, 0.0).unwrap();
477 /// assert_eq!(v, 2.0);
478 /// ```
479 pub fn abs<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
480 let a = a.into_node(self)?;
481 self.op_unary(a, UnaryOpcode::Abs)
482 }
483
484 /// Builds a node which calculates the square root of its input
485 /// ```
486 /// # let mut ctx = fidget_core::context::Context::new();
487 /// let x = ctx.x();
488 /// let op = ctx.sqrt(x).unwrap();
489 /// let v = ctx.eval_xyz(op, 4.0, 0.0, 0.0).unwrap();
490 /// assert_eq!(v, 2.0);
491 /// ```
492 pub fn sqrt<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
493 let a = a.into_node(self)?;
494 self.op_unary(a, UnaryOpcode::Sqrt)
495 }
496
497 /// Builds a node which calculates the sine of its input (in radians)
498 /// ```
499 /// # let mut ctx = fidget_core::context::Context::new();
500 /// let x = ctx.x();
501 /// let op = ctx.sin(x).unwrap();
502 /// let v = ctx.eval_xyz(op, std::f64::consts::PI / 2.0, 0.0, 0.0).unwrap();
503 /// assert_eq!(v, 1.0);
504 /// ```
505 pub fn sin<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
506 let a = a.into_node(self)?;
507 self.op_unary(a, UnaryOpcode::Sin)
508 }
509
510 /// Builds a node which calculates the cosine of its input (in radians)
511 pub fn cos<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
512 let a = a.into_node(self)?;
513 self.op_unary(a, UnaryOpcode::Cos)
514 }
515
516 /// Builds a node which calculates the tangent of its input (in radians)
517 pub fn tan<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
518 let a = a.into_node(self)?;
519 self.op_unary(a, UnaryOpcode::Tan)
520 }
521
522 /// Builds a node which calculates the arcsine of its input (in radians)
523 pub fn asin<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
524 let a = a.into_node(self)?;
525 self.op_unary(a, UnaryOpcode::Asin)
526 }
527
528 /// Builds a node which calculates the arccosine of its input (in radians)
529 pub fn acos<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
530 let a = a.into_node(self)?;
531 self.op_unary(a, UnaryOpcode::Acos)
532 }
533
534 /// Builds a node which calculates the arctangent of its input (in radians)
535 pub fn atan<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
536 let a = a.into_node(self)?;
537 self.op_unary(a, UnaryOpcode::Atan)
538 }
539
540 /// Builds a node which calculates the exponent of its input
541 pub fn exp<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
542 let a = a.into_node(self)?;
543 self.op_unary(a, UnaryOpcode::Exp)
544 }
545
546 /// Builds a node which calculates the natural log of its input
547 pub fn ln<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
548 let a = a.into_node(self)?;
549 self.op_unary(a, UnaryOpcode::Ln)
550 }
551
552 ////////////////////////////////////////////////////////////////////////////
553 // Derived functions
554 /// Builds a node which squares its input
555 /// ```
556 /// # let mut ctx = fidget_core::context::Context::new();
557 /// let x = ctx.x();
558 /// let op = ctx.square(x).unwrap();
559 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
560 /// assert_eq!(v, 4.0);
561 /// ```
562 pub fn square<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
563 let a = a.into_node(self)?;
564 self.op_unary(a, UnaryOpcode::Square)
565 }
566
567 /// Builds a node which takes the floor of its input
568 /// ```
569 /// # let mut ctx = fidget_core::context::Context::new();
570 /// let x = ctx.x();
571 /// let op = ctx.floor(x).unwrap();
572 /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
573 /// assert_eq!(v, 1.0);
574 /// ```
575 pub fn floor<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
576 let a = a.into_node(self)?;
577 self.op_unary(a, UnaryOpcode::Floor)
578 }
579
580 /// Builds a node which takes the ceiling of its input
581 /// ```
582 /// # let mut ctx = fidget_core::context::Context::new();
583 /// let x = ctx.x();
584 /// let op = ctx.ceil(x).unwrap();
585 /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
586 /// assert_eq!(v, 2.0);
587 /// ```
588 pub fn ceil<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
589 let a = a.into_node(self)?;
590 self.op_unary(a, UnaryOpcode::Ceil)
591 }
592
593 /// Builds a node which rounds its input to the nearest integer
594 /// ```
595 /// # let mut ctx = fidget_core::context::Context::new();
596 /// let x = ctx.x();
597 /// let op = ctx.round(x).unwrap();
598 /// let v = ctx.eval_xyz(op, 1.2, 0.0, 0.0).unwrap();
599 /// assert_eq!(v, 1.0);
600 /// let v = ctx.eval_xyz(op, 1.6, 0.0, 0.0).unwrap();
601 /// assert_eq!(v, 2.0);
602 /// let v = ctx.eval_xyz(op, 1.5, 0.0, 0.0).unwrap();
603 /// assert_eq!(v, 2.0); // rounds away from 0.0 if ambiguous
604 /// ```
605 pub fn round<A: IntoNode>(&mut self, a: A) -> Result<Node, Error> {
606 let a = a.into_node(self)?;
607 self.op_unary(a, UnaryOpcode::Round)
608 }
609
610 /// Builds a node which performs subtraction.
611 /// ```
612 /// # let mut ctx = fidget_core::context::Context::new();
613 /// let x = ctx.x();
614 /// let y = ctx.y();
615 /// let op = ctx.sub(x, y).unwrap();
616 /// let v = ctx.eval_xyz(op, 3.0, 2.0, 0.0).unwrap();
617 /// assert_eq!(v, 1.0);
618 /// ```
619 pub fn sub<A: IntoNode, B: IntoNode>(
620 &mut self,
621 a: A,
622 b: B,
623 ) -> Result<Node, Error> {
624 let a = a.into_node(self)?;
625 let b = b.into_node(self)?;
626
627 match (self.get_const(a), self.get_const(b)) {
628 (Ok(0.0), _) => self.neg(b),
629 (_, Ok(0.0)) => Ok(a),
630 _ => self.op_binary(a, b, BinaryOpcode::Sub),
631 }
632 }
633
634 /// Builds a node which performs division.
635 /// ```
636 /// # let mut ctx = fidget_core::context::Context::new();
637 /// let x = ctx.x();
638 /// let y = ctx.y();
639 /// let op = ctx.div(x, y).unwrap();
640 /// let v = ctx.eval_xyz(op, 3.0, 2.0, 0.0).unwrap();
641 /// assert_eq!(v, 1.5);
642 /// ```
643 pub fn div<A: IntoNode, B: IntoNode>(
644 &mut self,
645 a: A,
646 b: B,
647 ) -> Result<Node, Error> {
648 let a = a.into_node(self)?;
649 let b = b.into_node(self)?;
650
651 match (self.get_const(a), self.get_const(b)) {
652 (Ok(0.0), _) => Ok(a),
653 (_, Ok(1.0)) => Ok(a),
654 _ => self.op_binary(a, b, BinaryOpcode::Div),
655 }
656 }
657
658 /// Builds a node which computes `atan2(y, x)`
659 /// ```
660 /// # let mut ctx = fidget_core::context::Context::new();
661 /// let x = ctx.x();
662 /// let y = ctx.y();
663 /// let op = ctx.atan2(y, x).unwrap();
664 /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
665 /// assert_eq!(v, std::f64::consts::FRAC_PI_2);
666 /// ```
667 pub fn atan2<A: IntoNode, B: IntoNode>(
668 &mut self,
669 y: A,
670 x: B,
671 ) -> Result<Node, Error> {
672 let y = y.into_node(self)?;
673 let x = x.into_node(self)?;
674
675 self.op_binary(y, x, BinaryOpcode::Atan)
676 }
677
678 /// Builds a node that compares two values
679 ///
680 /// The result is -1 if `a < b`, +1 if `a > b`, 0 if `a == b`, and `NaN` if
681 /// either side is `NaN`.
682 /// ```
683 /// # let mut ctx = fidget_core::context::Context::new();
684 /// let x = ctx.x();
685 /// let op = ctx.compare(x, 1.0).unwrap();
686 /// let v = ctx.eval_xyz(op, 0.0, 0.0, 0.0).unwrap();
687 /// assert_eq!(v, -1.0);
688 /// let v = ctx.eval_xyz(op, 2.0, 0.0, 0.0).unwrap();
689 /// assert_eq!(v, 1.0);
690 /// let v = ctx.eval_xyz(op, 1.0, 0.0, 0.0).unwrap();
691 /// assert_eq!(v, 0.0);
692 /// ```
693 pub fn compare<A: IntoNode, B: IntoNode>(
694 &mut self,
695 a: A,
696 b: B,
697 ) -> Result<Node, Error> {
698 let a = a.into_node(self)?;
699 let b = b.into_node(self)?;
700 self.op_binary(a, b, BinaryOpcode::Compare)
701 }
702
703 /// Builds a node that is 1 if `lhs < rhs` and 0 otherwise
704 ///
705 /// ```
706 /// # let mut ctx = fidget_core::context::Context::new();
707 /// let x = ctx.x();
708 /// let y = ctx.y();
709 /// let op = ctx.less_than(x, y).unwrap();
710 /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
711 /// assert_eq!(v, 1.0);
712 /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
713 /// assert_eq!(v, 0.0);
714 /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
715 /// assert_eq!(v, 0.0);
716 /// ```
717 pub fn less_than<A: IntoNode, B: IntoNode>(
718 &mut self,
719 lhs: A,
720 rhs: B,
721 ) -> Result<Node, Error> {
722 let lhs = lhs.into_node(self)?;
723 let rhs = rhs.into_node(self)?;
724 let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
725 self.max(cmp, 0.0)
726 }
727
728 /// Builds a node that is 1 if `lhs <= rhs` and 0 otherwise
729 ///
730 /// ```
731 /// # let mut ctx = fidget_core::context::Context::new();
732 /// let x = ctx.x();
733 /// let y = ctx.y();
734 /// let op = ctx.less_than_or_equal(x, y).unwrap();
735 /// let v = ctx.eval_xyz(op, 0.0, 1.0, 0.0).unwrap();
736 /// assert_eq!(v, 1.0);
737 /// let v = ctx.eval_xyz(op, 1.0, 1.0, 0.0).unwrap();
738 /// assert_eq!(v, 1.0);
739 /// let v = ctx.eval_xyz(op, 2.0, 1.0, 0.0).unwrap();
740 /// assert_eq!(v, 0.0);
741 /// ```
742 pub fn less_than_or_equal<A: IntoNode, B: IntoNode>(
743 &mut self,
744 lhs: A,
745 rhs: B,
746 ) -> Result<Node, Error> {
747 let lhs = lhs.into_node(self)?;
748 let rhs = rhs.into_node(self)?;
749 let cmp = self.op_binary(rhs, lhs, BinaryOpcode::Compare)?;
750 let shift = self.add(cmp, 1.0)?;
751 self.min(shift, 1.0)
752 }
753
754 /// Builds a node that takes the modulo (least non-negative remainder)
755 pub fn modulo<A: IntoNode, B: IntoNode>(
756 &mut self,
757 a: A,
758 b: B,
759 ) -> Result<Node, Error> {
760 let a = a.into_node(self)?;
761 let b = b.into_node(self)?;
762 self.op_binary(a, b, BinaryOpcode::Mod)
763 }
764
765 /// Builds a node that returns the first node if the condition is not
766 /// equal to zero, else returns the other node
767 ///
768 /// The result is `a` if `condition != 0`, else the result is `b`.
769 /// ```
770 /// # let mut ctx = fidget_core::context::Context::new();
771 /// let x = ctx.x();
772 /// let y = ctx.y();
773 /// let z = ctx.z();
774 ///
775 /// let if_else = ctx.if_nonzero_else(x, y, z).unwrap();
776 ///
777 /// assert_eq!(ctx.eval_xyz(if_else, 0.0, 2.0, 3.0).unwrap(), 3.0);
778 /// assert_eq!(ctx.eval_xyz(if_else, 1.0, 2.0, 3.0).unwrap(), 2.0);
779 /// assert_eq!(ctx.eval_xyz(if_else, 0.0, f64::NAN, 3.0).unwrap(), 3.0);
780 /// assert_eq!(ctx.eval_xyz(if_else, 1.0, 2.0, f64::NAN).unwrap(), 2.0);
781 /// ```
782 pub fn if_nonzero_else<Condition: IntoNode, A: IntoNode, B: IntoNode>(
783 &mut self,
784 condition: Condition,
785 a: A,
786 b: B,
787 ) -> Result<Node, Error> {
788 let condition = condition.into_node(self)?;
789 let a = a.into_node(self)?;
790 let b = b.into_node(self)?;
791
792 let lhs = self.and(condition, a)?;
793 let n_condition = self.not(condition)?;
794 let rhs = self.and(n_condition, b)?;
795 self.or(lhs, rhs)
796 }
797
798 ////////////////////////////////////////////////////////////////////////////
799 /// Evaluates the given node with the provided values for X, Y, and Z.
800 ///
801 /// This is extremely inefficient; consider converting the node into a
802 /// [`Shape`](crate::shape::Shape) and using its evaluators instead.
803 ///
804 /// ```
805 /// # let mut ctx = fidget_core::context::Context::new();
806 /// let x = ctx.x();
807 /// let y = ctx.y();
808 /// let z = ctx.z();
809 /// let op = ctx.mul(x, y).unwrap();
810 /// let op = ctx.div(op, z).unwrap();
811 /// let v = ctx.eval_xyz(op, 3.0, 5.0, 2.0).unwrap();
812 /// assert_eq!(v, 7.5); // (3.0 * 5.0) / 2.0
813 /// ```
814 pub fn eval_xyz(
815 &self,
816 root: Node,
817 x: f64,
818 y: f64,
819 z: f64,
820 ) -> Result<f64, Error> {
821 let vars = [(Var::X, x), (Var::Y, y), (Var::Z, z)]
822 .into_iter()
823 .collect();
824 self.eval(root, &vars)
825 }
826
827 /// Evaluates the given node with a generic set of variables
828 ///
829 /// This is extremely inefficient; consider converting the node into a
830 /// [`Shape`](crate::shape::Shape) and using its evaluators instead.
831 pub fn eval(
832 &self,
833 root: Node,
834 vars: &HashMap<Var, f64>,
835 ) -> Result<f64, Error> {
836 let mut cache = vec![None; self.ops.len()].into();
837 self.eval_inner(root, vars, &mut cache)
838 }
839
840 fn eval_inner(
841 &self,
842 node: Node,
843 vars: &HashMap<Var, f64>,
844 cache: &mut IndexVec<Option<f64>, Node>,
845 ) -> Result<f64, Error> {
846 if node.0 >= cache.len() {
847 return Err(Error::BadNode);
848 }
849 if let Some(v) = cache[node] {
850 return Ok(v);
851 }
852 let mut get = |n: Node| self.eval_inner(n, vars, cache);
853 let v = match self.get_op(node).ok_or(Error::BadNode)? {
854 Op::Input(v) => *vars.get(v).ok_or(Error::MissingVar(*v))?,
855 Op::Const(c) => c.0,
856
857 Op::Binary(op, a, b) => {
858 let a = get(*a)?;
859 let b = get(*b)?;
860 match op {
861 BinaryOpcode::Add => a + b,
862 BinaryOpcode::Sub => a - b,
863 BinaryOpcode::Mul => a * b,
864 BinaryOpcode::Div => a / b,
865 BinaryOpcode::Atan => a.atan2(b),
866 BinaryOpcode::Min => a.min(b),
867 BinaryOpcode::Max => a.max(b),
868 BinaryOpcode::Compare => a
869 .partial_cmp(&b)
870 .map(|i| i as i8 as f64)
871 .unwrap_or(f64::NAN),
872 BinaryOpcode::Mod => a.rem_euclid(b),
873 BinaryOpcode::And => {
874 if a == 0.0 {
875 a
876 } else {
877 b
878 }
879 }
880 BinaryOpcode::Or => {
881 if a != 0.0 {
882 a
883 } else {
884 b
885 }
886 }
887 }
888 }
889
890 // Unary operations
891 Op::Unary(op, a) => {
892 let a = get(*a)?;
893 match op {
894 UnaryOpcode::Neg => -a,
895 UnaryOpcode::Abs => a.abs(),
896 UnaryOpcode::Recip => 1.0 / a,
897 UnaryOpcode::Sqrt => a.sqrt(),
898 UnaryOpcode::Square => a * a,
899 UnaryOpcode::Floor => a.floor(),
900 UnaryOpcode::Ceil => a.ceil(),
901 UnaryOpcode::Round => a.round(),
902 UnaryOpcode::Sin => a.sin(),
903 UnaryOpcode::Cos => a.cos(),
904 UnaryOpcode::Tan => a.tan(),
905 UnaryOpcode::Asin => a.asin(),
906 UnaryOpcode::Acos => a.acos(),
907 UnaryOpcode::Atan => a.atan(),
908 UnaryOpcode::Exp => a.exp(),
909 UnaryOpcode::Ln => a.ln(),
910 UnaryOpcode::Not => (a == 0.0).into(),
911 }
912 }
913 };
914
915 cache[node] = Some(v);
916 Ok(v)
917 }
918
919 /// Parses a flat text representation of a math tree. For example, the
920 /// circle `(- (+ (square x) (square y)) 1)` can be parsed from
921 /// ```
922 /// # use fidget_core::context::Context;
923 /// let txt = "
924 /// ## This is a comment!
925 /// 0x600000b90000 var-x
926 /// 0x600000b900a0 square 0x600000b90000
927 /// 0x600000b90050 var-y
928 /// 0x600000b900f0 square 0x600000b90050
929 /// 0x600000b90140 add 0x600000b900a0 0x600000b900f0
930 /// 0x600000b90190 sqrt 0x600000b90140
931 /// 0x600000b901e0 const 1
932 /// ";
933 /// let (ctx, _node) = Context::from_text(&mut txt.as_bytes()).unwrap();
934 /// assert_eq!(ctx.len(), 7);
935 /// ```
936 ///
937 /// This representation is loosely defined and only intended for use in
938 /// quick experiments.
939 pub fn from_text<R: Read>(r: R) -> Result<(Self, Node), Error> {
940 let reader = BufReader::new(r);
941 let mut ctx = Self::new();
942 let mut seen = BTreeMap::new();
943 let mut last = None;
944
945 for line in reader.lines().map(|line| line.unwrap()) {
946 if line.is_empty() || line.starts_with('#') {
947 continue;
948 }
949 let mut iter = line.split_whitespace();
950 let i: String = iter.next().unwrap().to_owned();
951 let opcode = iter.next().unwrap();
952
953 let mut pop = || {
954 let txt = iter.next().unwrap();
955 seen.get(txt)
956 .cloned()
957 .ok_or_else(|| Error::UnknownVariable(txt.to_string()))
958 };
959 let node = match opcode {
960 "const" => ctx.constant(iter.next().unwrap().parse().unwrap()),
961 "var-x" => ctx.x(),
962 "var-y" => ctx.y(),
963 "var-z" => ctx.z(),
964 "abs" => ctx.abs(pop()?)?,
965 "neg" => ctx.neg(pop()?)?,
966 "sqrt" => ctx.sqrt(pop()?)?,
967 "square" => ctx.square(pop()?)?,
968 "floor" => ctx.floor(pop()?)?,
969 "ceil" => ctx.ceil(pop()?)?,
970 "round" => ctx.round(pop()?)?,
971 "sin" => ctx.sin(pop()?)?,
972 "cos" => ctx.cos(pop()?)?,
973 "tan" => ctx.tan(pop()?)?,
974 "asin" => ctx.asin(pop()?)?,
975 "acos" => ctx.acos(pop()?)?,
976 "atan" => ctx.atan(pop()?)?,
977 "ln" => ctx.ln(pop()?)?,
978 "not" => ctx.not(pop()?)?,
979 "exp" => ctx.exp(pop()?)?,
980 "add" => ctx.add(pop()?, pop()?)?,
981 "mul" => ctx.mul(pop()?, pop()?)?,
982 "min" => ctx.min(pop()?, pop()?)?,
983 "max" => ctx.max(pop()?, pop()?)?,
984 "div" => ctx.div(pop()?, pop()?)?,
985 "atan2" => ctx.atan2(pop()?, pop()?)?,
986 "sub" => ctx.sub(pop()?, pop()?)?,
987 "compare" => ctx.compare(pop()?, pop()?)?,
988 "mod" => ctx.modulo(pop()?, pop()?)?,
989 "and" => ctx.and(pop()?, pop()?)?,
990 "or" => ctx.or(pop()?, pop()?)?,
991 op => return Err(Error::UnknownOpcode(op.to_owned())),
992 };
993 seen.insert(i, node);
994 last = Some(node);
995 }
996 match last {
997 Some(node) => Ok((ctx, node)),
998 None => Err(Error::EmptyFile),
999 }
1000 }
1001
1002 /// Converts the entire context into a GraphViz drawing
1003 pub fn dot(&self) -> String {
1004 let mut out = "digraph mygraph{\n".to_owned();
1005 for node in self.ops.keys() {
1006 let op = self.get_op(node).unwrap();
1007 out += &self.dot_node(node);
1008 out += &op.dot_edges(node);
1009 }
1010 out += "}\n";
1011 out
1012 }
1013
1014 /// Converts the given node into a GraphViz node
1015 ///
1016 /// (this is a local function instead of a function on `Op` because it
1017 /// requires looking up variables by name)
1018 fn dot_node(&self, i: Node) -> String {
1019 let mut out = format!(r#"n{} [label = ""#, i.get());
1020 let op = self.get_op(i).unwrap();
1021 match op {
1022 Op::Const(c) => write!(out, "{c}").unwrap(),
1023 Op::Input(v) => {
1024 out += &v.to_string();
1025 }
1026 Op::Binary(op, ..) => match op {
1027 BinaryOpcode::Add => out += "add",
1028 BinaryOpcode::Sub => out += "sub",
1029 BinaryOpcode::Mul => out += "mul",
1030 BinaryOpcode::Div => out += "div",
1031 BinaryOpcode::Atan => out += "atan2",
1032 BinaryOpcode::Min => out += "min",
1033 BinaryOpcode::Max => out += "max",
1034 BinaryOpcode::Compare => out += "compare",
1035 BinaryOpcode::Mod => out += "mod",
1036 BinaryOpcode::And => out += "and",
1037 BinaryOpcode::Or => out += "or",
1038 },
1039 Op::Unary(op, ..) => match op {
1040 UnaryOpcode::Neg => out += "neg",
1041 UnaryOpcode::Abs => out += "abs",
1042 UnaryOpcode::Recip => out += "recip",
1043 UnaryOpcode::Sqrt => out += "sqrt",
1044 UnaryOpcode::Square => out += "square",
1045 UnaryOpcode::Floor => out += "floor",
1046 UnaryOpcode::Ceil => out += "ceil",
1047 UnaryOpcode::Round => out += "round",
1048 UnaryOpcode::Sin => out += "sin",
1049 UnaryOpcode::Cos => out += "cos",
1050 UnaryOpcode::Tan => out += "tan",
1051 UnaryOpcode::Asin => out += "asin",
1052 UnaryOpcode::Acos => out += "acos",
1053 UnaryOpcode::Atan => out += "atan",
1054 UnaryOpcode::Exp => out += "exp",
1055 UnaryOpcode::Ln => out += "ln",
1056 UnaryOpcode::Not => out += "not",
1057 },
1058 };
1059 write!(
1060 out,
1061 r#"" color="{0}1" shape="{1}" fontcolor="{0}4"]"#,
1062 op.dot_node_color(),
1063 op.dot_node_shape()
1064 )
1065 .unwrap();
1066 out
1067 }
1068
1069 /// Looks up an operation by `Node` handle
1070 pub fn get_op(&self, node: Node) -> Option<&Op> {
1071 self.ops.get_by_index(node)
1072 }
1073
1074 /// Imports the given tree, deduplicating and returning the root
1075 pub fn import(&mut self, tree: &Tree) -> Node {
1076 // A naive remapping implementation would use recursion. A naive
1077 // remapping implementation would blow up the stack given any
1078 // significant tree size.
1079 //
1080 // Instead, we maintain our own pseudo-stack here in a pair of Vecs (one
1081 // stack for actions, and a second stack for return values).
1082 enum Action<'a> {
1083 /// Pushes `Up(op)` followed by `Down(c)` for each child
1084 Down(&'a Arc<TreeOp>),
1085 /// Consumes imported trees from the stack and pushes a new tree
1086 Up(&'a Arc<TreeOp>),
1087 /// Pops the latest axis frame
1088 Pop,
1089 /// Pops the latest affine frame
1090 PopAffine,
1091 }
1092 let mut axes = vec![(self.x(), self.y(), self.z())];
1093 let mut todo = vec![Action::Down(tree.arc())];
1094 let mut stack = vec![];
1095 let mut affine: Vec<Matrix4<f64>> = vec![];
1096
1097 // Cache of TreeOp -> Node mapping under a particular frame (axes)
1098 //
1099 // This isn't required for correctness, but can be a speed optimization
1100 // (because it means we don't have to walk the same tree twice).
1101 let mut seen = HashMap::new();
1102
1103 while let Some(t) = todo.pop() {
1104 match t {
1105 Action::Down(t) => {
1106 // If we've already seen this TreeOp with these axes, then
1107 // we can return the previous Node.
1108 if matches!(
1109 t.as_ref(),
1110 TreeOp::Unary(..) | TreeOp::Binary(..)
1111 ) {
1112 if let Some(p) =
1113 seen.get(&(*axes.last().unwrap(), Arc::as_ptr(t)))
1114 {
1115 stack.push(*p);
1116 continue;
1117 }
1118 }
1119 match t.as_ref() {
1120 TreeOp::Const(c) => {
1121 stack.push(self.constant(*c));
1122 }
1123 TreeOp::Input(s) => {
1124 let axes = axes.last().unwrap();
1125 stack.push(match *s {
1126 Var::X => axes.0,
1127 Var::Y => axes.1,
1128 Var::Z => axes.2,
1129 v @ Var::V(..) => self.var(v),
1130 });
1131 }
1132 TreeOp::Unary(_op, arg) => {
1133 todo.push(Action::Up(t));
1134 todo.push(Action::Down(arg));
1135 }
1136 TreeOp::Binary(_op, lhs, rhs) => {
1137 todo.push(Action::Up(t));
1138 todo.push(Action::Down(lhs));
1139 todo.push(Action::Down(rhs));
1140 }
1141 TreeOp::RemapAxes { target: _, x, y, z } => {
1142 // Action::Up(t) does the remapping and target eval
1143 todo.push(Action::Up(t));
1144 todo.push(Action::Down(x));
1145 todo.push(Action::Down(y));
1146 todo.push(Action::Down(z));
1147 }
1148 TreeOp::RemapAffine { target, mat } => {
1149 let prev = affine
1150 .last()
1151 .cloned()
1152 .unwrap_or(Matrix4::identity());
1153 let mat = prev * mat.to_homogeneous();
1154
1155 // Push either an affine frame or an axis frame,
1156 // depending on whether the target is also affine
1157 if matches!(&**target, TreeOp::RemapAffine { .. }) {
1158 affine.push(mat);
1159 todo.push(Action::PopAffine);
1160 } else {
1161 let (x, y, z) = axes.last().unwrap();
1162 let mut out = [None; 3];
1163 for i in 0..3 {
1164 let a = self.mul(mat[(i, 0)], *x).unwrap();
1165 let b = self.mul(mat[(i, 1)], *y).unwrap();
1166 let c = self.mul(mat[(i, 2)], *z).unwrap();
1167 let d = self.constant(mat[(i, 3)]);
1168 let ab = self.add(a, b).unwrap();
1169 let cd = self.add(c, d).unwrap();
1170 out[i] = Some(self.add(ab, cd).unwrap());
1171 }
1172 let [x, y, z] = out.map(Option::unwrap);
1173 axes.push((x, y, z));
1174 todo.push(Action::Pop);
1175 }
1176 todo.push(Action::Down(target));
1177 }
1178 }
1179 }
1180 Action::Up(t) => {
1181 match t.as_ref() {
1182 TreeOp::Const(..)
1183 | TreeOp::Input(..)
1184 | TreeOp::RemapAffine { .. } => unreachable!(),
1185 TreeOp::Unary(op, ..) => {
1186 let arg = stack.pop().unwrap();
1187 let out = self.op_unary(arg, *op).unwrap();
1188 stack.push(out);
1189 }
1190 TreeOp::Binary(op, ..) => {
1191 let lhs = stack.pop().unwrap();
1192 let rhs = stack.pop().unwrap();
1193 // Call individual builders to apply optimizations
1194 let out = match op {
1195 BinaryOpcode::Add => self.add(lhs, rhs),
1196 BinaryOpcode::Sub => self.sub(lhs, rhs),
1197 BinaryOpcode::Mul => self.mul(lhs, rhs),
1198 BinaryOpcode::Div => self.div(lhs, rhs),
1199 BinaryOpcode::Atan => self.atan2(lhs, rhs),
1200 BinaryOpcode::Min => self.min(lhs, rhs),
1201 BinaryOpcode::Max => self.max(lhs, rhs),
1202 BinaryOpcode::Compare => self.compare(lhs, rhs),
1203 BinaryOpcode::Mod => self.modulo(lhs, rhs),
1204 BinaryOpcode::And => self.and(lhs, rhs),
1205 BinaryOpcode::Or => self.or(lhs, rhs),
1206 }
1207 .unwrap();
1208 if Arc::strong_count(t) > 1 {
1209 seen.insert(
1210 (*axes.last().unwrap(), Arc::as_ptr(t)),
1211 out,
1212 );
1213 }
1214 stack.push(out);
1215 }
1216 TreeOp::RemapAxes { target, .. } => {
1217 let x = stack.pop().unwrap();
1218 let y = stack.pop().unwrap();
1219 let z = stack.pop().unwrap();
1220 axes.push((x, y, z));
1221 todo.push(Action::Pop);
1222 todo.push(Action::Down(target));
1223 }
1224 }
1225 // Update the cache with the new tree, if relevant
1226 //
1227 // The `strong_count` check is a rough heuristic to avoid
1228 // caching if there's only a single owner of the tree. This
1229 // isn't perfect, but it doesn't need to be for correctness.
1230 if matches!(
1231 t.as_ref(),
1232 TreeOp::Unary(..) | TreeOp::Binary(..)
1233 ) && Arc::strong_count(t) > 1
1234 {
1235 seen.insert(
1236 (*axes.last().unwrap(), Arc::as_ptr(t)),
1237 *stack.last().unwrap(),
1238 );
1239 }
1240 }
1241 Action::Pop => {
1242 axes.pop().unwrap();
1243 }
1244 Action::PopAffine => {
1245 affine.pop().unwrap();
1246 }
1247 }
1248 }
1249 assert_eq!(stack.len(), 1);
1250 stack.pop().unwrap()
1251 }
1252
1253 /// Converts from a context-specific node into a standalone [`Tree`]
1254 pub fn export(&self, n: Node) -> Result<Tree, Error> {
1255 if self.get_op(n).is_none() {
1256 return Err(Error::BadNode);
1257 }
1258
1259 // Do recursion on the heap to avoid stack overflows for deep trees
1260 enum Action {
1261 /// Pushes `Up(n)` followed by `Down(n)` for each child
1262 Down(Node),
1263 /// Consumes trees from the stack and pushes a new tree
1264 Up(Node, Op),
1265 }
1266 let mut todo = vec![Action::Down(n)];
1267 let mut stack = vec![];
1268
1269 // Cache of Node -> Tree mapping, for Tree deduplication
1270 let mut seen: HashMap<Node, Tree> = HashMap::new();
1271
1272 while let Some(t) = todo.pop() {
1273 match t {
1274 Action::Down(n) => {
1275 // If we've already seen this TreeOp with these axes, then
1276 // we can return the previous Node.
1277 if let Some(p) = seen.get(&n) {
1278 stack.push(p.clone());
1279 continue;
1280 }
1281 let op = self.get_op(n).unwrap();
1282 match op {
1283 Op::Const(c) => {
1284 let t = Tree::from(c.0);
1285 seen.insert(n, t.clone());
1286 stack.push(t);
1287 }
1288 Op::Input(v) => {
1289 let t = Tree::from(*v);
1290 seen.insert(n, t.clone());
1291 stack.push(t);
1292 }
1293 Op::Unary(_op, arg) => {
1294 todo.push(Action::Up(n, *op));
1295 todo.push(Action::Down(*arg));
1296 }
1297 Op::Binary(_op, lhs, rhs) => {
1298 todo.push(Action::Up(n, *op));
1299 todo.push(Action::Down(*lhs));
1300 todo.push(Action::Down(*rhs));
1301 }
1302 }
1303 }
1304 Action::Up(n, op) => match op {
1305 Op::Const(..) | Op::Input(..) => unreachable!(),
1306 Op::Unary(op, ..) => {
1307 let arg = stack.pop().unwrap();
1308 let out =
1309 Tree::from(TreeOp::Unary(op, arg.arc().clone()));
1310 seen.insert(n, out.clone());
1311 stack.push(out);
1312 }
1313 Op::Binary(op, ..) => {
1314 let lhs = stack.pop().unwrap();
1315 let rhs = stack.pop().unwrap();
1316 let out = Tree::from(TreeOp::Binary(
1317 op,
1318 lhs.arc().clone(),
1319 rhs.arc().clone(),
1320 ));
1321 seen.insert(n, out.clone());
1322 stack.push(out);
1323 }
1324 },
1325 }
1326 }
1327 assert_eq!(stack.len(), 1);
1328 Ok(stack.pop().unwrap())
1329 }
1330
1331 /// Takes the symbolic derivative of a node with respect to a variable
1332 pub fn deriv(&mut self, n: Node, v: Var) -> Result<Node, Error> {
1333 if self.get_op(n).is_none() {
1334 return Err(Error::BadNode);
1335 }
1336
1337 // Do recursion on the heap to avoid stack overflows for deep trees
1338 enum Action {
1339 /// Pushes `Up(n)` followed by `Down(n)` for each child
1340 Down(Node),
1341 /// Consumes trees from the stack and pushes a new tree
1342 Up(Node, Op),
1343 }
1344 let mut todo = vec![Action::Down(n)];
1345 let mut stack = vec![];
1346 let zero = self.constant(0.0);
1347
1348 // Cache of Node -> Node mapping, for deduplication
1349 let mut seen: HashMap<Node, Node> = HashMap::new();
1350
1351 while let Some(t) = todo.pop() {
1352 match t {
1353 Action::Down(n) => {
1354 // If we've already seen this TreeOp with these axes, then
1355 // we can return the previous Node.
1356 if let Some(p) = seen.get(&n) {
1357 stack.push(*p);
1358 continue;
1359 }
1360 let op = *self.get_op(n).unwrap();
1361 match op {
1362 Op::Const(_c) => {
1363 seen.insert(n, zero);
1364 stack.push(zero);
1365 }
1366 Op::Input(u) => {
1367 let z =
1368 if v == u { self.constant(1.0) } else { zero };
1369 seen.insert(n, z);
1370 stack.push(z);
1371 }
1372 Op::Unary(_op, arg) => {
1373 todo.push(Action::Up(n, op));
1374 todo.push(Action::Down(arg));
1375 }
1376 Op::Binary(_op, lhs, rhs) => {
1377 todo.push(Action::Up(n, op));
1378 todo.push(Action::Down(lhs));
1379 todo.push(Action::Down(rhs));
1380 }
1381 }
1382 }
1383 Action::Up(n, op) => match op {
1384 Op::Const(..) | Op::Input(..) => unreachable!(),
1385 Op::Unary(op, v_arg) => {
1386 let d_arg = stack.pop().unwrap();
1387 let out = match op {
1388 UnaryOpcode::Neg => self.neg(d_arg),
1389 UnaryOpcode::Abs => {
1390 let cond = self.less_than(v_arg, zero).unwrap();
1391 let pos = d_arg;
1392 let neg = self.neg(d_arg).unwrap();
1393 self.if_nonzero_else(cond, neg, pos)
1394 }
1395 UnaryOpcode::Recip => {
1396 let a = self.square(v_arg).unwrap();
1397 let b = self.neg(d_arg).unwrap();
1398 self.div(b, a)
1399 }
1400 UnaryOpcode::Sqrt => {
1401 let v = self.mul(n, 2.0).unwrap();
1402 self.div(d_arg, v)
1403 }
1404 UnaryOpcode::Square => {
1405 let v = self.mul(d_arg, v_arg).unwrap();
1406 self.mul(2.0, v)
1407 }
1408 // Discontinuous constants don't have Dirac deltas
1409 UnaryOpcode::Floor
1410 | UnaryOpcode::Ceil
1411 | UnaryOpcode::Round => Ok(zero),
1412
1413 UnaryOpcode::Sin => {
1414 let c = self.cos(v_arg).unwrap();
1415 self.mul(c, d_arg)
1416 }
1417
1418 UnaryOpcode::Cos => {
1419 let s = self.sin(v_arg).unwrap();
1420 let s = self.neg(s).unwrap();
1421 self.mul(s, d_arg)
1422 }
1423
1424 UnaryOpcode::Tan => {
1425 let c = self.cos(v_arg).unwrap();
1426 let c = self.square(c).unwrap();
1427 self.div(d_arg, c)
1428 }
1429
1430 UnaryOpcode::Asin => {
1431 let v = self.square(v_arg).unwrap();
1432 let v = self.sub(1.0, v).unwrap();
1433 let v = self.sqrt(v).unwrap();
1434 self.div(d_arg, v)
1435 }
1436 UnaryOpcode::Acos => {
1437 let v = self.square(v_arg).unwrap();
1438 let v = self.sub(1.0, v).unwrap();
1439 let v = self.sqrt(v).unwrap();
1440 let v = self.neg(v).unwrap();
1441 self.div(d_arg, v)
1442 }
1443 UnaryOpcode::Atan => {
1444 let v = self.square(v_arg).unwrap();
1445 let v = self.add(1.0, v).unwrap();
1446 self.div(d_arg, v)
1447 }
1448 UnaryOpcode::Exp => self.mul(n, d_arg),
1449 UnaryOpcode::Ln => self.div(d_arg, v_arg),
1450 UnaryOpcode::Not => Ok(zero),
1451 }
1452 .unwrap();
1453 seen.insert(n, out);
1454 stack.push(out);
1455 }
1456 Op::Binary(op, v_lhs, v_rhs) => {
1457 let d_lhs = stack.pop().unwrap();
1458 let d_rhs = stack.pop().unwrap();
1459 let out = match op {
1460 BinaryOpcode::Add => self.add(d_lhs, d_rhs),
1461 BinaryOpcode::Sub => self.sub(d_lhs, d_rhs),
1462 BinaryOpcode::Mul => {
1463 let a = self.mul(d_lhs, v_rhs).unwrap();
1464 let b = self.mul(v_lhs, d_rhs).unwrap();
1465 self.add(a, b)
1466 }
1467 BinaryOpcode::Div => {
1468 let v = self.square(v_rhs).unwrap();
1469 let a = self.mul(v_rhs, d_lhs).unwrap();
1470 let b = self.mul(v_lhs, d_rhs).unwrap();
1471 let c = self.sub(a, b).unwrap();
1472 self.div(c, v)
1473 }
1474 BinaryOpcode::Atan => {
1475 let a = self.square(v_lhs).unwrap();
1476 let b = self.square(v_rhs).unwrap();
1477 let d = self.add(a, b).unwrap();
1478
1479 let a = self.mul(v_rhs, d_lhs).unwrap();
1480 let b = self.mul(v_lhs, d_rhs).unwrap();
1481 let v = self.sub(a, b).unwrap();
1482 self.div(v, d)
1483 }
1484 BinaryOpcode::Min => {
1485 let cond =
1486 self.less_than(v_lhs, v_rhs).unwrap();
1487 self.if_nonzero_else(cond, d_lhs, d_rhs)
1488 }
1489 BinaryOpcode::Max => {
1490 let cond =
1491 self.less_than(v_rhs, v_lhs).unwrap();
1492 self.if_nonzero_else(cond, d_lhs, d_rhs)
1493 }
1494 BinaryOpcode::Compare => Ok(zero),
1495 BinaryOpcode::Mod => {
1496 let e = self.div(v_lhs, v_rhs).unwrap();
1497 let q = self.floor(e).unwrap();
1498
1499 // XXX
1500 // (we don't actually have %, so hack it from
1501 // `modulo`, which is actually `rem_euclid`)
1502 // ???
1503 let m = self.modulo(q, v_rhs).unwrap();
1504 let cond = self.less_than(q, zero).unwrap();
1505 let offset = self
1506 .if_nonzero_else(cond, v_rhs, zero)
1507 .unwrap();
1508 let m = self.sub(m, offset).unwrap();
1509
1510 // Torn from the div_euclid implementation
1511 let outer = self.less_than(m, zero).unwrap();
1512 let inner =
1513 self.less_than(zero, v_rhs).unwrap();
1514 let qa = self.sub(q, 1.0).unwrap();
1515 let qb = self.add(q, 1.0).unwrap();
1516 let inner = self
1517 .if_nonzero_else(inner, qa, qb)
1518 .unwrap();
1519 let e = self
1520 .if_nonzero_else(outer, inner, q)
1521 .unwrap();
1522
1523 let v = self.mul(d_rhs, e).unwrap();
1524 self.sub(d_lhs, v)
1525 }
1526 BinaryOpcode::And => {
1527 let cond = self.compare(v_lhs, zero).unwrap();
1528 self.if_nonzero_else(cond, d_rhs, d_lhs)
1529 }
1530 BinaryOpcode::Or => {
1531 let cond = self.compare(v_lhs, zero).unwrap();
1532 self.if_nonzero_else(cond, d_lhs, d_rhs)
1533 }
1534 }
1535 .unwrap();
1536 seen.insert(n, out);
1537 stack.push(out);
1538 }
1539 },
1540 }
1541 }
1542 assert_eq!(stack.len(), 1);
1543 Ok(stack.pop().unwrap())
1544 }
1545}
1546
1547////////////////////////////////////////////////////////////////////////////////
1548/// Helper trait for things that can be converted into a [`Node`] given a
1549/// [`Context`].
1550///
1551/// This trait allows you to write
1552/// ```
1553/// # let mut ctx = fidget_core::context::Context::new();
1554/// let x = ctx.x();
1555/// let sum = ctx.add(x, 1.0).unwrap();
1556/// ```
1557/// instead of the more verbose
1558/// ```
1559/// # let mut ctx = fidget_core::context::Context::new();
1560/// let x = ctx.x();
1561/// let num = ctx.constant(1.0);
1562/// let sum = ctx.add(x, num).unwrap();
1563/// ```
1564pub trait IntoNode {
1565 /// Converts the given values into a node
1566 fn into_node(self, ctx: &mut Context) -> Result<Node, Error>;
1567}
1568
1569impl IntoNode for Node {
1570 fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1571 ctx.check_node(self)?;
1572 Ok(self)
1573 }
1574}
1575
1576impl IntoNode for f32 {
1577 fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1578 Ok(ctx.constant(self as f64))
1579 }
1580}
1581
1582impl IntoNode for f64 {
1583 fn into_node(self, ctx: &mut Context) -> Result<Node, Error> {
1584 Ok(ctx.constant(self))
1585 }
1586}
1587
1588////////////////////////////////////////////////////////////////////////////////
1589
1590#[cfg(test)]
1591mod test {
1592 use super::*;
1593 use crate::vm::VmData;
1594
1595 // This can't be in a doctest, because it uses a private function
1596 #[test]
1597 fn test_get_op() {
1598 let mut ctx = Context::new();
1599 let x = ctx.x();
1600 let op_x = ctx.get_op(x).unwrap();
1601 assert!(matches!(op_x, Op::Input(_)));
1602 }
1603
1604 #[test]
1605 fn test_ring() {
1606 let mut ctx = Context::new();
1607 let c0 = ctx.constant(0.5);
1608 let x = ctx.x();
1609 let y = ctx.y();
1610 let x2 = ctx.square(x).unwrap();
1611 let y2 = ctx.square(y).unwrap();
1612 let r = ctx.add(x2, y2).unwrap();
1613 let c6 = ctx.sub(r, c0).unwrap();
1614 let c7 = ctx.constant(0.25);
1615 let c8 = ctx.sub(c7, r).unwrap();
1616 let c9 = ctx.max(c8, c6).unwrap();
1617
1618 let tape = VmData::<255>::new(&ctx, &[c9]).unwrap();
1619 assert_eq!(tape.len(), 9);
1620 assert_eq!(tape.vars.len(), 2);
1621 }
1622
1623 #[test]
1624 fn test_dupe() {
1625 let mut ctx = Context::new();
1626 let x = ctx.x();
1627 let x_squared = ctx.mul(x, x).unwrap();
1628
1629 let tape = VmData::<255>::new(&ctx, &[x_squared]).unwrap();
1630 assert_eq!(tape.len(), 3); // x, square, output
1631 assert_eq!(tape.vars.len(), 1);
1632 }
1633
1634 #[test]
1635 fn test_export() {
1636 let mut ctx = Context::new();
1637 let x = ctx.x();
1638 let s = ctx.sin(x).unwrap();
1639 let c = ctx.cos(x).unwrap();
1640 let sum = ctx.add(s, c).unwrap();
1641 let t = ctx.export(sum).unwrap();
1642 if let TreeOp::Binary(BinaryOpcode::Add, lhs, rhs) = &*t {
1643 match (&**lhs, &**rhs) {
1644 (
1645 TreeOp::Unary(UnaryOpcode::Sin, x1),
1646 TreeOp::Unary(UnaryOpcode::Cos, x2),
1647 ) => {
1648 assert_eq!(Arc::as_ptr(x1), Arc::as_ptr(x2));
1649 let TreeOp::Input(Var::X) = &**x1 else {
1650 panic!("invalid X: {x1:?}");
1651 };
1652 }
1653 _ => panic!("invalid lhs / rhs: {lhs:?} {rhs:?}"),
1654 }
1655 } else {
1656 panic!("unexpected opcode {t:?}");
1657 }
1658 }
1659
1660 #[test]
1661 fn import_optimization() {
1662 let t = Tree::x() + 0;
1663 let mut ctx = Context::new();
1664 let root = ctx.import(&t);
1665 assert_eq!(ctx.get_op(root).unwrap(), &Op::Input(Var::X));
1666 }
1667}