Skip to main content

provable_contracts/kernels/
absolute_position.rs

1//! Absolute position embeddings kernel.
2//!
3//! Matches `absolute-position-v1.yaml`.
4//! `output[t] = token_embed[t] + pos_embed[t]` — learned additive positional encoding.
5//!
6//! Each function provides one of three backends:
7//! - `fn abs_position_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn abs_position_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn abs_position_ptx() -> &'static str` -- PTX assembly source string
10
11// ────────────────────────────────────────────────────────────────────────────
12// Scalar implementation
13// ────────────────────────────────────────────────────────────────────────────
14
15/// Add learned position embeddings to token embeddings (scalar reference).
16///
17/// `token_embed` is `seq_len x dim` (row-major), `pos_embed` is `max_pos x dim`,
18/// `output` is `seq_len x dim`. Each position `t` gets `output[t] = token_embed[t] + pos_embed[t]`.
19///
20/// # Panics
21/// Panics if `seq_len > max_pos` or dimensions don't match.
22pub fn abs_position_scalar(
23    token_embed: &[f32],
24    pos_embed: &[f32],
25    seq_len: usize,
26    max_pos: usize,
27    dim: usize,
28    output: &mut [f32],
29) {
30    assert_eq!(
31        token_embed.len(),
32        seq_len * dim,
33        "token_embed dimension mismatch"
34    );
35    assert_eq!(
36        pos_embed.len(),
37        max_pos * dim,
38        "pos_embed dimension mismatch"
39    );
40    assert_eq!(output.len(), seq_len * dim, "output dimension mismatch");
41    assert!(
42        seq_len <= max_pos,
43        "seq_len {seq_len} exceeds max_pos {max_pos}"
44    );
45
46    for t in 0..seq_len {
47        for d in 0..dim {
48            let idx = t * dim + d;
49            output[idx] = token_embed[idx] + pos_embed[idx];
50        }
51    }
52}
53
54// ────────────────────────────────────────────────────────────────────────────
55// AVX2 implementation
56// ────────────────────────────────────────────────────────────────────────────
57
58/// AVX2 absolute position embeddings -- delegates to scalar.
59///
60/// # Safety
61/// Requires AVX2 support.
62#[cfg(target_arch = "x86_64")]
63#[target_feature(enable = "avx2")]
64pub unsafe fn abs_position_avx2(
65    token_embed: &[f32],
66    pos_embed: &[f32],
67    seq_len: usize,
68    max_pos: usize,
69    dim: usize,
70    output: &mut [f32],
71) {
72    abs_position_scalar(token_embed, pos_embed, seq_len, max_pos, dim, output);
73}
74
75// ────────────────────────────────────────────────────────────────────────────
76// PTX implementation
77// ────────────────────────────────────────────────────────────────────────────
78
79/// PTX assembly for absolute position embeddings.
80///
81/// One thread per (position, dimension) pair. Each thread adds
82/// one element from the position embedding table to the token embedding.
83pub fn abs_position_ptx() -> &'static str {
84    r#".version 8.5
85.target sm_90
86.address_size 64
87.visible .entry abs_position_kernel(
88    .param .u64 TOKEN_EMBED,
89    .param .u64 POS_EMBED,
90    .param .u64 OUT,
91    .param .u32 SEQ_LEN,
92    .param .u32 DIM
93) {
94    .reg .u32 %tid, %bid, %dim, %seq_len, %offset;
95    .reg .u64 %te_ptr, %pe_ptr, %out_ptr, %addr, %off64;
96    .reg .f32 %te_val, %pe_val, %result;
97    .reg .pred %p_bound;
98
99    mov.u32 %tid, %tid.x;
100    mov.u32 %bid, %ctaid.x;
101
102    ld.param.u32 %dim, [DIM];
103    ld.param.u32 %seq_len, [SEQ_LEN];
104    ld.param.u64 %te_ptr, [TOKEN_EMBED];
105    ld.param.u64 %pe_ptr, [POS_EMBED];
106    ld.param.u64 %out_ptr, [OUT];
107
108    // bid = position index, tid = dimension index
109    setp.ge.u32 %p_bound, %tid, %dim;
110    @%p_bound bra EXIT;
111
112    // offset = bid * dim + tid
113    mad.lo.u32 %offset, %bid, %dim, %tid;
114    mul.wide.u32 %off64, %offset, 4;
115
116    // Load token_embed[offset]
117    add.u64 %addr, %te_ptr, %off64;
118    ld.global.f32 %te_val, [%addr];
119
120    // Load pos_embed[offset] (same offset since pos[t] uses same t*dim+d)
121    add.u64 %addr, %pe_ptr, %off64;
122    ld.global.f32 %pe_val, [%addr];
123
124    // output = token_embed + pos_embed
125    add.f32 %result, %te_val, %pe_val;
126
127    // Store output[offset]
128    add.u64 %addr, %out_ptr, %off64;
129    st.global.f32 [%addr], %result;
130
131EXIT:
132    ret;
133}
134"#
135}
136
137// ────────────────────────────────────────────────────────────────────────────
138// Tests
139// ────────────────────────────────────────────────────────────────────────────
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use proptest::prelude::*;
145
146    #[test]
147    fn test_abs_position_basic() {
148        let token = [1.0, 2.0, 3.0, 4.0]; // 2x2
149        let pos = [0.1, 0.2, 0.3, 0.4]; // 2x2 (max_pos=2)
150        let mut output = [0.0f32; 4];
151
152        abs_position_scalar(&token, &pos, 2, 2, 2, &mut output);
153        assert!((output[0] - 1.1).abs() < 1e-6);
154        assert!((output[1] - 2.2).abs() < 1e-6);
155        assert!((output[2] - 3.3).abs() < 1e-6);
156        assert!((output[3] - 4.4).abs() < 1e-6);
157    }
158
159    #[test]
160    fn test_abs_position_zero_pos_is_identity() {
161        let token = [1.0, 2.0, 3.0];
162        let pos = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; // max_pos=2, dim=3
163        let mut output = [0.0f32; 3];
164
165        abs_position_scalar(&token, &pos, 1, 2, 3, &mut output);
166        assert_eq!(&output, &[1.0, 2.0, 3.0]);
167    }
168
169    #[test]
170    fn test_abs_position_shape_preservation() {
171        let seq_len = 3;
172        let dim = 4;
173        let max_pos = 5;
174        let token = vec![1.0f32; seq_len * dim];
175        let pos = vec![0.5f32; max_pos * dim];
176        let mut output = vec![0.0f32; seq_len * dim];
177
178        abs_position_scalar(&token, &pos, seq_len, max_pos, dim, &mut output);
179        assert_eq!(output.len(), seq_len * dim);
180    }
181
182    #[test]
183    #[should_panic(expected = "seq_len 5 exceeds max_pos 3")]
184    fn test_abs_position_oob() {
185        let token = vec![0.0f32; 10];
186        let pos = vec![0.0f32; 6]; // max_pos=3, dim=2
187        let mut output = vec![0.0f32; 10];
188        abs_position_scalar(&token, &pos, 5, 3, 2, &mut output);
189    }
190
191    proptest! {
192        #[test]
193        fn prop_abs_position_finite(
194            seq_len in 1usize..5,
195            dim in 1usize..5,
196        ) {
197            let max_pos = seq_len + 2;
198            let token: Vec<f32> = (0..seq_len * dim).map(|i| (i as f32) * 0.1).collect();
199            let pos: Vec<f32> = (0..max_pos * dim).map(|i| (i as f32) * 0.01).collect();
200            let mut output = vec![0.0f32; seq_len * dim];
201
202            abs_position_scalar(&token, &pos, seq_len, max_pos, dim, &mut output);
203
204            for (idx, &val) in output.iter().enumerate() {
205                prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
206            }
207        }
208    }
209
210    #[test]
211    fn test_abs_position_ptx_structure() {
212        let ptx = abs_position_ptx();
213        assert!(ptx.contains(".entry abs_position_kernel"));
214        assert!(ptx.contains("add.f32"));
215        assert!(ptx.contains("ret;"));
216    }
217
218    #[cfg(target_arch = "x86_64")]
219    #[test]
220    fn test_abs_position_avx2_parity() {
221        if !is_x86_feature_detected!("avx2") {
222            return;
223        }
224        let token = [1.0, 2.0, 3.0, 4.0];
225        let pos = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; // max_pos=4
226        let mut scalar_out = [0.0f32; 4];
227        let mut avx2_out = [0.0f32; 4];
228        abs_position_scalar(&token, &pos, 2, 4, 2, &mut scalar_out);
229        unsafe { abs_position_avx2(&token, &pos, 2, 4, 2, &mut avx2_out) };
230        assert_eq!(scalar_out, avx2_out);
231    }
232}