1use super::Op;
2use crate::{Arity, ops::op_names};
3use radiate_core::random_provider;
4
5pub(super) const MAX_VALUE: f32 = 1e+10_f32;
6pub(super) const MIN_VALUE: f32 = -1e+10_f32;
7pub(super) const ONE: f32 = 1.0_f32;
8pub(super) const ZERO: f32 = 0.0_f32;
9pub(super) const TWO: f32 = 2.0_f32;
10pub(super) const HALF: f32 = 0.5_f32;
11pub(super) const TENTH: f32 = 0.1_f32;
12
13pub(super) const fn clamp(value: f32) -> f32 {
16 if value.is_nan() {
17 return ZERO;
18 }
19
20 value.clamp(MIN_VALUE, MAX_VALUE)
21}
22
23pub(super) fn aggregate(vals: &[f32]) -> f32 {
27 let len = vals.len();
28 if len == 0 {
29 return ZERO;
30 } else if len == 1 {
31 return vals[0];
32 } else if len == 2 {
33 return vals[0] + vals[1];
34 } else if len == 3 {
35 return vals[0] + vals[1] + vals[2];
36 } else if len == 4 {
37 return vals[0] + vals[1] + vals[2] + vals[3];
38 } else if len == 5 {
39 return vals[0] + vals[1] + vals[2] + vals[3] + vals[4];
40 }
41
42 vals.iter().cloned().sum::<f32>()
43}
44
45#[inline]
46const fn add(vals: &[f32]) -> f32 {
47 clamp(vals[0] + vals[1])
48}
49
50#[inline]
51const fn sub(vals: &[f32]) -> f32 {
52 clamp(vals[0] - vals[1])
53}
54
55#[inline]
56const fn mul(vals: &[f32]) -> f32 {
57 clamp(vals[0] * vals[1])
58}
59
60#[inline]
61const fn div(vals: &[f32]) -> f32 {
62 if vals[1].abs() < MIN_VALUE {
63 clamp(vals[0] / ONE)
64 } else {
65 clamp(vals[0] / vals[1])
66 }
67}
68
69#[inline]
70const fn neg(vals: &[f32]) -> f32 {
71 clamp(-vals[0])
72}
73
74#[inline]
75const fn abs(vals: &[f32]) -> f32 {
76 clamp(vals[0].abs())
77}
78
79#[inline]
80const fn ceil(vals: &[f32]) -> f32 {
81 clamp(vals[0].ceil())
82}
83
84#[inline]
85const fn floor(vals: &[f32]) -> f32 {
86 clamp(vals[0].floor())
87}
88
89pub enum AggregateOperations {
90 Sum,
91 Prod,
92 Diff,
93 Pow,
94 Sqrt,
95 Exp,
96 Log,
97 Sin,
98 Cos,
99 Tan,
100 Max,
101 Min,
102}
103
104impl AggregateOperations {
107 pub fn apply(&self, inputs: &[f32]) -> f32 {
108 match self {
109 AggregateOperations::Sum => clamp(aggregate(inputs)),
110 AggregateOperations::Diff => clamp(inputs.iter().cloned().fold(ZERO, |acc, x| acc - x)),
111 AggregateOperations::Prod => clamp(inputs.iter().product()),
112 AggregateOperations::Pow => clamp(inputs[0].powf(inputs[1])),
113 AggregateOperations::Sqrt => clamp(inputs[0].sqrt()),
114 AggregateOperations::Exp => clamp(inputs[0].exp()),
115 AggregateOperations::Log => clamp(if inputs[0] > ZERO {
116 inputs[0].ln()
117 } else {
118 ZERO
119 }),
120 AggregateOperations::Sin => clamp(inputs[0].sin()),
121 AggregateOperations::Cos => clamp(inputs[0].cos()),
122 AggregateOperations::Tan => clamp(inputs[0].tan()),
123 AggregateOperations::Max => clamp(inputs.iter().cloned().fold(MIN_VALUE, f32::max)),
124 AggregateOperations::Min => clamp(inputs.iter().cloned().fold(MAX_VALUE, f32::min)),
125 }
126 }
127}
128
129pub enum ActivationOperation {
130 Sigmoid,
131 Tanh,
132 ReLU,
133 LeakyReLU,
134 ELU,
135 Linear,
136 Mish,
137 Swish,
138 Softplus,
139}
140
141impl ActivationOperation {
146 #[inline]
147 pub fn apply(&self, inputs: &[f32]) -> f32 {
148 match self {
149 ActivationOperation::Sigmoid => {
150 let total = aggregate(inputs);
151 clamp(ONE / (ONE + (-total).exp()))
152 }
153 ActivationOperation::Tanh => {
154 let total = aggregate(inputs);
155 clamp(total.tanh())
156 }
157 ActivationOperation::ReLU => clamp(inputs.iter().cloned().sum::<f32>().max(ZERO)),
158 ActivationOperation::LeakyReLU => {
159 let x = clamp(inputs.iter().cloned().sum::<f32>());
160 if x > ZERO { x } else { clamp(HALF * x) }
161 }
162 ActivationOperation::ELU => {
163 let x = clamp(inputs.iter().cloned().sum::<f32>());
164 if x > ZERO {
165 x
166 } else {
167 clamp(HALF * (x.exp() - ONE))
168 }
169 }
170 ActivationOperation::Linear => clamp(inputs.iter().cloned().sum::<f32>()),
171 ActivationOperation::Mish => {
172 let x = clamp(inputs.iter().cloned().sum::<f32>());
173 clamp(x * (x.exp().ln_1p().tanh()))
174 }
175 ActivationOperation::Swish => {
176 let x = clamp(inputs.iter().cloned().sum::<f32>());
177 clamp(x / (ONE + (-x).exp()))
178 }
179 ActivationOperation::Softplus => {
180 let x = clamp(inputs.iter().cloned().sum::<f32>());
181 clamp(x.exp().ln_1p())
182 }
183 }
184 }
185}
186
187impl Op<f32> {
188 pub fn weight() -> Self {
189 Self::weight_with(random_provider::random::<f32>() * TWO - ONE)
190 }
191
192 pub fn weight_with(value: f32) -> Self {
193 let supplier = || random_provider::random::<f32>() * TWO - ONE;
194 let operation = |inputs: &[f32], weight: &f32| clamp(inputs[0] * weight);
195 let modifier = |current: &f32| {
196 let diff = (random_provider::random::<f32>() * TWO - ONE) * TENTH;
197 clamp(current + diff)
198 };
199
200 Op::MutableConst {
201 name: op_names::WEIGHT,
202 arity: 1.into(),
203 value: clamp(value),
204 supplier,
205 modifier,
206 operation,
207 }
208 }
209
210 pub fn add() -> Self {
211 Op::Fn(op_names::ADD, 2.into(), add)
212 }
213
214 pub fn sub() -> Self {
215 Op::Fn(op_names::SUB, 2.into(), sub)
216 }
217
218 pub fn mul() -> Self {
219 Op::Fn(op_names::MUL, 2.into(), mul)
220 }
221
222 pub fn div() -> Self {
223 Op::Fn(op_names::DIV, 2.into(), div)
224 }
225
226 pub fn sum() -> Self {
227 Op::Fn(op_names::SUM, Arity::Any, |inputs: &[f32]| {
228 AggregateOperations::Sum.apply(inputs)
229 })
230 }
231
232 pub fn diff() -> Self {
233 Op::Fn(op_names::DIFF, Arity::Any, |inputs: &[f32]| {
234 AggregateOperations::Diff.apply(inputs)
235 })
236 }
237
238 pub fn prod() -> Self {
239 Op::Fn(op_names::PROD, Arity::Any, |inputs: &[f32]| {
240 AggregateOperations::Prod.apply(inputs)
241 })
242 }
243
244 pub fn neg() -> Self {
245 Op::Fn(op_names::NEG, 1.into(), neg)
246 }
247
248 pub fn pow() -> Self {
249 Op::Fn(op_names::POW, 2.into(), |inputs: &[f32]| {
250 AggregateOperations::Pow.apply(inputs)
251 })
252 }
253
254 pub fn sqrt() -> Self {
255 Op::Fn(op_names::SQRT, 1.into(), |inputs: &[f32]| {
256 AggregateOperations::Sqrt.apply(inputs)
257 })
258 }
259
260 pub fn abs() -> Self {
261 Op::Fn(op_names::ABS, 1.into(), abs)
262 }
263
264 pub fn exp() -> Self {
265 Op::Fn(op_names::EXP, 1.into(), |inputs: &[f32]| {
266 AggregateOperations::Exp.apply(inputs)
267 })
268 }
269
270 pub fn log() -> Self {
271 Op::Fn(op_names::LOG, 1.into(), |inputs: &[f32]| {
272 AggregateOperations::Log.apply(inputs)
273 })
274 }
275
276 pub fn sin() -> Self {
277 Op::Fn(op_names::SIN, 1.into(), |inputs: &[f32]| {
278 AggregateOperations::Sin.apply(inputs)
279 })
280 }
281
282 pub fn cos() -> Self {
283 Op::Fn(op_names::COS, 1.into(), |inputs: &[f32]| {
284 AggregateOperations::Cos.apply(inputs)
285 })
286 }
287
288 pub fn max() -> Self {
289 Op::Fn(op_names::MAX, Arity::Any, |inputs: &[f32]| {
290 AggregateOperations::Max.apply(inputs)
291 })
292 }
293
294 pub fn min() -> Self {
295 Op::Fn(op_names::MIN, Arity::Any, |inputs: &[f32]| {
296 AggregateOperations::Min.apply(inputs)
297 })
298 }
299
300 pub fn tan() -> Self {
301 Op::Fn(op_names::TAN, 1.into(), |inputs: &[f32]| {
302 AggregateOperations::Tan.apply(inputs)
303 })
304 }
305
306 pub fn ceil() -> Self {
307 Op::Fn(op_names::CEIL, 1.into(), ceil)
308 }
309
310 pub fn floor() -> Self {
311 Op::Fn(op_names::FLOOR, 1.into(), floor)
312 }
313
314 pub fn sigmoid() -> Self {
315 Op::Fn(op_names::SIGMOID, Arity::Any, |inputs: &[f32]| {
316 ActivationOperation::Sigmoid.apply(inputs)
317 })
318 }
319
320 pub fn tanh() -> Self {
321 Op::Fn(op_names::TANH, Arity::Any, |inputs: &[f32]| {
322 ActivationOperation::Tanh.apply(inputs)
323 })
324 }
325
326 pub fn relu() -> Self {
327 Op::Fn(op_names::RELU, Arity::Any, |inputs: &[f32]| {
328 ActivationOperation::ReLU.apply(inputs)
329 })
330 }
331
332 pub fn leaky_relu() -> Self {
333 Op::Fn(op_names::LEAKY_RELU, Arity::Any, |inputs: &[f32]| {
334 ActivationOperation::LeakyReLU.apply(inputs)
335 })
336 }
337
338 pub fn elu() -> Self {
339 Op::Fn(op_names::ELU, Arity::Any, |inputs: &[f32]| {
340 ActivationOperation::ELU.apply(inputs)
341 })
342 }
343
344 pub fn linear() -> Self {
345 Op::Fn(op_names::LINEAR, Arity::Any, |inputs: &[f32]| {
346 ActivationOperation::Linear.apply(inputs)
347 })
348 }
349
350 pub fn mish() -> Self {
351 Op::Fn(op_names::MISH, Arity::Any, |inputs: &[f32]| {
352 ActivationOperation::Mish.apply(inputs)
353 })
354 }
355
356 pub fn swish() -> Self {
357 Op::Fn(op_names::SWISH, Arity::Any, |inputs: &[f32]| {
358 ActivationOperation::Swish.apply(inputs)
359 })
360 }
361
362 pub fn softplus() -> Self {
363 Op::Fn(op_names::SOFTPLUS, Arity::Any, |inputs: &[f32]| {
364 ActivationOperation::Softplus.apply(inputs)
365 })
366 }
367}
368
369pub fn math_ops() -> Vec<Op<f32>> {
371 vec![
372 Op::add(),
373 Op::sub(),
374 Op::mul(),
375 Op::div(),
376 Op::sum(),
377 Op::prod(),
378 Op::neg(),
379 Op::diff(),
380 Op::pow(),
381 Op::sqrt(),
382 Op::abs(),
383 Op::exp(),
384 Op::log(),
385 Op::sin(),
386 Op::cos(),
387 Op::tan(),
388 Op::ceil(),
389 Op::floor(),
390 Op::max(),
391 Op::min(),
392 ]
393}
394
395pub fn activation_ops() -> Vec<Op<f32>> {
397 vec![
398 Op::sigmoid(),
399 Op::tanh(),
400 Op::relu(),
401 Op::leaky_relu(),
402 Op::elu(),
403 Op::linear(),
404 Op::mish(),
405 Op::swish(),
406 Op::softplus(),
407 ]
408}
409
410pub fn all_ops() -> Vec<Op<f32>> {
412 math_ops().into_iter().chain(activation_ops()).collect()
413}
414
415#[cfg(test)]
416mod tests {
417 use crate::Eval;
418
419 use super::*;
420 use std::f32;
421
422 #[inline]
423 fn approx(a: f32, b: f32, eps: f32) -> bool {
424 (a - b).abs() <= eps
425 }
426
427 #[test]
428 fn clamp_behaves_as_specified() {
429 assert_eq!(super::clamp(f32::NAN), ZERO);
430 assert_eq!(super::clamp(1e20_f32), MAX_VALUE);
431 assert_eq!(super::clamp(-1e20_f32), MIN_VALUE);
432 assert_eq!(super::clamp(123.456), 123.456);
433 }
434
435 #[test]
436 fn math_div_near_zero_clamps_large_quotient() {
437 let xs = [10.0, 1e-12_f32];
438 let y = Op::div().eval(&xs);
439 assert_eq!(
440 y, MAX_VALUE,
441 "huge quotient should clamp to MAX_VALUE with current code"
442 );
443 }
444
445 #[test]
446 fn math_sum_prod_diff_pow_sqrt_abs() {
447 let xs = [2.0, 3.0, 4.0];
448 assert_eq!(AggregateOperations::Sum.apply(&xs), 9.0);
449 assert_eq!(AggregateOperations::Prod.apply(&xs), 24.0);
450 assert_eq!(AggregateOperations::Diff.apply(&xs), -9.0);
452
453 let p = AggregateOperations::Pow.apply(&[3.0, 2.0]);
454 assert_eq!(p, 9.0);
455
456 assert_eq!(AggregateOperations::Sqrt.apply(&[9.0]), 3.0);
457 }
458
459 #[test]
460 fn math_exp_log_trig_rounding() {
461 let e = AggregateOperations::Exp.apply(&[1.0]);
462 assert!(approx(e, f32::consts::E, 1e-5), "exp(1) ~= e");
463
464 assert_eq!(AggregateOperations::Log.apply(&[0.0]), 0.0);
466 assert_eq!(AggregateOperations::Log.apply(&[-1.0]), 0.0);
467
468 let s = AggregateOperations::Sin.apply(&[f32::consts::PI / 2.0]);
469 assert!(approx(s, 1.0, 1e-5));
470
471 let c = AggregateOperations::Cos.apply(&[0.0]);
472 assert!(approx(c, 1.0, 1e-5));
473
474 let t = AggregateOperations::Tan.apply(&[0.0]);
475 assert!(approx(t, 0.0, 1e-6));
476 }
477
478 #[test]
479 fn math_max_min_variadic_including_empty_behavior() {
480 let xs = [1.5, -2.0, 7.25, 3.0];
481 let mx = AggregateOperations::Max.apply(&xs);
482 let mn = AggregateOperations::Min.apply(&xs);
483 assert_eq!(mx, 7.25);
484 assert_eq!(mn, -2.0);
485
486 let empty: [f32; 0] = [];
487 assert_eq!(AggregateOperations::Max.apply(&empty), MIN_VALUE);
488 assert_eq!(AggregateOperations::Min.apply(&empty), MAX_VALUE);
489 }
490
491 #[test]
492 fn act_sigmoid_on_sum() {
493 let xs = [2.0, -1.0];
495 let y = ActivationOperation::Sigmoid.apply(&xs);
496 assert!(y > 0.73 && y < 0.74, "got {y}");
497 }
498
499 #[test]
500 fn act_tanh_on_sum() {
501 let xs = [2.0, -0.5]; let y = ActivationOperation::Tanh.apply(&xs);
503 assert!(y > 0.90 && y < 0.91, "got {y}");
504 }
505
506 #[test]
507 fn act_relu_and_leaky_and_elu_match_current_params() {
508 let xs = [-1.0, 0.25, 0.25]; assert_eq!(ActivationOperation::ReLU.apply(&xs), 0.0);
511
512 let xs2 = [-0.6];
514 let y2 = ActivationOperation::LeakyReLU.apply(&xs2);
515 assert_eq!(y2, -0.3);
516
517 let xs3 = [-1.0];
519 let y3 = ActivationOperation::ELU.apply(&xs3);
520 assert!(approx(y3, 0.5 * (f32::consts::E.powf(-1.0) - 1.0), 1e-6));
522 }
523
524 #[test]
525 fn act_linear_mish_swish_softplus() {
526 let xs = [1.0, 2.0, 3.0];
528 assert_eq!(ActivationOperation::Linear.apply(&xs), 6.0);
529
530 let x = 1.5_f32;
532 let mish_ref = x * ((x.exp().ln_1p()).tanh());
533 let mish_y = ActivationOperation::Mish.apply(&[x]);
534 assert!(approx(mish_y, mish_ref, 1e-6));
535
536 let sw = ActivationOperation::Swish.apply(&[x]);
538 let sw_ref = x / (1.0 + (-x).exp());
539 assert!(approx(sw, sw_ref, 1e-6));
540
541 let sp = ActivationOperation::Softplus.apply(&[x]);
543 let sp_ref = x.exp().ln_1p();
544 assert!(approx(sp, sp_ref, 1e-6));
545 }
546
547 #[test]
548 fn weight_op_runs_and_is_clamped() {
549 let w = Op::<f32>::weight();
550 if let Op::MutableConst {
551 operation, value, ..
552 } = &w
553 {
554 let out = (operation)(&[0.5], value);
555 assert!(out.is_finite());
556 assert!(out <= MAX_VALUE && out >= MIN_VALUE);
557 } else {
558 panic!("weight() did not return MutableConst as expected");
559 }
560 }
561}