use crate::autograd::Tensor;
#[provable_contracts_macros::contract("activation-kernel-v1", equation = "relu")]
#[must_use]
pub fn relu(x: &Tensor) -> Tensor {
contract_pre_relu!(x.data());
x.relu()
}
#[inline]
#[must_use]
pub fn relu_scalar(x: f32) -> f32 {
trueno::relu_scalar(x)
}
#[must_use]
pub fn leaky_relu(x: &Tensor, negative_slope: f32) -> Tensor {
let src = x.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = if src[i] > 0.0 {
src[i]
} else {
negative_slope * src[i]
};
}
Tensor::from_vec(data, x.shape())
}
#[provable_contracts_macros::contract("silu-kernel-v1", equation = "sigmoid")]
#[must_use]
pub fn sigmoid(x: &Tensor) -> Tensor {
contract_pre_sigmoid!(x.data());
x.sigmoid()
}
#[inline]
#[must_use]
pub fn sigmoid_scalar(x: f32) -> f32 {
trueno::sigmoid_scalar(x)
}
#[inline]
#[must_use]
pub fn sigmoid_scalar_f64(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
#[provable_contracts_macros::contract("silu-kernel-v1", equation = "silu")]
#[must_use]
pub fn silu(x: &Tensor) -> Tensor {
contract_pre_silu!(x.data());
let src = x.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = trueno::silu_scalar(src[i]);
}
let result = Tensor::from_vec(data, x.shape());
contract_post_silu!(result.data());
result
}
#[inline]
#[must_use]
pub fn silu_scalar(x: f32) -> f32 {
trueno::silu_scalar(x)
}
#[must_use]
pub fn swiglu(x: &Tensor, gate: &Tensor) -> Tensor {
contract_pre_swiglu!(x.data());
let src_x = x.data();
let src_g = gate.data();
let n = src_x.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = src_x[i] * trueno::silu_scalar(src_g[i]);
}
let result = Tensor::from_vec(data, x.shape());
contract_post_swiglu!(result.data());
result
}
#[inline]
#[must_use]
pub fn swiglu_scalar(x: f32, gate: f32) -> f32 {
x * gate / (1.0 + (-gate).exp())
}
#[must_use]
pub fn softmax_1d(logits: &[f32]) -> Vec<f32> {
contract_pre_softmax!(logits);
let result = trueno::blis::softmax::softmax_1d_alloc(logits);
contract_post_softmax!(&result);
result
}
#[must_use]
pub fn softmax_1d_f64(logits: &[f64]) -> Vec<f64> {
let n = logits.len();
let mut out = vec![0.0f64; n];
let mut max_val = f64::NEG_INFINITY;
for &v in logits {
max_val = max_val.max(v);
}
let mut sum = 0.0f64;
for i in 0..n {
let e = (logits[i] - max_val).exp();
out[i] = e;
sum += e;
}
let inv_sum = 1.0 / sum;
for i in 0..n {
out[i] *= inv_sum;
}
out
}
#[must_use]
pub fn log_softmax_1d(logits: &[f32]) -> Vec<f32> {
contract_pre_log_softmax!(logits);
let n = logits.len();
let mut out = vec![0.0f32; n];
let mut max_val = f32::NEG_INFINITY;
for &v in logits {
max_val = max_val.max(v);
}
let mut sum_exp = 0.0f32;
for &v in logits {
sum_exp += (v - max_val).exp();
}
let log_sum_exp = sum_exp.ln();
for i in 0..n {
out[i] = logits[i] - max_val - log_sum_exp;
}
out
}
#[must_use]
pub fn tanh(x: &Tensor) -> Tensor {
x.tanh_()
}
#[provable_contracts_macros::contract("activation-kernel-v1", equation = "gelu")]
#[must_use]
pub fn gelu(x: &Tensor) -> Tensor {
contract_pre_gelu!(x.data());
let src = x.data();
let n = src.len();
let mut data = vec![0.0f32; n];
for i in 0..n {
data[i] = trueno::gelu_scalar(src[i]);
}
let result = Tensor::from_vec(data, x.shape());
contract_post_gelu!(result.data());
result
}
#[provable_contracts_macros::contract("softmax-kernel-v1", equation = "softmax")]
#[must_use]
pub fn softmax(x: &Tensor, _dim: i32) -> Tensor {
contract_pre_softmax!(x.data());
let shape = x.shape();
let last_dim = shape[shape.len() - 1];
let batch_size: usize = shape[..shape.len() - 1].iter().product();
let data = x.data();
let mut output = vec![0.0f32; data.len()];
for b in 0..batch_size {
let start = b * last_dim;
let row = &data[start..start + last_dim];
let out = &mut output[start..start + last_dim];
let mut max_val = f32::NEG_INFINITY;
for &v in row {
max_val = max_val.max(v);
}
let mut sum = 0.0f32;
for i in 0..last_dim {
let e = (row[i] - max_val).exp();
out[i] = e;
sum += e;
}
let inv_sum = 1.0 / sum;
for i in 0..last_dim {
out[i] *= inv_sum;
}
contract_post_softmax!(out);
}
Tensor::from_vec(output, shape)
}
#[provable_contracts_macros::contract("cross-entropy-kernel-v1", equation = "log_softmax")]
#[must_use]
pub fn log_softmax(x: &Tensor, _dim: i32) -> Tensor {
let shape = x.shape();
let last_dim = shape[shape.len() - 1];
let batch_size: usize = shape[..shape.len() - 1].iter().product();
let mut output = vec![0.0f32; x.data().len()];
for b in 0..batch_size {
let start = b * last_dim;
let row = &x.data()[start..start + last_dim];
let max_val = row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let log_sum_exp: f32 = row.iter().map(|&v| (v - max_val).exp()).sum::<f32>().ln();
for j in 0..last_dim {
output[start + j] = row[j] - max_val - log_sum_exp;
}
}
Tensor::from_vec(output, shape)
}
#[must_use]
pub fn dropout(x: &Tensor, p: f32, training: bool) -> Tensor {
if !training || p == 0.0 {
return x.clone();
}
use rand::Rng;
let mut rng = rand::rng();
let scale = 1.0 / (1.0 - p);
let data: Vec<f32> = x
.data()
.iter()
.map(|&v| {
if rng.random::<f32>() < p {
0.0
} else {
v * scale
}
})
.collect();
Tensor::from_vec(data, x.shape())
}
#[must_use]
pub fn layer_norm(x: &Tensor, weight: &Tensor, bias: &Tensor, eps: f32) -> Tensor {
let shape = x.shape();
let data = x.data();
let weight_data = weight.data();
let bias_data = bias.data();
let norm_dim = weight_data.len();
let batch_size = data.len() / norm_dim;
if batch_size == 1 {
let output = trueno::blis::norms::layer_norm_alloc(data, weight_data, bias_data, eps);
return Tensor::from_vec(output, shape);
}
let mut output = vec![0.0f32; data.len()];
for b in 0..batch_size {
let start = b * norm_dim;
let slice = &data[start..start + norm_dim];
let out_slice = &mut output[start..start + norm_dim];
trueno::blis::norms::layer_norm(slice, weight_data, bias_data, eps, out_slice)
.expect("layer_norm: dimension mismatch (should be impossible)");
}
Tensor::from_vec(output, shape)
}
#[must_use]
pub fn rms_norm(x: &Tensor, weight: &Tensor, eps: f32) -> Tensor {
let shape = x.shape();
let data = x.data();
let weight_data = weight.data();
let norm_dim = weight_data.len();
let batch_size = data.len() / norm_dim;
if batch_size == 1 {
let output = trueno::blis::norms::rms_norm_alloc(data, weight_data, eps);
return Tensor::from_vec(output, shape);
}
let mut output = vec![0.0f32; data.len()];
for b in 0..batch_size {
let start = b * norm_dim;
let slice = &data[start..start + norm_dim];
let out_slice = &mut output[start..start + norm_dim];
trueno::blis::norms::rms_norm(slice, weight_data, eps, out_slice)
.expect("rms_norm: dimension mismatch (should be impossible)");
}
Tensor::from_vec(output, shape)
}
#[must_use]
pub fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
let weight_t = weight.transpose();
let output = x.matmul(&weight_t);
match bias {
Some(b) => broadcast_add_1d(&output, b),
None => output,
}
}
fn broadcast_add_1d(matrix: &Tensor, vector: &Tensor) -> Tensor {
let (rows, cols) = (matrix.shape()[0], matrix.shape()[1]);
let mut result = vec![0.0; rows * cols];
for i in 0..rows {
for j in 0..cols {
result[i * cols + j] = matrix.data()[i * cols + j] + vector.data()[j];
}
}
Tensor::new(&result, &[rows, cols])
}
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn cosine_similarity_slice(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0_f32;
let mut norm_a = 0.0_f32;
let mut norm_b = 0.0_f32;
for (&x, &y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom < 1e-10 {
0.0
} else {
(dot / denom).clamp(-1.0, 1.0)
}
}
#[cfg(test)]
mod softmax_contract_tests {
use super::*;
#[test]
fn falsify_sm_001_sums_to_one() {
let cases: Vec<Vec<f32>> = vec![
vec![1.0, 2.0, 3.0],
vec![-10.0, 0.0, 10.0],
vec![100.0, 101.0, 102.0],
(0..100).map(|i| (i as f32 * 0.37).sin() * 5.0).collect(),
];
for (idx, logits) in cases.iter().enumerate() {
let probs = softmax_1d(logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"FALSIFIED SM-001: case {idx} sum={sum}"
);
}
}
#[test]
fn falsify_sm_001b_f64_sums_to_one() {
let logits: Vec<f64> = vec![1.0, 2.0, 3.0, -5.0, 10.0];
let probs = softmax_1d_f64(&logits);
let sum: f64 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"FALSIFIED SM-001b: f64 sum={sum}"
);
}
#[test]
fn falsify_sm_002_strictly_positive() {
let logits: Vec<f32> = (0..50).map(|i| (i as f32 - 25.0) * 2.0).collect();
let probs = softmax_1d(&logits);
for (i, &p) in probs.iter().enumerate() {
assert!(
p > 0.0,
"FALSIFIED SM-002: probs[{i}] = {p} not strictly positive"
);
}
}
#[test]
fn falsify_sm_003_order_preservation() {
let logits = vec![1.0f32, 5.0, 3.0, 2.0];
let probs = softmax_1d(&logits);
let input_argmax = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
let output_argmax = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
assert_eq!(
input_argmax, output_argmax,
"FALSIFIED SM-003: argmax changed from {input_argmax} to {output_argmax}"
);
}
#[test]
fn falsify_sm_004_bounded_zero_one() {
let logits: Vec<f32> = (0..20).map(|i| (i as f32 * 1.7).sin() * 10.0).collect();
let probs = softmax_1d(&logits);
for (i, &p) in probs.iter().enumerate() {
assert!(
p > 0.0 && p < 1.0,
"FALSIFIED SM-004: probs[{i}] = {p} not in (0, 1)"
);
}
}
#[test]
fn falsify_sm_005_numerical_stability() {
let extreme = vec![1000.0f32, 1001.0, 1002.0];
let probs = softmax_1d(&extreme);
assert!(
probs.iter().all(|p| p.is_finite()),
"FALSIFIED SM-005: extreme inputs produced non-finite"
);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"FALSIFIED SM-005: extreme sum={sum}"
);
}
#[test]
fn falsify_sm_006_tensor_softmax_shape() {
let x = Tensor::new(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let result = softmax(&x, -1);
assert_eq!(result.shape(), &[2, 3], "FALSIFIED SM-006: shape changed");
let data = result.data();
let row1_sum: f32 = data[0..3].iter().sum();
let row2_sum: f32 = data[3..6].iter().sum();
assert!(
(row1_sum - 1.0).abs() < 1e-5,
"FALSIFIED SM-006: row 1 sum={row1_sum}"
);
assert!(
(row2_sum - 1.0).abs() < 1e-5,
"FALSIFIED SM-006: row 2 sum={row2_sum}"
);
}
#[test]
fn falsify_sm_009_single_element() {
for x in [0.0_f32, 1.0, -1.0, 100.0, -100.0, f32::MIN_POSITIVE] {
let t = Tensor::new(&[x], &[1, 1]);
let result = softmax(&t, -1);
assert!(
(result.data()[0] - 1.0).abs() < 1e-6,
"FALSIFIED SM-009: softmax([{x}]) = {}, expected 1.0",
result.data()[0]
);
}
}
#[test]
fn falsify_sm_007_translation_invariance() {
let base = Tensor::new(&[1.0_f32, 3.0, -2.0, 0.5], &[1, 4]);
let base_probs = softmax(&base, -1);
for c in [100.0_f32, -100.0, 0.0, 42.0, -999.0] {
let shifted = Tensor::new(&[1.0 + c, 3.0 + c, -2.0 + c, 0.5 + c], &[1, 4]);
let shifted_probs = softmax(&shifted, -1);
for (i, (&orig, &shift)) in base_probs
.data()
.iter()
.zip(shifted_probs.data().iter())
.enumerate()
{
assert!(
(orig - shift).abs() < 1e-5,
"FALSIFIED SM-007: σ(x+{c})[{i}] = {shift} != σ(x)[{i}] = {orig}"
);
}
}
}
}
#[cfg(test)]
mod softmax_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn falsify_sm_001_prop_sums_to_one(
logits in proptest::collection::vec(-100.0_f32..100.0, 1..64),
) {
let probs = softmax_1d(&logits);
let sum: f32 = probs.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-4,
"FALSIFIED SM-001-prop: sum={} for {} elements", sum, logits.len()
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn falsify_sm_002_prop_positive(
logits in proptest::collection::vec(-500.0_f32..500.0, 2..32),
) {
let probs = softmax_1d(&logits);
for (i, &p) in probs.iter().enumerate() {
prop_assert!(
p >= 0.0,
"FALSIFIED SM-002-prop: probs[{}] = {} negative (n={})", i, p, logits.len()
);
prop_assert!(
p.is_finite(),
"FALSIFIED SM-002-prop: probs[{}] = {} non-finite", i, p
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn falsify_sm_003_prop_order_preservation(
logits in proptest::collection::vec(-50.0_f32..50.0, 2..32),
) {
let has_dupes = logits.windows(2).any(|w| (w[0] - w[1]).abs() < 1e-10);
if has_dupes {
return Ok(());
}
let probs = softmax_1d(&logits);
let input_argmax = logits.iter().enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).unwrap().0;
let output_argmax = probs.iter().enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).unwrap().0;
prop_assert_eq!(
input_argmax, output_argmax,
"FALSIFIED SM-003-prop: argmax {} -> {} for {:?}", input_argmax, output_argmax, logits
);
}
}
}
#[cfg(test)]
mod gelu_contract_tests {
use super::*;
#[test]
fn falsify_ge_001_non_negativity() {
let x = Tensor::new(&[0.001, 0.1, 1.0, 5.0, 10.0, 100.0], &[6]);
let y = gelu(&x);
for (i, &val) in y.data().iter().enumerate() {
assert!(
val >= 0.0,
"FALSIFIED GE-001: gelu(positive)[{i}] = {val} < 0"
);
}
}
#[test]
fn falsify_ge_002_positive_monotonicity() {
let x = Tensor::new(&[0.1, 0.5, 1.0, 2.0, 5.0, 10.0], &[6]);
let y = gelu(&x);
let data = y.data();
for i in 1..data.len() {
assert!(
data[i] > data[i - 1],
"FALSIFIED GE-002: GELU not monotonic: [{i}]={} not > [{}]={}",
data[i],
i - 1,
data[i - 1]
);
}
}
#[test]
fn falsify_ge_003_zero_preservation() {
let x = Tensor::new(&[0.0], &[1]);
let y = gelu(&x);
assert!(
y.data()[0].abs() < 1e-7,
"FALSIFIED GE-003: GELU(0) = {}, expected 0",
y.data()[0]
);
}
#[test]
fn falsify_ge_006_large_input_stability() {
let x = Tensor::new(&[10.0, 50.0, -10.0, -50.0], &[4]);
let y = gelu(&x);
let data = y.data();
assert!(
(data[0] - 10.0).abs() < 0.01,
"FALSIFIED GE-006: GELU(10) = {}",
data[0]
);
assert!(
(data[1] - 50.0).abs() < 0.01,
"FALSIFIED GE-006: GELU(50) = {}",
data[1]
);
assert!(
data[2].abs() < 0.01,
"FALSIFIED GE-006: GELU(-10) = {}",
data[2]
);
assert!(
data[3].abs() < 0.01,
"FALSIFIED GE-006: GELU(-50) = {}",
data[3]
);
}
}
#[cfg(test)]
#[path = "functional_tests_bias_contract.rs"]
mod functional_tests_bias_contract;
#[cfg(test)]
#[path = "functional_tests_silu_contract.rs"]
mod functional_tests_silu_contract;
#[cfg(test)]
#[path = "functional_tests_swiglu_contract.rs"]
mod functional_tests_swiglu_contract;
#[cfg(test)]
#[path = "functional_tests_relu_contract.rs"]
mod functional_tests_relu_contract;
#[cfg(test)]
#[path = "functional_tests_sigmoid_contract.rs"]
mod functional_tests_sigmoid_contract;
include!("functional_include_01.rs");