provable_contracts/kernels/
alibi.rs1#[inline]
19pub fn alibi_slope(h: usize, num_heads: usize) -> f32 {
20 let exponent = -8.0 * ((h + 1) as f32) / (num_heads as f32);
21 2.0f32.powf(exponent)
22}
23
24pub fn alibi_bias_scalar(scores: &mut [f32], num_heads: usize, seq_len: usize) {
32 assert_eq!(
33 scores.len(),
34 num_heads * seq_len * seq_len,
35 "scores dimension mismatch"
36 );
37
38 let head_stride = seq_len * seq_len;
39
40 for h in 0..num_heads {
41 let slope = alibi_slope(h, num_heads);
42 let base = h * head_stride;
43
44 for i in 0..seq_len {
45 for j in 0..seq_len {
46 let dist = i.abs_diff(j);
47 scores[base + i * seq_len + j] -= slope * (dist as f32);
48 }
49 }
50 }
51}
52
53#[cfg(target_arch = "x86_64")]
62#[target_feature(enable = "avx2")]
63pub unsafe fn alibi_bias_avx2(scores: &mut [f32], num_heads: usize, seq_len: usize) {
64 alibi_bias_scalar(scores, num_heads, seq_len);
65}
66
67pub fn alibi_ptx() -> &'static str {
75 r#".version 8.5
76.target sm_90
77.address_size 64
78.visible .entry alibi_kernel(
79 .param .u64 SCORES,
80 .param .u32 NUM_HEADS,
81 .param .u32 SEQ_LEN
82) {
83 .reg .u32 %tid, %bid, %num_heads, %seq_len;
84 .reg .u32 %head, %i, %j, %dist, %head_stride, %offset;
85 .reg .u64 %scores_ptr, %addr, %off64;
86 .reg .f32 %slope, %dist_f, %bias, %score, %exp, %neg8, %h_f, %nh_f;
87 .reg .pred %p_bound, %p_ge;
88
89 mov.u32 %tid, %tid.x;
90 mov.u32 %bid, %ctaid.x;
91
92 ld.param.u32 %num_heads, [NUM_HEADS];
93 ld.param.u32 %seq_len, [SEQ_LEN];
94 ld.param.u64 %scores_ptr, [SCORES];
95
96 // bid = head, tid = flattened (i * seq_len + j)
97 mov.u32 %head, %bid;
98 mul.lo.u32 %head_stride, %seq_len, %seq_len;
99
100 // i = tid / seq_len, j = tid % seq_len
101 div.u32 %i, %tid, %seq_len;
102 rem.u32 %j, %tid, %seq_len;
103
104 setp.ge.u32 %p_bound, %tid, %head_stride;
105 @%p_bound bra EXIT;
106
107 // slope = 2^(-8 * (head+1) / num_heads)
108 add.u32 %offset, %head, 1;
109 cvt.rn.f32.u32 %h_f, %offset;
110 cvt.rn.f32.u32 %nh_f, %num_heads;
111 mov.f32 %neg8, 0fC1000000;
112 mul.f32 %exp, %neg8, %h_f;
113 div.rn.f32 %exp, %exp, %nh_f;
114 ex2.approx.f32 %slope, %exp;
115
116 // dist = |i - j|
117 setp.ge.u32 %p_ge, %i, %j;
118 @%p_ge bra CALC_DIST_FORWARD;
119 sub.u32 %dist, %j, %i;
120 bra APPLY_BIAS;
121CALC_DIST_FORWARD:
122 sub.u32 %dist, %i, %j;
123
124APPLY_BIAS:
125 cvt.rn.f32.u32 %dist_f, %dist;
126 mul.f32 %bias, %slope, %dist_f;
127
128 // scores[head * head_stride + tid] -= bias
129 mad.lo.u32 %offset, %head, %head_stride, %tid;
130 mul.wide.u32 %off64, %offset, 4;
131 add.u64 %addr, %scores_ptr, %off64;
132 ld.global.f32 %score, [%addr];
133 sub.f32 %score, %score, %bias;
134 st.global.f32 [%addr], %score;
135
136EXIT:
137 ret;
138}
139"#
140}
141
142#[cfg(test)]
147mod tests {
148 use super::*;
149 use proptest::prelude::*;
150
151 #[test]
152 fn test_alibi_slopes() {
153 let slopes: Vec<f32> = (0..8).map(|h| alibi_slope(h, 8)).collect();
155 assert!((slopes[0] - 0.5).abs() < 1e-6);
156 assert!((slopes[1] - 0.25).abs() < 1e-6);
157 assert!((slopes[7] - 1.0 / 256.0).abs() < 1e-6);
158 for i in 1..8 {
160 assert!(slopes[i] < slopes[i - 1], "slopes not decreasing at {i}");
161 }
162 }
163
164 #[test]
165 fn test_alibi_diagonal_zero() {
166 let seq_len = 4;
168 let num_heads = 2;
169 let mut scores = vec![1.0f32; num_heads * seq_len * seq_len];
170 alibi_bias_scalar(&mut scores, num_heads, seq_len);
171
172 for h in 0..num_heads {
173 for i in 0..seq_len {
174 let idx = h * seq_len * seq_len + i * seq_len + i;
175 assert_eq!(
176 scores[idx], 1.0,
177 "diagonal should be unchanged at h={h} i={i}"
178 );
179 }
180 }
181 }
182
183 #[test]
184 fn test_alibi_negative_bias() {
185 let seq_len = 3;
187 let num_heads = 1;
188 let mut scores = vec![0.0f32; seq_len * seq_len];
189 alibi_bias_scalar(&mut scores, num_heads, seq_len);
190
191 for i in 0..seq_len {
193 for j in 0..seq_len {
194 if i != j {
195 assert!(
196 scores[i * seq_len + j] < 0.0,
197 "off-diagonal [{i},{j}] should be negative, got {}",
198 scores[i * seq_len + j]
199 );
200 }
201 }
202 }
203 }
204
205 #[test]
206 fn test_alibi_symmetry() {
207 let seq_len = 5;
209 let num_heads = 2;
210 let mut scores = vec![0.0f32; num_heads * seq_len * seq_len];
211 alibi_bias_scalar(&mut scores, num_heads, seq_len);
212
213 for h in 0..num_heads {
214 let base = h * seq_len * seq_len;
215 for i in 0..seq_len {
216 for j in 0..seq_len {
217 let a = scores[base + i * seq_len + j];
218 let b = scores[base + j * seq_len + i];
219 assert!(
220 (a - b).abs() < 1e-6,
221 "asymmetry at h={h} [{i},{j}]: {a} vs {b}"
222 );
223 }
224 }
225 }
226 }
227
228 proptest! {
229 #[test]
230 fn prop_alibi_slopes_positive(num_heads in 1usize..17) {
231 for h in 0..num_heads {
232 let s = alibi_slope(h, num_heads);
233 prop_assert!(s > 0.0, "slope must be positive, got {s} at h={h}");
234 prop_assert!(s <= 1.0, "slope must be <= 1, got {s} at h={h}");
235 }
236 }
237
238 #[test]
239 fn prop_alibi_output_finite(
240 num_heads in 1usize..5,
241 seq_len in 1usize..8,
242 ) {
243 let mut scores = vec![0.0f32; num_heads * seq_len * seq_len];
244 alibi_bias_scalar(&mut scores, num_heads, seq_len);
245
246 for (idx, &val) in scores.iter().enumerate() {
247 prop_assert!(val.is_finite(), "scores[{idx}] = {val} not finite");
248 }
249 }
250 }
251
252 #[test]
253 fn test_alibi_ptx_structure() {
254 let ptx = alibi_ptx();
255 assert!(ptx.contains(".entry alibi_kernel"));
256 assert!(ptx.contains("ex2.approx.f32"));
257 assert!(ptx.contains("ret;"));
258 }
259
260 #[cfg(target_arch = "x86_64")]
261 #[test]
262 fn test_alibi_avx2_parity() {
263 if !is_x86_feature_detected!("avx2") {
264 return;
265 }
266 let num_heads = 2;
267 let seq_len = 4;
268 let mut scalar_scores = vec![1.0f32; num_heads * seq_len * seq_len];
269 let mut avx2_scores = scalar_scores.clone();
270 alibi_bias_scalar(&mut scalar_scores, num_heads, seq_len);
271 unsafe { alibi_bias_avx2(&mut avx2_scores, num_heads, seq_len) };
272 assert_eq!(scalar_scores, avx2_scores);
273 }
274}