Skip to main content

rlx_ir/ops/
reduction.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reduction builders: reduce, softmax, cumsum, sample
17//! (plan #53).
18
19use crate::op::ReduceOp;
20use crate::{Graph, NodeId, Op, Shape};
21
22impl Graph {
23    /// Reduce.
24    pub fn reduce(
25        &mut self,
26        input: NodeId,
27        op: ReduceOp,
28        axes: Vec<usize>,
29        keep_dim: bool,
30        shape: Shape,
31    ) -> NodeId {
32        self.push(Op::Reduce { op, axes, keep_dim }, vec![input], shape, None)
33    }
34
35    /// Softmax.
36    pub fn softmax(&mut self, input: NodeId, axis: i32, shape: Shape) -> NodeId {
37        self.push(Op::Softmax { axis }, vec![input], shape, None)
38    }
39
40    /// Cumulative sum along an axis (output shape == input shape).
41    pub fn cumsum(&mut self, input: NodeId, axis: i32, exclusive: bool, shape: Shape) -> NodeId {
42        self.push(Op::Cumsum { axis, exclusive }, vec![input], shape, None)
43    }
44
45    /// Fused sample: logits → token id (one f32-encoded id per row).
46    /// `output_shape` should be `[batch]` (one id per logit row).
47    pub fn sample(
48        &mut self,
49        logits: NodeId,
50        top_k: usize,
51        top_p: f32,
52        temperature: f32,
53        seed: u64,
54        output_shape: Shape,
55    ) -> NodeId {
56        self.push(
57            Op::Sample {
58                top_k,
59                top_p,
60                temperature,
61                seed,
62            },
63            vec![logits],
64            output_shape,
65            None,
66        )
67    }
68}