provable_contracts/kernels/
absolute_position.rs1pub fn abs_position_scalar(
23 token_embed: &[f32],
24 pos_embed: &[f32],
25 seq_len: usize,
26 max_pos: usize,
27 dim: usize,
28 output: &mut [f32],
29) {
30 assert_eq!(
31 token_embed.len(),
32 seq_len * dim,
33 "token_embed dimension mismatch"
34 );
35 assert_eq!(
36 pos_embed.len(),
37 max_pos * dim,
38 "pos_embed dimension mismatch"
39 );
40 assert_eq!(output.len(), seq_len * dim, "output dimension mismatch");
41 assert!(
42 seq_len <= max_pos,
43 "seq_len {seq_len} exceeds max_pos {max_pos}"
44 );
45
46 for t in 0..seq_len {
47 for d in 0..dim {
48 let idx = t * dim + d;
49 output[idx] = token_embed[idx] + pos_embed[idx];
50 }
51 }
52}
53
54#[cfg(target_arch = "x86_64")]
63#[target_feature(enable = "avx2")]
64pub unsafe fn abs_position_avx2(
65 token_embed: &[f32],
66 pos_embed: &[f32],
67 seq_len: usize,
68 max_pos: usize,
69 dim: usize,
70 output: &mut [f32],
71) {
72 abs_position_scalar(token_embed, pos_embed, seq_len, max_pos, dim, output);
73}
74
75pub fn abs_position_ptx() -> &'static str {
84 r#".version 8.5
85.target sm_90
86.address_size 64
87.visible .entry abs_position_kernel(
88 .param .u64 TOKEN_EMBED,
89 .param .u64 POS_EMBED,
90 .param .u64 OUT,
91 .param .u32 SEQ_LEN,
92 .param .u32 DIM
93) {
94 .reg .u32 %tid, %bid, %dim, %seq_len, %offset;
95 .reg .u64 %te_ptr, %pe_ptr, %out_ptr, %addr, %off64;
96 .reg .f32 %te_val, %pe_val, %result;
97 .reg .pred %p_bound;
98
99 mov.u32 %tid, %tid.x;
100 mov.u32 %bid, %ctaid.x;
101
102 ld.param.u32 %dim, [DIM];
103 ld.param.u32 %seq_len, [SEQ_LEN];
104 ld.param.u64 %te_ptr, [TOKEN_EMBED];
105 ld.param.u64 %pe_ptr, [POS_EMBED];
106 ld.param.u64 %out_ptr, [OUT];
107
108 // bid = position index, tid = dimension index
109 setp.ge.u32 %p_bound, %tid, %dim;
110 @%p_bound bra EXIT;
111
112 // offset = bid * dim + tid
113 mad.lo.u32 %offset, %bid, %dim, %tid;
114 mul.wide.u32 %off64, %offset, 4;
115
116 // Load token_embed[offset]
117 add.u64 %addr, %te_ptr, %off64;
118 ld.global.f32 %te_val, [%addr];
119
120 // Load pos_embed[offset] (same offset since pos[t] uses same t*dim+d)
121 add.u64 %addr, %pe_ptr, %off64;
122 ld.global.f32 %pe_val, [%addr];
123
124 // output = token_embed + pos_embed
125 add.f32 %result, %te_val, %pe_val;
126
127 // Store output[offset]
128 add.u64 %addr, %out_ptr, %off64;
129 st.global.f32 [%addr], %result;
130
131EXIT:
132 ret;
133}
134"#
135}
136
137#[cfg(test)]
142mod tests {
143 use super::*;
144 use proptest::prelude::*;
145
146 #[test]
147 fn test_abs_position_basic() {
148 let token = [1.0, 2.0, 3.0, 4.0]; let pos = [0.1, 0.2, 0.3, 0.4]; let mut output = [0.0f32; 4];
151
152 abs_position_scalar(&token, &pos, 2, 2, 2, &mut output);
153 assert!((output[0] - 1.1).abs() < 1e-6);
154 assert!((output[1] - 2.2).abs() < 1e-6);
155 assert!((output[2] - 3.3).abs() < 1e-6);
156 assert!((output[3] - 4.4).abs() < 1e-6);
157 }
158
159 #[test]
160 fn test_abs_position_zero_pos_is_identity() {
161 let token = [1.0, 2.0, 3.0];
162 let pos = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; let mut output = [0.0f32; 3];
164
165 abs_position_scalar(&token, &pos, 1, 2, 3, &mut output);
166 assert_eq!(&output, &[1.0, 2.0, 3.0]);
167 }
168
169 #[test]
170 fn test_abs_position_shape_preservation() {
171 let seq_len = 3;
172 let dim = 4;
173 let max_pos = 5;
174 let token = vec![1.0f32; seq_len * dim];
175 let pos = vec![0.5f32; max_pos * dim];
176 let mut output = vec![0.0f32; seq_len * dim];
177
178 abs_position_scalar(&token, &pos, seq_len, max_pos, dim, &mut output);
179 assert_eq!(output.len(), seq_len * dim);
180 }
181
182 #[test]
183 #[should_panic(expected = "seq_len 5 exceeds max_pos 3")]
184 fn test_abs_position_oob() {
185 let token = vec![0.0f32; 10];
186 let pos = vec![0.0f32; 6]; let mut output = vec![0.0f32; 10];
188 abs_position_scalar(&token, &pos, 5, 3, 2, &mut output);
189 }
190
191 proptest! {
192 #[test]
193 fn prop_abs_position_finite(
194 seq_len in 1usize..5,
195 dim in 1usize..5,
196 ) {
197 let max_pos = seq_len + 2;
198 let token: Vec<f32> = (0..seq_len * dim).map(|i| (i as f32) * 0.1).collect();
199 let pos: Vec<f32> = (0..max_pos * dim).map(|i| (i as f32) * 0.01).collect();
200 let mut output = vec![0.0f32; seq_len * dim];
201
202 abs_position_scalar(&token, &pos, seq_len, max_pos, dim, &mut output);
203
204 for (idx, &val) in output.iter().enumerate() {
205 prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
206 }
207 }
208 }
209
210 #[test]
211 fn test_abs_position_ptx_structure() {
212 let ptx = abs_position_ptx();
213 assert!(ptx.contains(".entry abs_position_kernel"));
214 assert!(ptx.contains("add.f32"));
215 assert!(ptx.contains("ret;"));
216 }
217
218 #[cfg(target_arch = "x86_64")]
219 #[test]
220 fn test_abs_position_avx2_parity() {
221 if !is_x86_feature_detected!("avx2") {
222 return;
223 }
224 let token = [1.0, 2.0, 3.0, 4.0];
225 let pos = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; let mut scalar_out = [0.0f32; 4];
227 let mut avx2_out = [0.0f32; 4];
228 abs_position_scalar(&token, &pos, 2, 4, 2, &mut scalar_out);
229 unsafe { abs_position_avx2(&token, &pos, 2, 4, 2, &mut avx2_out) };
230 assert_eq!(scalar_out, avx2_out);
231 }
232}