rlx_ir/ops/
axial_rope2d.rs1use crate::{Graph, NodeId, Op};
19
20pub fn apply_axial_rope2d(
22 x: &[f32],
23 num_heads: usize,
24 n_tokens: usize,
25 head_dim: usize,
26 end_x: usize,
27 end_y: usize,
28 theta: f32,
29 repeat_factor: usize,
30) -> Vec<f32> {
31 debug_assert!(head_dim.is_multiple_of(4));
32 let half = head_dim / 2;
33 let q4 = head_dim / 4;
34 let spatial = end_x * end_y;
35 let repeat = repeat_factor.max(1);
36 debug_assert_eq!(n_tokens, spatial * repeat);
37
38 let mut freqs = vec![0f32; q4];
39 for i in 0..q4 {
40 freqs[i] = 1.0 / theta.powf((4 * i) as f32 / head_dim as f32);
41 }
42 let mut cs_x = vec![0f32; spatial * q4];
43 let mut sn_x = vec![0f32; spatial * q4];
44 let mut cs_y = vec![0f32; spatial * q4];
45 let mut sn_y = vec![0f32; spatial * q4];
46 for pos in 0..spatial {
47 let tx = (pos % end_x) as f32;
48 let ty = (pos / end_x) as f32;
49 for c in 0..q4 {
50 let ax = tx * freqs[c];
51 let ay = ty * freqs[c];
52 cs_x[pos * q4 + c] = ax.cos();
53 sn_x[pos * q4 + c] = ax.sin();
54 cs_y[pos * q4 + c] = ay.cos();
55 sn_y[pos * q4 + c] = ay.sin();
56 }
57 }
58 let mut cos_x = vec![0f32; n_tokens * q4];
59 let mut sin_x = vec![0f32; n_tokens * q4];
60 let mut cos_y = vec![0f32; n_tokens * q4];
61 let mut sin_y = vec![0f32; n_tokens * q4];
62 for tok in 0..n_tokens {
63 let pos = tok / repeat;
64 for c in 0..q4 {
65 cos_x[tok * q4 + c] = cs_x[pos * q4 + c];
66 sin_x[tok * q4 + c] = sn_x[pos * q4 + c];
67 cos_y[tok * q4 + c] = cs_y[pos * q4 + c];
68 sin_y[tok * q4 + c] = sn_y[pos * q4 + c];
69 }
70 }
71
72 let hs = num_heads * head_dim;
74 let mut out = x.to_vec();
75 for tok in 0..n_tokens {
76 for h in 0..num_heads {
77 let base = tok * hs + h * head_dim;
78 for c in 0..q4 {
79 let ix0 = base + 2 * c;
80 let ix1 = base + 2 * c + 1;
81 let x0 = out[ix0];
82 let x1 = out[ix1];
83 let co = cos_x[tok * q4 + c];
84 let si = sin_x[tok * q4 + c];
85 out[ix0] = x0 * co - x1 * si;
86 out[ix1] = x0 * si + x1 * co;
87 }
88 for c in 0..q4 {
89 let ix0 = base + half + 2 * c;
90 let ix1 = base + half + 2 * c + 1;
91 let x0 = out[ix0];
92 let x1 = out[ix1];
93 let co = cos_y[tok * q4 + c];
94 let si = sin_y[tok * q4 + c];
95 out[ix0] = x0 * co - x1 * si;
96 out[ix1] = x0 * si + x1 * co;
97 }
98 }
99 }
100 out
101}
102
103impl Graph {
104 pub fn axial_rope2d(
106 &mut self,
107 x: NodeId,
108 end_x: usize,
109 end_y: usize,
110 head_dim: usize,
111 num_heads: usize,
112 theta: f32,
113 repeat_factor: usize,
114 ) -> NodeId {
115 let s = crate::shape::unary_shape(self.shape(x));
116 self.push(
117 Op::AxialRope2d {
118 end_x,
119 end_y,
120 head_dim,
121 num_heads,
122 theta,
123 repeat_factor,
124 },
125 vec![x],
126 s,
127 None,
128 )
129 }
130}