Skip to main content

rlx_ir/ops/
elementwise.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Element-wise builders: binary ops, activations (plan #53).
17
18use crate::op::{Activation, BinaryOp};
19use crate::{Graph, NodeId, Op, Shape};
20
21impl Graph {
22    /// Binary element-wise operation.
23    pub fn binary(&mut self, op: BinaryOp, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId {
24        self.push(Op::Binary(op), vec![lhs, rhs], out_shape, None)
25    }
26
27    /// Unary activation.
28    pub fn activation(&mut self, act: Activation, input: NodeId, shape: Shape) -> NodeId {
29        self.push(Op::Activation(act), vec![input], shape, None)
30    }
31
32    /// Per-tensor INT8 quantization. Output dtype = `I8`, same shape
33    /// otherwise. `scale` and `zero_point` apply uniformly to every
34    /// element. Use `quantize_per_channel` when weights deserve
35    /// per-channel scales (the standard PTQ improvement).
36    pub fn quantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId {
37        let shape = self.shape(x).clone().with_dtype(crate::DType::I8);
38        self.push(
39            Op::Quantize {
40                axis: None,
41                scales: vec![scale],
42                zero_points: vec![zero_point],
43            },
44            vec![x],
45            shape,
46            None,
47        )
48    }
49
50    /// Per-channel INT8 quantization. `scales` and `zero_points` must
51    /// each have length `input.dim(axis)`; the kernel picks the i-th
52    /// pair when quantizing the i-th slice along `axis`. The most
53    /// common usage is `axis = 0` for a `[C_out, C_in, kH, kW]`
54    /// conv weight (one scale per output channel).
55    pub fn quantize_per_channel(
56        &mut self,
57        x: NodeId,
58        axis: usize,
59        scales: Vec<f32>,
60        zero_points: Vec<i32>,
61    ) -> NodeId {
62        debug_assert_eq!(scales.len(), zero_points.len());
63        let shape = self.shape(x).clone().with_dtype(crate::DType::I8);
64        debug_assert_eq!(
65            shape.dim(axis),
66            crate::shape::Dim::Static(scales.len()),
67            "quantize_per_channel: scales.len() must match input.dim(axis)"
68        );
69        self.push(
70            Op::Quantize {
71                axis: Some(axis),
72                scales,
73                zero_points,
74            },
75            vec![x],
76            shape,
77            None,
78        )
79    }
80
81    /// Per-tensor INT8 dequantization (inverse of `quantize`). Output
82    /// dtype is f32.
83    pub fn dequantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId {
84        let shape = self.shape(x).clone().with_dtype(crate::DType::F32);
85        self.push(
86            Op::Dequantize {
87                axis: None,
88                scales: vec![scale],
89                zero_points: vec![zero_point],
90            },
91            vec![x],
92            shape,
93            None,
94        )
95    }
96
97    /// Per-channel INT8 dequantization (inverse of
98    /// `quantize_per_channel`).
99    pub fn dequantize_per_channel(
100        &mut self,
101        x: NodeId,
102        axis: usize,
103        scales: Vec<f32>,
104        zero_points: Vec<i32>,
105    ) -> NodeId {
106        debug_assert_eq!(scales.len(), zero_points.len());
107        let shape = self.shape(x).clone().with_dtype(crate::DType::F32);
108        debug_assert_eq!(
109            shape.dim(axis),
110            crate::shape::Dim::Static(scales.len()),
111            "dequantize_per_channel: scales.len() must match input.dim(axis)"
112        );
113        self.push(
114            Op::Dequantize {
115                axis: Some(axis),
116                scales,
117                zero_points,
118            },
119            vec![x],
120            shape,
121            None,
122        )
123    }
124}