use provable_contracts::traits::{
ActivationKernelV1, AdamwKernelV1, AttentionKernelV1, CrossEntropyKernelV1,
FlashAttentionV1, GqaKernelV1, LayernormKernelV1, MatmulKernelV1, RmsnormKernelV1,
RopeKernelV1, SiluKernelV1, SoftmaxKernelV1, SwigluKernelV1,
};
struct ReferenceKernels;
impl SoftmaxKernelV1 for ReferenceKernels {
fn softmax(&self, x: &[f32]) -> Vec<f32> {
if x.is_empty() {
return vec![];
}
let max = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = x.iter().map(|&xi| (xi - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
}
impl ActivationKernelV1 for ReferenceKernels {
fn gelu(&self, x: f32) -> Vec<f32> {
let inner = (2.0_f32 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
vec![0.5 * x * (1.0 + inner.tanh())]
}
fn relu(&self, x: f32) -> Vec<f32> {
vec![x.max(0.0)]
}
fn silu(&self, x: f32) -> Vec<f32> {
vec![x / (1.0 + (-x).exp())]
}
}
impl SiluKernelV1 for ReferenceKernels {
fn sigmoid(&self, x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect()
}
fn silu(&self, x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| v / (1.0 + (-v).exp())).collect()
}
}
impl RmsnormKernelV1 for ReferenceKernels {
fn rmsnorm(&self, x: &[f32]) -> Vec<f32> {
let rms = (x.iter().map(|v| v * v).sum::<f32>() / x.len() as f32).sqrt();
x.iter().map(|v| v / (rms + 1e-5)).collect()
}
}
impl LayernormKernelV1 for ReferenceKernels {
fn layernorm(&self, x: &[f32], gamma: &[f32]) -> Vec<f32> {
let mean = x.iter().sum::<f32>() / x.len() as f32;
let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let std = (var + 1e-5).sqrt();
x.iter()
.enumerate()
.map(|(i, v)| ((v - mean) / std) * gamma.get(i).copied().unwrap_or(1.0))
.collect()
}
fn statistics(&self, x: &[f32]) -> Vec<f32> {
let mean = x.iter().sum::<f32>() / x.len() as f32;
let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
vec![mean, var]
}
}
impl CrossEntropyKernelV1 for ReferenceKernels {
fn cross_entropy(&self, targets: &[f32], logits: &[f32]) -> Vec<f32> {
let log_sm = CrossEntropyKernelV1::log_softmax(self, logits);
let loss = -targets
.iter()
.zip(log_sm.iter())
.map(|(t, l)| t * l)
.sum::<f32>();
vec![loss]
}
fn log_softmax(&self, x: &[f32]) -> Vec<f32> {
let max = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp = x.iter().map(|v| (v - max).exp()).sum::<f32>().ln();
x.iter().map(|v| v - max - sum_exp).collect()
}
}
impl SwigluKernelV1 for ReferenceKernels {
fn silu(&self, x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| v / (1.0 + (-v).exp())).collect()
}
fn swiglu(&self, x: &[f32], w: &[f32], v: &[f32], _b: &[f32], _c: &[f32]) -> Vec<f32> {
let gate: Vec<f32> = x.iter().zip(w.iter()).map(|(xi, wi)| xi * wi).collect();
let silu_gate: Vec<f32> = gate.iter().map(|&g| g / (1.0 + (-g).exp())).collect();
let value: Vec<f32> = x.iter().zip(v.iter()).map(|(xi, vi)| xi * vi).collect();
silu_gate
.iter()
.zip(value.iter())
.map(|(s, val)| s * val)
.collect()
}
}
#[test]
fn contract_traits_compile() {
let k = ReferenceKernels;
let out = k.softmax(&[1.0, 2.0, 3.0]);
let sum: f32 = out.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "softmax must sum to 1.0");
let gelu_zero = k.gelu(0.0);
assert!(gelu_zero[0].abs() < 1e-6, "GELU(0) = 0");
let relu_neg = k.relu(-1.0);
assert_eq!(relu_neg[0], 0.0, "ReLU(-1) = 0");
let relu_pos = k.relu(2.0);
assert_eq!(relu_pos[0], 2.0, "ReLU(2) = 2");
let silu_zero = ActivationKernelV1::silu(&k, 0.0);
assert!(silu_zero[0].abs() < 1e-6, "SiLU(0) = 0");
}
#[test]
fn silu_kernel_v1_properties() {
let k = ReferenceKernels;
let sig = SiluKernelV1::sigmoid(&k, &[0.0]);
assert!((sig[0] - 0.5).abs() < 1e-6, "sigmoid(0) = 0.5");
let sig_wide = SiluKernelV1::sigmoid(&k, &[-10.0, 0.0, 10.0]);
for &v in &sig_wide {
assert!(v > 0.0 && v < 1.0, "sigmoid output must be in (0,1)");
}
let silu_zero = SiluKernelV1::silu(&k, &[0.0]);
assert!(silu_zero[0].abs() < 1e-6, "SiLU(0) = 0");
}
#[test]
fn rmsnorm_kernel_v1_properties() {
let k = ReferenceKernels;
let out = k.rmsnorm(&[3.0, 4.0]);
assert!(out.len() == 2);
let rms_out = (out.iter().map(|v| v * v).sum::<f32>() / out.len() as f32).sqrt();
assert!((rms_out - 1.0).abs() < 1e-3, "rmsnorm output should have ~unit RMS");
}
#[test]
fn layernorm_kernel_v1_properties() {
let k = ReferenceKernels;
let out = k.layernorm(&[1.0, 2.0, 3.0], &[1.0, 1.0, 1.0]);
let mean: f32 = out.iter().sum::<f32>() / out.len() as f32;
assert!(mean.abs() < 1e-5, "layernorm output should be zero-mean");
let stats = k.statistics(&[2.0, 4.0, 6.0]);
assert_eq!(stats.len(), 2);
assert!((stats[0] - 4.0).abs() < 1e-6, "mean of [2,4,6] = 4");
assert!((stats[1] - 8.0 / 3.0).abs() < 1e-5, "var of [2,4,6] = 8/3");
}
#[test]
fn cross_entropy_kernel_v1_properties() {
let k = ReferenceKernels;
let lsm = CrossEntropyKernelV1::log_softmax(&k, &[1.0, 2.0, 3.0]);
assert_eq!(lsm.len(), 3);
for &v in &lsm {
assert!(v <= 0.0, "log_softmax values must be <= 0");
}
let sum_exp: f32 = lsm.iter().map(|v| v.exp()).sum();
assert!((sum_exp - 1.0).abs() < 1e-5, "exp(log_softmax) must sum to 1");
let targets = vec![0.0, 0.0, 1.0];
let logits = vec![1.0, 2.0, 3.0];
let loss = k.cross_entropy(&targets, &logits);
assert_eq!(loss.len(), 1);
assert!(loss[0] > 0.0, "cross-entropy loss must be positive");
}
#[test]
fn swiglu_kernel_v1_properties() {
let k = ReferenceKernels;
let silu_zero = SwigluKernelV1::silu(&k, &[0.0]);
assert!(silu_zero[0].abs() < 1e-6, "SwigluKernelV1::silu(0) = 0");
let x = vec![1.0, 2.0];
let w = vec![1.0, 1.0]; let v = vec![1.0, 1.0]; let b = vec![0.0, 0.0]; let c = vec![0.0, 0.0]; let out = k.swiglu(&x, &w, &v, &b, &c);
assert_eq!(out.len(), 2);
for (i, &xi) in x.iter().enumerate() {
let expected = (xi / (1.0 + (-xi).exp())) * xi;
assert!(
(out[i] - expected).abs() < 1e-5,
"swiglu with identity weights: element {i}"
);
}
}
impl AttentionKernelV1 for ReferenceKernels {
fn attention(&self, q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
let n = (q.len() as f32).sqrt() as usize;
if n == 0 {
return vec![];
}
let d = q.len() / n;
let mut out = vec![0.0f32; n * d];
for i in 0..n {
let mut scores = vec![0.0f32; n];
for j in 0..n {
for kk in 0..d {
scores[j] += q[i * d + kk] * k[j * d + kk];
}
}
let scale = (d as f32).sqrt();
let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = scores.iter().map(|s| ((s / scale) - max).exp()).collect();
let sum: f32 = exps.iter().sum();
for j in 0..n {
for kk in 0..d {
out[i * d + kk] += (exps[j] / sum) * v[j * d + kk];
}
}
}
out
}
}
impl FlashAttentionV1 for ReferenceKernels {
fn flash_attention(&self, q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
AttentionKernelV1::attention(self, q, k, v)
}
}
impl GqaKernelV1 for ReferenceKernels {
fn gqa(&self, q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
AttentionKernelV1::attention(self, q, k, v)
}
}
impl MatmulKernelV1 for ReferenceKernels {
fn matmul(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
let n = (a.len() as f32).sqrt() as usize;
if n == 0 {
return vec![];
}
let mut c = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
for k in 0..n {
c[i * n + j] += a[i * n + k] * b[k * n + j];
}
}
}
c
}
fn quantized_dot(&self, b: &[f32], s_b: f32) -> Vec<f32> {
vec![b.iter().sum::<f32>() * s_b]
}
}
impl RopeKernelV1 for ReferenceKernels {
fn rope(&self, x: &[f32], m: &[f32]) -> Vec<f32> {
let pos = m.first().copied().unwrap_or(0.0);
let mut out = x.to_vec();
for i in (0..x.len()).step_by(2) {
if i + 1 < x.len() {
let theta = pos / 10000_f32.powf(i as f32 / x.len() as f32);
let (sin_t, cos_t) = theta.sin_cos();
out[i] = x[i] * cos_t - x[i + 1] * sin_t;
out[i + 1] = x[i] * sin_t + x[i + 1] * cos_t;
}
}
out
}
}
impl AdamwKernelV1 for ReferenceKernels {
fn adam_moments(&self, g_t: &[f32]) -> Vec<f32> {
g_t.iter().map(|g| 0.9 * 0.0 + 0.1 * g).collect()
}
fn adam_variance(&self, g_t: &[f32]) -> Vec<f32> {
g_t.iter().map(|g| 0.999 * 0.0 + 0.001 * g * g).collect()
}
fn bias_correction(&self, input: &[f32]) -> Vec<f32> {
input.iter().map(|v| v / (1.0 - 0.9)).collect()
}
fn weight_update(&self, theta: &[f32]) -> Vec<f32> {
theta.iter().map(|t| t - 0.001 * t).collect()
}
}
#[test]
fn attention_output_size() {
let k = ReferenceKernels;
let out = AttentionKernelV1::attention(&k, &[1.0; 4], &[1.0; 4], &[1.0; 4]);
assert_eq!(out.len(), 4);
}
#[test]
fn matmul_identity() {
let k = ReferenceKernels;
let identity = vec![1.0, 0.0, 0.0, 1.0]; let a = vec![1.0, 2.0, 3.0, 4.0];
let out = MatmulKernelV1::matmul(&k, &a, &identity);
assert!((out[0] - 1.0).abs() < 1e-5 && (out[3] - 4.0).abs() < 1e-5);
}
#[test]
fn rope_preserves_norm() {
let k = ReferenceKernels;
let x = vec![1.0, 0.0, 0.0, 1.0];
let out = RopeKernelV1::rope(&k, &x, &[1.0]);
let norm_in: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
let norm_out: f32 = out.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm_in - norm_out).abs() < 1e-5);
}