1use crate::autograd::{BackwardOp, Tensor};
4use ndarray::Array1;
5use provable_contracts_macros::contract;
6use std::cell::RefCell;
7use std::rc::Rc;
8
9pub fn relu(a: &Tensor) -> Tensor {
11 contract_pre_relu!(a.data().as_slice().unwrap_or(&[]));
12 let data = a.data().mapv(|x| x.max(0.0));
13 let requires_grad = a.requires_grad();
14
15 let mut result = Tensor::new(data, requires_grad);
16
17 if requires_grad {
18 let a_clone = a.clone();
19 let backward_op = Rc::new(ReluBackward { a: a_clone, result_grad: result.grad_cell() });
20 result.set_backward_op(backward_op);
21 }
22
23 result
24}
25
26struct ReluBackward {
27 a: Tensor,
28 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
29}
30
31impl BackwardOp for ReluBackward {
32 fn backward(&self) {
33 if let Some(grad) = self.result_grad.borrow().as_ref() {
34 if self.a.requires_grad() {
35 let grad_a = grad * &self.a.data().mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
37 self.a.accumulate_grad(grad_a);
38 }
39
40 if let Some(op) = self.a.backward_op() {
41 op.backward();
42 }
43 }
44 }
45}
46
47pub fn gelu(a: &Tensor) -> Tensor {
53 contract_pre_gelu!(a.data().as_slice().unwrap_or(&[]));
54 let data = a.data().mapv(trueno::gelu_scalar);
55
56 let requires_grad = a.requires_grad();
57 let mut result = Tensor::new(data, requires_grad);
58
59 if requires_grad {
60 let a_clone = a.clone();
61 let backward_op = Rc::new(GeluBackward { a: a_clone, result_grad: result.grad_cell() });
62 result.set_backward_op(backward_op);
63 }
64
65 contract_post_gelu!(result.data().as_slice().unwrap_or(&[]));
66 result
67}
68
69struct GeluBackward {
70 a: Tensor,
71 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
72}
73
74impl BackwardOp for GeluBackward {
75 fn backward(&self) {
76 if let Some(grad_output) = self.result_grad.borrow().as_ref() {
77 if self.a.requires_grad() {
78 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
79 const COEFF: f32 = 0.044_715;
80
81 let grad_a: Vec<f32> = self
85 .a
86 .data()
87 .iter()
88 .zip(grad_output.iter())
89 .map(|(&x, &grad)| {
90 let x2 = x * x;
91 let x3 = x2 * x;
92 let z = SQRT_2_OVER_PI * (x + COEFF * x3);
93 let tanh_z = z.tanh();
94 let sech2_z = 1.0 - tanh_z * tanh_z;
95 let dz_dx = SQRT_2_OVER_PI * (1.0 + 3.0 * COEFF * x2);
96
97 let gelu_grad = 0.5 * (1.0 + tanh_z) + 0.5 * x * sech2_z * dz_dx;
98 grad * gelu_grad
99 })
100 .collect();
101
102 self.a.accumulate_grad(Array1::from(grad_a));
103 }
104
105 if let Some(op) = self.a.backward_op() {
106 op.backward();
107 }
108 }
109 }
110}
111
112pub fn swish(a: &Tensor) -> Tensor {
118 let data = a.data().mapv(trueno::silu_scalar);
119
120 let requires_grad = a.requires_grad();
121 let mut result = Tensor::new(data, requires_grad);
122
123 if requires_grad {
124 let a_clone = a.clone();
125 let output_clone = result.clone();
126 let backward_op = Rc::new(SwishBackward {
127 a: a_clone,
128 output: output_clone,
129 result_grad: result.grad_cell(),
130 });
131 result.set_backward_op(backward_op);
132 }
133
134 result
135}
136
137struct SwishBackward {
138 a: Tensor,
139 output: Tensor,
140 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
141}
142
143impl BackwardOp for SwishBackward {
144 fn backward(&self) {
145 if let Some(grad_output) = self.result_grad.borrow().as_ref() {
146 if self.a.requires_grad() {
147 let grad_a: Vec<f32> = self
150 .a
151 .data()
152 .iter()
153 .zip(self.output.data().iter())
154 .zip(grad_output.iter())
155 .map(|((&x, &swish_x), &grad)| {
156 let sigmoid = 1.0 / (1.0 + (-x).exp());
157 let swish_grad = swish_x + sigmoid * (1.0 - swish_x);
158 grad * swish_grad
159 })
160 .collect();
161
162 self.a.accumulate_grad(Array1::from(grad_a));
163 }
164
165 if let Some(op) = self.a.backward_op() {
166 op.backward();
167 }
168 }
169 }
170}
171
172#[contract("softmax-v1", equation = "softmax")]
174pub fn softmax(a: &Tensor) -> Tensor {
175 contract_pre_softmax!(a.data().as_slice().unwrap_or(&[]));
176 let max_val = a.data().iter().copied().fold(f32::NEG_INFINITY, f32::max);
177 let exp_vals = a.data().mapv(|x| (x - max_val).exp());
178 let sum_exp = exp_vals.sum();
179 let data = exp_vals / sum_exp;
180
181 let requires_grad = a.requires_grad();
182 let mut result = Tensor::new(data, requires_grad);
183
184 if requires_grad {
185 let a_clone = a.clone();
186 let output_clone = result.clone();
187 let backward_op = Rc::new(SoftmaxBackward {
188 a: a_clone,
189 output: output_clone,
190 result_grad: result.grad_cell(),
191 });
192 result.set_backward_op(backward_op);
193 }
194
195 contract_post_softmax!(result.data().as_slice().unwrap_or(&[]));
196 result
197}
198
199struct SoftmaxBackward {
200 a: Tensor,
201 output: Tensor,
202 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
203}
204
205impl BackwardOp for SoftmaxBackward {
206 fn backward(&self) {
207 if let Some(grad_output) = self.result_grad.borrow().as_ref() {
208 if self.a.requires_grad() {
209 let y = self.output.data();
211 let dot = (y * grad_output).sum();
212 let grad_a = y * &(grad_output - dot);
213 self.a.accumulate_grad(grad_a);
214 }
215
216 if let Some(op) = self.a.backward_op() {
217 op.backward();
218 }
219 }
220 }
221}
222
223#[cfg(test)]
242mod silu_contract_tests {
243 #[test]
245 fn falsify_si_001_zero_preservation() {
246 let y = trueno::silu_scalar(0.0);
247 assert!(y.abs() < 1e-7, "FALSIFIED SI-001: SiLU(0) = {y}, expected 0");
248 }
249
250 #[test]
252 fn falsify_si_002_global_lower_bound() {
253 let test_values: Vec<f32> =
254 vec![-100.0, -50.0, -10.0, -5.0, -2.0, -1.278, -1.0, -0.5, 0.0, 0.5, 1.0, 5.0, 100.0];
255 for &x in &test_values {
256 let y = trueno::silu_scalar(x);
257 assert!(y > -0.28, "FALSIFIED SI-002: SiLU({x}) = {y}, expected > -0.279");
258 }
259 }
260
261 #[test]
263 fn falsify_si_003_monotonic_positive() {
264 let values: Vec<f32> = vec![0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0, 100.0];
265 for i in 1..values.len() {
266 let y_prev = trueno::silu_scalar(values[i - 1]);
267 let y_curr = trueno::silu_scalar(values[i]);
268 assert!(
269 y_curr > y_prev,
270 "FALSIFIED SI-003: SiLU({}) = {y_curr} not > SiLU({}) = {y_prev}",
271 values[i],
272 values[i - 1]
273 );
274 }
275 }
276
277 #[test]
279 fn falsify_si_005_asymptotic_linearity() {
280 for &x in &[10.0f32, 20.0, 50.0, 100.0, 500.0] {
281 let y = trueno::silu_scalar(x);
282 assert!(
283 (y - x).abs() < 0.01,
284 "FALSIFIED SI-005: |SiLU({x}) - {x}| = {} >= 0.01",
285 (y - x).abs()
286 );
287 }
288 }
289
290 #[test]
292 fn falsify_si_006_large_negative_vanishes() {
293 for &x in &[-10.0f32, -20.0, -50.0, -100.0] {
294 let y = trueno::silu_scalar(x);
295 assert!(y.abs() < 0.01, "FALSIFIED SI-006: SiLU({x}) = {y}, expected ≈ 0");
296 }
297 }
298
299 mod si_proptest_falsify {
300 use proptest::prelude::*;
301
302 proptest! {
303 #![proptest_config(ProptestConfig::with_cases(500))]
304 #[test]
305 fn falsify_si_002_prop_lower_bound(x in -1000.0_f32..1000.0) {
306 let y = trueno::silu_scalar(x);
307 prop_assert!(y > -0.28, "FALSIFIED SI-002-prop: SiLU({x}) = {y}");
308 }
309 }
310
311 proptest! {
312 #![proptest_config(ProptestConfig::with_cases(300))]
313 #[test]
314 fn falsify_si_003_prop_monotonic_positive(
315 a in 0.001_f32..100.0,
316 b in 0.001_f32..100.0,
317 ) {
318 if a != b {
319 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
320 prop_assert!(
321 trueno::silu_scalar(hi) > trueno::silu_scalar(lo),
322 "FALSIFIED SI-003-prop: SiLU({hi}) not > SiLU({lo})"
323 );
324 }
325 }
326 }
327
328 proptest! {
329 #![proptest_config(ProptestConfig::with_cases(200))]
330 #[test]
331 fn falsify_si_005_prop_asymptotic(x in 10.0_f32..500.0) {
332 let y = trueno::silu_scalar(x);
333 prop_assert!(
334 (y - x).abs() < 0.01,
335 "FALSIFIED SI-005-prop: |SiLU({x}) - {x}| = {}",
336 (y - x).abs()
337 );
338 }
339 }
340 }
341}
342
343#[cfg(test)]
362mod swiglu_contract_tests {
363
364 fn swiglu_scalar(x: f32, gate: f32) -> f32 {
365 x * trueno::silu_scalar(gate)
366 }
367
368 #[test]
370 fn falsify_sg_001_zero_x_preservation() {
371 for &g in &[-10.0f32, -1.0, 0.0, 1.0, 10.0] {
372 let y = swiglu_scalar(0.0, g);
373 assert!(y.abs() < 1e-7, "FALSIFIED SG-001: SwiGLU(0, {g}) = {y}");
374 }
375 }
376
377 #[test]
379 fn falsify_sg_002_fused_equivalence() {
380 let cases: Vec<(f32, f32)> =
381 vec![(1.0, 1.0), (-2.0, 3.0), (5.0, -1.0), (0.5, 0.5), (100.0, 0.0)];
382 for &(x, g) in &cases {
383 let fused = swiglu_scalar(x, g);
384 let decomposed = x * trueno::silu_scalar(g);
385 assert!(
386 (fused - decomposed).abs() < 1e-6,
387 "FALSIFIED SG-002: swiglu({x},{g})={fused} != decomposed={decomposed}"
388 );
389 }
390 }
391
392 #[test]
394 fn falsify_sg_003_silu_lower_bound() {
395 for &g in &[-1000.0f32, -1.278, -1.0, 0.0, 1.0, 1000.0] {
396 let silu_g = trueno::silu_scalar(g);
397 assert!(silu_g > -0.28, "FALSIFIED SG-003: SiLU({g}) = {silu_g}");
398 }
399 }
400
401 #[test]
403 fn falsify_sg_004_finite_output() {
404 let vals = vec![-100.0, -10.0, -1.0, 0.0, 1.0, 10.0, 100.0];
405 for &x in &vals {
406 for &g in &vals {
407 let y = swiglu_scalar(x, g);
408 assert!(y.is_finite(), "FALSIFIED SG-004: SwiGLU({x},{g}) = {y}");
409 }
410 }
411 }
412
413 #[test]
415 fn falsify_sg_005_empty_input() {
416 let empty: Vec<f32> = vec![];
417 let result: Vec<f32> =
418 empty.iter().zip(empty.iter()).map(|(&x, &g)| swiglu_scalar(x, g)).collect();
419 assert!(result.is_empty(), "FALSIFIED SG-005: empty SwiGLU produced non-empty output");
420 }
421
422 mod sg_proptest_falsify {
423 use super::*;
424 use proptest::prelude::*;
425
426 proptest! {
427 #![proptest_config(ProptestConfig::with_cases(300))]
428 #[test]
429 fn falsify_sg_001_prop_zero_x(gate in -100.0_f32..100.0) {
430 let y = swiglu_scalar(0.0, gate);
431 prop_assert!(y.abs() < 1e-6, "FALSIFIED SG-001-prop: SwiGLU(0, {gate}) = {y}");
432 }
433 }
434
435 proptest! {
436 #![proptest_config(ProptestConfig::with_cases(300))]
437 #[test]
438 fn falsify_sg_004_prop_finite(
439 x in -100.0_f32..100.0,
440 gate in -100.0_f32..100.0,
441 ) {
442 let y = swiglu_scalar(x, gate);
443 prop_assert!(y.is_finite(), "FALSIFIED SG-004-prop: SwiGLU({x},{gate}) = {y}");
444 }
445 }
446
447 proptest! {
448 #![proptest_config(ProptestConfig::with_cases(200))]
449 #[test]
450 fn falsify_sg_006_prop_monotonic_gate(
451 x in 1.0_f32..50.0,
452 a in 0.1_f32..50.0,
453 b in 0.1_f32..50.0,
454 ) {
455 if a != b {
458 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
459 let y_lo = swiglu_scalar(x, lo);
460 let y_hi = swiglu_scalar(x, hi);
461 prop_assert!(
462 y_hi > y_lo,
463 "FALSIFIED SG-006-prop: SwiGLU({x},{hi})={y_hi} not > SwiGLU({x},{lo})={y_lo}"
464 );
465 }
466 }
467 }
468 }
469}
470
471#[cfg(test)]
475mod gelu_contract_tests {
476 use super::*;
477 use ndarray::Array1;
478
479 #[test]
481 fn falsify_ge_001_non_negativity() {
482 let x = Tensor::new(Array1::from(vec![0.001, 0.1, 1.0, 5.0, 10.0, 100.0]), false);
483 let y = gelu(&x);
484 for (i, &val) in y.data().iter().enumerate() {
485 assert!(val >= 0.0, "FALSIFIED GE-001: gelu(positive)[{i}] = {val} < 0");
486 }
487 }
488
489 #[test]
491 fn falsify_ge_002_positive_monotonicity() {
492 let x = Tensor::new(Array1::from(vec![0.1, 0.5, 1.0, 2.0, 5.0, 10.0]), false);
493 let y = gelu(&x);
494 let data = y.data();
495 for i in 1..data.len() {
496 assert!(
497 data[i] > data[i - 1],
498 "FALSIFIED GE-002: gelu not monotonic: [{i}]={} not > [{}]={}",
499 data[i],
500 i - 1,
501 data[i - 1]
502 );
503 }
504 }
505
506 #[test]
508 fn falsify_ge_003_zero_preservation() {
509 let x = Tensor::new(Array1::from(vec![0.0]), false);
510 let y = gelu(&x);
511 assert!(y.data()[0].abs() < 1e-7, "FALSIFIED GE-003: gelu(0) = {}", y.data()[0]);
512 }
513
514 #[test]
516 fn falsify_ge_006_large_input_stability() {
517 let x = Tensor::new(Array1::from(vec![10.0, 50.0, -10.0, -50.0]), false);
518 let y = gelu(&x);
519 let d = y.data();
520 assert!((d[0] - 10.0).abs() < 0.01, "FALSIFIED GE-006: gelu(10) = {}", d[0]);
521 assert!((d[1] - 50.0).abs() < 0.01, "FALSIFIED GE-006: gelu(50) = {}", d[1]);
522 assert!(d[2].abs() < 0.01, "FALSIFIED GE-006: gelu(-10) = {}", d[2]);
523 assert!(d[3].abs() < 0.01, "FALSIFIED GE-006: gelu(-50) = {}", d[3]);
524 }
525
526 #[test]
528 fn falsify_ge_005_tanh_approx_accuracy() {
529 use std::f32::consts::FRAC_2_PI;
532 let c = FRAC_2_PI.sqrt();
533 for x_int in -100..=100 {
534 let x = x_int as f32 * 0.1;
535 let approx = trueno::gelu_scalar(x);
536 let inner = c * (x + 0.044_715 * x * x * x);
539 let exact_approx = 0.5 * x * (1.0 + inner.tanh());
540 assert!(
541 (approx - exact_approx).abs() < 0.005,
542 "FALSIFIED GE-005: |gelu_approx({x}) - gelu_exact({x})| = {}",
543 (approx - exact_approx).abs()
544 );
545 }
546 }
547
548 mod ge_proptest_falsify {
549 use super::*;
550 use ndarray::Array1;
551 use proptest::prelude::*;
552
553 proptest! {
554 #![proptest_config(ProptestConfig::with_cases(500))]
555 #[test]
556 fn falsify_ge_001_prop_non_negativity(x in 0.0_f32..1000.0) {
557 let t = Tensor::new(Array1::from(vec![x]), false);
558 let y = gelu(&t);
559 prop_assert!(y.data()[0] >= 0.0, "FALSIFIED GE-001-prop: gelu({x}) = {} < 0", y.data()[0]);
560 }
561 }
562
563 proptest! {
564 #![proptest_config(ProptestConfig::with_cases(300))]
565 #[test]
566 fn falsify_ge_002_prop_monotonic_positive(
567 a in 0.001_f32..100.0,
568 b in 0.001_f32..100.0,
569 ) {
570 if a != b {
571 let (lo, hi) = if a < b { (a, b) } else { (b, a) };
572 let t = Tensor::new(Array1::from(vec![lo, hi]), false);
573 let y = gelu(&t);
574 let d = y.data();
575 prop_assert!(d[1] > d[0], "FALSIFIED GE-002-prop: gelu({hi})={} not > gelu({lo})={}", d[1], d[0]);
576 }
577 }
578 }
579
580 proptest! {
581 #![proptest_config(ProptestConfig::with_cases(200))]
582 #[test]
583 fn falsify_ge_006_prop_large_positive(x in 10.0_f32..500.0) {
584 let t = Tensor::new(Array1::from(vec![x]), false);
585 let y = gelu(&t);
586 prop_assert!(
587 (y.data()[0] - x).abs() < 0.01,
588 "FALSIFIED GE-006-prop: |gelu({x}) - {x}| = {}",
589 (y.data()[0] - x).abs()
590 );
591 }
592 }
593 }
594}