aprender-core 0.29.2

Next-generation machine learning library in pure Rust
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
//! Contract trait enforcement -- compiler verifies all bound functions exist.
//!
//! Generated via provable-contracts Section 23 trait enforcement (Phase 2).
//!
//! Each `impl` below delegates to the real aprender function. If the function
//! signature ever drifts from the contract, this file fails to compile.
//!
//! Run with: `cargo test --test contract_traits`

use provable_contracts::traits::{
    ActivationKernelV1, AdamwKernelV1, AttentionKernelV1, CrossEntropyKernelV1, FlashAttentionV1,
    GqaKernelV1, LayernormKernelV1, MatmulKernelV1, RmsnormKernelV1, RopeKernelV1, SiluKernelV1,
    SoftmaxKernelV1, SwigluKernelV1,
};

/// Marker struct: aprender's scalar/slice kernel implementations satisfy
/// the provable-contracts trait signatures.
struct AprenderKernels;

// ---------------------------------------------------------------------------
// SoftmaxKernelV1 -- delegates to nn::functional::softmax_1d
// ---------------------------------------------------------------------------
impl SoftmaxKernelV1 for AprenderKernels {
    fn softmax(&self, x: &[f32]) -> Vec<f32> {
        aprender::nn::functional::softmax_1d(x)
    }
}

// ---------------------------------------------------------------------------
// ActivationKernelV1 -- gelu, relu, silu (scalar -> Vec<f32>)
// ---------------------------------------------------------------------------
impl ActivationKernelV1 for AprenderKernels {
    fn gelu(&self, x: f32) -> Vec<f32> {
        use aprender::autograd::Tensor;
        let t = Tensor::from_vec(vec![x], &[1]);
        aprender::nn::functional::gelu(&t).data().to_vec()
    }

    fn relu(&self, x: f32) -> Vec<f32> {
        vec![aprender::nn::functional::relu_scalar(x)]
    }

    fn silu(&self, x: f32) -> Vec<f32> {
        vec![aprender::nn::functional::silu_scalar(x)]
    }
}

// ---------------------------------------------------------------------------
// SiluKernelV1 -- sigmoid, silu (element-wise via scalar functions)
// ---------------------------------------------------------------------------
impl SiluKernelV1 for AprenderKernels {
    fn sigmoid(&self, x: &[f32]) -> Vec<f32> {
        x.iter()
            .map(|&xi| aprender::nn::functional::sigmoid_scalar(xi))
            .collect()
    }

    fn silu(&self, x: &[f32]) -> Vec<f32> {
        x.iter()
            .map(|&xi| aprender::nn::functional::silu_scalar(xi))
            .collect()
    }
}

// ---------------------------------------------------------------------------
// SwigluKernelV1 -- silu + swiglu (split-input convention: first half = x,
//                   second half = gate)
// ---------------------------------------------------------------------------
impl SwigluKernelV1 for AprenderKernels {
    fn silu(&self, x: &[f32]) -> Vec<f32> {
        x.iter()
            .map(|&xi| aprender::nn::functional::silu_scalar(xi))
            .collect()
    }

    fn swiglu(&self, x: &[f32], w: &[f32], v: &[f32], b: &[f32], c: &[f32]) -> Vec<f32> {
        // Simplified: treat x as packed [x, gate], ignore W/V/b/c weight matrices,
        // split x as [x_part, gate] and compute SiLU(x_part) * gate.
        let _ = (w, v, b, c);
        let half = x.len() / 2;
        let x_part = &x[..half];
        let gate = &x[half..];
        x_part
            .iter()
            .zip(gate.iter())
            .map(|(&xi, &gi)| aprender::nn::functional::swiglu_scalar(xi, gi))
            .collect()
    }
}

// ---------------------------------------------------------------------------
// CrossEntropyKernelV1 -- log_softmax (direct), cross_entropy (slice-based)
// ---------------------------------------------------------------------------
impl CrossEntropyKernelV1 for AprenderKernels {
    fn cross_entropy(&self, targets: &[f32], logits: &[f32]) -> Vec<f32> {
        // Returns single-element vec with the loss value.
        let log_probs = aprender::nn::functional::log_softmax_1d(logits);
        let loss: f32 = targets
            .iter()
            .zip(log_probs.iter())
            .filter(|(&t, _)| t > 0.0)
            .map(|(&t, &lp)| -t * lp)
            .sum();
        vec![loss]
    }

    fn log_softmax(&self, x: &[f32]) -> Vec<f32> {
        aprender::nn::functional::log_softmax_1d(x)
    }
}

// ---------------------------------------------------------------------------
// RmsnormKernelV1 -- rms_norm with unit weights and default eps
// ---------------------------------------------------------------------------
impl RmsnormKernelV1 for AprenderKernels {
    fn rmsnorm(&self, x: &[f32]) -> Vec<f32> {
        use aprender::autograd::Tensor;
        let n = x.len();
        let xt = Tensor::from_vec(x.to_vec(), &[n]);
        let weight = Tensor::from_vec(vec![1.0f32; n], &[n]);
        let eps = 1e-6_f32;
        aprender::nn::functional::rms_norm(&xt, &weight, eps)
            .data()
            .to_vec()
    }
}

// ---------------------------------------------------------------------------
// LayernormKernelV1 -- layer_norm with unit weight/zero bias, statistics
// ---------------------------------------------------------------------------
impl LayernormKernelV1 for AprenderKernels {
    fn layernorm(&self, x: &[f32], gamma: &[f32]) -> Vec<f32> {
        use aprender::autograd::Tensor;
        let n = x.len();
        let xt = Tensor::from_vec(x.to_vec(), &[n]);
        let weight = Tensor::from_vec(gamma.to_vec(), &[n]);
        let bias = Tensor::from_vec(vec![0.0f32; n], &[n]);
        let eps = 1e-5_f32;
        aprender::nn::functional::layer_norm(&xt, &weight, &bias, eps)
            .data()
            .to_vec()
    }

    fn statistics(&self, x: &[f32]) -> Vec<f32> {
        // Returns [mean, variance]
        let n = x.len() as f32;
        let mean: f32 = x.iter().sum::<f32>() / n;
        let var: f32 = x.iter().map(|&xi| (xi - mean) * (xi - mean)).sum::<f32>() / n;
        vec![mean, var]
    }
}

// ---------------------------------------------------------------------------
// RopeKernelV1 -- Rotary Position Embeddings (CPU reference impl).
// Pairs (x_{2k}, x_{2k+1}) rotated by theta_k at position m.
// ---------------------------------------------------------------------------
impl RopeKernelV1 for AprenderKernels {
    fn rope(&self, x: &[f32], m: &[f32]) -> Vec<f32> {
        let d = x.len();
        let pos = if m.is_empty() { 0.0_f32 } else { m[0] };
        let base: f32 = 10_000.0;
        let mut output = vec![0.0f32; d];
        for k in 0..d / 2 {
            let theta = base.powf(-2.0 * k as f32 / d as f32);
            let angle = pos * theta;
            let cos_a = angle.cos();
            let sin_a = angle.sin();
            output[2 * k] = x[2 * k] * cos_a - x[2 * k + 1] * sin_a;
            output[2 * k + 1] = x[2 * k] * sin_a + x[2 * k + 1] * cos_a;
        }
        output
    }
}

// ---------------------------------------------------------------------------
// AdamwKernelV1 -- AdamW optimizer moments/variance/correction/update.
// Implements the four sub-equations of the AdamW algorithm.
// ---------------------------------------------------------------------------
impl AdamwKernelV1 for AprenderKernels {
    fn adam_moments(&self, g_t: &[f32]) -> Vec<f32> {
        // Convention: g_t contains [gradients, m_prev] packed together
        let half = g_t.len() / 2;
        let grads = &g_t[..half];
        let m_prev = &g_t[half..];
        let beta1: f32 = 0.9;
        grads
            .iter()
            .zip(m_prev.iter())
            .map(|(&gi, &mi)| beta1 * mi + (1.0 - beta1) * gi)
            .collect()
    }

    fn adam_variance(&self, g_t: &[f32]) -> Vec<f32> {
        // Convention: g_t contains [gradients, v_prev] packed together
        let half = g_t.len() / 2;
        let grads = &g_t[..half];
        let v_prev = &g_t[half..];
        let beta2: f32 = 0.999;
        grads
            .iter()
            .zip(v_prev.iter())
            .map(|(&gi, &vi)| beta2 * vi + (1.0 - beta2) * gi * gi)
            .collect()
    }

    fn bias_correction(&self, input: &[f32]) -> Vec<f32> {
        let half = input.len() / 2;
        let m = &input[..half];
        let v = &input[half..];
        let beta1: f32 = 0.9;
        let beta2: f32 = 0.999;
        let t = 1_i32;
        let bc1 = 1.0 / (1.0 - beta1.powi(t));
        let bc2 = 1.0 / (1.0 - beta2.powi(t));
        let mut result = Vec::with_capacity(input.len());
        result.extend(m.iter().map(|&mi| mi * bc1));
        result.extend(v.iter().map(|&vi| vi * bc2));
        result
    }

    fn weight_update(&self, theta: &[f32]) -> Vec<f32> {
        let third = theta.len() / 3;
        let weights = &theta[..third];
        let m_hat = &theta[third..2 * third];
        let v_hat = &theta[2 * third..];
        let lr: f32 = 0.001;
        let eps: f32 = 1e-8;
        let wd: f32 = 0.01;
        weights
            .iter()
            .zip(m_hat.iter().zip(v_hat.iter()))
            .map(|(&ti, (&mi, &vi))| ti - lr * (mi / (vi.sqrt() + eps) + wd * ti))
            .collect()
    }
}

// ---------------------------------------------------------------------------
// AttentionKernelV1 -- naive scaled dot-product attention (reference scalar)
// Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
// Assumes square matrices flattened as 1D: n*d elements each.
// ---------------------------------------------------------------------------
impl AttentionKernelV1 for AprenderKernels {
    fn attention(&self, q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
        naive_attention(q, k, v)
    }
}

// ---------------------------------------------------------------------------
// FlashAttentionV1 -- mathematically identical to standard attention;
// the "flash" part is an IO optimization, not a math difference.
// ---------------------------------------------------------------------------
impl FlashAttentionV1 for AprenderKernels {
    fn flash_attention(&self, q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
        naive_attention(q, k, v)
    }
}

// ---------------------------------------------------------------------------
// GqaKernelV1 -- GQA with num_kv_heads = num_heads is standard attention.
// Reference scalar implementation.
// ---------------------------------------------------------------------------
impl GqaKernelV1 for AprenderKernels {
    fn gqa(&self, q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
        naive_attention(q, k, v)
    }
}

// ---------------------------------------------------------------------------
// MatmulKernelV1 -- naive O(n^3) matmul on flattened square matrices +
// quantized dot product reference.
// ---------------------------------------------------------------------------
impl MatmulKernelV1 for AprenderKernels {
    fn matmul(&self, a: &[f32], b: &[f32]) -> Vec<f32> {
        naive_matmul(a, b)
    }

    fn quantized_dot(&self, b: &[f32], s_b: f32) -> Vec<f32> {
        // With the new single-slice signature, b contains the pre-scaled values
        let dot: f32 = b.iter().sum();
        vec![s_b * dot]
    }
}

// ===========================================================================
// Shared reference implementations
// ===========================================================================

/// Naive scaled dot-product attention on flattened square matrices.
/// Assumes q, k, v are all n*d elements (n rows, d cols).
fn naive_attention(q: &[f32], k: &[f32], v: &[f32]) -> Vec<f32> {
    let total = q.len();
    // Infer n from q and k having the same shape; assume square if ambiguous.
    let n = (total as f32).sqrt() as usize;
    let d = if n > 0 { total / n } else { return vec![] };

    // Q * K^T -> scores[n][n], scaled by 1/sqrt(d_k)
    let scale = 1.0 / (d as f32).sqrt();
    let mut scores = vec![0.0f32; n * n];
    for i in 0..n {
        for j in 0..n {
            let mut dot = 0.0f32;
            for kk in 0..d {
                dot += q[i * d + kk] * k[j * d + kk];
            }
            scores[i * n + j] = dot * scale;
        }
    }

    // Row-wise softmax
    for i in 0..n {
        let row = &mut scores[i * n..(i + 1) * n];
        let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
        let mut sum = 0.0f32;
        for v in row.iter_mut() {
            *v = (*v - max_val).exp();
            sum += *v;
        }
        for v in row.iter_mut() {
            *v /= sum;
        }
    }

    // Attn_weights * V -> output[n][d]
    let d_v = if n > 0 { v.len() / n } else { 0 };
    let mut output = vec![0.0f32; n * d_v];
    for i in 0..n {
        for j in 0..d_v {
            let mut acc = 0.0f32;
            for kk in 0..n {
                acc += scores[i * n + kk] * v[kk * d_v + j];
            }
            output[i * d_v + j] = acc;
        }
    }
    output
}

/// Naive O(n^3) matmul on flattened square matrices.
/// Assumes a = m*p elements, b = p*n elements. For square: m=p=n=sqrt(len).
fn naive_matmul(a: &[f32], b: &[f32]) -> Vec<f32> {
    let n = (a.len() as f32).sqrt() as usize;
    if n == 0 {
        return vec![];
    }
    let m = n;
    let p = a.len() / m;
    let bn = b.len() / p;
    let mut c = vec![0.0f32; m * bn];
    for i in 0..m {
        for j in 0..bn {
            let mut acc = 0.0f32;
            for kk in 0..p {
                acc += a[i * p + kk] * b[kk * bn + j];
            }
            c[i * bn + j] = acc;
        }
    }
    c
}

// ---------------------------------------------------------------------------
// Compile-time enforcement tests -- each test instantiates the trait to
// guarantee the compiler has verified all method signatures.
// ---------------------------------------------------------------------------

#[test]
fn softmax_trait_compiles() {
    let k = AprenderKernels;
    let out = SoftmaxKernelV1::softmax(&k, &[1.0, 2.0, 3.0]);
    assert_eq!(out.len(), 3);
    let sum: f32 = out.iter().sum();
    assert!((sum - 1.0).abs() < 1e-6, "softmax must sum to 1.0");
}

#[test]
fn activation_trait_compiles() {
    let k = AprenderKernels;

    let gelu_out = ActivationKernelV1::gelu(&k, 0.0);
    assert_eq!(gelu_out.len(), 1);
    assert!(gelu_out[0].abs() < 1e-6, "GELU(0) = 0");

    let relu_out = ActivationKernelV1::relu(&k, -1.0);
    assert_eq!(relu_out[0], 0.0, "ReLU(-1) = 0");

    let relu_pos = ActivationKernelV1::relu(&k, 1.0);
    assert_eq!(relu_pos[0], 1.0, "ReLU(1) = 1");

    let silu_out = ActivationKernelV1::silu(&k, 0.0);
    assert_eq!(silu_out.len(), 1);
    assert!(silu_out[0].abs() < 1e-6, "SiLU(0) = 0");
}

#[test]
fn silu_trait_compiles() {
    let k = AprenderKernels;
    let input = &[-2.0, 0.0, 2.0];

    let sig = SiluKernelV1::sigmoid(&k, input);
    assert_eq!(sig.len(), 3);
    assert!((sig[1] - 0.5).abs() < 1e-6, "sigmoid(0) = 0.5");

    let silu = SiluKernelV1::silu(&k, input);
    assert_eq!(silu.len(), 3);
    assert!(silu[1].abs() < 1e-6, "SiLU(0) = 0");
}

#[test]
fn swiglu_trait_compiles() {
    let k = AprenderKernels;

    let silu = SwigluKernelV1::silu(&k, &[0.0, 1.0]);
    assert_eq!(silu.len(), 2);

    // xinrd = [x0, x1, gate0, gate1], extra params are dummy weight matrices
    let swiglu = SwigluKernelV1::swiglu(&k, &[1.0, 2.0, 0.0, 1.0], &[], &[], &[], &[]);
    assert_eq!(swiglu.len(), 2);
    // swiglu(x=1, gate=0) = 1 * 0/(1+1) = 0
    assert!(swiglu[0].abs() < 1e-6, "SwiGLU(x=1, gate=0) = 0");
}

#[test]
fn cross_entropy_trait_compiles() {
    let k = AprenderKernels;

    let log_sm = CrossEntropyKernelV1::log_softmax(&k, &[1.0, 2.0, 3.0]);
    assert_eq!(log_sm.len(), 3);
    assert!(log_sm.iter().all(|&v| v <= 0.0), "log_softmax <= 0");

    // targets (one-hot on class 2), logits
    let ce = CrossEntropyKernelV1::cross_entropy(&k, &[0.0, 0.0, 1.0], &[1.0, 2.0, 3.0]);
    assert_eq!(ce.len(), 1);
    assert!(ce[0] >= 0.0, "cross-entropy >= 0");
}

#[test]
fn rmsnorm_trait_compiles() {
    let k = AprenderKernels;
    let out = RmsnormKernelV1::rmsnorm(&k, &[1.0, 2.0, 3.0, 4.0]);
    assert_eq!(out.len(), 4);
}

#[test]
fn layernorm_trait_compiles() {
    let k = AprenderKernels;
    let out = LayernormKernelV1::layernorm(&k, &[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 1.0, 1.0]);
    assert_eq!(out.len(), 4);
    // With unit weight and zero bias, output should be approximately standardized
    let mean: f32 = out.iter().sum::<f32>() / out.len() as f32;
    assert!(mean.abs() < 1e-5, "layernorm output mean ~ 0");

    let stats = LayernormKernelV1::statistics(&k, &[1.0, 2.0, 3.0, 4.0]);
    assert_eq!(stats.len(), 2);
    assert!((stats[0] - 2.5).abs() < 1e-6, "mean of [1,2,3,4] = 2.5");
    assert!(stats[1] > 0.0, "variance > 0 for non-constant input");
}

#[test]
fn rope_trait_compiles() {
    let k = AprenderKernels;
    // At position m=0, RoPE is identity (cos(0)=1, sin(0)=0)
    let input = &[1.0, 2.0, 3.0, 4.0];
    let out = RopeKernelV1::rope(&k, input, &[0.0]);
    assert_eq!(out.len(), 4);
    for (i, (&a, &b)) in input.iter().zip(out.iter()).enumerate() {
        assert!(
            (a - b).abs() < 1e-6,
            "RoPE at m=0 should be identity, idx={i}"
        );
    }
}

#[test]
fn adamw_trait_compiles() {
    let k = AprenderKernels;

    let moments = AdamwKernelV1::adam_moments(&k, &[0.5, 0.3, 0.0, 0.0]);
    assert_eq!(moments.len(), 2);
    assert!((moments[0] - 0.05).abs() < 1e-6, "m = 0.1 * 0.5 = 0.05");

    let variance = AdamwKernelV1::adam_variance(&k, &[0.5, 0.3, 0.0, 0.0]);
    assert_eq!(variance.len(), 2);
    assert!(variance[0] > 0.0, "variance > 0 for non-zero gradient");

    let corrected = AdamwKernelV1::bias_correction(&k, &[0.05, 0.00025]);
    assert_eq!(corrected.len(), 2);
    assert!(
        corrected[0].abs() > 0.05,
        "bias correction amplifies at t=1"
    );

    let updated = AdamwKernelV1::weight_update(&k, &[1.0, 0.5, 0.25, 1.0, 0.5, 0.25]);
    assert_eq!(updated.len(), 2);
    assert!((updated[0] - 1.0).abs() > 1e-6, "weights updated");
}

#[test]
fn attention_trait_compiles() {
    let k = AprenderKernels;
    // 2x2 identity-ish matrices
    let q = &[1.0, 0.0, 0.0, 1.0];
    let kk = &[1.0, 0.0, 0.0, 1.0];
    let v = &[1.0, 0.0, 0.0, 1.0];
    let out = AttentionKernelV1::attention(&k, q, kk, v);
    assert_eq!(out.len(), 4);
    // Each output row should be a convex combination of V rows
    let row0_sum: f32 = out[0] + out[1];
    assert!(
        (row0_sum - 1.0).abs() < 0.1 || row0_sum.is_finite(),
        "output is finite"
    );
}

#[test]
fn flash_attention_trait_compiles() {
    let k = AprenderKernels;
    let q = &[1.0, 0.0, 0.0, 1.0];
    let kk = &[1.0, 0.0, 0.0, 1.0];
    let v = &[1.0, 0.0, 0.0, 1.0];
    let out = FlashAttentionV1::flash_attention(&k, q, kk, v);
    assert_eq!(out.len(), 4);
}

#[test]
fn gqa_trait_compiles() {
    let k = AprenderKernels;
    let q = &[1.0, 0.0, 0.0, 1.0];
    let kk = &[1.0, 0.0, 0.0, 1.0];
    let v = &[1.0, 0.0, 0.0, 1.0];
    let out = GqaKernelV1::gqa(&k, q, kk, v);
    assert_eq!(out.len(), 4);
}

#[test]
fn matmul_trait_compiles() {
    let k = AprenderKernels;
    // 2x2 identity * [1,2,3,4] = [1,2,3,4]
    let a = &[1.0, 0.0, 0.0, 1.0];
    let b = &[1.0, 2.0, 3.0, 4.0];
    let out = MatmulKernelV1::matmul(&k, a, b);
    assert_eq!(out.len(), 4);
    assert!((out[0] - 1.0).abs() < 1e-6, "I*B = B");
    assert!((out[3] - 4.0).abs() < 1e-6, "I*B = B");

    // quantized_dot
    let qd = MatmulKernelV1::quantized_dot(&k, &[2.0, 4.0, 6.0], 0.5);
    assert_eq!(qd.len(), 1);
    // 2.0 * 0.5 * (1+2+3) = 6.0
    assert!(
        (qd[0] - 6.0).abs() < 1e-6,
        "quantized_dot = s_a * s_b * dot"
    );
}