radiate_gp/ops/
math.rs

1use super::Op;
2use crate::{Arity, ops::op_names};
3use radiate_core::random_provider;
4
5pub(super) const MAX_VALUE: f32 = 1e+10_f32;
6pub(super) const MIN_VALUE: f32 = -1e+10_f32;
7pub(super) const ONE: f32 = 1.0_f32;
8pub(super) const ZERO: f32 = 0.0_f32;
9pub(super) const TWO: f32 = 2.0_f32;
10pub(super) const HALF: f32 = 0.5_f32;
11pub(super) const TENTH: f32 = 0.1_f32;
12
13/// Clamp a value to the range [-1e+10, 1e+10]. Without this, values can quickly become
14/// too large or too small to be useful.
15pub(super) const fn clamp(value: f32) -> f32 {
16    if value.is_nan() {
17        return ZERO;
18    }
19
20    value.clamp(MIN_VALUE, MAX_VALUE)
21}
22
23/// Aggregate a slice of 'f32' values by summing them, then applying a function to the result.
24/// There usually arent too many inputs, so we can use an if statement to handle a few of the
25/// common cases - vals with a len <= 5.
26pub(super) fn aggregate(vals: &[f32]) -> f32 {
27    let len = vals.len();
28    if len == 0 {
29        return ZERO;
30    } else if len == 1 {
31        return vals[0];
32    } else if len == 2 {
33        return vals[0] + vals[1];
34    } else if len == 3 {
35        return vals[0] + vals[1] + vals[2];
36    } else if len == 4 {
37        return vals[0] + vals[1] + vals[2] + vals[3];
38    } else if len == 5 {
39        return vals[0] + vals[1] + vals[2] + vals[3] + vals[4];
40    }
41
42    vals.iter().cloned().sum::<f32>()
43}
44
45#[inline]
46const fn add(vals: &[f32]) -> f32 {
47    clamp(vals[0] + vals[1])
48}
49
50#[inline]
51const fn sub(vals: &[f32]) -> f32 {
52    clamp(vals[0] - vals[1])
53}
54
55#[inline]
56const fn mul(vals: &[f32]) -> f32 {
57    clamp(vals[0] * vals[1])
58}
59
60#[inline]
61const fn div(vals: &[f32]) -> f32 {
62    if vals[1].abs() < MIN_VALUE {
63        clamp(vals[0] / ONE)
64    } else {
65        clamp(vals[0] / vals[1])
66    }
67}
68
69#[inline]
70const fn neg(vals: &[f32]) -> f32 {
71    clamp(-vals[0])
72}
73
74#[inline]
75const fn abs(vals: &[f32]) -> f32 {
76    clamp(vals[0].abs())
77}
78
79#[inline]
80const fn ceil(vals: &[f32]) -> f32 {
81    clamp(vals[0].ceil())
82}
83
84#[inline]
85const fn floor(vals: &[f32]) -> f32 {
86    clamp(vals[0].floor())
87}
88
89pub enum AggregateOperations {
90    Sum,
91    Prod,
92    Diff,
93    Pow,
94    Sqrt,
95    Exp,
96    Log,
97    Sin,
98    Cos,
99    Tan,
100    Max,
101    Min,
102}
103
104/// Implementations of the [MathOperation] enum. These are the basic math operations.
105/// Each operation takes a slice of `f32` values and returns a single `f32` value.
106impl AggregateOperations {
107    pub fn apply(&self, inputs: &[f32]) -> f32 {
108        match self {
109            AggregateOperations::Sum => clamp(aggregate(inputs)),
110            AggregateOperations::Diff => clamp(inputs.iter().cloned().fold(ZERO, |acc, x| acc - x)),
111            AggregateOperations::Prod => clamp(inputs.iter().product()),
112            AggregateOperations::Pow => clamp(inputs[0].powf(inputs[1])),
113            AggregateOperations::Sqrt => clamp(inputs[0].sqrt()),
114            AggregateOperations::Exp => clamp(inputs[0].exp()),
115            AggregateOperations::Log => clamp(if inputs[0] > ZERO {
116                inputs[0].ln()
117            } else {
118                ZERO
119            }),
120            AggregateOperations::Sin => clamp(inputs[0].sin()),
121            AggregateOperations::Cos => clamp(inputs[0].cos()),
122            AggregateOperations::Tan => clamp(inputs[0].tan()),
123            AggregateOperations::Max => clamp(inputs.iter().cloned().fold(MIN_VALUE, f32::max)),
124            AggregateOperations::Min => clamp(inputs.iter().cloned().fold(MAX_VALUE, f32::min)),
125        }
126    }
127}
128
129pub enum ActivationOperation {
130    Sigmoid,
131    Tanh,
132    ReLU,
133    LeakyReLU,
134    ELU,
135    Linear,
136    Mish,
137    Swish,
138    Softplus,
139}
140
141/// Implementations of the [ActivationOperation] enum. These are the basic activation functions used
142/// in neural networks. However, they are particularly useful in this context because they can
143/// accept any number of inputs. Thus, they act as reducers or aggregates and are a key part of
144/// being able to define complex 'Graph' and 'Tree' structures.
145impl ActivationOperation {
146    #[inline]
147    pub fn apply(&self, inputs: &[f32]) -> f32 {
148        match self {
149            ActivationOperation::Sigmoid => {
150                let total = aggregate(inputs);
151                clamp(ONE / (ONE + (-total).exp()))
152            }
153            ActivationOperation::Tanh => {
154                let total = aggregate(inputs);
155                clamp(total.tanh())
156            }
157            ActivationOperation::ReLU => clamp(inputs.iter().cloned().sum::<f32>().max(ZERO)),
158            ActivationOperation::LeakyReLU => {
159                let x = clamp(inputs.iter().cloned().sum::<f32>());
160                if x > ZERO { x } else { clamp(HALF * x) }
161            }
162            ActivationOperation::ELU => {
163                let x = clamp(inputs.iter().cloned().sum::<f32>());
164                if x > ZERO {
165                    x
166                } else {
167                    clamp(HALF * (x.exp() - ONE))
168                }
169            }
170            ActivationOperation::Linear => clamp(inputs.iter().cloned().sum::<f32>()),
171            ActivationOperation::Mish => {
172                let x = clamp(inputs.iter().cloned().sum::<f32>());
173                clamp(x * (x.exp().ln_1p().tanh()))
174            }
175            ActivationOperation::Swish => {
176                let x = clamp(inputs.iter().cloned().sum::<f32>());
177                clamp(x / (ONE + (-x).exp()))
178            }
179            ActivationOperation::Softplus => {
180                let x = clamp(inputs.iter().cloned().sum::<f32>());
181                clamp(x.exp().ln_1p())
182            }
183        }
184    }
185}
186
187impl Op<f32> {
188    pub fn weight() -> Self {
189        Self::weight_with(random_provider::random::<f32>() * TWO - ONE)
190    }
191
192    pub fn weight_with(value: f32) -> Self {
193        let supplier = || random_provider::random::<f32>() * TWO - ONE;
194        let operation = |inputs: &[f32], weight: &f32| clamp(inputs[0] * weight);
195        let modifier = |current: &f32| {
196            let diff = (random_provider::random::<f32>() * TWO - ONE) * TENTH;
197            clamp(current + diff)
198        };
199
200        Op::MutableConst {
201            name: op_names::WEIGHT,
202            arity: 1.into(),
203            value: clamp(value),
204            supplier,
205            modifier,
206            operation,
207        }
208    }
209
210    pub fn add() -> Self {
211        Op::Fn(op_names::ADD, 2.into(), add)
212    }
213
214    pub fn sub() -> Self {
215        Op::Fn(op_names::SUB, 2.into(), sub)
216    }
217
218    pub fn mul() -> Self {
219        Op::Fn(op_names::MUL, 2.into(), mul)
220    }
221
222    pub fn div() -> Self {
223        Op::Fn(op_names::DIV, 2.into(), div)
224    }
225
226    pub fn sum() -> Self {
227        Op::Fn(op_names::SUM, Arity::Any, |inputs: &[f32]| {
228            AggregateOperations::Sum.apply(inputs)
229        })
230    }
231
232    pub fn diff() -> Self {
233        Op::Fn(op_names::DIFF, Arity::Any, |inputs: &[f32]| {
234            AggregateOperations::Diff.apply(inputs)
235        })
236    }
237
238    pub fn prod() -> Self {
239        Op::Fn(op_names::PROD, Arity::Any, |inputs: &[f32]| {
240            AggregateOperations::Prod.apply(inputs)
241        })
242    }
243
244    pub fn neg() -> Self {
245        Op::Fn(op_names::NEG, 1.into(), neg)
246    }
247
248    pub fn pow() -> Self {
249        Op::Fn(op_names::POW, 2.into(), |inputs: &[f32]| {
250            AggregateOperations::Pow.apply(inputs)
251        })
252    }
253
254    pub fn sqrt() -> Self {
255        Op::Fn(op_names::SQRT, 1.into(), |inputs: &[f32]| {
256            AggregateOperations::Sqrt.apply(inputs)
257        })
258    }
259
260    pub fn abs() -> Self {
261        Op::Fn(op_names::ABS, 1.into(), abs)
262    }
263
264    pub fn exp() -> Self {
265        Op::Fn(op_names::EXP, 1.into(), |inputs: &[f32]| {
266            AggregateOperations::Exp.apply(inputs)
267        })
268    }
269
270    pub fn log() -> Self {
271        Op::Fn(op_names::LOG, 1.into(), |inputs: &[f32]| {
272            AggregateOperations::Log.apply(inputs)
273        })
274    }
275
276    pub fn sin() -> Self {
277        Op::Fn(op_names::SIN, 1.into(), |inputs: &[f32]| {
278            AggregateOperations::Sin.apply(inputs)
279        })
280    }
281
282    pub fn cos() -> Self {
283        Op::Fn(op_names::COS, 1.into(), |inputs: &[f32]| {
284            AggregateOperations::Cos.apply(inputs)
285        })
286    }
287
288    pub fn max() -> Self {
289        Op::Fn(op_names::MAX, Arity::Any, |inputs: &[f32]| {
290            AggregateOperations::Max.apply(inputs)
291        })
292    }
293
294    pub fn min() -> Self {
295        Op::Fn(op_names::MIN, Arity::Any, |inputs: &[f32]| {
296            AggregateOperations::Min.apply(inputs)
297        })
298    }
299
300    pub fn tan() -> Self {
301        Op::Fn(op_names::TAN, 1.into(), |inputs: &[f32]| {
302            AggregateOperations::Tan.apply(inputs)
303        })
304    }
305
306    pub fn ceil() -> Self {
307        Op::Fn(op_names::CEIL, 1.into(), ceil)
308    }
309
310    pub fn floor() -> Self {
311        Op::Fn(op_names::FLOOR, 1.into(), floor)
312    }
313
314    pub fn sigmoid() -> Self {
315        Op::Fn(op_names::SIGMOID, Arity::Any, |inputs: &[f32]| {
316            ActivationOperation::Sigmoid.apply(inputs)
317        })
318    }
319
320    pub fn tanh() -> Self {
321        Op::Fn(op_names::TANH, Arity::Any, |inputs: &[f32]| {
322            ActivationOperation::Tanh.apply(inputs)
323        })
324    }
325
326    pub fn relu() -> Self {
327        Op::Fn(op_names::RELU, Arity::Any, |inputs: &[f32]| {
328            ActivationOperation::ReLU.apply(inputs)
329        })
330    }
331
332    pub fn leaky_relu() -> Self {
333        Op::Fn(op_names::LEAKY_RELU, Arity::Any, |inputs: &[f32]| {
334            ActivationOperation::LeakyReLU.apply(inputs)
335        })
336    }
337
338    pub fn elu() -> Self {
339        Op::Fn(op_names::ELU, Arity::Any, |inputs: &[f32]| {
340            ActivationOperation::ELU.apply(inputs)
341        })
342    }
343
344    pub fn linear() -> Self {
345        Op::Fn(op_names::LINEAR, Arity::Any, |inputs: &[f32]| {
346            ActivationOperation::Linear.apply(inputs)
347        })
348    }
349
350    pub fn mish() -> Self {
351        Op::Fn(op_names::MISH, Arity::Any, |inputs: &[f32]| {
352            ActivationOperation::Mish.apply(inputs)
353        })
354    }
355
356    pub fn swish() -> Self {
357        Op::Fn(op_names::SWISH, Arity::Any, |inputs: &[f32]| {
358            ActivationOperation::Swish.apply(inputs)
359        })
360    }
361
362    pub fn softplus() -> Self {
363        Op::Fn(op_names::SOFTPLUS, Arity::Any, |inputs: &[f32]| {
364            ActivationOperation::Softplus.apply(inputs)
365        })
366    }
367}
368
369/// Get a list of all the math operations.
370pub fn math_ops() -> Vec<Op<f32>> {
371    vec![
372        Op::add(),
373        Op::sub(),
374        Op::mul(),
375        Op::div(),
376        Op::sum(),
377        Op::prod(),
378        Op::neg(),
379        Op::diff(),
380        Op::pow(),
381        Op::sqrt(),
382        Op::abs(),
383        Op::exp(),
384        Op::log(),
385        Op::sin(),
386        Op::cos(),
387        Op::tan(),
388        Op::ceil(),
389        Op::floor(),
390        Op::max(),
391        Op::min(),
392    ]
393}
394
395/// Get a list of all the activation operations.
396pub fn activation_ops() -> Vec<Op<f32>> {
397    vec![
398        Op::sigmoid(),
399        Op::tanh(),
400        Op::relu(),
401        Op::leaky_relu(),
402        Op::elu(),
403        Op::linear(),
404        Op::mish(),
405        Op::swish(),
406        Op::softplus(),
407    ]
408}
409
410/// Get a list of all the operations.
411pub fn all_ops() -> Vec<Op<f32>> {
412    math_ops().into_iter().chain(activation_ops()).collect()
413}
414
415#[cfg(test)]
416mod tests {
417    use crate::Eval;
418
419    use super::*;
420    use std::f32;
421
422    #[inline]
423    fn approx(a: f32, b: f32, eps: f32) -> bool {
424        (a - b).abs() <= eps
425    }
426
427    #[test]
428    fn clamp_behaves_as_specified() {
429        assert_eq!(super::clamp(f32::NAN), ZERO);
430        assert_eq!(super::clamp(1e20_f32), MAX_VALUE);
431        assert_eq!(super::clamp(-1e20_f32), MIN_VALUE);
432        assert_eq!(super::clamp(123.456), 123.456);
433    }
434
435    #[test]
436    fn math_div_near_zero_clamps_large_quotient() {
437        let xs = [10.0, 1e-12_f32];
438        let y = Op::div().eval(&xs);
439        assert_eq!(
440            y, MAX_VALUE,
441            "huge quotient should clamp to MAX_VALUE with current code"
442        );
443    }
444
445    #[test]
446    fn math_sum_prod_diff_pow_sqrt_abs() {
447        let xs = [2.0, 3.0, 4.0];
448        assert_eq!(AggregateOperations::Sum.apply(&xs), 9.0);
449        assert_eq!(AggregateOperations::Prod.apply(&xs), 24.0);
450        // Diff is left fold from ZERO: (((0-2)-3)-4) = -9
451        assert_eq!(AggregateOperations::Diff.apply(&xs), -9.0);
452
453        let p = AggregateOperations::Pow.apply(&[3.0, 2.0]);
454        assert_eq!(p, 9.0);
455
456        assert_eq!(AggregateOperations::Sqrt.apply(&[9.0]), 3.0);
457    }
458
459    #[test]
460    fn math_exp_log_trig_rounding() {
461        let e = AggregateOperations::Exp.apply(&[1.0]);
462        assert!(approx(e, f32::consts::E, 1e-5), "exp(1) ~= e");
463
464        // log on <=0 becomes NaN, then clamp -> 0.0
465        assert_eq!(AggregateOperations::Log.apply(&[0.0]), 0.0);
466        assert_eq!(AggregateOperations::Log.apply(&[-1.0]), 0.0);
467
468        let s = AggregateOperations::Sin.apply(&[f32::consts::PI / 2.0]);
469        assert!(approx(s, 1.0, 1e-5));
470
471        let c = AggregateOperations::Cos.apply(&[0.0]);
472        assert!(approx(c, 1.0, 1e-5));
473
474        let t = AggregateOperations::Tan.apply(&[0.0]);
475        assert!(approx(t, 0.0, 1e-6));
476    }
477
478    #[test]
479    fn math_max_min_variadic_including_empty_behavior() {
480        let xs = [1.5, -2.0, 7.25, 3.0];
481        let mx = AggregateOperations::Max.apply(&xs);
482        let mn = AggregateOperations::Min.apply(&xs);
483        assert_eq!(mx, 7.25);
484        assert_eq!(mn, -2.0);
485
486        let empty: [f32; 0] = [];
487        assert_eq!(AggregateOperations::Max.apply(&empty), MIN_VALUE);
488        assert_eq!(AggregateOperations::Min.apply(&empty), MAX_VALUE);
489    }
490
491    #[test]
492    fn act_sigmoid_on_sum() {
493        // sum = 1.0 -> sigmoid(1) ~ 0.731
494        let xs = [2.0, -1.0];
495        let y = ActivationOperation::Sigmoid.apply(&xs);
496        assert!(y > 0.73 && y < 0.74, "got {y}");
497    }
498
499    #[test]
500    fn act_tanh_on_sum() {
501        let xs = [2.0, -0.5]; // sum = 1.5 -> tanh(1.5) ~ 0.9051
502        let y = ActivationOperation::Tanh.apply(&xs);
503        assert!(y > 0.90 && y < 0.91, "got {y}");
504    }
505
506    #[test]
507    fn act_relu_and_leaky_and_elu_match_current_params() {
508        // ReLU(sum)
509        let xs = [-1.0, 0.25, 0.25]; // sum = -0.5
510        assert_eq!(ActivationOperation::ReLU.apply(&xs), 0.0);
511
512        // LeakyReLU uses slope HALF (=0.5) per current code
513        let xs2 = [-0.6];
514        let y2 = ActivationOperation::LeakyReLU.apply(&xs2);
515        assert_eq!(y2, -0.3);
516
517        // ELU uses alpha HALF (=0.5) currently
518        let xs3 = [-1.0];
519        let y3 = ActivationOperation::ELU.apply(&xs3);
520        // 0.5 * (exp(-1) - 1) ~= -0.316060...
521        assert!(approx(y3, 0.5 * (f32::consts::E.powf(-1.0) - 1.0), 1e-6));
522    }
523
524    #[test]
525    fn act_linear_mish_swish_softplus() {
526        // Linear is sum
527        let xs = [1.0, 2.0, 3.0];
528        assert_eq!(ActivationOperation::Linear.apply(&xs), 6.0);
529
530        // Mish ~ x * tanh(ln(1+exp(x))) at sum(x)
531        let x = 1.5_f32;
532        let mish_ref = x * ((x.exp().ln_1p()).tanh());
533        let mish_y = ActivationOperation::Mish.apply(&[x]);
534        assert!(approx(mish_y, mish_ref, 1e-6));
535
536        // Swish ~ x * sigmoid(x); implementation uses x / (1 + exp(-x))
537        let sw = ActivationOperation::Swish.apply(&[x]);
538        let sw_ref = x / (1.0 + (-x).exp());
539        assert!(approx(sw, sw_ref, 1e-6));
540
541        // Softplus = ln(1 + exp(x))
542        let sp = ActivationOperation::Softplus.apply(&[x]);
543        let sp_ref = x.exp().ln_1p();
544        assert!(approx(sp, sp_ref, 1e-6));
545    }
546
547    #[test]
548    fn weight_op_runs_and_is_clamped() {
549        let w = Op::<f32>::weight();
550        if let Op::MutableConst {
551            operation, value, ..
552        } = &w
553        {
554            let out = (operation)(&[0.5], value);
555            assert!(out.is_finite());
556            assert!(out <= MAX_VALUE && out >= MIN_VALUE);
557        } else {
558            panic!("weight() did not return MutableConst as expected");
559        }
560    }
561}