rlx_ir/ops/elementwise.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Element-wise builders: binary ops, activations (plan #53).
17
18use crate::op::{Activation, BinaryOp};
19use crate::{Graph, NodeId, Op, Shape};
20
21impl Graph {
22 /// Binary element-wise operation.
23 pub fn binary(&mut self, op: BinaryOp, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId {
24 self.push(Op::Binary(op), vec![lhs, rhs], out_shape, None)
25 }
26
27 /// Unary activation.
28 pub fn activation(&mut self, act: Activation, input: NodeId, shape: Shape) -> NodeId {
29 self.push(Op::Activation(act), vec![input], shape, None)
30 }
31
32 /// Per-tensor INT8 quantization. Output dtype = `I8`, same shape
33 /// otherwise. `scale` and `zero_point` apply uniformly to every
34 /// element. Use `quantize_per_channel` when weights deserve
35 /// per-channel scales (the standard PTQ improvement).
36 pub fn quantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId {
37 let shape = self.shape(x).clone().with_dtype(crate::DType::I8);
38 self.push(
39 Op::Quantize {
40 axis: None,
41 scales: vec![scale],
42 zero_points: vec![zero_point],
43 },
44 vec![x],
45 shape,
46 None,
47 )
48 }
49
50 /// Per-channel INT8 quantization. `scales` and `zero_points` must
51 /// each have length `input.dim(axis)`; the kernel picks the i-th
52 /// pair when quantizing the i-th slice along `axis`. The most
53 /// common usage is `axis = 0` for a `[C_out, C_in, kH, kW]`
54 /// conv weight (one scale per output channel).
55 pub fn quantize_per_channel(
56 &mut self,
57 x: NodeId,
58 axis: usize,
59 scales: Vec<f32>,
60 zero_points: Vec<i32>,
61 ) -> NodeId {
62 debug_assert_eq!(scales.len(), zero_points.len());
63 let shape = self.shape(x).clone().with_dtype(crate::DType::I8);
64 debug_assert_eq!(
65 shape.dim(axis),
66 crate::shape::Dim::Static(scales.len()),
67 "quantize_per_channel: scales.len() must match input.dim(axis)"
68 );
69 self.push(
70 Op::Quantize {
71 axis: Some(axis),
72 scales,
73 zero_points,
74 },
75 vec![x],
76 shape,
77 None,
78 )
79 }
80
81 /// Per-tensor INT8 dequantization (inverse of `quantize`). Output
82 /// dtype is f32.
83 pub fn dequantize(&mut self, x: NodeId, scale: f32, zero_point: i32) -> NodeId {
84 let shape = self.shape(x).clone().with_dtype(crate::DType::F32);
85 self.push(
86 Op::Dequantize {
87 axis: None,
88 scales: vec![scale],
89 zero_points: vec![zero_point],
90 },
91 vec![x],
92 shape,
93 None,
94 )
95 }
96
97 /// Per-channel INT8 dequantization (inverse of
98 /// `quantize_per_channel`).
99 pub fn dequantize_per_channel(
100 &mut self,
101 x: NodeId,
102 axis: usize,
103 scales: Vec<f32>,
104 zero_points: Vec<i32>,
105 ) -> NodeId {
106 debug_assert_eq!(scales.len(), zero_points.len());
107 let shape = self.shape(x).clone().with_dtype(crate::DType::F32);
108 debug_assert_eq!(
109 shape.dim(axis),
110 crate::shape::Dim::Static(scales.len()),
111 "dequantize_per_channel: scales.len() must match input.dim(axis)"
112 );
113 self.push(
114 Op::Dequantize {
115 axis: Some(axis),
116 scales,
117 zero_points,
118 },
119 vec![x],
120 shape,
121 None,
122 )
123 }
124}