use rustc_hash::FxHashMap;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use super::error::AutodiffError;
use super::rules::{binop_adjoints, fma_adjoints, unop_adjoint};
pub fn grad(
program: &Program,
outputs: &[&str],
inputs: &[&str],
) -> Result<Program, AutodiffError> {
let buf_names: Vec<String> = program
.buffers()
.iter()
.map(|b| b.name().to_string())
.collect();
for out in outputs {
if !buf_names.iter().any(|b| b == out) {
return Err(AutodiffError::BufferNotFound {
name: (*out).to_string(),
});
}
}
for inp in inputs {
if !buf_names.iter().any(|b| b == inp) {
return Err(AutodiffError::BufferNotFound {
name: (*inp).to_string(),
});
}
}
let mut back_buffers: Vec<BufferDecl> = Vec::new();
let mut next_binding = 0u32;
for fwd_buf in program.buffers() {
back_buffers.push(
BufferDecl::storage(
fwd_buf.name(),
next_binding,
BufferAccess::ReadOnly,
fwd_buf.element(),
)
.with_count(fwd_buf.count()),
);
next_binding += 1;
}
let output_set: Vec<String> = outputs.iter().map(|s| s.to_string()).collect();
let mut grad_buf_binding: FxHashMap<String, u32> = FxHashMap::default();
for out_name in &output_set {
let grad_name = format!("grad_{out_name}");
let Some(fwd_buf) = program
.buffers()
.iter()
.find(|b| b.name() == out_name.as_str())
else {
continue;
};
back_buffers.push(
BufferDecl::storage(
&grad_name,
next_binding,
BufferAccess::ReadWrite,
DataType::F32,
)
.with_count(fwd_buf.count()),
);
grad_buf_binding.insert(grad_name, next_binding);
next_binding += 1;
}
let input_set: Vec<String> = inputs.iter().map(|s| s.to_string()).collect();
for inp_name in &input_set {
let grad_name = format!("grad_{inp_name}");
if grad_buf_binding.contains_key(&grad_name) {
continue;
}
let Some(fwd_buf) = program
.buffers()
.iter()
.find(|b| b.name() == inp_name.as_str())
else {
continue;
};
back_buffers.push(
BufferDecl::storage(
&grad_name,
next_binding,
BufferAccess::ReadWrite,
DataType::F32,
)
.with_count(fwd_buf.count()),
);
grad_buf_binding.insert(grad_name, next_binding);
next_binding += 1;
}
let mut body: Vec<Node> = Vec::new();
let i_expr = Expr::InvocationId { axis: 0 };
for out_name in &output_set {
let grad_name = format!("grad_{out_name}");
body.push(Node::Store {
buffer: grad_name.into(),
index: i_expr.clone(),
value: Expr::f32(1.0),
});
}
let forward_nodes = program.entry();
let mut adjoint_env: AdjointEnv = AdjointEnv::new(&input_set);
let reversed: Vec<&Node> = forward_nodes.iter().rev().collect();
for node in reversed {
emit_adjoint_node(node, &mut body, &mut adjoint_env, &output_set)?;
}
for inp_name in &input_set {
let grad_name = format!("grad_{inp_name}");
if let Some(accum_var) = adjoint_env.get_accumulator(inp_name) {
body.push(Node::Store {
buffer: grad_name.into(),
index: i_expr.clone(),
value: Expr::Var(accum_var.into()),
});
}
}
Ok(Program::wrapped(
back_buffers,
program.workgroup_size(),
body,
))
}
struct AdjointEnv {
var_adjoints: FxHashMap<String, String>,
fresh_counter: u32,
input_buffers: Vec<String>,
}
impl AdjointEnv {
fn new(inputs: &[String]) -> Self {
Self {
var_adjoints: FxHashMap::default(),
fresh_counter: 0,
input_buffers: inputs.to_vec(),
}
}
fn ensure_adjoint_var(&mut self, var_name: &str) -> String {
if let Some(existing) = self.var_adjoints.get(var_name) {
return existing.clone();
}
let adj_name = format!("_adj_{var_name}_{}", self.fresh_counter);
self.fresh_counter += 1;
self.var_adjoints
.insert(var_name.to_string(), adj_name.clone());
adj_name
}
fn get_accumulator(&self, buf_name: &str) -> Option<String> {
self.var_adjoints.get(buf_name).cloned()
}
fn is_tracked_input(&self, buf_name: &str) -> bool {
self.input_buffers.iter().any(|b| b == buf_name)
}
}
fn emit_adjoint_node(
node: &Node,
body: &mut Vec<Node>,
env: &mut AdjointEnv,
output_set: &[String],
) -> Result<(), AutodiffError> {
match node {
Node::Let { name, value } => {
let var_name = name.as_str();
let adj_var = env.ensure_adjoint_var(var_name);
body.push(Node::Let {
name: adj_var.clone().into(),
value: Expr::f32(0.0),
});
emit_adjoint_expr(value, &Expr::Var(adj_var.into()), body, env)?;
}
Node::Store {
buffer,
index,
value,
} => {
let buf_name = buffer.as_str();
let grad_buf = format!("grad_{buf_name}");
let adj_expr =
if output_set.iter().any(|o| o == buf_name) || env.is_tracked_input(buf_name) {
Expr::Load {
buffer: grad_buf.into(),
index: Box::new(index.clone()),
}
} else {
Expr::f32(0.0)
};
emit_adjoint_expr(value, &adj_expr, body, env)?;
}
Node::Assign { name, value } => {
let adj_var = env.ensure_adjoint_var(name.as_str());
emit_adjoint_expr(value, &Expr::Var(adj_var.into()), body, env)?;
}
Node::If {
cond,
then,
otherwise,
} => {
let mut then_body = Vec::new();
for n in then.iter().rev() {
emit_adjoint_node(n, &mut then_body, env, output_set)?;
}
let mut else_body = Vec::new();
for n in otherwise.iter().rev() {
emit_adjoint_node(n, &mut else_body, env, output_set)?;
}
body.push(Node::If {
cond: cond.clone(),
then: then_body,
otherwise: else_body,
});
}
Node::Loop {
var,
from,
to,
body: loop_body,
} => {
let mut adj_body = Vec::new();
for n in loop_body.iter().rev() {
emit_adjoint_node(n, &mut adj_body, env, output_set)?;
}
body.push(Node::Loop {
var: var.clone(),
from: from.clone(),
to: to.clone(),
body: adj_body,
});
}
Node::Barrier { ordering } => {
body.push(Node::barrier_with_ordering(*ordering));
}
Node::Block(nodes) => {
for n in nodes.iter().rev() {
emit_adjoint_node(n, body, env, output_set)?;
}
}
Node::Region {
generator,
source_region,
body: region_body,
} => {
let mut adj_region_body = Vec::new();
for n in region_body.iter().rev() {
emit_adjoint_node(n, &mut adj_region_body, env, output_set)?;
}
body.push(Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: std::sync::Arc::new(adj_region_body),
});
}
Node::Return
| Node::IndirectDispatch { .. }
| Node::AsyncLoad { .. }
| Node::AsyncStore { .. }
| Node::AsyncWait { .. }
| Node::Trap { .. }
| Node::Resume { .. } => {
return Err(AutodiffError::UnsupportedNode {
kind: format!("{node:?}").chars().take(60).collect(),
});
}
Node::Opaque(_) => {
return Err(AutodiffError::UnsupportedNode {
kind: "Node::Opaque".to_string(),
});
}
}
Ok(())
}
fn emit_adjoint_expr(
expr: &Expr,
adjoint: &Expr,
body: &mut Vec<Node>,
env: &mut AdjointEnv,
) -> Result<(), AutodiffError> {
match expr {
Expr::Var(name) => {
let adj_var = env.ensure_adjoint_var(name.as_str());
body.push(Node::Assign {
name: adj_var.clone().into(),
value: Expr::add(Expr::Var(adj_var.into()), adjoint.clone()),
});
}
Expr::Load { buffer, index } => {
let buf_name = buffer.as_str();
if env.is_tracked_input(buf_name) {
let grad_buf = format!("grad_{buf_name}");
body.push(Node::Store {
buffer: grad_buf.into(),
index: *index.clone(),
value: Expr::add(
Expr::Load {
buffer: format!("grad_{buf_name}").into(),
index: index.clone(),
},
adjoint.clone(),
),
});
}
}
Expr::LitF32(_) | Expr::LitU32(_) | Expr::LitI32(_) | Expr::LitBool(_) => {}
Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::BufLen { .. } => {}
Expr::BinOp { op, left, right } => {
let contribs = binop_adjoints(*op, left, right, adjoint)?;
for contrib in contribs {
emit_adjoint_expr(&contrib.child, &contrib.adjoint, body, env)?;
}
}
Expr::UnOp { op, operand } => {
let contrib = unop_adjoint(op, operand, adjoint)?;
emit_adjoint_expr(&contrib.child, &contrib.adjoint, body, env)?;
}
Expr::Select {
cond,
true_val,
false_val,
} => {
let true_adj = Expr::Select {
cond: cond.clone(),
true_val: Box::new(adjoint.clone()),
false_val: Box::new(Expr::f32(0.0)),
};
let false_adj = Expr::Select {
cond: cond.clone(),
true_val: Box::new(Expr::f32(0.0)),
false_val: Box::new(adjoint.clone()),
};
emit_adjoint_expr(true_val, &true_adj, body, env)?;
emit_adjoint_expr(false_val, &false_adj, body, env)?;
}
Expr::Fma { a, b, c } => {
let contribs = fma_adjoints(a, b, c, adjoint);
for contrib in contribs {
emit_adjoint_expr(&contrib.child, &contrib.adjoint, body, env)?;
}
}
Expr::Cast { value, .. } => {
emit_adjoint_expr(value, adjoint, body, env)?;
}
Expr::Call { op_id, .. } => {
return Err(AutodiffError::NotDifferentiable {
op: format!("Expr::Call({op_id})"),
fix:
"inline the call before running autodiff, or register a derivative for this op"
.into(),
});
}
Expr::Atomic { .. } => {
return Err(AutodiffError::NotDifferentiable {
op: "Expr::Atomic".into(),
fix: "atomics are not differentiable; restructure to use non-atomic accumulation"
.into(),
});
}
Expr::SubgroupBallot { .. } | Expr::SubgroupShuffle { .. } | Expr::SubgroupAdd { .. } => {
return Err(AutodiffError::NotDifferentiable {
op: format!("{expr:?}").chars().take(40).collect(),
fix: "subgroup ops are not differentiable in the general case".into(),
});
}
Expr::Opaque(_) => {
return Err(AutodiffError::NotDifferentiable {
op: "Expr::Opaque".into(),
fix: "register a derivative rule for this opaque expression".into(),
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::BinOp;
#[test]
fn grad_simple_square() {
let program = Program::wrapped(
vec![
BufferDecl::storage("x", 0, BufferAccess::ReadOnly, DataType::F32).with_count(4),
BufferDecl::output("out", 1, DataType::F32).with_count(4),
],
[64, 1, 1],
vec![Node::Store {
buffer: "out".into(),
index: Expr::InvocationId { axis: 0 },
value: Expr::mul(
Expr::Load {
buffer: "x".into(),
index: Box::new(Expr::InvocationId { axis: 0 }),
},
Expr::Load {
buffer: "x".into(),
index: Box::new(Expr::InvocationId { axis: 0 }),
},
),
}],
);
let result = grad(&program, &["out"], &["x"]);
assert!(result.is_ok(), "grad should succeed: {:?}", result.err());
let backward = result.unwrap();
let buf_names: Vec<&str> = backward.buffers().iter().map(|b| b.name()).collect();
assert!(
buf_names.contains(&"grad_out"),
"should have grad_out buffer"
);
assert!(buf_names.contains(&"grad_x"), "should have grad_x buffer");
}
#[test]
fn grad_bitwise_errors() {
let program = Program::wrapped(
vec![
BufferDecl::storage("x", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1),
BufferDecl::output("out", 1, DataType::U32).with_count(1),
],
[1, 1, 1],
vec![Node::Store {
buffer: "out".into(),
index: Expr::u32(0),
value: Expr::BinOp {
op: BinOp::BitAnd,
left: Box::new(Expr::Load {
buffer: "x".into(),
index: Box::new(Expr::u32(0)),
}),
right: Box::new(Expr::u32(0xFF)),
},
}],
);
let result = grad(&program, &["out"], &["x"]);
assert!(result.is_err());
match result.unwrap_err() {
AutodiffError::NotDifferentiable { op, .. } => {
assert!(op.contains("BitAnd"));
}
e => panic!("expected NotDifferentiable, got: {e}"),
}
}
#[test]
fn grad_missing_buffer() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::F32).with_count(1)],
[1, 1, 1],
vec![],
);
let result = grad(&program, &["nonexistent"], &[]);
assert!(matches!(result, Err(AutodiffError::BufferNotFound { .. })));
}
#[test]
fn grad_exp() {
let program = Program::wrapped(
vec![
BufferDecl::storage("x", 0, BufferAccess::ReadOnly, DataType::F32).with_count(1),
BufferDecl::output("out", 1, DataType::F32).with_count(1),
],
[1, 1, 1],
vec![Node::Store {
buffer: "out".into(),
index: Expr::u32(0),
value: Expr::UnOp {
op: crate::ir::UnOp::Exp,
operand: Box::new(Expr::Load {
buffer: "x".into(),
index: Box::new(Expr::u32(0)),
}),
},
}],
);
let result = grad(&program, &["out"], &["x"]);
assert!(
result.is_ok(),
"exp should be differentiable: {:?}",
result.err()
);
}
}