Skip to main content

rlx_ir/ops/
attention.rs

1// RLX โ€” versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Attention builders: SDPA with custom or kernel-synthesized
17//! masks (plan #53).
18
19use crate::op::MaskKind;
20use crate::{Graph, NodeId, Op, Shape};
21
22impl Graph {
23    /// Scaled dot-product attention with a custom (caller-supplied) mask.
24    /// Equivalent to `attention_kind(.., MaskKind::Custom, ..)`.
25    pub fn attention(
26        &mut self,
27        q: NodeId,
28        k: NodeId,
29        v: NodeId,
30        mask: NodeId,
31        num_heads: usize,
32        head_dim: usize,
33        shape: Shape,
34    ) -> NodeId {
35        self.push(
36            Op::Attention {
37                num_heads,
38                head_dim,
39                mask_kind: MaskKind::Custom,
40            },
41            vec![q, k, v, mask],
42            shape,
43            None,
44        )
45    }
46
47    /// Scaled dot-product attention with a kernel-synthesized mask
48    /// (`None` / `Causal` / `SlidingWindow`). Inputs are Q, K, V only โ€”
49    /// no mask tensor is allocated or read in the inner loop. Use
50    /// `MaskKind::None` for a single un-padded sequence.
51    pub fn attention_kind(
52        &mut self,
53        q: NodeId,
54        k: NodeId,
55        v: NodeId,
56        num_heads: usize,
57        head_dim: usize,
58        mask_kind: MaskKind,
59        shape: Shape,
60    ) -> NodeId {
61        debug_assert!(
62            !matches!(mask_kind, MaskKind::Custom | MaskKind::Bias),
63            "attention_kind() requires a non-tensor MaskKind; use attention() for Custom or attention_bias() for Bias"
64        );
65        self.push(
66            Op::Attention {
67                num_heads,
68                head_dim,
69                mask_kind,
70            },
71            vec![q, k, v],
72            shape,
73            None,
74        )
75    }
76
77    /// Scaled dot-product attention with an additive bias tensor of shape
78    /// `[batch, num_heads, query_len, key_len]` added to the
79    /// `QK^T ยท scale` scores before softmax. Lets boxRPB / per-query
80    /// position biases reuse the fast `Op::Attention` kernel path.
81    pub fn attention_bias(
82        &mut self,
83        q: NodeId,
84        k: NodeId,
85        v: NodeId,
86        bias: NodeId,
87        num_heads: usize,
88        head_dim: usize,
89        shape: Shape,
90    ) -> NodeId {
91        self.push(
92            Op::Attention {
93                num_heads,
94                head_dim,
95                mask_kind: MaskKind::Bias,
96            },
97            vec![q, k, v, bias],
98            shape,
99            None,
100        )
101    }
102}