Skip to main content

rlx_ir/ops/
attention.rs

1// RLX โ€” versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Attention builders: SDPA with custom or kernel-synthesized
17//! masks (plan #53).
18
19use crate::op::MaskKind;
20use crate::{Graph, NodeId, Op, Shape};
21
22/// Build an [`Op::Attention`] with optional score scale and logit softcap.
23pub fn attention_kind_op(
24    num_heads: usize,
25    head_dim: usize,
26    mask_kind: MaskKind,
27    score_scale: Option<f32>,
28    attn_logit_softcap: Option<f32>,
29) -> Op {
30    Op::Attention {
31        num_heads,
32        head_dim,
33        mask_kind,
34        score_scale,
35        attn_logit_softcap,
36    }
37}
38
39impl Graph {
40    /// Scaled dot-product attention with a custom (caller-supplied) mask.
41    /// Equivalent to `attention_kind(.., MaskKind::Custom, ..)`.
42    pub fn attention(
43        &mut self,
44        q: NodeId,
45        k: NodeId,
46        v: NodeId,
47        mask: NodeId,
48        num_heads: usize,
49        head_dim: usize,
50        shape: Shape,
51    ) -> NodeId {
52        self.attention_opts(q, k, v, mask, num_heads, head_dim, shape, None, None)
53    }
54
55    /// Like [`Self::attention`] with optional score scale and logit softcap.
56    pub fn attention_opts(
57        &mut self,
58        q: NodeId,
59        k: NodeId,
60        v: NodeId,
61        mask: NodeId,
62        num_heads: usize,
63        head_dim: usize,
64        shape: Shape,
65        score_scale: Option<f32>,
66        attn_logit_softcap: Option<f32>,
67    ) -> NodeId {
68        self.push(
69            attention_kind_op(
70                num_heads,
71                head_dim,
72                MaskKind::Custom,
73                score_scale,
74                attn_logit_softcap,
75            ),
76            vec![q, k, v, mask],
77            shape,
78            None,
79        )
80    }
81
82    /// Scaled dot-product attention with a kernel-synthesized mask
83    /// (`None` / `Causal` / `SlidingWindow`). Inputs are Q, K, V only โ€”
84    /// no mask tensor is allocated or read in the inner loop. Use
85    /// `MaskKind::None` for a single un-padded sequence.
86    pub fn attention_kind(
87        &mut self,
88        q: NodeId,
89        k: NodeId,
90        v: NodeId,
91        num_heads: usize,
92        head_dim: usize,
93        mask_kind: MaskKind,
94        shape: Shape,
95    ) -> NodeId {
96        self.attention_kind_opts(q, k, v, num_heads, head_dim, mask_kind, shape, None, None)
97    }
98
99    /// Like [`Self::attention_kind`] with optional score scale and logit softcap.
100    pub fn attention_kind_opts(
101        &mut self,
102        q: NodeId,
103        k: NodeId,
104        v: NodeId,
105        num_heads: usize,
106        head_dim: usize,
107        mask_kind: MaskKind,
108        shape: Shape,
109        score_scale: Option<f32>,
110        attn_logit_softcap: Option<f32>,
111    ) -> NodeId {
112        debug_assert!(
113            !matches!(mask_kind, MaskKind::Custom | MaskKind::Bias),
114            "attention_kind() requires a non-tensor MaskKind; use attention() for Custom or attention_bias() for Bias"
115        );
116        self.push(
117            attention_kind_op(
118                num_heads,
119                head_dim,
120                mask_kind,
121                score_scale,
122                attn_logit_softcap,
123            ),
124            vec![q, k, v],
125            shape,
126            None,
127        )
128    }
129
130    /// Scaled dot-product attention with an additive bias tensor of shape
131    /// `[batch, num_heads, query_len, key_len]` added to the
132    /// `QK^T ยท scale` scores before softmax. Lets boxRPB / per-query
133    /// position biases reuse the fast `Op::Attention` kernel path.
134    pub fn attention_bias(
135        &mut self,
136        q: NodeId,
137        k: NodeId,
138        v: NodeId,
139        bias: NodeId,
140        num_heads: usize,
141        head_dim: usize,
142        shape: Shape,
143    ) -> NodeId {
144        self.push(
145            attention_kind_op(num_heads, head_dim, MaskKind::Bias, None, None),
146            vec![q, k, v, bias],
147            shape,
148            None,
149        )
150    }
151}