Skip to main content

tang_expr/
node.rs

1//! Expression node types and ExprId handle.
2
3use std::fmt;
4
5/// Handle into the expression graph. Lightweight (4 bytes), Copy.
6#[derive(Clone, Copy, PartialEq, Eq, Hash)]
7pub struct ExprId(pub(crate) u32);
8
9/// Well-known node indices, pre-populated in every graph.
10impl ExprId {
11    /// The constant 0.0 (index 0).
12    pub const ZERO: Self = Self(0);
13    /// The constant 1.0 (index 1).
14    pub const ONE: Self = Self(1);
15    /// The constant 2.0 (index 2).
16    pub const TWO: Self = Self(2);
17
18    /// Create an ExprId from a raw index.
19    #[inline]
20    pub fn from_index(index: u32) -> Self {
21        Self(index)
22    }
23
24    /// The raw index of this expression in the graph.
25    #[inline]
26    pub fn index(&self) -> u32 {
27        self.0
28    }
29}
30
31impl fmt::Debug for ExprId {
32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33        write!(f, "e{}", self.0)
34    }
35}
36
37impl fmt::Display for ExprId {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        write!(f, "e{}", self.0)
40    }
41}
42
43impl Default for ExprId {
44    fn default() -> Self {
45        Self::ZERO
46    }
47}
48
49impl PartialOrd for ExprId {
50    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
51        Some(self.0.cmp(&other.0))
52    }
53}
54
55/// A node in the expression graph.
56///
57/// 9 RISC primitive operations + 2 atom types. Every higher-level math
58/// operation decomposes into these primitives via the `Scalar` impl.
59#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
60pub enum Node {
61    // Atoms
62    /// Input variable by index.
63    Var(u16),
64    /// Literal f64 value stored as bits for Hash/Eq.
65    Lit(u64),
66
67    // RISC ops (9 primitives)
68    /// Addition.
69    Add(ExprId, ExprId),
70    /// Multiplication.
71    Mul(ExprId, ExprId),
72    /// Negation.
73    Neg(ExprId),
74    /// Reciprocal (1/x).
75    Recip(ExprId),
76    /// Square root.
77    Sqrt(ExprId),
78    /// Sine (only trig primitive).
79    Sin(ExprId),
80    /// Two-argument arctangent atan2(y, x).
81    Atan2(ExprId, ExprId),
82    /// Base-2 exponential (2^x).
83    Exp2(ExprId),
84    /// Base-2 logarithm.
85    Log2(ExprId),
86    /// Branchless select: returns `a` if `cond > 0`, else `b`.
87    Select(ExprId, ExprId, ExprId),
88}
89
90impl Node {
91    /// Create a `Lit` node from an f64 value.
92    #[inline]
93    pub fn lit(v: f64) -> Self {
94        Self::Lit(v.to_bits())
95    }
96
97    /// Extract f64 value from a `Lit` node, or `None`.
98    #[inline]
99    pub fn as_f64(&self) -> Option<f64> {
100        match self {
101            Self::Lit(bits) => Some(f64::from_bits(*bits)),
102            _ => None,
103        }
104    }
105}