Skip to main content

rlx_runtime/
jacfwd.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Forward-mode Jacobian materialization.
17//!
18//! Pre-vmap convenience: given a compiled JVP graph (output of
19//! `rlx_opt::autodiff_fwd::jvp`), assemble the Jacobian by running
20//! the graph once per standard-basis unit vector and stacking the
21//! tangent outputs.
22//!
23//! Use this when the *input* dimension is small (Circulax component
24//! groups have handfuls of params) — once a `vmap` transformation
25//! lands the same shape gets vectorised into a single call.
26//!
27//! ## Layout convention
28//!
29//! For each primal output of shape `output_shape`, the Jacobian is
30//! returned as a flat byte buffer encoding a row-major
31//! `[output_size, wrt_size]` matrix where:
32//!
33//! * `output_size = product(output_shape)`
34//! * `wrt_size    = product(wrt_shape)`
35//!
36//! Element `[i, j]` is `∂output[i] / ∂wrt[j]` after both shapes are
37//! flattened. Callers reshape to the natural
38//! `[output_shape..., wrt_shape...]` if convenient.
39//!
40//! ## Compose with [`crate::Session`]
41//!
42//! ```ignore
43//! use rlx_opt::autodiff_fwd::jvp;
44//! use rlx_runtime::{Session, Device, jacfwd};
45//!
46//! let jvp_graph = jvp(&forward, &[wrt_node]);
47//! let mut compiled = Session::new(Device::Cpu).compile(jvp_graph);
48//! let jacs = jacfwd(&mut compiled, &primals, "x", &[3], DType::F64);
49//! ```
50
51use crate::compiled::CompiledGraph;
52use rlx_ir::DType;
53
54/// One Jacobian per primal output of the original forward graph.
55#[derive(Debug, Clone)]
56pub struct JacobianBytes {
57    /// Flat row-major `[output_size, wrt_size]` matrix.
58    pub bytes: Vec<u8>,
59    /// Number of elements in the primal output (= rows).
60    pub output_size: usize,
61    /// Number of elements in the wrt input (= columns).
62    pub wrt_size: usize,
63    /// Element dtype of `bytes`.
64    pub dtype: DType,
65}
66
67impl JacobianBytes {
68    /// Reinterpret the byte buffer as `&[f64]` (row-major).
69    /// Panics if `dtype != F64` or the byte length isn't a multiple of 8.
70    pub fn as_f64(&self) -> &[f64] {
71        assert_eq!(
72            self.dtype,
73            DType::F64,
74            "as_f64: dtype is {:?}, not F64",
75            self.dtype
76        );
77        assert_eq!(
78            self.bytes.len(),
79            self.output_size * self.wrt_size * 8,
80            "as_f64: byte length doesn't match shape"
81        );
82        // SAFETY: bytes are 8-aligned (rlx-runtime allocates with at
83        // least 8-byte alignment) and the byte length is a multiple of 8.
84        unsafe {
85            std::slice::from_raw_parts(self.bytes.as_ptr() as *const f64, self.bytes.len() / 8)
86        }
87    }
88
89    /// Reinterpret as `&[f32]` (row-major). Mirror of `as_f64`.
90    pub fn as_f32(&self) -> &[f32] {
91        assert_eq!(
92            self.dtype,
93            DType::F32,
94            "as_f32: dtype is {:?}, not F32",
95            self.dtype
96        );
97        assert_eq!(self.bytes.len(), self.output_size * self.wrt_size * 4);
98        unsafe {
99            std::slice::from_raw_parts(self.bytes.as_ptr() as *const f32, self.bytes.len() / 4)
100        }
101    }
102}
103
104/// Materialize the Jacobian of every primal output w.r.t. the input
105/// named `wrt_name`. The compiled graph must be the result of
106/// `rlx_opt::autodiff_fwd::jvp(forward, &[wrt_node])` — it has a
107/// `tangent_<wrt_name>` Input that we drive with unit vectors, and
108/// outputs `[primals..., tangents...]`.
109///
110/// `primals` carries values for every non-tangent input (including
111/// any `Param` that was bound externally; if the model uses
112/// `set_param_typed` for params, call it before invoking `jacfwd` —
113/// params persist across the multiple runs).
114///
115/// One JVP run per element of `wrt_shape`; cost scales linearly with
116/// the wrt dimension. Use reverse-mode (`grad_with_loss`) when the
117/// output dimension is what's small instead.
118pub fn jacfwd(
119    compiled: &mut CompiledGraph,
120    primals: &[(&str, &[u8], DType)],
121    wrt_name: &str,
122    wrt_shape: &[usize],
123    dtype: DType,
124) -> Vec<JacobianBytes> {
125    let elem_size = dtype.size_bytes();
126    let wrt_size: usize = wrt_shape.iter().product();
127    if wrt_size == 0 {
128        return Vec::new();
129    }
130
131    let tangent_name = format!("tangent_{wrt_name}");
132    let mut tangent_buf = vec![0u8; wrt_size * elem_size];
133
134    // First run sets the tangent to e_0 and gives us the output
135    // shapes / sizes. After that we know how to size the Jacobian
136    // buffers and can fill them column by column.
137    set_unit(&mut tangent_buf, 0, dtype);
138    let first = run_one(compiled, primals, &tangent_name, &tangent_buf, dtype);
139    // Outputs are [primals_0..k-1, tangents_0..k-1].
140    assert!(
141        first.len().is_multiple_of(2),
142        "jacfwd: JVP graph must have even output count [primals..., tangents...], got {}",
143        first.len()
144    );
145    let n_outs = first.len() / 2;
146
147    // Allocate Jacobian buffers — `output_size` discovered per-output.
148    let mut jacs: Vec<JacobianBytes> = (0..n_outs)
149        .map(|i| {
150            let (bytes, dt) = &first[n_outs + i];
151            debug_assert_eq!(
152                *dt, dtype,
153                "jacfwd: tangent output {} has dtype {:?}, expected {:?}",
154                i, dt, dtype
155            );
156            let output_size = bytes.len() / elem_size;
157            JacobianBytes {
158                bytes: vec![0u8; output_size * wrt_size * elem_size],
159                output_size,
160                wrt_size,
161                dtype,
162            }
163        })
164        .collect();
165
166    // Write column 0 from the first run, then loop for columns 1..wrt_size.
167    write_column(&first[n_outs..], &mut jacs, 0, elem_size);
168
169    for j in 1..wrt_size {
170        // Reset previous slot, set new unit. (set_unit writes a 1 at
171        // index j; we still need to clear j-1.)
172        clear_index(&mut tangent_buf, j - 1, dtype);
173        set_unit(&mut tangent_buf, j, dtype);
174
175        let outs = run_one(compiled, primals, &tangent_name, &tangent_buf, dtype);
176        write_column(&outs[n_outs..], &mut jacs, j, elem_size);
177    }
178
179    jacs
180}
181
182/// Single-shot run with a freshly-set tangent slot.
183fn run_one(
184    compiled: &mut CompiledGraph,
185    primals: &[(&str, &[u8], DType)],
186    tangent_name: &str,
187    tangent_bytes: &[u8],
188    dtype: DType,
189) -> Vec<(Vec<u8>, DType)> {
190    let mut all = primals.to_vec();
191    all.push((tangent_name, tangent_bytes, dtype));
192    compiled.run_typed(&all)
193}
194
195/// Copy each tangent output into column `j` of its Jacobian.
196fn write_column(
197    tangent_outputs: &[(Vec<u8>, DType)],
198    jacs: &mut [JacobianBytes],
199    j: usize,
200    elem_size: usize,
201) {
202    debug_assert_eq!(tangent_outputs.len(), jacs.len());
203    for (out_idx, (bytes, _)) in tangent_outputs.iter().enumerate() {
204        let jac = &mut jacs[out_idx];
205        debug_assert_eq!(
206            bytes.len(),
207            jac.output_size * elem_size,
208            "tangent output size changed mid-jacfwd run"
209        );
210        // Row-major [output_size, wrt_size] → column j is element
211        // i*wrt_size + j for each i. Single byte-stripe write per row.
212        for i in 0..jac.output_size {
213            let dst_off = (i * jac.wrt_size + j) * elem_size;
214            let src_off = i * elem_size;
215            jac.bytes[dst_off..dst_off + elem_size]
216                .copy_from_slice(&bytes[src_off..src_off + elem_size]);
217        }
218    }
219}
220
221fn set_unit(buf: &mut [u8], idx: usize, dtype: DType) {
222    match dtype {
223        DType::F64 => {
224            let off = idx * 8;
225            buf[off..off + 8].copy_from_slice(&1.0_f64.to_le_bytes());
226        }
227        DType::F32 => {
228            let off = idx * 4;
229            buf[off..off + 4].copy_from_slice(&1.0_f32.to_le_bytes());
230        }
231        other => panic!("jacfwd: dtype {other:?} not supported (f64 / f32 only today)"),
232    }
233}
234
235fn clear_index(buf: &mut [u8], idx: usize, dtype: DType) {
236    let n = dtype.size_bytes();
237    let off = idx * n;
238    for b in &mut buf[off..off + n] {
239        *b = 0;
240    }
241}
242
243#[cfg(test)]
244#[cfg(feature = "cpu")]
245mod tests {
246
247    use rlx_ir::{Graph, Shape};
248    use rlx_opt::autodiff_fwd::jvp;
249
250    fn f64_bytes(xs: &[f64]) -> Vec<u8> {
251        let mut out = Vec::with_capacity(xs.len() * 8);
252        for x in xs {
253            out.extend_from_slice(&x.to_le_bytes());
254        }
255        out
256    }
257
258    /// `f(b) = 3·b` ⇒ `df/db = diag(3)` — smallest possible jacfwd
259    /// shape check that doesn't depend on any other AD machinery.
260    /// Builds a graph that scales `b` by a constant via `Mul`, runs
261    /// `jvp`, then `jacfwd`, and asserts the result is a diagonal of 3s.
262    #[test]
263    fn jacfwd_scalar_mul_gives_diagonal() {
264        use rlx_ir::DType;
265        use rlx_ir::op::BinaryOp;
266        let n = 4usize;
267
268        let mut g = Graph::new("scale");
269        let b = g.input("b", Shape::new(&[n], DType::F64));
270        // Scale constant: a 1-D tensor of 3s.
271        let three_bytes = f64_bytes(&vec![3.0; n]);
272        let three = g.add_node(
273            rlx_ir::Op::Constant { data: three_bytes },
274            vec![],
275            Shape::new(&[n], DType::F64),
276        );
277        let y = g.binary(BinaryOp::Mul, b, three, Shape::new(&[n], DType::F64));
278        g.set_outputs(vec![y]);
279
280        let jg = jvp(&g, &[b]);
281        let mut compiled = crate::Session::new(crate::Device::Cpu).compile(jg);
282
283        let b_data = vec![10.0_f64; n];
284        let jacs = super::jacfwd(
285            &mut compiled,
286            &[("b", &f64_bytes(&b_data), DType::F64)],
287            "b",
288            &[n],
289            DType::F64,
290        );
291        assert_eq!(jacs.len(), 1);
292        let jac = &jacs[0];
293        assert_eq!(jac.output_size, n);
294        assert_eq!(jac.wrt_size, n);
295        let m = jac.as_f64();
296        for i in 0..n {
297            for j in 0..n {
298                let want = if i == j { 3.0 } else { 0.0 };
299                assert!(
300                    (m[i * n + j] - want).abs() < 1e-12,
301                    "jac[{i},{j}] = {} (expected {want})",
302                    m[i * n + j]
303                );
304            }
305        }
306    }
307}