rlx_ir/ops/reduction.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reduction builders: reduce, softmax, cumsum, sample
17//! (plan #53).
18
19use crate::op::ReduceOp;
20use crate::{Graph, NodeId, Op, Shape};
21
22impl Graph {
23 /// Reduce.
24 pub fn reduce(
25 &mut self,
26 input: NodeId,
27 op: ReduceOp,
28 axes: Vec<usize>,
29 keep_dim: bool,
30 shape: Shape,
31 ) -> NodeId {
32 self.push(Op::Reduce { op, axes, keep_dim }, vec![input], shape, None)
33 }
34
35 /// Softmax.
36 pub fn softmax(&mut self, input: NodeId, axis: i32, shape: Shape) -> NodeId {
37 self.push(Op::Softmax { axis }, vec![input], shape, None)
38 }
39
40 /// Cumulative sum along an axis (output shape == input shape).
41 pub fn cumsum(&mut self, input: NodeId, axis: i32, exclusive: bool, shape: Shape) -> NodeId {
42 self.push(Op::Cumsum { axis, exclusive }, vec![input], shape, None)
43 }
44
45 /// Index of the max along `axis` (f32-encoded indices).
46 pub fn argmax(&mut self, input: NodeId, axis: usize, keep_dim: bool, shape: Shape) -> NodeId {
47 self.push(Op::ArgMax { axis, keep_dim }, vec![input], shape, None)
48 }
49
50 /// Index of the min along `axis` (f32-encoded indices).
51 pub fn argmin(&mut self, input: NodeId, axis: usize, keep_dim: bool, shape: Shape) -> NodeId {
52 self.push(Op::ArgMin { axis, keep_dim }, vec![input], shape, None)
53 }
54
55 /// Fused sample: logits → token id (one f32-encoded id per row).
56 /// `output_shape` should be `[batch]` (one id per logit row).
57 pub fn sample(
58 &mut self,
59 logits: NodeId,
60 top_k: usize,
61 top_p: f32,
62 temperature: f32,
63 seed: u64,
64 output_shape: Shape,
65 ) -> NodeId {
66 self.push(
67 Op::Sample {
68 top_k,
69 top_p,
70 temperature,
71 seed,
72 },
73 vec![logits],
74 output_shape,
75 None,
76 )
77 }
78}