Skip to main content

rlx_ir/ops/
io.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Graph I/O builders: inputs, parameters (plan #53).
17
18use crate::{Graph, NodeId, Op, Shape};
19
20impl Graph {
21    /// Graph input (runtime-provided tensor).
22    pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> NodeId {
23        let name: String = name.into();
24        self.push(Op::Input { name: name.clone() }, vec![], shape, Some(name))
25    }
26
27    /// Model parameter (weight loaded at init).
28    pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> NodeId {
29        let name: String = name.into();
30        self.push(Op::Param { name: name.clone() }, vec![], shape, Some(name))
31    }
32
33    /// Generic node constructor for custom ops.
34    pub fn add_node(&mut self, op: Op, inputs: Vec<NodeId>, shape: Shape) -> NodeId {
35        self.push(op, inputs, shape, None)
36    }
37
38    /// Build an `Op::Custom` node, dispatching shape inference through
39    /// the global op registry. The named op must already be registered
40    /// via [`crate::register_op`]; `attrs` is forwarded verbatim to
41    /// the impl's `infer_shape` (and later, at execution time, to its
42    /// per-backend kernel).
43    ///
44    /// Panics if `name` is not registered or if `inputs.len()` does
45    /// not match the registered `num_inputs()` — both are programmer
46    /// errors that should fail loudly at graph-build time, not silently
47    /// at execution.
48    pub fn custom_op(
49        &mut self,
50        name: impl Into<String>,
51        attrs: Vec<u8>,
52        inputs: Vec<NodeId>,
53    ) -> NodeId {
54        let name: String = name.into();
55        let ext = crate::lookup_op(&name)
56            .unwrap_or_else(|| panic!("custom_op: '{name}' is not registered in the op registry"));
57        assert_eq!(
58            ext.num_inputs(),
59            inputs.len(),
60            "custom_op '{name}': registered op expects {} inputs, got {}",
61            ext.num_inputs(),
62            inputs.len(),
63        );
64        let in_shapes: Vec<&Shape> = inputs.iter().map(|id| self.shape(*id)).collect();
65        let out_shape = ext.infer_shape(&in_shapes, &attrs);
66        let num_inputs = ext.num_inputs() as u32;
67        self.push(
68            Op::Custom {
69                name,
70                num_inputs,
71                attrs,
72            },
73            inputs,
74            out_shape,
75            None,
76        )
77    }
78
79    /// Build an `Op::Custom` node with a caller-supplied output shape,
80    /// **bypassing** the registry's `infer_shape`. Use this for ops
81    /// whose output shape can't be determined by static input shapes
82    /// alone — most importantly, ops with multiple logical outputs
83    /// packed into one buffer.
84    ///
85    /// The canonical multi-output pattern:
86    ///
87    /// ```ignore
88    /// // Sparse-LU returns L_values + U_values packed end-to-end.
89    /// // Caller knows nnz_L and nnz_U from the symbolic factor.
90    /// let lu = g.custom_op_packed(
91    ///     "sparse_lu",
92    ///     attrs,
93    ///     vec![A, b],
94    ///     Shape::new(&[nnz_L + nnz_U], DType::F64),
95    /// );
96    /// let l_vals = g.narrow_(lu, 0, 0, nnz_L);
97    /// let u_vals = g.narrow_(lu, 0, nnz_L, nnz_U);
98    /// ```
99    ///
100    /// The op must still be registered (so `num_inputs` validation
101    /// and autodiff routing still work); only the shape is overridden.
102    pub fn custom_op_packed(
103        &mut self,
104        name: impl Into<String>,
105        attrs: Vec<u8>,
106        inputs: Vec<NodeId>,
107        out_shape: Shape,
108    ) -> NodeId {
109        let name: String = name.into();
110        let ext = crate::lookup_op(&name).unwrap_or_else(|| {
111            panic!("custom_op_packed: '{name}' is not registered in the op registry")
112        });
113        assert_eq!(
114            ext.num_inputs(),
115            inputs.len(),
116            "custom_op_packed '{name}': registered op expects {} inputs, got {}",
117            ext.num_inputs(),
118            inputs.len(),
119        );
120        let num_inputs = ext.num_inputs() as u32;
121        self.push(
122            Op::Custom {
123                name,
124                num_inputs,
125                attrs,
126            },
127            inputs,
128            out_shape,
129            None,
130        )
131    }
132
133    /// 1D FFT along the last axis.
134    ///
135    /// * **F32 / F64** — 2N real-block layout: last axis is `[re…, im…]`.
136    /// * **C64** — interleaved `[re, im]` pairs per complex element.
137    ///
138    /// Output shape matches input. Radix-2 when `N` is a power of two,
139    /// Bluestein otherwise. Default normalization is unnormalized
140    /// (`FftNorm::Backward`; `ifft(fft(x)) = N·x`).
141    pub fn fft(&mut self, x: NodeId, inverse: bool) -> NodeId {
142        self.fft_norm(x, inverse, crate::fft::FftNorm::Backward)
143    }
144
145    /// 1D FFT with explicit normalization mode.
146    pub fn fft_norm(&mut self, x: NodeId, inverse: bool, norm: crate::fft::FftNorm) -> NodeId {
147        let s = self.shape(x).clone();
148        crate::fft::fft_meta(&s);
149        self.push(Op::Fft { inverse, norm }, vec![x], s, None)
150    }
151
152    /// 1D FFT along an arbitrary axis. Lowers to
153    /// `Transpose(axis ↔ last) → Fft(last) → Transpose(last ↔ axis)`.
154    ///
155    /// AD is free: both `Op::Transpose` and `Op::Fft` have VJP/JVP rules.
156    pub fn fft_axis(&mut self, x: NodeId, axis: usize, inverse: bool) -> NodeId {
157        use crate::infer::GraphExt as _;
158        let rank = self.shape(x).rank();
159        assert!(
160            axis < rank,
161            "fft_axis: axis {axis} out of range for rank-{rank} tensor"
162        );
163        let last = rank - 1;
164        if axis == last {
165            return self.fft(x, inverse);
166        }
167        let mut perm: Vec<usize> = (0..rank).collect();
168        perm.swap(axis, last);
169
170        let x_t = self.transpose_(x, perm.clone());
171        let y_t = self.fft(x_t, inverse);
172        self.transpose_(y_t, perm)
173    }
174
175    /// N-dimensional FFT along `axes` (NumPy `fftn` semantics).
176    ///
177    /// Applies a 1D FFT along each listed axis in ascending order.
178    /// Empty `axes` is a no-op. For multi-axis transforms on tensors
179    /// with more than one spatial dimension, use `DType::C64`; the
180    /// F32/F64 2N-block layout only describes a single complex axis.
181    pub fn fftn(&mut self, x: NodeId, axes: &[usize], inverse: bool) -> NodeId {
182        let rank = self.shape(x).rank();
183        let axes = crate::fft::normalize_fftn_axes(rank, axes);
184        if axes.is_empty() {
185            return x;
186        }
187        if axes.len() > 1 && !self.shape(x).dtype().is_complex() {
188            panic!(
189                "fftn: multi-axis FFT on {:?} requires DType::C64; \
190                 the F32/F64 2N real-block layout supports only one complex axis — \
191                 call fft_axis for a single transform",
192                self.shape(x).dtype()
193            );
194        }
195        let mut y = x;
196        for axis in axes {
197            y = self.fft_axis(y, axis, inverse);
198        }
199        y
200    }
201
202    /// Inverse N-dimensional FFT — alias for `fftn(..., inverse: true)`.
203    pub fn ifftn(&mut self, x: NodeId, axes: &[usize]) -> NodeId {
204        self.fftn(x, axes, true)
205    }
206}