Skip to main content

provable_contracts/kernels/
ops.rs

1//! Shared kernel primitives: dot product, softmax row, score matrix.
2//!
3//! These building blocks are used across attention, GQA, and flash attention kernels.
4//! Centralizing them eliminates duplicated DataTransformation patterns.
5
6/// Dot product of two slices.
7#[inline]
8pub fn dot(a: &[f32], b: &[f32]) -> f32 {
9    debug_assert_eq!(a.len(), b.len());
10    let mut sum = 0.0f32;
11    for i in 0..a.len() {
12        sum += a[i] * b[i];
13    }
14    sum
15}
16
17/// In-place softmax over a contiguous row.
18///
19/// Uses the numerically stable formulation: subtract max, exponentiate, normalize.
20pub fn softmax_row(row: &mut [f32]) {
21    let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
22    let mut sum = 0.0f32;
23    for v in row.iter_mut() {
24        *v = (*v - max_val).exp();
25        sum += *v;
26    }
27    if sum > 0.0 {
28        for v in row.iter_mut() {
29            *v /= sum;
30        }
31    }
32}
33
34/// Apply softmax to each row of a `rows x cols` matrix (in-place).
35pub fn softmax_rows(matrix: &mut [f32], rows: usize, cols: usize) {
36    debug_assert_eq!(matrix.len(), rows * cols);
37    for i in 0..rows {
38        softmax_row(&mut matrix[i * cols..(i + 1) * cols]);
39    }
40}
41
42/// Compute scaled dot-product score matrix: `scores[i,j] = Q[i] . K[j] / sqrt(d)`.
43///
44/// Q is `m x d`, K is `n x d`, scores is `m x n` (row-major).
45pub fn score_matrix(q: &[f32], k: &[f32], m: usize, n: usize, d: usize, scores: &mut [f32]) {
46    debug_assert_eq!(q.len(), m * d);
47    debug_assert_eq!(k.len(), n * d);
48    debug_assert_eq!(scores.len(), m * n);
49    let scale = 1.0 / (d as f32).sqrt();
50
51    for i in 0..m {
52        for j in 0..n {
53            scores[i * n + j] = dot(&q[i * d..(i + 1) * d], &k[j * d..(j + 1) * d]) * scale;
54        }
55    }
56}
57
58/// Matrix multiply: `output = scores * V`, where scores is `rows x cols` and V is `cols x d_v`.
59///
60/// This is the final step in attention: applying softmax weights to value vectors.
61/// `output` must be `rows x d_v`, zeroed or overwritten.
62pub fn matmul_sv(
63    scores: &[f32],
64    v: &[f32],
65    rows: usize,
66    cols: usize,
67    d_v: usize,
68    output: &mut [f32],
69) {
70    debug_assert_eq!(scores.len(), rows * cols);
71    debug_assert_eq!(v.len(), cols * d_v);
72    debug_assert_eq!(output.len(), rows * d_v);
73
74    for i in 0..rows {
75        for j in 0..d_v {
76            let mut sum = 0.0f32;
77            for c in 0..cols {
78                sum += scores[i * cols + c] * v[c * d_v + j];
79            }
80            output[i * d_v + j] = sum;
81        }
82    }
83}
84
85/// Weighted sum: `output[i] += weight * v_row[i]` for accumulation in attention.
86#[inline]
87pub fn weighted_accumulate(output: &mut [f32], weight: f32, v_row: &[f32]) {
88    debug_assert_eq!(output.len(), v_row.len());
89    for (o, v) in output.iter_mut().zip(v_row.iter()) {
90        *o += weight * v;
91    }
92}
93
94/// Generate a sequential float test vector: `[0*scale, 1*scale, 2*scale, ...]`.
95///
96/// Used across attention kernel tests to create deterministic Q/K/V test data.
97#[cfg(test)]
98pub fn sequential_floats(len: usize, scale: f32) -> Vec<f32> {
99    (0..len).map(|i| (i as f32) * scale).collect()
100}
101
102/// Generate a patterned float test vector: `[(i % modulus - offset) * scale, ...]`.
103///
104/// Used in flash attention tests for varied test data.
105#[cfg(test)]
106pub fn patterned_floats(len: usize, modulus: usize, offset: f32, scale: f32) -> Vec<f32> {
107    (0..len)
108        .map(|i| ((i % modulus) as f32 - offset) * scale)
109        .collect()
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn dot_basic() {
118        assert!((dot(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]) - 32.0).abs() < 1e-6);
119    }
120
121    #[test]
122    fn dot_zero() {
123        assert_eq!(dot(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
124    }
125
126    #[test]
127    fn softmax_row_uniform() {
128        let mut row = vec![1.0; 4];
129        softmax_row(&mut row);
130        for v in &row {
131            assert!((*v - 0.25).abs() < 1e-6);
132        }
133    }
134
135    #[test]
136    fn softmax_row_sums_to_one() {
137        let mut row = vec![1.0, 2.0, 3.0, 4.0];
138        softmax_row(&mut row);
139        let sum: f32 = row.iter().sum();
140        assert!((sum - 1.0).abs() < 1e-6);
141    }
142
143    #[test]
144    fn score_matrix_basic() {
145        // 1x2 Q, 1x2 K => 1x1 scores
146        let q = [1.0, 0.0];
147        let k = [1.0, 0.0];
148        let mut scores = [0.0f32; 1];
149        score_matrix(&q, &k, 1, 1, 2, &mut scores);
150        // dot = 1.0, scale = 1/sqrt(2) ≈ 0.707
151        assert!((scores[0] - 1.0 / 2.0f32.sqrt()).abs() < 1e-5);
152    }
153
154    #[test]
155    fn matmul_sv_basic() {
156        // scores = [[0.5, 0.5]], V = [[1.0, 2.0], [3.0, 4.0]]
157        // output = [[0.5*1+0.5*3, 0.5*2+0.5*4]] = [[2.0, 3.0]]
158        let scores = [0.5, 0.5];
159        let v = [1.0, 2.0, 3.0, 4.0];
160        let mut output = [0.0f32; 2];
161        matmul_sv(&scores, &v, 1, 2, 2, &mut output);
162        assert!((output[0] - 2.0).abs() < 1e-6);
163        assert!((output[1] - 3.0).abs() < 1e-6);
164    }
165
166    #[test]
167    fn matmul_sv_identity_weights() {
168        // scores = [[1, 0], [0, 1]], V = [[10, 20], [30, 40]]
169        // output = [[10, 20], [30, 40]]
170        let scores = [1.0, 0.0, 0.0, 1.0];
171        let v = [10.0, 20.0, 30.0, 40.0];
172        let mut output = [0.0f32; 4];
173        matmul_sv(&scores, &v, 2, 2, 2, &mut output);
174        assert!((output[0] - 10.0).abs() < 1e-6);
175        assert!((output[1] - 20.0).abs() < 1e-6);
176        assert!((output[2] - 30.0).abs() < 1e-6);
177        assert!((output[3] - 40.0).abs() < 1e-6);
178    }
179
180    #[test]
181    fn weighted_accumulate_basic() {
182        let mut out = [1.0, 2.0];
183        weighted_accumulate(&mut out, 0.5, &[4.0, 6.0]);
184        assert!((out[0] - 3.0).abs() < 1e-6);
185        assert!((out[1] - 5.0).abs() < 1e-6);
186    }
187}