1#[inline]
26#[must_use]
27pub fn silu_scalar(x: f32) -> f32 {
28 x / (1.0 + (-x).exp())
29}
30
31#[inline]
47#[must_use]
48pub fn gelu_scalar(x: f32) -> f32 {
49 let c = (2.0_f32 / std::f32::consts::PI).sqrt();
50 0.5 * x * (1.0 + (c * (x + 0.044_715 * x * x * x)).tanh())
51}
52
53#[inline]
66#[must_use]
67pub fn sigmoid_scalar(x: f32) -> f32 {
68 1.0 / (1.0 + (-x).exp())
69}
70
71#[inline]
84#[must_use]
85pub fn relu_scalar(x: f32) -> f32 {
86 x.max(0.0)
87}
88
89#[inline]
102#[must_use]
103pub fn tanh_scalar(x: f32) -> f32 {
104 x.tanh()
105}
106
107#[inline]
118#[must_use]
119pub fn f16_to_f32(bits: u16) -> f32 {
120 let sign = (bits >> 15) & 0x1;
121 let exponent = (bits >> 10) & 0x1F;
122 let mantissa = bits & 0x3FF;
123
124 if exponent != 0 && exponent != 31 {
126 let f32_exp = (exponent as u32 + 112) as u32; let f32_mant = (mantissa as u32) << 13; let f32_bits = ((sign as u32) << 31) | (f32_exp << 23) | f32_mant;
129 return f32::from_bits(f32_bits);
130 }
131
132 if exponent == 0 {
134 if mantissa == 0 {
135 return if sign == 1 { -0.0 } else { 0.0 };
136 }
137 const TWO_POW_NEG_14: f32 = 6.103_515_625e-5; let m = mantissa as f32 * (1.0 / 1024.0);
140 let result = m * TWO_POW_NEG_14;
141 return if sign == 1 { -result } else { result };
142 }
143
144 if mantissa == 0 {
146 if sign == 1 {
147 f32::NEG_INFINITY
148 } else {
149 f32::INFINITY
150 }
151 } else {
152 f32::NAN
153 }
154}
155
156#[inline]
165#[must_use]
166pub fn f32_to_f16(x: f32) -> u16 {
167 let bits = x.to_bits();
168 let sign = ((bits >> 16) & 0x8000) as u16;
169 let exponent = ((bits >> 23) & 0xFF) as i32;
170 let mantissa = bits & 0x007F_FFFF;
171
172 if exponent == 255 {
174 if mantissa == 0 {
176 return sign | 0x7C00; }
178 return sign | 0x7C00 | ((mantissa >> 13) as u16).max(1); }
180
181 let new_exp = exponent - 112; if new_exp >= 31 {
185 return sign | 0x7C00; }
187 if new_exp <= 0 {
188 if new_exp < -10 {
190 return sign; }
192 let mant = (mantissa | 0x0080_0000) >> (1 - new_exp + 13);
193 return sign | mant as u16;
194 }
195
196 let round_bit = (mantissa >> 12) & 1;
198 let mant16 = ((mantissa >> 13) as u16) + round_bit as u16;
199 sign | ((new_exp as u16) << 10) | (mant16 & 0x03FF)
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_silu_zero() {
208 assert!((silu_scalar(0.0)).abs() < 1e-7);
209 }
210
211 #[test]
212 fn test_silu_positive() {
213 let x = 10.0;
215 assert!((silu_scalar(x) - x).abs() < 0.01);
216 }
217
218 #[test]
219 fn test_silu_negative() {
220 assert!(silu_scalar(-10.0).abs() < 0.01);
222 }
223
224 #[test]
225 fn test_gelu_zero() {
226 assert!((gelu_scalar(0.0)).abs() < 1e-7);
227 }
228
229 #[test]
230 fn test_gelu_positive() {
231 let x = 10.0;
232 assert!((gelu_scalar(x) - x).abs() < 0.01);
233 }
234
235 #[test]
236 fn test_sigmoid_zero() {
237 assert!((sigmoid_scalar(0.0) - 0.5).abs() < 1e-7);
238 }
239
240 #[test]
241 fn test_sigmoid_symmetry() {
242 let x = 2.5;
243 assert!((sigmoid_scalar(x) + sigmoid_scalar(-x) - 1.0).abs() < 1e-6);
244 }
245
246 #[test]
247 fn test_relu_positive() {
248 assert!((relu_scalar(3.0) - 3.0).abs() < 1e-7);
249 }
250
251 #[test]
252 fn test_relu_negative() {
253 assert!((relu_scalar(-3.0)).abs() < 1e-7);
254 }
255
256 #[test]
257 fn test_tanh_zero() {
258 assert!((tanh_scalar(0.0)).abs() < 1e-7);
259 }
260
261 #[test]
262 fn test_tanh_odd() {
263 let x = 1.5;
264 assert!((tanh_scalar(x) + tanh_scalar(-x)).abs() < 1e-6);
265 }
266
267 #[test]
268 fn test_f16_roundtrip() {
269 let val = 1.5_f32;
270 let bits = f32_to_f16(val);
271 let back = f16_to_f32(bits);
272 assert!((val - back).abs() < 1e-3);
273 }
274
275 #[test]
276 fn test_f16_zero() {
277 assert_eq!(f16_to_f32(0), 0.0);
278 }
279
280 #[test]
297 fn falsify_ge_001_non_negativity() {
298 let test_values = [0.001, 0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0, 1e6];
299 for &x in &test_values {
300 let y = gelu_scalar(x);
301 assert!(y >= 0.0, "FALSIFIED GE-001: GELU({x}) = {y} < 0 for positive input");
302 }
303 }
304
305 #[test]
307 fn falsify_ge_002_positive_monotonicity() {
308 let values: Vec<f32> = vec![0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0];
309 for window in values.windows(2) {
310 let (y_lo, y_hi) = (gelu_scalar(window[0]), gelu_scalar(window[1]));
311 assert!(
312 y_hi > y_lo,
313 "FALSIFIED GE-002: GELU({}) = {} not > GELU({}) = {}",
314 window[1],
315 y_hi,
316 window[0],
317 y_lo
318 );
319 }
320 }
321
322 #[test]
324 fn falsify_ge_003_zero_preservation() {
325 let y = gelu_scalar(0.0);
326 assert!(y.abs() < 1e-7, "FALSIFIED GE-003: GELU(0) = {y}, expected 0");
327 }
328
329 #[test]
334 fn falsify_ge_005_tanh_approx_accuracy() {
335 fn erf_approx(x: f32) -> f32 {
337 let sign = x.signum();
338 let x = x.abs();
339 let t = 1.0 / (1.0 + 0.327_591_1 * x);
340 let t2 = t * t;
341 let t3 = t2 * t;
342 let t4 = t3 * t;
343 let t5 = t4 * t;
344 let poly = 0.254_829_592 * t - 0.284_496_736 * t2 + 1.421_413_741 * t3
345 - 1.453_152_027 * t4
346 + 1.061_405_429 * t5;
347 sign * (1.0 - poly * (-x * x).exp())
348 }
349
350 fn gelu_exact(x: f32) -> f32 {
351 let phi = 0.5 * (1.0 + erf_approx(x / std::f32::consts::SQRT_2));
352 x * phi
353 }
354
355 let test_values: Vec<f32> = (-100..=100).map(|i| i as f32 * 0.1).collect();
356 for &x in &test_values {
357 let approx = gelu_scalar(x);
358 let exact = gelu_exact(x);
359 let diff = (approx - exact).abs();
360 assert!(
361 diff < 0.005,
362 "FALSIFIED GE-005: |GELU_approx({x}) - GELU_exact({x})| = {diff} >= 0.005"
363 );
364 }
365 }
366
367 #[test]
369 fn falsify_ge_006_large_input_stability() {
370 for &x in &[10.0_f32, 50.0, 100.0, 1000.0] {
371 let y = gelu_scalar(x);
372 assert!((y - x).abs() < 0.01, "FALSIFIED GE-006: GELU({x}) = {y}, expected ≈ {x}");
373 }
374 for &x in &[-10.0_f32, -50.0, -100.0, -1000.0] {
375 let y = gelu_scalar(x);
376 assert!(y.abs() < 0.01, "FALSIFIED GE-006: GELU({x}) = {y}, expected ≈ 0");
377 }
378 }
379
380 mod ge_proptest_falsify {
381 use super::*;
382 use proptest::prelude::*;
383
384 proptest! {
386 #![proptest_config(ProptestConfig::with_cases(500))]
387 #[test]
388 fn falsify_ge_001_prop_non_negativity(x in 0.0_f32..1000.0) {
389 let y = gelu_scalar(x);
390 prop_assert!(y >= 0.0, "FALSIFIED GE-001-prop: gelu({x}) = {y} < 0");
391 }
392 }
393
394 proptest! {
396 #![proptest_config(ProptestConfig::with_cases(300))]
397 #[test]
398 fn falsify_ge_002_prop_monotonic_positive(
399 a in 0.001_f32..100.0,
400 b in 0.001_f32..100.0,
401 ) {
402 if a != b {
403 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
404 let y_lo = gelu_scalar(lo);
405 let y_hi = gelu_scalar(hi);
406 prop_assert!(
407 y_hi > y_lo,
408 "FALSIFIED GE-002-prop: gelu({hi})={y_hi} not > gelu({lo})={y_lo}"
409 );
410 }
411 }
412 }
413
414 proptest! {
416 #![proptest_config(ProptestConfig::with_cases(200))]
417 #[test]
418 fn falsify_ge_006_prop_large_positive(x in 10.0_f32..500.0) {
419 let y = gelu_scalar(x);
420 prop_assert!(
421 (y - x).abs() < 0.01,
422 "FALSIFIED GE-006-prop: |gelu({x}) - {x}| = {}",
423 (y - x).abs()
424 );
425 }
426 }
427 }
428}
429
430#[cfg(test)]
446mod silu_contract_tests {
447 use super::*;
448
449 #[test]
451 fn falsify_si_001_zero_preservation() {
452 let y = silu_scalar(0.0);
453 assert!(y.abs() < 1e-7, "FALSIFIED SI-001: SiLU(0) = {y}, expected 0");
454 }
455
456 #[test]
458 fn falsify_si_002_global_lower_bound() {
459 let test_values: Vec<f32> =
460 vec![-100.0, -50.0, -10.0, -5.0, -2.0, -1.278, -1.0, -0.5, 0.0, 0.5, 1.0, 5.0, 100.0];
461 for &x in &test_values {
462 let y = silu_scalar(x);
463 assert!(y > -0.28, "FALSIFIED SI-002: SiLU({x}) = {y}, expected > -0.279");
464 }
465 }
466
467 #[test]
469 fn falsify_si_003_monotonic_positive() {
470 let values: Vec<f32> = vec![0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0];
471 for i in 1..values.len() {
472 let y_prev = silu_scalar(values[i - 1]);
473 let y_curr = silu_scalar(values[i]);
474 assert!(
475 y_curr > y_prev,
476 "FALSIFIED SI-003: SiLU({}) = {y_curr} not > SiLU({}) = {y_prev}",
477 values[i],
478 values[i - 1]
479 );
480 }
481 }
482
483 #[test]
485 fn falsify_si_005_asymptotic_linearity() {
486 for &x in &[10.0f32, 20.0, 50.0, 100.0, 500.0] {
487 let y = silu_scalar(x);
488 assert!(
489 (y - x).abs() < 0.01,
490 "FALSIFIED SI-005: |SiLU({x}) - {x}| = {} >= 0.01",
491 (y - x).abs()
492 );
493 }
494 }
495
496 #[test]
498 fn falsify_si_006_large_negative_vanishes() {
499 for &x in &[-10.0f32, -20.0, -50.0, -100.0, -500.0] {
500 let y = silu_scalar(x);
501 assert!(y.abs() < 0.01, "FALSIFIED SI-006: SiLU({x}) = {y}, expected ≈ 0");
502 }
503 }
504
505 mod si_proptest_falsify {
506 use super::*;
507 use proptest::prelude::*;
508
509 proptest! {
511 #![proptest_config(ProptestConfig::with_cases(500))]
512 #[test]
513 fn falsify_si_002_prop_lower_bound(x in -1000.0_f32..1000.0) {
514 let y = silu_scalar(x);
515 prop_assert!(
516 y > -0.28,
517 "FALSIFIED SI-002-prop: SiLU({x}) = {y} <= -0.279"
518 );
519 }
520 }
521
522 proptest! {
524 #![proptest_config(ProptestConfig::with_cases(300))]
525 #[test]
526 fn falsify_si_003_prop_monotonic_positive(
527 a in 0.001_f32..100.0,
528 b in 0.001_f32..100.0,
529 ) {
530 if a != b {
531 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
532 let y_lo = silu_scalar(lo);
533 let y_hi = silu_scalar(hi);
534 prop_assert!(
535 y_hi > y_lo,
536 "FALSIFIED SI-003-prop: SiLU({hi})={y_hi} not > SiLU({lo})={y_lo}"
537 );
538 }
539 }
540 }
541
542 proptest! {
544 #![proptest_config(ProptestConfig::with_cases(200))]
545 #[test]
546 fn falsify_si_005_prop_asymptotic(x in 10.0_f32..500.0) {
547 let y = silu_scalar(x);
548 prop_assert!(
549 (y - x).abs() < 0.01,
550 "FALSIFIED SI-005-prop: |SiLU({x}) - {x}| = {}",
551 (y - x).abs()
552 );
553 }
554 }
555 }
556}