rlx_runtime/jacfwd.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Forward-mode Jacobian materialization.
17//!
18//! Pre-vmap convenience: given a compiled JVP graph (output of
19//! `rlx_opt::autodiff_fwd::jvp`), assemble the Jacobian by running
20//! the graph once per standard-basis unit vector and stacking the
21//! tangent outputs.
22//!
23//! Use this when the *input* dimension is small (Circulax component
24//! groups have handfuls of params) — once a `vmap` transformation
25//! lands the same shape gets vectorised into a single call.
26//!
27//! ## Layout convention
28//!
29//! For each primal output of shape `output_shape`, the Jacobian is
30//! returned as a flat byte buffer encoding a row-major
31//! `[output_size, wrt_size]` matrix where:
32//!
33//! * `output_size = product(output_shape)`
34//! * `wrt_size = product(wrt_shape)`
35//!
36//! Element `[i, j]` is `∂output[i] / ∂wrt[j]` after both shapes are
37//! flattened. Callers reshape to the natural
38//! `[output_shape..., wrt_shape...]` if convenient.
39//!
40//! ## Compose with [`crate::Session`]
41//!
42//! ```ignore
43//! use rlx_opt::autodiff_fwd::jvp;
44//! use rlx_runtime::{Session, Device, jacfwd};
45//!
46//! let jvp_graph = jvp(&forward, &[wrt_node]);
47//! let mut compiled = Session::new(Device::Cpu).compile(jvp_graph);
48//! let jacs = jacfwd(&mut compiled, &primals, "x", &[3], DType::F64);
49//! ```
50
51use crate::compiled::CompiledGraph;
52use rlx_ir::DType;
53
54/// One Jacobian per primal output of the original forward graph.
55#[derive(Debug, Clone)]
56pub struct JacobianBytes {
57 /// Flat row-major `[output_size, wrt_size]` matrix.
58 pub bytes: Vec<u8>,
59 /// Number of elements in the primal output (= rows).
60 pub output_size: usize,
61 /// Number of elements in the wrt input (= columns).
62 pub wrt_size: usize,
63 /// Element dtype of `bytes`.
64 pub dtype: DType,
65}
66
67impl JacobianBytes {
68 /// Reinterpret the byte buffer as `&[f64]` (row-major).
69 /// Panics if `dtype != F64` or the byte length isn't a multiple of 8.
70 pub fn as_f64(&self) -> &[f64] {
71 assert_eq!(
72 self.dtype,
73 DType::F64,
74 "as_f64: dtype is {:?}, not F64",
75 self.dtype
76 );
77 assert_eq!(
78 self.bytes.len(),
79 self.output_size * self.wrt_size * 8,
80 "as_f64: byte length doesn't match shape"
81 );
82 // SAFETY: bytes are 8-aligned (rlx-runtime allocates with at
83 // least 8-byte alignment) and the byte length is a multiple of 8.
84 unsafe {
85 std::slice::from_raw_parts(self.bytes.as_ptr() as *const f64, self.bytes.len() / 8)
86 }
87 }
88
89 /// Reinterpret as `&[f32]` (row-major). Mirror of `as_f64`.
90 pub fn as_f32(&self) -> &[f32] {
91 assert_eq!(
92 self.dtype,
93 DType::F32,
94 "as_f32: dtype is {:?}, not F32",
95 self.dtype
96 );
97 assert_eq!(self.bytes.len(), self.output_size * self.wrt_size * 4);
98 unsafe {
99 std::slice::from_raw_parts(self.bytes.as_ptr() as *const f32, self.bytes.len() / 4)
100 }
101 }
102}
103
104/// Materialize the Jacobian of every primal output w.r.t. the input
105/// named `wrt_name`. The compiled graph must be the result of
106/// `rlx_opt::autodiff_fwd::jvp(forward, &[wrt_node])` — it has a
107/// `tangent_<wrt_name>` Input that we drive with unit vectors, and
108/// outputs `[primals..., tangents...]`.
109///
110/// `primals` carries values for every non-tangent input (including
111/// any `Param` that was bound externally; if the model uses
112/// `set_param_typed` for params, call it before invoking `jacfwd` —
113/// params persist across the multiple runs).
114///
115/// One JVP run per element of `wrt_shape`; cost scales linearly with
116/// the wrt dimension. Use reverse-mode (`grad_with_loss`) when the
117/// output dimension is what's small instead.
118pub fn jacfwd(
119 compiled: &mut CompiledGraph,
120 primals: &[(&str, &[u8], DType)],
121 wrt_name: &str,
122 wrt_shape: &[usize],
123 dtype: DType,
124) -> Vec<JacobianBytes> {
125 let elem_size = dtype.size_bytes();
126 let wrt_size: usize = wrt_shape.iter().product();
127 if wrt_size == 0 {
128 return Vec::new();
129 }
130
131 let tangent_name = format!("tangent_{wrt_name}");
132 let mut tangent_buf = vec![0u8; wrt_size * elem_size];
133
134 // First run sets the tangent to e_0 and gives us the output
135 // shapes / sizes. After that we know how to size the Jacobian
136 // buffers and can fill them column by column.
137 set_unit(&mut tangent_buf, 0, dtype);
138 let first = run_one(compiled, primals, &tangent_name, &tangent_buf, dtype);
139 // Outputs are [primals_0..k-1, tangents_0..k-1].
140 assert!(
141 first.len().is_multiple_of(2),
142 "jacfwd: JVP graph must have even output count [primals..., tangents...], got {}",
143 first.len()
144 );
145 let n_outs = first.len() / 2;
146
147 // Allocate Jacobian buffers — `output_size` discovered per-output.
148 let mut jacs: Vec<JacobianBytes> = (0..n_outs)
149 .map(|i| {
150 let (bytes, dt) = &first[n_outs + i];
151 debug_assert_eq!(
152 *dt, dtype,
153 "jacfwd: tangent output {} has dtype {:?}, expected {:?}",
154 i, dt, dtype
155 );
156 let output_size = bytes.len() / elem_size;
157 JacobianBytes {
158 bytes: vec![0u8; output_size * wrt_size * elem_size],
159 output_size,
160 wrt_size,
161 dtype,
162 }
163 })
164 .collect();
165
166 // Write column 0 from the first run, then loop for columns 1..wrt_size.
167 write_column(&first[n_outs..], &mut jacs, 0, elem_size);
168
169 for j in 1..wrt_size {
170 // Reset previous slot, set new unit. (set_unit writes a 1 at
171 // index j; we still need to clear j-1.)
172 clear_index(&mut tangent_buf, j - 1, dtype);
173 set_unit(&mut tangent_buf, j, dtype);
174
175 let outs = run_one(compiled, primals, &tangent_name, &tangent_buf, dtype);
176 write_column(&outs[n_outs..], &mut jacs, j, elem_size);
177 }
178
179 jacs
180}
181
182/// Single-shot run with a freshly-set tangent slot.
183fn run_one(
184 compiled: &mut CompiledGraph,
185 primals: &[(&str, &[u8], DType)],
186 tangent_name: &str,
187 tangent_bytes: &[u8],
188 dtype: DType,
189) -> Vec<(Vec<u8>, DType)> {
190 let mut all = primals.to_vec();
191 all.push((tangent_name, tangent_bytes, dtype));
192 compiled.run_typed(&all)
193}
194
195/// Copy each tangent output into column `j` of its Jacobian.
196fn write_column(
197 tangent_outputs: &[(Vec<u8>, DType)],
198 jacs: &mut [JacobianBytes],
199 j: usize,
200 elem_size: usize,
201) {
202 debug_assert_eq!(tangent_outputs.len(), jacs.len());
203 for (out_idx, (bytes, _)) in tangent_outputs.iter().enumerate() {
204 let jac = &mut jacs[out_idx];
205 debug_assert_eq!(
206 bytes.len(),
207 jac.output_size * elem_size,
208 "tangent output size changed mid-jacfwd run"
209 );
210 // Row-major [output_size, wrt_size] → column j is element
211 // i*wrt_size + j for each i. Single byte-stripe write per row.
212 for i in 0..jac.output_size {
213 let dst_off = (i * jac.wrt_size + j) * elem_size;
214 let src_off = i * elem_size;
215 jac.bytes[dst_off..dst_off + elem_size]
216 .copy_from_slice(&bytes[src_off..src_off + elem_size]);
217 }
218 }
219}
220
221fn set_unit(buf: &mut [u8], idx: usize, dtype: DType) {
222 match dtype {
223 DType::F64 => {
224 let off = idx * 8;
225 buf[off..off + 8].copy_from_slice(&1.0_f64.to_le_bytes());
226 }
227 DType::F32 => {
228 let off = idx * 4;
229 buf[off..off + 4].copy_from_slice(&1.0_f32.to_le_bytes());
230 }
231 other => panic!("jacfwd: dtype {other:?} not supported (f64 / f32 only today)"),
232 }
233}
234
235fn clear_index(buf: &mut [u8], idx: usize, dtype: DType) {
236 let n = dtype.size_bytes();
237 let off = idx * n;
238 for b in &mut buf[off..off + n] {
239 *b = 0;
240 }
241}
242
243#[cfg(test)]
244#[cfg(feature = "cpu")]
245mod tests {
246
247 use rlx_ir::{Graph, Shape};
248 use rlx_opt::autodiff_fwd::jvp;
249
250 fn f64_bytes(xs: &[f64]) -> Vec<u8> {
251 let mut out = Vec::with_capacity(xs.len() * 8);
252 for x in xs {
253 out.extend_from_slice(&x.to_le_bytes());
254 }
255 out
256 }
257
258 /// `f(b) = 3·b` ⇒ `df/db = diag(3)` — smallest possible jacfwd
259 /// shape check that doesn't depend on any other AD machinery.
260 /// Builds a graph that scales `b` by a constant via `Mul`, runs
261 /// `jvp`, then `jacfwd`, and asserts the result is a diagonal of 3s.
262 #[test]
263 fn jacfwd_scalar_mul_gives_diagonal() {
264 use rlx_ir::DType;
265 use rlx_ir::op::BinaryOp;
266 let n = 4usize;
267
268 let mut g = Graph::new("scale");
269 let b = g.input("b", Shape::new(&[n], DType::F64));
270 // Scale constant: a 1-D tensor of 3s.
271 let three_bytes = f64_bytes(&vec![3.0; n]);
272 let three = g.add_node(
273 rlx_ir::Op::Constant { data: three_bytes },
274 vec![],
275 Shape::new(&[n], DType::F64),
276 );
277 let y = g.binary(BinaryOp::Mul, b, three, Shape::new(&[n], DType::F64));
278 g.set_outputs(vec![y]);
279
280 let jg = jvp(&g, &[b]);
281 let mut compiled = crate::Session::new(crate::Device::Cpu).compile(jg);
282
283 let b_data = vec![10.0_f64; n];
284 let jacs = super::jacfwd(
285 &mut compiled,
286 &[("b", &f64_bytes(&b_data), DType::F64)],
287 "b",
288 &[n],
289 DType::F64,
290 );
291 assert_eq!(jacs.len(), 1);
292 let jac = &jacs[0];
293 assert_eq!(jac.output_size, n);
294 assert_eq!(jac.wrt_size, n);
295 let m = jac.as_f64();
296 for i in 0..n {
297 for j in 0..n {
298 let want = if i == j { 3.0 } else { 0.0 };
299 assert!(
300 (m[i * n + j] - want).abs() < 1e-12,
301 "jac[{i},{j}] = {} (expected {want})",
302 m[i * n + j]
303 );
304 }
305 }
306 }
307}