rlx_ir/ops/attention.rs
1// RLX โ versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Attention builders: SDPA with custom or kernel-synthesized
17//! masks (plan #53).
18
19use crate::op::MaskKind;
20use crate::{Graph, NodeId, Op, Shape};
21
22/// Build an [`Op::Attention`] with optional score scale and logit softcap.
23pub fn attention_kind_op(
24 num_heads: usize,
25 head_dim: usize,
26 mask_kind: MaskKind,
27 score_scale: Option<f32>,
28 attn_logit_softcap: Option<f32>,
29) -> Op {
30 Op::Attention {
31 num_heads,
32 head_dim,
33 mask_kind,
34 score_scale,
35 attn_logit_softcap,
36 }
37}
38
39impl Graph {
40 /// Scaled dot-product attention with a custom (caller-supplied) mask.
41 /// Equivalent to `attention_kind(.., MaskKind::Custom, ..)`.
42 pub fn attention(
43 &mut self,
44 q: NodeId,
45 k: NodeId,
46 v: NodeId,
47 mask: NodeId,
48 num_heads: usize,
49 head_dim: usize,
50 shape: Shape,
51 ) -> NodeId {
52 self.attention_opts(q, k, v, mask, num_heads, head_dim, shape, None, None)
53 }
54
55 /// Like [`Self::attention`] with optional score scale and logit softcap.
56 pub fn attention_opts(
57 &mut self,
58 q: NodeId,
59 k: NodeId,
60 v: NodeId,
61 mask: NodeId,
62 num_heads: usize,
63 head_dim: usize,
64 shape: Shape,
65 score_scale: Option<f32>,
66 attn_logit_softcap: Option<f32>,
67 ) -> NodeId {
68 self.push(
69 attention_kind_op(
70 num_heads,
71 head_dim,
72 MaskKind::Custom,
73 score_scale,
74 attn_logit_softcap,
75 ),
76 vec![q, k, v, mask],
77 shape,
78 None,
79 )
80 }
81
82 /// Scaled dot-product attention with a kernel-synthesized mask
83 /// (`None` / `Causal` / `SlidingWindow`). Inputs are Q, K, V only โ
84 /// no mask tensor is allocated or read in the inner loop. Use
85 /// `MaskKind::None` for a single un-padded sequence.
86 pub fn attention_kind(
87 &mut self,
88 q: NodeId,
89 k: NodeId,
90 v: NodeId,
91 num_heads: usize,
92 head_dim: usize,
93 mask_kind: MaskKind,
94 shape: Shape,
95 ) -> NodeId {
96 self.attention_kind_opts(q, k, v, num_heads, head_dim, mask_kind, shape, None, None)
97 }
98
99 /// Like [`Self::attention_kind`] with optional score scale and logit softcap.
100 pub fn attention_kind_opts(
101 &mut self,
102 q: NodeId,
103 k: NodeId,
104 v: NodeId,
105 num_heads: usize,
106 head_dim: usize,
107 mask_kind: MaskKind,
108 shape: Shape,
109 score_scale: Option<f32>,
110 attn_logit_softcap: Option<f32>,
111 ) -> NodeId {
112 debug_assert!(
113 !matches!(mask_kind, MaskKind::Custom | MaskKind::Bias),
114 "attention_kind() requires a non-tensor MaskKind; use attention() for Custom or attention_bias() for Bias"
115 );
116 self.push(
117 attention_kind_op(
118 num_heads,
119 head_dim,
120 mask_kind,
121 score_scale,
122 attn_logit_softcap,
123 ),
124 vec![q, k, v],
125 shape,
126 None,
127 )
128 }
129
130 /// Scaled dot-product attention with an additive bias tensor of shape
131 /// `[batch, num_heads, query_len, key_len]` added to the
132 /// `QK^T ยท scale` scores before softmax. Lets boxRPB / per-query
133 /// position biases reuse the fast `Op::Attention` kernel path.
134 pub fn attention_bias(
135 &mut self,
136 q: NodeId,
137 k: NodeId,
138 v: NodeId,
139 bias: NodeId,
140 num_heads: usize,
141 head_dim: usize,
142 shape: Shape,
143 ) -> NodeId {
144 self.push(
145 attention_kind_op(num_heads, head_dim, MaskKind::Bias, None, None),
146 vec![q, k, v, bias],
147 shape,
148 None,
149 )
150 }
151}