rlx_ir/ops/reduction.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reduction builders: reduce, softmax, cumsum, sample
17//! (plan #53).
18
19use crate::op::ReduceOp;
20use crate::{Graph, NodeId, Op, Shape};
21
22impl Graph {
23 /// Reduce.
24 pub fn reduce(
25 &mut self,
26 input: NodeId,
27 op: ReduceOp,
28 axes: Vec<usize>,
29 keep_dim: bool,
30 shape: Shape,
31 ) -> NodeId {
32 self.push(Op::Reduce { op, axes, keep_dim }, vec![input], shape, None)
33 }
34
35 /// Softmax.
36 pub fn softmax(&mut self, input: NodeId, axis: i32, shape: Shape) -> NodeId {
37 self.push(Op::Softmax { axis }, vec![input], shape, None)
38 }
39
40 /// Cumulative sum along an axis (output shape == input shape).
41 pub fn cumsum(&mut self, input: NodeId, axis: i32, exclusive: bool, shape: Shape) -> NodeId {
42 self.push(Op::Cumsum { axis, exclusive }, vec![input], shape, None)
43 }
44
45 /// Fused sample: logits → token id (one f32-encoded id per row).
46 /// `output_shape` should be `[batch]` (one id per logit row).
47 pub fn sample(
48 &mut self,
49 logits: NodeId,
50 top_k: usize,
51 top_p: f32,
52 temperature: f32,
53 seed: u64,
54 output_shape: Shape,
55 ) -> NodeId {
56 self.push(
57 Op::Sample {
58 top_k,
59 top_p,
60 temperature,
61 seed,
62 },
63 vec![logits],
64 output_shape,
65 None,
66 )
67 }
68}