#![cfg(feature = "cpu")]
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use rlx_opt::autodiff::grad_with_loss;
use rlx_runtime::{Device, Session};
#[derive(Clone, Copy, Debug)]
struct Complex {
real: NodeId,
imag: NodeId,
}
#[allow(dead_code)]
impl Complex {
fn new(real: NodeId, imag: NodeId) -> Self {
Self { real, imag }
}
fn add(self, other: Self, g: &mut Graph) -> Self {
Self {
real: g.add(self.real, other.real),
imag: g.add(self.imag, other.imag),
}
}
fn sub(self, other: Self, g: &mut Graph) -> Self {
Self {
real: g.sub(self.real, other.real),
imag: g.sub(self.imag, other.imag),
}
}
fn neg(self, g: &mut Graph) -> Self {
Self {
real: g.neg(self.real),
imag: g.neg(self.imag),
}
}
fn mul(self, other: Self, g: &mut Graph) -> Self {
let ac = g.mul(self.real, other.real);
let bd = g.mul(self.imag, other.imag);
let ad = g.mul(self.real, other.imag);
let bc = g.mul(self.imag, other.real);
Self {
real: g.sub(ac, bd),
imag: g.add(ad, bc),
}
}
fn conj(self, g: &mut Graph) -> Self {
Self {
real: self.real,
imag: g.neg(self.imag),
}
}
fn abs_sq(self, g: &mut Graph) -> NodeId {
let aa = g.mul(self.real, self.real);
let bb = g.mul(self.imag, self.imag);
g.add(aa, bb)
}
fn abs(self, g: &mut Graph) -> NodeId {
let asq = self.abs_sq(g);
g.sqrt(asq)
}
fn scale_real(self, scalar: NodeId, g: &mut Graph) -> Self {
Self {
real: g.mul(self.real, scalar),
imag: g.mul(self.imag, scalar),
}
}
fn wirtinger_grad_from_real_grads(g_x: NodeId, g_y: NodeId) -> Self {
Self {
real: g_x,
imag: g_y,
}
}
}
fn f32s_to_bytes(xs: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(xs.len() * 4);
for x in xs {
out.extend_from_slice(&x.to_le_bytes());
}
out
}
fn bytes_to_f32s(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect()
}
#[test]
fn complex_mul_matches_textbook_formula_through_cpu_pipeline() {
let mut g = Graph::new("c_mul");
let n = 3usize;
let ar = g.input("a_re", Shape::new(&[n], DType::F32));
let ai = g.input("a_im", Shape::new(&[n], DType::F32));
let br = g.input("b_re", Shape::new(&[n], DType::F32));
let bi = g.input("b_im", Shape::new(&[n], DType::F32));
let a = Complex::new(ar, ai);
let b = Complex::new(br, bi);
let y = a.mul(b, &mut g);
g.set_outputs(vec![y.real, y.imag]);
let mut compiled = Session::new(Device::Cpu).compile(g);
let a_re = [1.0_f32, 0.0, -2.0];
let a_im = [2.0_f32, 1.0, 3.0];
let b_re = [3.0_f32, 4.0, 1.0];
let b_im = [4.0_f32, -2.0, -1.0];
let outs = compiled.run_typed(&[
("a_re", &f32s_to_bytes(&a_re), DType::F32),
("a_im", &f32s_to_bytes(&a_im), DType::F32),
("b_re", &f32s_to_bytes(&b_re), DType::F32),
("b_im", &f32s_to_bytes(&b_im), DType::F32),
]);
assert_eq!(outs.len(), 2);
let y_re = bytes_to_f32s(&outs[0].0);
let y_im = bytes_to_f32s(&outs[1].0);
for i in 0..n {
let exp_re = a_re[i] * b_re[i] - a_im[i] * b_im[i];
let exp_im = a_re[i] * b_im[i] + a_im[i] * b_re[i];
assert!(
(y_re[i] - exp_re).abs() < 1e-5,
"real[{i}]: {} vs {exp_re}",
y_re[i]
);
assert!(
(y_im[i] - exp_im).abs() < 1e-5,
"imag[{i}]: {} vs {exp_im}",
y_im[i]
);
}
}
#[test]
fn complex_abs_sq_returns_real_norm_squared() {
let mut g = Graph::new("c_abs_sq");
let n = 4usize;
let ar = g.input("a_re", Shape::new(&[n], DType::F32));
let ai = g.input("a_im", Shape::new(&[n], DType::F32));
let z = Complex::new(ar, ai);
let m = z.abs_sq(&mut g);
g.set_outputs(vec![m]);
let mut compiled = Session::new(Device::Cpu).compile(g);
let a_re = [3.0_f32, 0.0, -1.0, 2.0];
let a_im = [4.0_f32, 5.0, 1.0, -2.0];
let outs = compiled.run_typed(&[
("a_re", &f32s_to_bytes(&a_re), DType::F32),
("a_im", &f32s_to_bytes(&a_im), DType::F32),
]);
let m_got = bytes_to_f32s(&outs[0].0);
let exp = [25.0_f32, 25.0, 2.0, 8.0];
for i in 0..n {
assert!(
(m_got[i] - exp[i]).abs() < 1e-5,
"|z|²[{i}]: {} vs {}",
m_got[i],
exp[i]
);
}
}
#[test]
fn wirtinger_grad_descent_step_decreases_real_loss() {
let z0 = (1.0_f32, 0.0_f32);
let z_star = (3.0_f32, -2.0_f32);
let mut g = Graph::new("c_wirtinger");
let zr = g.input("z_re", Shape::new(&[1], DType::F32));
let zi = g.input("z_im", Shape::new(&[1], DType::F32));
let tr = g.input("t_re", Shape::new(&[1], DType::F32));
let ti = g.input("t_im", Shape::new(&[1], DType::F32));
let z = Complex::new(zr, zi);
let t = Complex::new(tr, ti);
let diff = z.sub(t, &mut g);
let m = diff.abs_sq(&mut g);
let loss = g.sum(m, vec![0], false);
g.set_outputs(vec![loss]);
let mut compiled = Session::new(Device::Cpu).compile(g.clone());
let outs = compiled.run_typed(&[
("z_re", &f32s_to_bytes(&[z0.0]), DType::F32),
("z_im", &f32s_to_bytes(&[z0.1]), DType::F32),
("t_re", &f32s_to_bytes(&[z_star.0]), DType::F32),
("t_im", &f32s_to_bytes(&[z_star.1]), DType::F32),
]);
let loss_before = bytes_to_f32s(&outs[0].0)[0];
let bwd = grad_with_loss(&g, &[zr, zi]);
assert_eq!(bwd.outputs.len(), 3, "[loss, dL/dzr, dL/dzi]");
let mut compiled_bwd = Session::new(Device::Cpu).compile(bwd);
let outs = compiled_bwd.run_typed(&[
("z_re", &f32s_to_bytes(&[z0.0]), DType::F32),
("z_im", &f32s_to_bytes(&[z0.1]), DType::F32),
("t_re", &f32s_to_bytes(&[z_star.0]), DType::F32),
("t_im", &f32s_to_bytes(&[z_star.1]), DType::F32),
("d_output", &f32s_to_bytes(&[1.0]), DType::F32),
]);
let g_x = bytes_to_f32s(&outs[1].0)[0];
let g_y = bytes_to_f32s(&outs[2].0)[0];
let lr = 0.5;
let z1 = (z0.0 - lr * g_x, z0.1 - lr * g_y);
let outs = compiled.run_typed(&[
("z_re", &f32s_to_bytes(&[z1.0]), DType::F32),
("z_im", &f32s_to_bytes(&[z1.1]), DType::F32),
("t_re", &f32s_to_bytes(&[z_star.0]), DType::F32),
("t_im", &f32s_to_bytes(&[z_star.1]), DType::F32),
]);
let loss_after = bytes_to_f32s(&outs[0].0)[0];
assert!(
loss_after < loss_before,
"Wirtinger SGD should reduce loss: before={loss_before}, after={loss_after}"
);
assert!(
loss_after < 1e-4,
"L = |z - z*|² with lr = 0.5 should land at z*: got loss {loss_after}"
);
let mut g2 = Graph::new("wirtinger_helper_check");
let gx_n = g2.input("g_x", Shape::new(&[1], DType::F32));
let gy_n = g2.input("g_y", Shape::new(&[1], DType::F32));
let grad = Complex::wirtinger_grad_from_real_grads(gx_n, gy_n);
assert_eq!(
grad.real, gx_n,
"Wirtinger grad real == real autodiff dL/dx"
);
assert_eq!(
grad.imag, gy_n,
"Wirtinger grad imag == real autodiff dL/dy (no sign flip)"
);
}