rlx_ir/ops/
axial_rope2d.rs1use crate::{Graph, NodeId, Op};
7
8pub fn apply_axial_rope2d(
10 x: &[f32],
11 num_heads: usize,
12 n_tokens: usize,
13 head_dim: usize,
14 end_x: usize,
15 end_y: usize,
16 theta: f32,
17 repeat_factor: usize,
18) -> Vec<f32> {
19 debug_assert!(head_dim.is_multiple_of(4));
20 let half = head_dim / 2;
21 let q4 = head_dim / 4;
22 let spatial = end_x * end_y;
23 let repeat = repeat_factor.max(1);
24 debug_assert_eq!(n_tokens, spatial * repeat);
25
26 let mut freqs = vec![0f32; q4];
27 for i in 0..q4 {
28 freqs[i] = 1.0 / theta.powf((4 * i) as f32 / head_dim as f32);
29 }
30 let mut cs_x = vec![0f32; spatial * q4];
31 let mut sn_x = vec![0f32; spatial * q4];
32 let mut cs_y = vec![0f32; spatial * q4];
33 let mut sn_y = vec![0f32; spatial * q4];
34 for pos in 0..spatial {
35 let tx = (pos % end_x) as f32;
36 let ty = (pos / end_x) as f32;
37 for c in 0..q4 {
38 let ax = tx * freqs[c];
39 let ay = ty * freqs[c];
40 cs_x[pos * q4 + c] = ax.cos();
41 sn_x[pos * q4 + c] = ax.sin();
42 cs_y[pos * q4 + c] = ay.cos();
43 sn_y[pos * q4 + c] = ay.sin();
44 }
45 }
46 let mut cos_x = vec![0f32; n_tokens * q4];
47 let mut sin_x = vec![0f32; n_tokens * q4];
48 let mut cos_y = vec![0f32; n_tokens * q4];
49 let mut sin_y = vec![0f32; n_tokens * q4];
50 for tok in 0..n_tokens {
51 let pos = tok / repeat;
52 for c in 0..q4 {
53 cos_x[tok * q4 + c] = cs_x[pos * q4 + c];
54 sin_x[tok * q4 + c] = sn_x[pos * q4 + c];
55 cos_y[tok * q4 + c] = cs_y[pos * q4 + c];
56 sin_y[tok * q4 + c] = sn_y[pos * q4 + c];
57 }
58 }
59
60 let hs = num_heads * head_dim;
62 let mut out = x.to_vec();
63 for tok in 0..n_tokens {
64 for h in 0..num_heads {
65 let base = tok * hs + h * head_dim;
66 for c in 0..q4 {
67 let ix0 = base + 2 * c;
68 let ix1 = base + 2 * c + 1;
69 let x0 = out[ix0];
70 let x1 = out[ix1];
71 let co = cos_x[tok * q4 + c];
72 let si = sin_x[tok * q4 + c];
73 out[ix0] = x0 * co - x1 * si;
74 out[ix1] = x0 * si + x1 * co;
75 }
76 for c in 0..q4 {
77 let ix0 = base + half + 2 * c;
78 let ix1 = base + half + 2 * c + 1;
79 let x0 = out[ix0];
80 let x1 = out[ix1];
81 let co = cos_y[tok * q4 + c];
82 let si = sin_y[tok * q4 + c];
83 out[ix0] = x0 * co - x1 * si;
84 out[ix1] = x0 * si + x1 * co;
85 }
86 }
87 }
88 out
89}
90
91impl Graph {
92 pub fn axial_rope2d(
94 &mut self,
95 x: NodeId,
96 end_x: usize,
97 end_y: usize,
98 head_dim: usize,
99 num_heads: usize,
100 theta: f32,
101 repeat_factor: usize,
102 ) -> NodeId {
103 let s = crate::shape::unary_shape(self.shape(x));
104 self.push(
105 Op::AxialRope2d {
106 end_x,
107 end_y,
108 head_dim,
109 num_heads,
110 theta,
111 repeat_factor,
112 },
113 vec![x],
114 s,
115 None,
116 )
117 }
118}