Skip to main content

rlx_ir/ops/
axial_rope2d.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! SAM2-style axial 2-D RoPE on `[batch, seq, num_heads * head_dim]`.
5
6use crate::{Graph, NodeId, Op};
7
8/// Apply axial 2-D RoPE on flattened `[nh, n_tokens, head_dim]` layout.
9pub fn apply_axial_rope2d(
10    x: &[f32],
11    num_heads: usize,
12    n_tokens: usize,
13    head_dim: usize,
14    end_x: usize,
15    end_y: usize,
16    theta: f32,
17    repeat_factor: usize,
18) -> Vec<f32> {
19    debug_assert!(head_dim.is_multiple_of(4));
20    let half = head_dim / 2;
21    let q4 = head_dim / 4;
22    let spatial = end_x * end_y;
23    let repeat = repeat_factor.max(1);
24    debug_assert_eq!(n_tokens, spatial * repeat);
25
26    let mut freqs = vec![0f32; q4];
27    for i in 0..q4 {
28        freqs[i] = 1.0 / theta.powf((4 * i) as f32 / head_dim as f32);
29    }
30    let mut cs_x = vec![0f32; spatial * q4];
31    let mut sn_x = vec![0f32; spatial * q4];
32    let mut cs_y = vec![0f32; spatial * q4];
33    let mut sn_y = vec![0f32; spatial * q4];
34    for pos in 0..spatial {
35        let tx = (pos % end_x) as f32;
36        let ty = (pos / end_x) as f32;
37        for c in 0..q4 {
38            let ax = tx * freqs[c];
39            let ay = ty * freqs[c];
40            cs_x[pos * q4 + c] = ax.cos();
41            sn_x[pos * q4 + c] = ax.sin();
42            cs_y[pos * q4 + c] = ay.cos();
43            sn_y[pos * q4 + c] = ay.sin();
44        }
45    }
46    let mut cos_x = vec![0f32; n_tokens * q4];
47    let mut sin_x = vec![0f32; n_tokens * q4];
48    let mut cos_y = vec![0f32; n_tokens * q4];
49    let mut sin_y = vec![0f32; n_tokens * q4];
50    for tok in 0..n_tokens {
51        let pos = tok / repeat;
52        for c in 0..q4 {
53            cos_x[tok * q4 + c] = cs_x[pos * q4 + c];
54            sin_x[tok * q4 + c] = sn_x[pos * q4 + c];
55            cos_y[tok * q4 + c] = cs_y[pos * q4 + c];
56            sin_y[tok * q4 + c] = sn_y[pos * q4 + c];
57        }
58    }
59
60    // `[batch, seq, num_heads * head_dim]` layout (token-major, heads interleaved).
61    let hs = num_heads * head_dim;
62    let mut out = x.to_vec();
63    for tok in 0..n_tokens {
64        for h in 0..num_heads {
65            let base = tok * hs + h * head_dim;
66            for c in 0..q4 {
67                let ix0 = base + 2 * c;
68                let ix1 = base + 2 * c + 1;
69                let x0 = out[ix0];
70                let x1 = out[ix1];
71                let co = cos_x[tok * q4 + c];
72                let si = sin_x[tok * q4 + c];
73                out[ix0] = x0 * co - x1 * si;
74                out[ix1] = x0 * si + x1 * co;
75            }
76            for c in 0..q4 {
77                let ix0 = base + half + 2 * c;
78                let ix1 = base + half + 2 * c + 1;
79                let x0 = out[ix0];
80                let x1 = out[ix1];
81                let co = cos_y[tok * q4 + c];
82                let si = sin_y[tok * q4 + c];
83                out[ix0] = x0 * co - x1 * si;
84                out[ix1] = x0 * si + x1 * co;
85            }
86        }
87    }
88    out
89}
90
91impl Graph {
92    /// `x`: `[1, seq, num_heads * head_dim]` → same shape.
93    pub fn axial_rope2d(
94        &mut self,
95        x: NodeId,
96        end_x: usize,
97        end_y: usize,
98        head_dim: usize,
99        num_heads: usize,
100        theta: f32,
101        repeat_factor: usize,
102    ) -> NodeId {
103        let s = crate::shape::unary_shape(self.shape(x));
104        self.push(
105            Op::AxialRope2d {
106                end_x,
107                end_y,
108                head_dim,
109                num_heads,
110                theta,
111                repeat_factor,
112            },
113            vec![x],
114            s,
115            None,
116        )
117    }
118}