Skip to main content

provable_contracts/kernels/
rope.rs

1//! Rotary Position Embedding (RoPE) kernel.
2//!
3//! Matches `rope-kernel-v1.yaml`.
4//! Rotate pairs of dimensions by position-dependent angles.
5//! theta_k = base^(-2k/d), apply 2D rotation matrix per pair.
6//!
7//! RoPE(x, m)_{2k}   = x_{2k} * cos(m * theta_k) - x_{2k+1} * sin(m * theta_k)
8//! RoPE(x, m)_{2k+1} = x_{2k} * sin(m * theta_k) + x_{2k+1} * cos(m * theta_k)
9//!
10//! Each function provides one of three backends:
11//! - `fn rope_scalar(...)` — Pure Rust scalar reference (ground truth)
12//! - `unsafe fn rope_avx2(...)` — AVX2 SIMD implementation
13//! - `fn rope_ptx() -> &'static str` — PTX assembly source string
14
15// ────────────────────────────────────────────────────────────────────────────
16// Scalar implementation
17// ────────────────────────────────────────────────────────────────────────────
18
19/// Apply Rotary Position Embedding to input vector `x`.
20///
21/// For each pair of dimensions (2k, 2k+1):
22///   theta_k = base^(-2k / dim) * position
23///   output[2k]   = x[2k] * cos(theta_k) - x[2k+1] * sin(theta_k)
24///   output[2k+1] = x[2k] * sin(theta_k) + x[2k+1] * cos(theta_k)
25///
26/// # Panics
27/// Panics if:
28/// - `x.len() != dim`
29/// - `x.len() != output.len()`
30/// - `dim` is odd (must be even for pair-wise rotation)
31/// - `dim` is zero
32pub fn rope_scalar(x: &[f32], position: u32, dim: usize, base: f32, output: &mut [f32]) {
33    assert_eq!(x.len(), dim, "x length must equal dim");
34    assert_eq!(x.len(), output.len(), "x/output length mismatch");
35    assert!(dim > 0, "dim must be positive");
36    assert_eq!(dim % 2, 0, "dim must be even for pair-wise rotation");
37
38    let half_dim = dim / 2;
39    for k in 0..half_dim {
40        let freq = base.powf(-2.0 * k as f32 / dim as f32);
41        let theta = freq * position as f32;
42        let cos_t = theta.cos();
43        let sin_t = theta.sin();
44        let x0 = x[2 * k];
45        let x1 = x[2 * k + 1];
46        output[2 * k] = x0 * cos_t - x1 * sin_t;
47        output[2 * k + 1] = x0 * sin_t + x1 * cos_t;
48    }
49}
50
51// ────────────────────────────────────────────────────────────────────────────
52// AVX2 implementation
53// ────────────────────────────────────────────────────────────────────────────
54
55/// AVX2 RoPE — delegates to scalar (no hardware `sin`/`cos` in AVX2).
56///
57/// # Safety
58/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
59///
60/// # Panics
61/// Panics if `x.len() != dim`, `x.len() != output.len()`, `dim` is odd, or `dim` is zero.
62#[cfg(target_arch = "x86_64")]
63#[target_feature(enable = "avx2")]
64pub unsafe fn rope_avx2(x: &[f32], position: u32, dim: usize, base: f32, output: &mut [f32]) {
65    rope_scalar(x, position, dim, base, output);
66}
67
68// ────────────────────────────────────────────────────────────────────────────
69// PTX implementation
70// ────────────────────────────────────────────────────────────────────────────
71
72/// PTX assembly for the RoPE kernel (1 thread per dimension pair).
73///
74/// Each thread handles one pair (2k, 2k+1):
75/// - Computes angle = position * base^(-2k/dim) using `lg2.approx.f32` and `ex2.approx.f32`
76/// - Applies rotation via `sin.approx.f32` and `cos.approx.f32`
77pub fn rope_ptx() -> &'static str {
78    r#".version 8.5
79.target sm_90
80.address_size 64
81.visible .entry rope_kernel(
82    .param .u64 input,
83    .param .u64 output,
84    .param .u32 position,
85    .param .u32 dim,
86    .param .f32 base
87) {
88    .reg .u32 %tid, %ntid, %ctaid, %idx, %half_dim, %dim, %pos;
89    .reg .u32 %idx2, %idx2p1;
90    .reg .u64 %in_ptr, %out_ptr, %off0, %off1;
91    .reg .f32 %x0, %x1, %y0, %y1;
92    .reg .f32 %k_f, %dim_f, %neg_exp, %freq, %pos_f, %theta;
93    .reg .f32 %cos_t, %sin_t;
94    .reg .f32 %base_val, %log_base, %k_two, %k_ln2, %k_rcp_ln2;
95    .reg .pred %p;
96
97    mov.u32 %tid, %tid.x;
98    mov.u32 %ntid, %ntid.x;
99    mov.u32 %ctaid, %ctaid.x;
100    mad.lo.u32 %idx, %ctaid, %ntid, %tid;
101
102    ld.param.u32 %dim, [dim];
103    shr.u32 %half_dim, %dim, 1;
104    setp.ge.u32 %p, %idx, %half_dim;
105    @%p bra DONE;
106
107    ld.param.u64 %in_ptr, [input];
108    ld.param.u64 %out_ptr, [output];
109    ld.param.u32 %pos, [position];
110    ld.param.f32 %base_val, [base];
111
112    // Constants
113    mov.f32 %k_two, 0f40000000;       // 2.0
114    mov.f32 %k_ln2, 0f3F317218;       // ln(2) ~ 0.693147
115    mov.f32 %k_rcp_ln2, 0f3FB8AA3B;   // 1/ln(2) ~ 1.442695
116
117    // Compute freq = base^(-2k/dim) using exp2(log2(base) * (-2k/dim))
118    cvt.rn.f32.u32 %k_f, %idx;
119    cvt.rn.f32.u32 %dim_f, %dim;
120    mul.f32 %neg_exp, %k_two, %k_f;
121    neg.f32 %neg_exp, %neg_exp;
122    div.approx.f32 %neg_exp, %neg_exp, %dim_f;
123    lg2.approx.f32 %log_base, %base_val;
124    mul.f32 %neg_exp, %log_base, %neg_exp;
125    ex2.approx.f32 %freq, %neg_exp;
126
127    // theta = freq * position
128    cvt.rn.f32.u32 %pos_f, %pos;
129    mul.f32 %theta, %freq, %pos_f;
130
131    // Compute cos and sin
132    cos.approx.f32 %cos_t, %theta;
133    sin.approx.f32 %sin_t, %theta;
134
135    // Load x[2k] and x[2k+1]
136    shl.b32 %idx2, %idx, 1;
137    add.u32 %idx2p1, %idx2, 1;
138    mul.wide.u32 %off0, %idx2, 4;
139    mul.wide.u32 %off1, %idx2p1, 4;
140    add.u64 %off0, %in_ptr, %off0;
141    add.u64 %off1, %in_ptr, %off1;
142    ld.global.f32 %x0, [%off0];
143    ld.global.f32 %x1, [%off1];
144
145    // Apply rotation:
146    //   y0 = x0 * cos - x1 * sin
147    //   y1 = x0 * sin + x1 * cos
148    mul.f32 %y0, %x0, %cos_t;
149    fma.rn.f32 %y0, %x1, %sin_t, %y0;
150    neg.f32 %y0, %y0;
151    fma.rn.f32 %y0, %x0, %cos_t, 0f00000000;
152    mul.f32 %y0, %x1, %sin_t;
153    neg.f32 %y0, %y0;
154    fma.rn.f32 %y0, %x0, %cos_t, %y0;
155
156    mul.f32 %y1, %x0, %sin_t;
157    fma.rn.f32 %y1, %x1, %cos_t, %y1;
158
159    // Store output[2k] and output[2k+1]
160    mul.wide.u32 %off0, %idx2, 4;
161    mul.wide.u32 %off1, %idx2p1, 4;
162    add.u64 %off0, %out_ptr, %off0;
163    add.u64 %off1, %out_ptr, %off1;
164    st.global.f32 [%off0], %y0;
165    st.global.f32 [%off1], %y1;
166
167DONE:
168    ret;
169}
170"#
171}
172
173// ────────────────────────────────────────────────────────────────────────────
174// Tests
175// ────────────────────────────────────────────────────────────────────────────
176
177#[cfg(test)]
178mod tests {
179    use super::super::ulp::assert_ulp_eq;
180    use super::*;
181    use proptest::prelude::*;
182
183    // ── Known-answer tests ────────────────────────────────────────────────
184
185    #[test]
186    fn test_rope_position_zero_identity() {
187        // At position 0, all angles are 0: cos(0)=1, sin(0)=0 -> identity
188        let x = [1.0f32, 2.0, 3.0, 4.0];
189        let mut output = vec![0.0f32; 4];
190        rope_scalar(&x, 0, 4, 10000.0, &mut output);
191        for i in 0..4 {
192            assert!(
193                (output[i] - x[i]).abs() < 1e-6,
194                "RoPE at position 0 should be identity: x[{i}]={}, output[{i}]={}",
195                x[i],
196                output[i]
197            );
198        }
199    }
200
201    #[test]
202    fn test_rope_preserves_norm() {
203        // Rotation preserves vector norm: |output| = |input|
204        let x = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
205        let mut output = vec![0.0f32; 6];
206        rope_scalar(&x, 42, 6, 10000.0, &mut output);
207
208        let input_norm: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
209        let output_norm: f32 = output.iter().map(|&v| v * v).sum::<f32>().sqrt();
210
211        assert!(
212            (input_norm - output_norm).abs() < 1e-4,
213            "RoPE should preserve norm: input={input_norm}, output={output_norm}"
214        );
215    }
216
217    #[test]
218    fn test_rope_pair_norm_preserved() {
219        // Each pair (x[2k], x[2k+1]) should have its norm preserved independently
220        let x = [3.0f32, 4.0, 1.0, 0.0];
221        let mut output = vec![0.0f32; 4];
222        rope_scalar(&x, 10, 4, 10000.0, &mut output);
223
224        let pair0_in = (x[0] * x[0] + x[1] * x[1]).sqrt();
225        let pair0_out = (output[0] * output[0] + output[1] * output[1]).sqrt();
226        assert!(
227            (pair0_in - pair0_out).abs() < 1e-5,
228            "Pair 0 norm not preserved: in={pair0_in}, out={pair0_out}"
229        );
230
231        let pair1_in = (x[2] * x[2] + x[3] * x[3]).sqrt();
232        let pair1_out = (output[2] * output[2] + output[3] * output[3]).sqrt();
233        assert!(
234            (pair1_in - pair1_out).abs() < 1e-5,
235            "Pair 1 norm not preserved: in={pair1_in}, out={pair1_out}"
236        );
237    }
238
239    #[test]
240    fn test_rope_known_rotation() {
241        // For dim=2 with base=1.0, theta = 1.0^0 * pos = pos
242        // At position 1: theta = 1.0, cos(1) ~ 0.5403, sin(1) ~ 0.8415
243        let x = [1.0f32, 0.0];
244        let mut output = vec![0.0f32; 2];
245        rope_scalar(&x, 1, 2, 1.0, &mut output);
246        let cos1 = 1.0f32.cos();
247        let sin1 = 1.0f32.sin();
248        assert!(
249            (output[0] - cos1).abs() < 1e-6,
250            "RoPE(1,0) at pos=1: expected ({cos1}, {sin1}), got ({}, {})",
251            output[0],
252            output[1]
253        );
254        assert!(
255            (output[1] - sin1).abs() < 1e-6,
256            "RoPE(1,0) at pos=1: expected ({cos1}, {sin1}), got ({}, {})",
257            output[0],
258            output[1]
259        );
260    }
261
262    #[test]
263    fn test_rope_default_base() {
264        // Standard transformer base = 10000
265        let x = [1.0f32, 0.0, 0.0, 1.0];
266        let mut output = vec![0.0f32; 4];
267        rope_scalar(&x, 100, 4, 10000.0, &mut output);
268
269        // theta_0 = 10000^(0/4) * 100 = 1 * 100 = 100
270        // theta_1 = 10000^(-2/4) * 100 = 10000^(-0.5) * 100 = 0.01 * 100 = 1.0
271        let theta0 = 100.0f32;
272        let theta1 = 10000.0f32.powf(-0.5) * 100.0;
273
274        let expected_0 = theta0.cos();
275        let expected_1 = theta0.sin();
276        assert!(
277            (output[0] - expected_0).abs() < 1e-4,
278            "pair 0: expected cos({theta0})={expected_0}, got {}",
279            output[0]
280        );
281        assert!(
282            (output[1] - expected_1).abs() < 1e-4,
283            "pair 0: expected sin({theta0})={expected_1}, got {}",
284            output[1]
285        );
286
287        let expected_2 = -(theta1.sin());
288        let expected_3 = theta1.cos();
289        assert!(
290            (output[2] - expected_2).abs() < 1e-4,
291            "pair 1: expected -sin({theta1})={expected_2}, got {}",
292            output[2]
293        );
294        assert!(
295            (output[3] - expected_3).abs() < 1e-4,
296            "pair 1: expected cos({theta1})={expected_3}, got {}",
297            output[3]
298        );
299    }
300
301    #[test]
302    #[should_panic(expected = "dim must be even")]
303    fn test_rope_odd_dim_panics() {
304        let x = [1.0f32, 2.0, 3.0];
305        let mut output = vec![0.0f32; 3];
306        rope_scalar(&x, 1, 3, 10000.0, &mut output);
307    }
308
309    #[test]
310    #[should_panic(expected = "x length must equal dim")]
311    fn test_rope_length_mismatch() {
312        let x = [1.0f32, 2.0];
313        let mut output = vec![0.0f32; 2];
314        rope_scalar(&x, 1, 4, 10000.0, &mut output);
315    }
316
317    #[test]
318    #[should_panic(expected = "x/output length mismatch")]
319    fn test_rope_output_length_mismatch() {
320        let x = [1.0f32, 2.0, 3.0, 4.0];
321        let mut output = vec![0.0f32; 6];
322        rope_scalar(&x, 1, 4, 10000.0, &mut output);
323    }
324
325    #[test]
326    #[should_panic(expected = "dim must be positive")]
327    fn test_rope_zero_dim_panics() {
328        let x: [f32; 0] = [];
329        let mut output: [f32; 0] = [];
330        rope_scalar(&x, 1, 0, 10000.0, &mut output);
331    }
332
333    // ── Property-based tests ──────────────────────────────────────────────
334
335    proptest! {
336        #[test]
337        fn prop_rope_preserves_norm(
338            x in proptest::collection::vec(-10.0f32..10.0, 1..16usize)
339                .prop_filter("even length", |v| v.len() % 2 == 0 && !v.is_empty()),
340            position in 0u32..1000,
341        ) {
342            let dim = x.len();
343            let mut output = vec![0.0f32; dim];
344            rope_scalar(&x, position, dim, 10000.0, &mut output);
345
346            let input_norm: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
347            let output_norm: f32 = output.iter().map(|&v| v * v).sum::<f32>().sqrt();
348
349            prop_assert!(
350                (input_norm - output_norm).abs() < 1e-3,
351                "Norm not preserved: input={input_norm}, output={output_norm}"
352            );
353        }
354
355        #[test]
356        fn prop_rope_position_zero_identity(
357            x in proptest::collection::vec(-10.0f32..10.0, 1..16usize)
358                .prop_filter("even length", |v| v.len() % 2 == 0 && !v.is_empty()),
359        ) {
360            let dim = x.len();
361            let mut output = vec![0.0f32; dim];
362            rope_scalar(&x, 0, dim, 10000.0, &mut output);
363
364            for (i, (&xi, &yi)) in x.iter().zip(output.iter()).enumerate() {
365                prop_assert!(
366                    (xi - yi).abs() < 1e-6,
367                    "RoPE at position 0 should be identity: index {i}, x={xi}, output={yi}"
368                );
369            }
370        }
371
372        #[test]
373        fn prop_rope_output_finite(
374            x in proptest::collection::vec(-100.0f32..100.0, 1..16usize)
375                .prop_filter("even length", |v| v.len() % 2 == 0 && !v.is_empty()),
376            position in 0u32..10000,
377        ) {
378            let dim = x.len();
379            let mut output = vec![0.0f32; dim];
380            rope_scalar(&x, position, dim, 10000.0, &mut output);
381
382            for (i, &y) in output.iter().enumerate() {
383                prop_assert!(
384                    y.is_finite(),
385                    "RoPE output must be finite at index {i}, got {y}"
386                );
387            }
388        }
389    }
390
391    // ── AVX2 parity tests ─────────────────────────────────────────────────
392
393    #[cfg(target_arch = "x86_64")]
394    #[test]
395    fn test_rope_avx2_parity() {
396        if !is_x86_feature_detected!("avx2") {
397            return;
398        }
399        let x: Vec<f32> = (0..16).map(|i| i as f32 * 0.5).collect();
400        let mut scalar_out = vec![0.0f32; x.len()];
401        let mut avx2_out = vec![0.0f32; x.len()];
402
403        rope_scalar(&x, 42, 16, 10000.0, &mut scalar_out);
404        unsafe { rope_avx2(&x, 42, 16, 10000.0, &mut avx2_out) };
405
406        // Delegates to scalar, so 0 ULP expected
407        assert_ulp_eq(&scalar_out, &avx2_out, 0);
408    }
409
410    #[cfg(target_arch = "x86_64")]
411    #[test]
412    fn test_rope_avx2_small_dim() {
413        if !is_x86_feature_detected!("avx2") {
414            return;
415        }
416        let x = [1.0f32, 2.0];
417        let mut scalar_out = vec![0.0f32; 2];
418        let mut avx2_out = vec![0.0f32; 2];
419
420        rope_scalar(&x, 100, 2, 10000.0, &mut scalar_out);
421        unsafe { rope_avx2(&x, 100, 2, 10000.0, &mut avx2_out) };
422
423        assert_ulp_eq(&scalar_out, &avx2_out, 0);
424    }
425
426    #[cfg(target_arch = "x86_64")]
427    #[test]
428    fn test_rope_avx2_position_zero() {
429        if !is_x86_feature_detected!("avx2") {
430            return;
431        }
432        let x: Vec<f32> = (0..8).map(|i| i as f32).collect();
433        let mut scalar_out = vec![0.0f32; 8];
434        let mut avx2_out = vec![0.0f32; 8];
435
436        rope_scalar(&x, 0, 8, 10000.0, &mut scalar_out);
437        unsafe { rope_avx2(&x, 0, 8, 10000.0, &mut avx2_out) };
438
439        assert_ulp_eq(&scalar_out, &avx2_out, 0);
440    }
441
442    // ── PTX structural tests ──────────────────────────────────────────────
443
444    #[test]
445    fn test_rope_ptx_structure() {
446        let ptx = rope_ptx();
447        assert!(ptx.contains(".version 8.5"), "missing PTX version");
448        assert!(ptx.contains(".target sm_90"), "missing PTX target");
449        assert!(ptx.contains(".entry rope_kernel"), "missing entry point");
450        assert!(ptx.contains("ret;"), "missing ret instruction");
451        assert!(
452            ptx.contains("sin.approx.f32"),
453            "missing sin.approx for trig"
454        );
455        assert!(
456            ptx.contains("cos.approx.f32"),
457            "missing cos.approx for trig"
458        );
459        assert!(
460            ptx.contains("ex2.approx.f32"),
461            "missing ex2.approx for powf"
462        );
463        assert!(
464            ptx.contains("lg2.approx.f32"),
465            "missing lg2.approx for powf"
466        );
467        let open = ptx.matches('{').count();
468        let close = ptx.matches('}').count();
469        assert_eq!(
470            open, close,
471            "unbalanced braces: {open} open vs {close} close"
472        );
473    }
474
475    #[test]
476    fn test_rope_ptx_nonempty() {
477        assert!(!rope_ptx().is_empty());
478    }
479
480    #[test]
481    fn test_rope_ptx_has_params() {
482        let ptx = rope_ptx();
483        assert!(ptx.contains(".param .u64 input"), "missing input param");
484        assert!(ptx.contains(".param .u64 output"), "missing output param");
485        assert!(
486            ptx.contains(".param .u32 position"),
487            "missing position param"
488        );
489        assert!(ptx.contains(".param .u32 dim"), "missing dim param");
490        assert!(ptx.contains(".param .f32 base"), "missing base param");
491    }
492}