rlx_ir/ops/io.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Graph I/O builders: inputs, parameters (plan #53).
17
18use crate::{Graph, NodeId, Op, Shape};
19
20impl Graph {
21 /// Graph input (runtime-provided tensor).
22 pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> NodeId {
23 let name: String = name.into();
24 self.push(Op::Input { name: name.clone() }, vec![], shape, Some(name))
25 }
26
27 /// Model parameter (weight loaded at init).
28 pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> NodeId {
29 let name: String = name.into();
30 self.push(Op::Param { name: name.clone() }, vec![], shape, Some(name))
31 }
32
33 /// Generic node constructor for custom ops.
34 pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, shape: Shape) -> NodeId {
35 self.push(op, inputs, shape, None)
36 }
37
38 /// Build an `Op::Custom` node, dispatching shape inference through
39 /// the global op registry. The named op must already be registered
40 /// via [`crate::register_op`]; `attrs` is forwarded verbatim to
41 /// the impl's `infer_shape` (and later, at execution time, to its
42 /// per-backend kernel).
43 ///
44 /// Panics if `name` is not registered or if `inputs.len()` does
45 /// not match the registered `num_inputs()` — both are programmer
46 /// errors that should fail loudly at graph-build time, not silently
47 /// at execution.
48 pub fn custom_op(
49 &mut self,
50 name: impl Into<String>,
51 attrs: Vec<u8>,
52 inputs: Vec<NodeId>,
53 ) -> NodeId {
54 let name: String = name.into();
55 let ext = crate::lookup_op(&name)
56 .unwrap_or_else(|| panic!("custom_op: '{name}' is not registered in the op registry"));
57 assert_eq!(
58 ext.num_inputs(),
59 inputs.len(),
60 "custom_op '{name}': registered op expects {} inputs, got {}",
61 ext.num_inputs(),
62 inputs.len(),
63 );
64 let in_shapes: Vec<&Shape> = inputs.iter().map(|id| self.shape(*id)).collect();
65 let out_shape = ext.infer_shape(&in_shapes, &attrs);
66 let num_inputs = ext.num_inputs() as u32;
67 self.push(
68 Op::Custom {
69 name,
70 num_inputs,
71 attrs,
72 },
73 inputs,
74 out_shape,
75 None,
76 )
77 }
78
79 /// Build an `Op::Custom` node with a caller-supplied output shape,
80 /// **bypassing** the registry's `infer_shape`. Use this for ops
81 /// whose output shape can't be determined by static input shapes
82 /// alone — most importantly, ops with multiple logical outputs
83 /// packed into one buffer.
84 ///
85 /// The canonical multi-output pattern:
86 ///
87 /// ```ignore
88 /// // Sparse-LU returns L_values + U_values packed end-to-end.
89 /// // Caller knows nnz_L and nnz_U from the symbolic factor.
90 /// let lu = g.custom_op_packed(
91 /// "sparse_lu",
92 /// attrs,
93 /// vec![A, b],
94 /// Shape::new(&[nnz_L + nnz_U], DType::F64),
95 /// );
96 /// let l_vals = g.narrow_(lu, 0, 0, nnz_L);
97 /// let u_vals = g.narrow_(lu, 0, nnz_L, nnz_U);
98 /// ```
99 ///
100 /// The op must still be registered (so `num_inputs` validation
101 /// and autodiff routing still work); only the shape is overridden.
102 pub fn custom_op_packed(
103 &mut self,
104 name: impl Into<String>,
105 attrs: Vec<u8>,
106 inputs: Vec<NodeId>,
107 out_shape: Shape,
108 ) -> NodeId {
109 let name: String = name.into();
110 let ext = crate::lookup_op(&name).unwrap_or_else(|| {
111 panic!("custom_op_packed: '{name}' is not registered in the op registry")
112 });
113 assert_eq!(
114 ext.num_inputs(),
115 inputs.len(),
116 "custom_op_packed '{name}': registered op expects {} inputs, got {}",
117 ext.num_inputs(),
118 inputs.len(),
119 );
120 let num_inputs = ext.num_inputs() as u32;
121 self.push(
122 Op::Custom {
123 name,
124 num_inputs,
125 attrs,
126 },
127 inputs,
128 out_shape,
129 None,
130 )
131 }
132
133 /// 1D FFT along the last axis of the 2N real-block complex layout.
134 /// Last axis size must be even (so the 2N real-block layout
135 /// resolves to an integer number of complex points). Output shape
136 /// == input shape.
137 ///
138 /// The CPU kernel uses radix-2 Cooley-Tukey when the complex
139 /// length `N = last/2` is a power of two, and Bluestein's
140 /// algorithm (chirp z-transform) otherwise. There is no
141 /// size restriction beyond `last` being even.
142 ///
143 /// See `Op::Fft` for the normalization convention
144 /// (unnormalized; ifft(fft(x)) = N·x).
145 pub fn fft(&mut self, x: NodeId, inverse: bool) -> NodeId {
146 let s = self.shape(x).clone();
147 assert!(s.rank() >= 1, "fft: tensor must have at least 1 axis");
148 let last = s.rank() - 1;
149 match s.dim(last) {
150 crate::shape::Dim::Static(n) => {
151 assert!(
152 n % 2 == 0,
153 "fft: last axis size {n} must be even (2N real-block layout)"
154 );
155 }
156 _ => panic!("fft: dynamic last-axis size not supported"),
157 }
158 self.push(Op::Fft { inverse }, vec![x], s, None)
159 }
160
161 /// 1D FFT along an arbitrary axis (not just the last). Lowers to
162 /// `Transpose(axis ↔ last) → Fft(last) → Transpose(last ↔ axis)`
163 /// — the 2N-real-block convention is intrinsic to whichever axis
164 /// the FFT runs along, and `Op::Transpose` is a pure permutation,
165 /// so semantics transport correctly.
166 ///
167 /// Limitation: this still describes a tensor with a *single*
168 /// complex axis. True ND `fftn` (e.g. 2D FFT of a 2D-complex
169 /// array, where two axes are independently complex) cannot be
170 /// expressed in the 2N-real-block layout — it needs native
171 /// `DType::C64` to keep the real/imag split off the axis grid.
172 /// See PLAN.md for the deferred C64 workstream.
173 ///
174 /// AD is free: Transpose and Fft both have VJP and JVP rules,
175 /// so the composition differentiates automatically.
176 pub fn fft_axis(&mut self, x: NodeId, axis: usize, inverse: bool) -> NodeId {
177 use crate::infer::GraphExt as _;
178 let rank = self.shape(x).rank();
179 assert!(
180 axis < rank,
181 "fft_axis: axis {axis} out of range for rank-{rank} tensor"
182 );
183 let last = rank - 1;
184 if axis == last {
185 // Fast path — no transpose needed.
186 return self.fft(x, inverse);
187 }
188 // perm = identity with `axis` ↔ `last` swapped. Same perm in
189 // both directions because it's a transposition (its own inverse).
190 let mut perm: Vec<usize> = (0..rank).collect();
191 perm.swap(axis, last);
192
193 let x_t = self.transpose_(x, perm.clone());
194 let y_t = self.fft(x_t, inverse);
195 self.transpose_(y_t, perm)
196 }
197}