Skip to main content

rlx_ir/ops/
axial_rope2d.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM2-style axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
17
18use crate::{Graph, NodeId, Op};
19
20/// Apply axial 2-D RoPE on flattened `[nh, n_tokens, head_dim]` layout.
21pub fn apply_axial_rope2d(
22    x: &[f32],
23    num_heads: usize,
24    n_tokens: usize,
25    head_dim: usize,
26    end_x: usize,
27    end_y: usize,
28    theta: f32,
29    repeat_factor: usize,
30) -> Vec<f32> {
31    debug_assert!(head_dim.is_multiple_of(4));
32    let half = head_dim / 2;
33    let q4 = head_dim / 4;
34    let spatial = end_x * end_y;
35    let repeat = repeat_factor.max(1);
36    debug_assert_eq!(n_tokens, spatial * repeat);
37
38    let mut freqs = vec![0f32; q4];
39    for i in 0..q4 {
40        freqs[i] = 1.0 / theta.powf((4 * i) as f32 / head_dim as f32);
41    }
42    let mut cs_x = vec![0f32; spatial * q4];
43    let mut sn_x = vec![0f32; spatial * q4];
44    let mut cs_y = vec![0f32; spatial * q4];
45    let mut sn_y = vec![0f32; spatial * q4];
46    for pos in 0..spatial {
47        let tx = (pos % end_x) as f32;
48        let ty = (pos / end_x) as f32;
49        for c in 0..q4 {
50            let ax = tx * freqs[c];
51            let ay = ty * freqs[c];
52            cs_x[pos * q4 + c] = ax.cos();
53            sn_x[pos * q4 + c] = ax.sin();
54            cs_y[pos * q4 + c] = ay.cos();
55            sn_y[pos * q4 + c] = ay.sin();
56        }
57    }
58    let mut cos_x = vec![0f32; n_tokens * q4];
59    let mut sin_x = vec![0f32; n_tokens * q4];
60    let mut cos_y = vec![0f32; n_tokens * q4];
61    let mut sin_y = vec![0f32; n_tokens * q4];
62    for tok in 0..n_tokens {
63        let pos = tok / repeat;
64        for c in 0..q4 {
65            cos_x[tok * q4 + c] = cs_x[pos * q4 + c];
66            sin_x[tok * q4 + c] = sn_x[pos * q4 + c];
67            cos_y[tok * q4 + c] = cs_y[pos * q4 + c];
68            sin_y[tok * q4 + c] = sn_y[pos * q4 + c];
69        }
70    }
71
72    // `[batch, seq, num_heads * head_dim]` layout (token-major, heads interleaved).
73    let hs = num_heads * head_dim;
74    let mut out = x.to_vec();
75    for tok in 0..n_tokens {
76        for h in 0..num_heads {
77            let base = tok * hs + h * head_dim;
78            for c in 0..q4 {
79                let ix0 = base + 2 * c;
80                let ix1 = base + 2 * c + 1;
81                let x0 = out[ix0];
82                let x1 = out[ix1];
83                let co = cos_x[tok * q4 + c];
84                let si = sin_x[tok * q4 + c];
85                out[ix0] = x0 * co - x1 * si;
86                out[ix1] = x0 * si + x1 * co;
87            }
88            for c in 0..q4 {
89                let ix0 = base + half + 2 * c;
90                let ix1 = base + half + 2 * c + 1;
91                let x0 = out[ix0];
92                let x1 = out[ix1];
93                let co = cos_y[tok * q4 + c];
94                let si = sin_y[tok * q4 + c];
95                out[ix0] = x0 * co - x1 * si;
96                out[ix1] = x0 * si + x1 * co;
97            }
98        }
99    }
100    out
101}
102
103impl Graph {
104    /// `x`: `[1, seq, num_heads * head_dim]` → same shape.
105    pub fn axial_rope2d(
106        &mut self,
107        x: NodeId,
108        end_x: usize,
109        end_y: usize,
110        head_dim: usize,
111        num_heads: usize,
112        theta: f32,
113        repeat_factor: usize,
114    ) -> NodeId {
115        let s = crate::shape::unary_shape(self.shape(x));
116        self.push(
117            Op::AxialRope2d {
118                end_x,
119                end_y,
120                head_dim,
121                num_heads,
122                theta,
123                repeat_factor,
124            },
125            vec![x],
126            s,
127            None,
128        )
129    }
130}