node_engine/nodes/
math.rs

1use anyhow::Result;
2
3use crate::*;
4
5#[macro_export]
6macro_rules! impl_dyn_vec_trinary_node {
7  ( $mod_name:ident, $ty_name:ident, $docs:expr, $op:expr ) => {
8    $crate::impl_dyn_vec_trinary_node!(
9      $mod_name,
10      $ty_name,
11      $docs,
12      a,
13      "Input `A`.",
14      b,
15      "Input `B`.",
16      c,
17      "Input `C`.",
18      $op
19    );
20  };
21  ( $mod_name:ident, $ty_name:ident, $docs:expr, $a:ident, $a_doc:expr, $b:ident, $b_doc:expr, $c:ident, $c_doc:expr, $op:expr ) => {
22    $crate::impl_node! {
23      mod $mod_name {
24        NodeInfo {
25          name: $ty_name,
26          category: ["Math", "Basic"],
27        }
28
29        #[doc = $docs]
30        #[derive(Default)]
31        pub struct $ty_name {
32          #[doc = $a_doc]
33          pub $a: Input<DynamicVector>,
34          #[doc = $b_doc]
35          pub $b: Input<DynamicVector>,
36          #[doc = $c_doc]
37          pub $c: Input<DynamicVector>,
38          /// Output `out`.
39          pub out: Output<DynamicVector>,
40        }
41
42        impl $ty_name {
43          pub fn new() -> Self {
44            Default::default()
45          }
46        }
47
48        impl NodeImpl for $ty_name {
49          fn compile(
50            &self,
51            graph: &NodeGraph,
52            compile: &mut NodeGraphCompile,
53            id: NodeId,
54          ) -> Result<()> {
55            let (a, b, c) = self.resolve_inputs(graph, compile)?;
56            let code = format!($op, a, b, c);
57            self.out.compile(compile, id, stringify!($mod_name), code, a.dt)
58          }
59        }
60      }
61    }
62  };
63}
64impl_dyn_vec_trinary_node!(
65  lerp_node,
66  LerpNode,
67  "Linearly interpolating between inputs A and B by input T.",
68  a,
69  "Input A",
70  b,
71  "Input B",
72  t,
73  "Input T",
74  "mix({}, {}, {})"
75);
76impl_dyn_vec_trinary_node!(
77  clamp_node,
78  ClampNode,
79  "Clamp input between `min` and `max`.",
80  input,
81  "Unclamped input value",
82  min,
83  "Minimum value",
84  max,
85  "Maximum value",
86  "clamp({}, {}, {})"
87);
88
89#[macro_export]
90macro_rules! impl_dyn_vec_binary_node {
91  ( $mod_name:ident, $ty_name:ident, $name:expr, $docs:expr, $op:expr ) => {
92    $crate::impl_node! {
93      mod $mod_name {
94        NodeInfo {
95          name: $ty_name,
96          category: ["Math", "Basic"],
97        }
98
99        #[doc = $docs]
100        #[derive(Default)]
101        pub struct $ty_name {
102          /// Input `A`.
103          pub a: Input<DynamicVector>,
104          /// Input `B`.
105          pub b: Input<DynamicVector>,
106          /// Output `out`.
107          pub out: Output<DynamicVector>,
108        }
109
110        impl $ty_name {
111          pub fn new() -> Self {
112            Default::default()
113          }
114        }
115
116        impl NodeImpl for $ty_name {
117          fn compile(
118            &self,
119            graph: &NodeGraph,
120            compile: &mut NodeGraphCompile,
121            id: NodeId,
122          ) -> Result<()> {
123            let (a, b) = self.resolve_inputs(graph, compile)?;
124            let code = format!($op, a, b);
125            self.out.compile(compile, id, stringify!($mod_name), code, a.dt)
126          }
127        }
128      }
129    }
130  };
131}
132
133impl_dyn_vec_binary_node!(add_node, AddNode, "Add", "Add two vectors.", "({} + {})");
134impl_dyn_vec_binary_node!(
135  subtract_node,
136  SubtractNode,
137  "Subtract",
138  "Subtract two vectors.",
139  "({} - {})"
140);
141impl_dyn_vec_binary_node!(
142  divide_node,
143  DivideNode,
144  "Divide",
145  "Divide two vectors.",
146  "({} / {})"
147);
148impl_dyn_vec_binary_node!(
149  power_node,
150  PowerNode,
151  "Power",
152  "Output input `a` to the power of input `b`.",
153  "pow({}, {})"
154);
155impl_dyn_vec_binary_node!(
156  min_node,
157  MinNode,
158  "Minimum",
159  "Output the smallest of two inputs.",
160  "min({}, {})"
161);
162impl_dyn_vec_binary_node!(
163  max_node,
164  MaxNode,
165  "Maximum",
166  "Output the largest of two inputs.",
167  "max({}, {})"
168);
169
170#[macro_export]
171macro_rules! impl_dyn_vec_unary_node {
172  ( $mod_name:ident, $ty_name:ident, $name:expr, $desp:expr, $op:expr ) => {
173    $crate::impl_node! {
174      mod $mod_name {
175        NodeInfo {
176          name: $name,
177          description: $desp,
178          category: ["Math", "Basic"],
179        }
180
181        #[doc = $desp]
182        #[derive(Default)]
183        pub struct $ty_name {
184          /// Input `A`.
185          pub a: Input<DynamicVector>,
186          /// Output `out`.
187          pub out: Output<DynamicVector>,
188        }
189
190        impl $ty_name {
191          pub fn new() -> Self {
192            Default::default()
193          }
194        }
195
196        impl NodeImpl for $ty_name {
197          fn compile(
198            &self,
199            graph: &NodeGraph,
200            compile: &mut NodeGraphCompile,
201            id: NodeId,
202          ) -> Result<()> {
203            let a = self.resolve_inputs(graph, compile)?;
204            let code = format!($op, a);
205            self.out.compile(compile, id, stringify!($mod_name), code, a.dt)
206          }
207        }
208      }
209    }
210  };
211}
212impl_dyn_vec_unary_node!(
213  sqrt_node,
214  SquareRootNode,
215  "Square Root",
216  "Output the square root of input `a`.",
217  "sqrt({})"
218);
219impl_dyn_vec_unary_node!(
220  round_node,
221  RoundNode,
222  "Round",
223  "Round input `a` to the nearest integer.",
224  "round({})"
225);
226impl_dyn_vec_unary_node!(
227  floor_node,
228  FloorNode,
229  "Floor",
230  "Floor input `a`.",
231  "floor({})"
232);
233impl_dyn_vec_unary_node!(
234  fract_node,
235  FractionNode,
236  "Fraction",
237  "Fraction input `a`.",
238  "fract({})"
239);
240impl_dyn_vec_unary_node!(
241  ceiling_node,
242  CeilingNode,
243  "Ceiling",
244  "Ceiling input `a`.",
245  "ceil({})"
246);
247impl_dyn_vec_unary_node!(
248  truncate_node,
249  TruncateNode,
250  "Truncate",
251  "Truncate input `a`.",
252  "trunc({})"
253);
254impl_dyn_vec_unary_node!(
255  absolute_node,
256  AbsoluteNode,
257  "Absolute",
258  "Absolute input `a`.",
259  "abs({})"
260);
261
262impl_node! {
263  mod multiply_node {
264    NodeInfo {
265      name: "Multiply",
266      category: ["Math", "Basic"],
267    }
268
269    /// Multiply vectors and matrixes.
270    #[derive(Default)]
271    pub struct MultiplyNode {
272      /// Input `A`.
273      pub a: Input<Dynamic>,
274      /// Input `B`.
275      pub b: Input<Dynamic>,
276      /// Output.
277      pub out: Output<Dynamic>,
278    }
279
280    impl MultiplyNode {
281      pub fn new() -> Self {
282        Default::default()
283      }
284    }
285
286    impl NodeImpl for MultiplyNode {
287      fn compile(&self, graph: &NodeGraph, compile: &mut NodeGraphCompile, id: NodeId) -> Result<()> {
288        let (a, b) = self.resolve_inputs(graph, compile)?;
289        let (code, out_dt) = match (a.dt.class(), b.dt.class()) {
290          // Re-order so the vector is first.
291          (DataTypeClass::Matrix, DataTypeClass::Vector) =>
292            (format!("({b} * {a})"), b.dt),
293          // default to using the type of `a`.
294          _ => (format!("({a} * {b})"), a.dt),
295        };
296        self.out.compile(compile, id, "multiply_node", code, out_dt)
297      }
298    }
299  }
300}