rlx_ir/ops/attention.rs
1// RLX โ versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Attention builders: SDPA with custom or kernel-synthesized
17//! masks (plan #53).
18
19use crate::op::MaskKind;
20use crate::{Graph, NodeId, Op, Shape};
21
22impl Graph {
23 /// Scaled dot-product attention with a custom (caller-supplied) mask.
24 /// Equivalent to `attention_kind(.., MaskKind::Custom, ..)`.
25 pub fn attention(
26 &mut self,
27 q: NodeId,
28 k: NodeId,
29 v: NodeId,
30 mask: NodeId,
31 num_heads: usize,
32 head_dim: usize,
33 shape: Shape,
34 ) -> NodeId {
35 self.push(
36 Op::Attention {
37 num_heads,
38 head_dim,
39 mask_kind: MaskKind::Custom,
40 },
41 vec![q, k, v, mask],
42 shape,
43 None,
44 )
45 }
46
47 /// Scaled dot-product attention with a kernel-synthesized mask
48 /// (`None` / `Causal` / `SlidingWindow`). Inputs are Q, K, V only โ
49 /// no mask tensor is allocated or read in the inner loop. Use
50 /// `MaskKind::None` for a single un-padded sequence.
51 pub fn attention_kind(
52 &mut self,
53 q: NodeId,
54 k: NodeId,
55 v: NodeId,
56 num_heads: usize,
57 head_dim: usize,
58 mask_kind: MaskKind,
59 shape: Shape,
60 ) -> NodeId {
61 debug_assert!(
62 !matches!(mask_kind, MaskKind::Custom | MaskKind::Bias),
63 "attention_kind() requires a non-tensor MaskKind; use attention() for Custom or attention_bias() for Bias"
64 );
65 self.push(
66 Op::Attention {
67 num_heads,
68 head_dim,
69 mask_kind,
70 },
71 vec![q, k, v],
72 shape,
73 None,
74 )
75 }
76
77 /// Scaled dot-product attention with an additive bias tensor of shape
78 /// `[batch, num_heads, query_len, key_len]` added to the
79 /// `QK^T ยท scale` scores before softmax. Lets boxRPB / per-query
80 /// position biases reuse the fast `Op::Attention` kernel path.
81 pub fn attention_bias(
82 &mut self,
83 q: NodeId,
84 k: NodeId,
85 v: NodeId,
86 bias: NodeId,
87 num_heads: usize,
88 head_dim: usize,
89 shape: Shape,
90 ) -> NodeId {
91 self.push(
92 Op::Attention {
93 num_heads,
94 head_dim,
95 mask_kind: MaskKind::Bias,
96 },
97 vec![q, k, v, bias],
98 shape,
99 None,
100 )
101 }
102}