use crate::compiled::CompiledGraph;
use rlx_ir::DType;
#[derive(Debug, Clone)]
pub struct JacobianBytes {
pub bytes: Vec<u8>,
pub output_size: usize,
pub wrt_size: usize,
pub dtype: DType,
}
impl JacobianBytes {
pub fn as_f64(&self) -> &[f64] {
assert_eq!(
self.dtype,
DType::F64,
"as_f64: dtype is {:?}, not F64",
self.dtype
);
assert_eq!(
self.bytes.len(),
self.output_size * self.wrt_size * 8,
"as_f64: byte length doesn't match shape"
);
unsafe {
std::slice::from_raw_parts(self.bytes.as_ptr() as *const f64, self.bytes.len() / 8)
}
}
pub fn as_f32(&self) -> &[f32] {
assert_eq!(
self.dtype,
DType::F32,
"as_f32: dtype is {:?}, not F32",
self.dtype
);
assert_eq!(self.bytes.len(), self.output_size * self.wrt_size * 4);
unsafe {
std::slice::from_raw_parts(self.bytes.as_ptr() as *const f32, self.bytes.len() / 4)
}
}
}
pub fn jacfwd(
compiled: &mut CompiledGraph,
primals: &[(&str, &[u8], DType)],
wrt_name: &str,
wrt_shape: &[usize],
dtype: DType,
) -> Vec<JacobianBytes> {
let elem_size = dtype.size_bytes();
let wrt_size: usize = wrt_shape.iter().product();
if wrt_size == 0 {
return Vec::new();
}
let tangent_name = format!("tangent_{wrt_name}");
let mut tangent_buf = vec![0u8; wrt_size * elem_size];
set_unit(&mut tangent_buf, 0, dtype);
let first = run_one(compiled, primals, &tangent_name, &tangent_buf, dtype);
assert!(
first.len().is_multiple_of(2),
"jacfwd: JVP graph must have even output count [primals..., tangents...], got {}",
first.len()
);
let n_outs = first.len() / 2;
let mut jacs: Vec<JacobianBytes> = (0..n_outs)
.map(|i| {
let (bytes, dt) = &first[n_outs + i];
debug_assert_eq!(
*dt, dtype,
"jacfwd: tangent output {} has dtype {:?}, expected {:?}",
i, dt, dtype
);
let output_size = bytes.len() / elem_size;
JacobianBytes {
bytes: vec![0u8; output_size * wrt_size * elem_size],
output_size,
wrt_size,
dtype,
}
})
.collect();
write_column(&first[n_outs..], &mut jacs, 0, elem_size);
for j in 1..wrt_size {
clear_index(&mut tangent_buf, j - 1, dtype);
set_unit(&mut tangent_buf, j, dtype);
let outs = run_one(compiled, primals, &tangent_name, &tangent_buf, dtype);
write_column(&outs[n_outs..], &mut jacs, j, elem_size);
}
jacs
}
fn run_one(
compiled: &mut CompiledGraph,
primals: &[(&str, &[u8], DType)],
tangent_name: &str,
tangent_bytes: &[u8],
dtype: DType,
) -> Vec<(Vec<u8>, DType)> {
let mut all = primals.to_vec();
all.push((tangent_name, tangent_bytes, dtype));
compiled.run_typed(&all)
}
fn write_column(
tangent_outputs: &[(Vec<u8>, DType)],
jacs: &mut [JacobianBytes],
j: usize,
elem_size: usize,
) {
debug_assert_eq!(tangent_outputs.len(), jacs.len());
for (out_idx, (bytes, _)) in tangent_outputs.iter().enumerate() {
let jac = &mut jacs[out_idx];
debug_assert_eq!(
bytes.len(),
jac.output_size * elem_size,
"tangent output size changed mid-jacfwd run"
);
for i in 0..jac.output_size {
let dst_off = (i * jac.wrt_size + j) * elem_size;
let src_off = i * elem_size;
jac.bytes[dst_off..dst_off + elem_size]
.copy_from_slice(&bytes[src_off..src_off + elem_size]);
}
}
}
fn set_unit(buf: &mut [u8], idx: usize, dtype: DType) {
match dtype {
DType::F64 => {
let off = idx * 8;
buf[off..off + 8].copy_from_slice(&1.0_f64.to_le_bytes());
}
DType::F32 => {
let off = idx * 4;
buf[off..off + 4].copy_from_slice(&1.0_f32.to_le_bytes());
}
other => panic!("jacfwd: dtype {other:?} not supported (f64 / f32 only today)"),
}
}
fn clear_index(buf: &mut [u8], idx: usize, dtype: DType) {
let n = dtype.size_bytes();
let off = idx * n;
for b in &mut buf[off..off + n] {
*b = 0;
}
}
#[cfg(test)]
#[cfg(feature = "cpu")]
mod tests {
use rlx_ir::{Graph, Shape};
use rlx_opt::autodiff_fwd::jvp;
fn f64_bytes(xs: &[f64]) -> Vec<u8> {
let mut out = Vec::with_capacity(xs.len() * 8);
for x in xs {
out.extend_from_slice(&x.to_le_bytes());
}
out
}
#[test]
fn jacfwd_scalar_mul_gives_diagonal() {
use rlx_ir::DType;
use rlx_ir::op::BinaryOp;
let n = 4usize;
let mut g = Graph::new("scale");
let b = g.input("b", Shape::new(&[n], DType::F64));
let three_bytes = f64_bytes(&vec![3.0; n]);
let three = g.add_node(
rlx_ir::Op::Constant { data: three_bytes },
vec![],
Shape::new(&[n], DType::F64),
);
let y = g.binary(BinaryOp::Mul, b, three, Shape::new(&[n], DType::F64));
g.set_outputs(vec![y]);
let jg = jvp(&g, &[b]);
let mut compiled = crate::Session::new(crate::Device::Cpu).compile(jg);
let b_data = vec![10.0_f64; n];
let jacs = super::jacfwd(
&mut compiled,
&[("b", &f64_bytes(&b_data), DType::F64)],
"b",
&[n],
DType::F64,
);
assert_eq!(jacs.len(), 1);
let jac = &jacs[0];
assert_eq!(jac.output_size, n);
assert_eq!(jac.wrt_size, n);
let m = jac.as_f64();
for i in 0..n {
for j in 0..n {
let want = if i == j { 3.0 } else { 0.0 };
assert!(
(m[i * n + j] - want).abs() < 1e-12,
"jac[{i},{j}] = {} (expected {want})",
m[i * n + j]
);
}
}
}
}