Skip to main content

rlx_ir/ops/
reduction.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reduction builders: reduce, softmax, cumsum, sample
17//! (plan #53).
18
19use crate::op::ReduceOp;
20use crate::{Graph, NodeId, Op, Shape};
21
22impl Graph {
23    /// Reduce.
24    pub fn reduce(
25        &mut self,
26        input: NodeId,
27        op: ReduceOp,
28        axes: Vec<usize>,
29        keep_dim: bool,
30        shape: Shape,
31    ) -> NodeId {
32        self.push(Op::Reduce { op, axes, keep_dim }, vec![input], shape, None)
33    }
34
35    /// Softmax.
36    pub fn softmax(&mut self, input: NodeId, axis: i32, shape: Shape) -> NodeId {
37        self.push(Op::Softmax { axis }, vec![input], shape, None)
38    }
39
40    /// Cumulative sum along an axis (output shape == input shape).
41    pub fn cumsum(&mut self, input: NodeId, axis: i32, exclusive: bool, shape: Shape) -> NodeId {
42        self.push(Op::Cumsum { axis, exclusive }, vec![input], shape, None)
43    }
44
45    /// Index of the max along `axis` (f32-encoded indices).
46    pub fn argmax(&mut self, input: NodeId, axis: usize, keep_dim: bool, shape: Shape) -> NodeId {
47        self.push(Op::ArgMax { axis, keep_dim }, vec![input], shape, None)
48    }
49
50    /// Index of the min along `axis` (f32-encoded indices).
51    pub fn argmin(&mut self, input: NodeId, axis: usize, keep_dim: bool, shape: Shape) -> NodeId {
52        self.push(Op::ArgMin { axis, keep_dim }, vec![input], shape, None)
53    }
54
55    /// Fused sample: logits → token id (one f32-encoded id per row).
56    /// `output_shape` should be `[batch]` (one id per logit row).
57    pub fn sample(
58        &mut self,
59        logits: NodeId,
60        top_k: usize,
61        top_p: f32,
62        temperature: f32,
63        seed: u64,
64        output_shape: Shape,
65    ) -> NodeId {
66        self.push(
67            Op::Sample {
68                top_k,
69                top_p,
70                temperature,
71                seed,
72            },
73            vec![logits],
74            output_shape,
75            None,
76        )
77    }
78}