provable_contracts/kernels/
ops.rs1#[inline]
8pub fn dot(a: &[f32], b: &[f32]) -> f32 {
9 debug_assert_eq!(a.len(), b.len());
10 let mut sum = 0.0f32;
11 for i in 0..a.len() {
12 sum += a[i] * b[i];
13 }
14 sum
15}
16
17pub fn softmax_row(row: &mut [f32]) {
21 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
22 let mut sum = 0.0f32;
23 for v in row.iter_mut() {
24 *v = (*v - max_val).exp();
25 sum += *v;
26 }
27 if sum > 0.0 {
28 for v in row.iter_mut() {
29 *v /= sum;
30 }
31 }
32}
33
34pub fn softmax_rows(matrix: &mut [f32], rows: usize, cols: usize) {
36 debug_assert_eq!(matrix.len(), rows * cols);
37 for i in 0..rows {
38 softmax_row(&mut matrix[i * cols..(i + 1) * cols]);
39 }
40}
41
42pub fn score_matrix(q: &[f32], k: &[f32], m: usize, n: usize, d: usize, scores: &mut [f32]) {
46 debug_assert_eq!(q.len(), m * d);
47 debug_assert_eq!(k.len(), n * d);
48 debug_assert_eq!(scores.len(), m * n);
49 let scale = 1.0 / (d as f32).sqrt();
50
51 for i in 0..m {
52 for j in 0..n {
53 scores[i * n + j] = dot(&q[i * d..(i + 1) * d], &k[j * d..(j + 1) * d]) * scale;
54 }
55 }
56}
57
58pub fn matmul_sv(
63 scores: &[f32],
64 v: &[f32],
65 rows: usize,
66 cols: usize,
67 d_v: usize,
68 output: &mut [f32],
69) {
70 debug_assert_eq!(scores.len(), rows * cols);
71 debug_assert_eq!(v.len(), cols * d_v);
72 debug_assert_eq!(output.len(), rows * d_v);
73
74 for i in 0..rows {
75 for j in 0..d_v {
76 let mut sum = 0.0f32;
77 for c in 0..cols {
78 sum += scores[i * cols + c] * v[c * d_v + j];
79 }
80 output[i * d_v + j] = sum;
81 }
82 }
83}
84
85#[inline]
87pub fn weighted_accumulate(output: &mut [f32], weight: f32, v_row: &[f32]) {
88 debug_assert_eq!(output.len(), v_row.len());
89 for (o, v) in output.iter_mut().zip(v_row.iter()) {
90 *o += weight * v;
91 }
92}
93
94#[cfg(test)]
98pub fn sequential_floats(len: usize, scale: f32) -> Vec<f32> {
99 (0..len).map(|i| (i as f32) * scale).collect()
100}
101
102#[cfg(test)]
106pub fn patterned_floats(len: usize, modulus: usize, offset: f32, scale: f32) -> Vec<f32> {
107 (0..len)
108 .map(|i| ((i % modulus) as f32 - offset) * scale)
109 .collect()
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn dot_basic() {
118 assert!((dot(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]) - 32.0).abs() < 1e-6);
119 }
120
121 #[test]
122 fn dot_zero() {
123 assert_eq!(dot(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
124 }
125
126 #[test]
127 fn softmax_row_uniform() {
128 let mut row = vec![1.0; 4];
129 softmax_row(&mut row);
130 for v in &row {
131 assert!((*v - 0.25).abs() < 1e-6);
132 }
133 }
134
135 #[test]
136 fn softmax_row_sums_to_one() {
137 let mut row = vec![1.0, 2.0, 3.0, 4.0];
138 softmax_row(&mut row);
139 let sum: f32 = row.iter().sum();
140 assert!((sum - 1.0).abs() < 1e-6);
141 }
142
143 #[test]
144 fn score_matrix_basic() {
145 let q = [1.0, 0.0];
147 let k = [1.0, 0.0];
148 let mut scores = [0.0f32; 1];
149 score_matrix(&q, &k, 1, 1, 2, &mut scores);
150 assert!((scores[0] - 1.0 / 2.0f32.sqrt()).abs() < 1e-5);
152 }
153
154 #[test]
155 fn matmul_sv_basic() {
156 let scores = [0.5, 0.5];
159 let v = [1.0, 2.0, 3.0, 4.0];
160 let mut output = [0.0f32; 2];
161 matmul_sv(&scores, &v, 1, 2, 2, &mut output);
162 assert!((output[0] - 2.0).abs() < 1e-6);
163 assert!((output[1] - 3.0).abs() < 1e-6);
164 }
165
166 #[test]
167 fn matmul_sv_identity_weights() {
168 let scores = [1.0, 0.0, 0.0, 1.0];
171 let v = [10.0, 20.0, 30.0, 40.0];
172 let mut output = [0.0f32; 4];
173 matmul_sv(&scores, &v, 2, 2, 2, &mut output);
174 assert!((output[0] - 10.0).abs() < 1e-6);
175 assert!((output[1] - 20.0).abs() < 1e-6);
176 assert!((output[2] - 30.0).abs() < 1e-6);
177 assert!((output[3] - 40.0).abs() < 1e-6);
178 }
179
180 #[test]
181 fn weighted_accumulate_basic() {
182 let mut out = [1.0, 2.0];
183 weighted_accumulate(&mut out, 0.5, &[4.0, 6.0]);
184 assert!((out[0] - 3.0).abs() < 1e-6);
185 assert!((out[1] - 5.0).abs() < 1e-6);
186 }
187}