1pub fn rope_scalar(x: &[f32], position: u32, dim: usize, base: f32, output: &mut [f32]) {
33 assert_eq!(x.len(), dim, "x length must equal dim");
34 assert_eq!(x.len(), output.len(), "x/output length mismatch");
35 assert!(dim > 0, "dim must be positive");
36 assert_eq!(dim % 2, 0, "dim must be even for pair-wise rotation");
37
38 let half_dim = dim / 2;
39 for k in 0..half_dim {
40 let freq = base.powf(-2.0 * k as f32 / dim as f32);
41 let theta = freq * position as f32;
42 let cos_t = theta.cos();
43 let sin_t = theta.sin();
44 let x0 = x[2 * k];
45 let x1 = x[2 * k + 1];
46 output[2 * k] = x0 * cos_t - x1 * sin_t;
47 output[2 * k + 1] = x0 * sin_t + x1 * cos_t;
48 }
49}
50
51#[cfg(target_arch = "x86_64")]
63#[target_feature(enable = "avx2")]
64pub unsafe fn rope_avx2(x: &[f32], position: u32, dim: usize, base: f32, output: &mut [f32]) {
65 rope_scalar(x, position, dim, base, output);
66}
67
68pub fn rope_ptx() -> &'static str {
78 r#".version 8.5
79.target sm_90
80.address_size 64
81.visible .entry rope_kernel(
82 .param .u64 input,
83 .param .u64 output,
84 .param .u32 position,
85 .param .u32 dim,
86 .param .f32 base
87) {
88 .reg .u32 %tid, %ntid, %ctaid, %idx, %half_dim, %dim, %pos;
89 .reg .u32 %idx2, %idx2p1;
90 .reg .u64 %in_ptr, %out_ptr, %off0, %off1;
91 .reg .f32 %x0, %x1, %y0, %y1;
92 .reg .f32 %k_f, %dim_f, %neg_exp, %freq, %pos_f, %theta;
93 .reg .f32 %cos_t, %sin_t;
94 .reg .f32 %base_val, %log_base, %k_two, %k_ln2, %k_rcp_ln2;
95 .reg .pred %p;
96
97 mov.u32 %tid, %tid.x;
98 mov.u32 %ntid, %ntid.x;
99 mov.u32 %ctaid, %ctaid.x;
100 mad.lo.u32 %idx, %ctaid, %ntid, %tid;
101
102 ld.param.u32 %dim, [dim];
103 shr.u32 %half_dim, %dim, 1;
104 setp.ge.u32 %p, %idx, %half_dim;
105 @%p bra DONE;
106
107 ld.param.u64 %in_ptr, [input];
108 ld.param.u64 %out_ptr, [output];
109 ld.param.u32 %pos, [position];
110 ld.param.f32 %base_val, [base];
111
112 // Constants
113 mov.f32 %k_two, 0f40000000; // 2.0
114 mov.f32 %k_ln2, 0f3F317218; // ln(2) ~ 0.693147
115 mov.f32 %k_rcp_ln2, 0f3FB8AA3B; // 1/ln(2) ~ 1.442695
116
117 // Compute freq = base^(-2k/dim) using exp2(log2(base) * (-2k/dim))
118 cvt.rn.f32.u32 %k_f, %idx;
119 cvt.rn.f32.u32 %dim_f, %dim;
120 mul.f32 %neg_exp, %k_two, %k_f;
121 neg.f32 %neg_exp, %neg_exp;
122 div.approx.f32 %neg_exp, %neg_exp, %dim_f;
123 lg2.approx.f32 %log_base, %base_val;
124 mul.f32 %neg_exp, %log_base, %neg_exp;
125 ex2.approx.f32 %freq, %neg_exp;
126
127 // theta = freq * position
128 cvt.rn.f32.u32 %pos_f, %pos;
129 mul.f32 %theta, %freq, %pos_f;
130
131 // Compute cos and sin
132 cos.approx.f32 %cos_t, %theta;
133 sin.approx.f32 %sin_t, %theta;
134
135 // Load x[2k] and x[2k+1]
136 shl.b32 %idx2, %idx, 1;
137 add.u32 %idx2p1, %idx2, 1;
138 mul.wide.u32 %off0, %idx2, 4;
139 mul.wide.u32 %off1, %idx2p1, 4;
140 add.u64 %off0, %in_ptr, %off0;
141 add.u64 %off1, %in_ptr, %off1;
142 ld.global.f32 %x0, [%off0];
143 ld.global.f32 %x1, [%off1];
144
145 // Apply rotation:
146 // y0 = x0 * cos - x1 * sin
147 // y1 = x0 * sin + x1 * cos
148 mul.f32 %y0, %x0, %cos_t;
149 fma.rn.f32 %y0, %x1, %sin_t, %y0;
150 neg.f32 %y0, %y0;
151 fma.rn.f32 %y0, %x0, %cos_t, 0f00000000;
152 mul.f32 %y0, %x1, %sin_t;
153 neg.f32 %y0, %y0;
154 fma.rn.f32 %y0, %x0, %cos_t, %y0;
155
156 mul.f32 %y1, %x0, %sin_t;
157 fma.rn.f32 %y1, %x1, %cos_t, %y1;
158
159 // Store output[2k] and output[2k+1]
160 mul.wide.u32 %off0, %idx2, 4;
161 mul.wide.u32 %off1, %idx2p1, 4;
162 add.u64 %off0, %out_ptr, %off0;
163 add.u64 %off1, %out_ptr, %off1;
164 st.global.f32 [%off0], %y0;
165 st.global.f32 [%off1], %y1;
166
167DONE:
168 ret;
169}
170"#
171}
172
173#[cfg(test)]
178mod tests {
179 use super::super::ulp::assert_ulp_eq;
180 use super::*;
181 use proptest::prelude::*;
182
183 #[test]
186 fn test_rope_position_zero_identity() {
187 let x = [1.0f32, 2.0, 3.0, 4.0];
189 let mut output = vec![0.0f32; 4];
190 rope_scalar(&x, 0, 4, 10000.0, &mut output);
191 for i in 0..4 {
192 assert!(
193 (output[i] - x[i]).abs() < 1e-6,
194 "RoPE at position 0 should be identity: x[{i}]={}, output[{i}]={}",
195 x[i],
196 output[i]
197 );
198 }
199 }
200
201 #[test]
202 fn test_rope_preserves_norm() {
203 let x = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
205 let mut output = vec![0.0f32; 6];
206 rope_scalar(&x, 42, 6, 10000.0, &mut output);
207
208 let input_norm: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
209 let output_norm: f32 = output.iter().map(|&v| v * v).sum::<f32>().sqrt();
210
211 assert!(
212 (input_norm - output_norm).abs() < 1e-4,
213 "RoPE should preserve norm: input={input_norm}, output={output_norm}"
214 );
215 }
216
217 #[test]
218 fn test_rope_pair_norm_preserved() {
219 let x = [3.0f32, 4.0, 1.0, 0.0];
221 let mut output = vec![0.0f32; 4];
222 rope_scalar(&x, 10, 4, 10000.0, &mut output);
223
224 let pair0_in = (x[0] * x[0] + x[1] * x[1]).sqrt();
225 let pair0_out = (output[0] * output[0] + output[1] * output[1]).sqrt();
226 assert!(
227 (pair0_in - pair0_out).abs() < 1e-5,
228 "Pair 0 norm not preserved: in={pair0_in}, out={pair0_out}"
229 );
230
231 let pair1_in = (x[2] * x[2] + x[3] * x[3]).sqrt();
232 let pair1_out = (output[2] * output[2] + output[3] * output[3]).sqrt();
233 assert!(
234 (pair1_in - pair1_out).abs() < 1e-5,
235 "Pair 1 norm not preserved: in={pair1_in}, out={pair1_out}"
236 );
237 }
238
239 #[test]
240 fn test_rope_known_rotation() {
241 let x = [1.0f32, 0.0];
244 let mut output = vec![0.0f32; 2];
245 rope_scalar(&x, 1, 2, 1.0, &mut output);
246 let cos1 = 1.0f32.cos();
247 let sin1 = 1.0f32.sin();
248 assert!(
249 (output[0] - cos1).abs() < 1e-6,
250 "RoPE(1,0) at pos=1: expected ({cos1}, {sin1}), got ({}, {})",
251 output[0],
252 output[1]
253 );
254 assert!(
255 (output[1] - sin1).abs() < 1e-6,
256 "RoPE(1,0) at pos=1: expected ({cos1}, {sin1}), got ({}, {})",
257 output[0],
258 output[1]
259 );
260 }
261
262 #[test]
263 fn test_rope_default_base() {
264 let x = [1.0f32, 0.0, 0.0, 1.0];
266 let mut output = vec![0.0f32; 4];
267 rope_scalar(&x, 100, 4, 10000.0, &mut output);
268
269 let theta0 = 100.0f32;
272 let theta1 = 10000.0f32.powf(-0.5) * 100.0;
273
274 let expected_0 = theta0.cos();
275 let expected_1 = theta0.sin();
276 assert!(
277 (output[0] - expected_0).abs() < 1e-4,
278 "pair 0: expected cos({theta0})={expected_0}, got {}",
279 output[0]
280 );
281 assert!(
282 (output[1] - expected_1).abs() < 1e-4,
283 "pair 0: expected sin({theta0})={expected_1}, got {}",
284 output[1]
285 );
286
287 let expected_2 = -(theta1.sin());
288 let expected_3 = theta1.cos();
289 assert!(
290 (output[2] - expected_2).abs() < 1e-4,
291 "pair 1: expected -sin({theta1})={expected_2}, got {}",
292 output[2]
293 );
294 assert!(
295 (output[3] - expected_3).abs() < 1e-4,
296 "pair 1: expected cos({theta1})={expected_3}, got {}",
297 output[3]
298 );
299 }
300
301 #[test]
302 #[should_panic(expected = "dim must be even")]
303 fn test_rope_odd_dim_panics() {
304 let x = [1.0f32, 2.0, 3.0];
305 let mut output = vec![0.0f32; 3];
306 rope_scalar(&x, 1, 3, 10000.0, &mut output);
307 }
308
309 #[test]
310 #[should_panic(expected = "x length must equal dim")]
311 fn test_rope_length_mismatch() {
312 let x = [1.0f32, 2.0];
313 let mut output = vec![0.0f32; 2];
314 rope_scalar(&x, 1, 4, 10000.0, &mut output);
315 }
316
317 #[test]
318 #[should_panic(expected = "x/output length mismatch")]
319 fn test_rope_output_length_mismatch() {
320 let x = [1.0f32, 2.0, 3.0, 4.0];
321 let mut output = vec![0.0f32; 6];
322 rope_scalar(&x, 1, 4, 10000.0, &mut output);
323 }
324
325 #[test]
326 #[should_panic(expected = "dim must be positive")]
327 fn test_rope_zero_dim_panics() {
328 let x: [f32; 0] = [];
329 let mut output: [f32; 0] = [];
330 rope_scalar(&x, 1, 0, 10000.0, &mut output);
331 }
332
333 proptest! {
336 #[test]
337 fn prop_rope_preserves_norm(
338 x in proptest::collection::vec(-10.0f32..10.0, 1..16usize)
339 .prop_filter("even length", |v| v.len() % 2 == 0 && !v.is_empty()),
340 position in 0u32..1000,
341 ) {
342 let dim = x.len();
343 let mut output = vec![0.0f32; dim];
344 rope_scalar(&x, position, dim, 10000.0, &mut output);
345
346 let input_norm: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
347 let output_norm: f32 = output.iter().map(|&v| v * v).sum::<f32>().sqrt();
348
349 prop_assert!(
350 (input_norm - output_norm).abs() < 1e-3,
351 "Norm not preserved: input={input_norm}, output={output_norm}"
352 );
353 }
354
355 #[test]
356 fn prop_rope_position_zero_identity(
357 x in proptest::collection::vec(-10.0f32..10.0, 1..16usize)
358 .prop_filter("even length", |v| v.len() % 2 == 0 && !v.is_empty()),
359 ) {
360 let dim = x.len();
361 let mut output = vec![0.0f32; dim];
362 rope_scalar(&x, 0, dim, 10000.0, &mut output);
363
364 for (i, (&xi, &yi)) in x.iter().zip(output.iter()).enumerate() {
365 prop_assert!(
366 (xi - yi).abs() < 1e-6,
367 "RoPE at position 0 should be identity: index {i}, x={xi}, output={yi}"
368 );
369 }
370 }
371
372 #[test]
373 fn prop_rope_output_finite(
374 x in proptest::collection::vec(-100.0f32..100.0, 1..16usize)
375 .prop_filter("even length", |v| v.len() % 2 == 0 && !v.is_empty()),
376 position in 0u32..10000,
377 ) {
378 let dim = x.len();
379 let mut output = vec![0.0f32; dim];
380 rope_scalar(&x, position, dim, 10000.0, &mut output);
381
382 for (i, &y) in output.iter().enumerate() {
383 prop_assert!(
384 y.is_finite(),
385 "RoPE output must be finite at index {i}, got {y}"
386 );
387 }
388 }
389 }
390
391 #[cfg(target_arch = "x86_64")]
394 #[test]
395 fn test_rope_avx2_parity() {
396 if !is_x86_feature_detected!("avx2") {
397 return;
398 }
399 let x: Vec<f32> = (0..16).map(|i| i as f32 * 0.5).collect();
400 let mut scalar_out = vec![0.0f32; x.len()];
401 let mut avx2_out = vec![0.0f32; x.len()];
402
403 rope_scalar(&x, 42, 16, 10000.0, &mut scalar_out);
404 unsafe { rope_avx2(&x, 42, 16, 10000.0, &mut avx2_out) };
405
406 assert_ulp_eq(&scalar_out, &avx2_out, 0);
408 }
409
410 #[cfg(target_arch = "x86_64")]
411 #[test]
412 fn test_rope_avx2_small_dim() {
413 if !is_x86_feature_detected!("avx2") {
414 return;
415 }
416 let x = [1.0f32, 2.0];
417 let mut scalar_out = vec![0.0f32; 2];
418 let mut avx2_out = vec![0.0f32; 2];
419
420 rope_scalar(&x, 100, 2, 10000.0, &mut scalar_out);
421 unsafe { rope_avx2(&x, 100, 2, 10000.0, &mut avx2_out) };
422
423 assert_ulp_eq(&scalar_out, &avx2_out, 0);
424 }
425
426 #[cfg(target_arch = "x86_64")]
427 #[test]
428 fn test_rope_avx2_position_zero() {
429 if !is_x86_feature_detected!("avx2") {
430 return;
431 }
432 let x: Vec<f32> = (0..8).map(|i| i as f32).collect();
433 let mut scalar_out = vec![0.0f32; 8];
434 let mut avx2_out = vec![0.0f32; 8];
435
436 rope_scalar(&x, 0, 8, 10000.0, &mut scalar_out);
437 unsafe { rope_avx2(&x, 0, 8, 10000.0, &mut avx2_out) };
438
439 assert_ulp_eq(&scalar_out, &avx2_out, 0);
440 }
441
442 #[test]
445 fn test_rope_ptx_structure() {
446 let ptx = rope_ptx();
447 assert!(ptx.contains(".version 8.5"), "missing PTX version");
448 assert!(ptx.contains(".target sm_90"), "missing PTX target");
449 assert!(ptx.contains(".entry rope_kernel"), "missing entry point");
450 assert!(ptx.contains("ret;"), "missing ret instruction");
451 assert!(
452 ptx.contains("sin.approx.f32"),
453 "missing sin.approx for trig"
454 );
455 assert!(
456 ptx.contains("cos.approx.f32"),
457 "missing cos.approx for trig"
458 );
459 assert!(
460 ptx.contains("ex2.approx.f32"),
461 "missing ex2.approx for powf"
462 );
463 assert!(
464 ptx.contains("lg2.approx.f32"),
465 "missing lg2.approx for powf"
466 );
467 let open = ptx.matches('{').count();
468 let close = ptx.matches('}').count();
469 assert_eq!(
470 open, close,
471 "unbalanced braces: {open} open vs {close} close"
472 );
473 }
474
475 #[test]
476 fn test_rope_ptx_nonempty() {
477 assert!(!rope_ptx().is_empty());
478 }
479
480 #[test]
481 fn test_rope_ptx_has_params() {
482 let ptx = rope_ptx();
483 assert!(ptx.contains(".param .u64 input"), "missing input param");
484 assert!(ptx.contains(".param .u64 output"), "missing output param");
485 assert!(
486 ptx.contains(".param .u32 position"),
487 "missing position param"
488 );
489 assert!(ptx.contains(".param .u32 dim"), "missing dim param");
490 assert!(ptx.contains(".param .f32 base"), "missing base param");
491 }
492}