use crate::ir::{BinOp, Expr, UnOp};
use super::AutodiffError;
#[derive(Debug, Clone)]
pub struct AdjointContrib {
pub child: Expr,
pub adjoint: Expr,
}
pub fn binop_adjoints(
op: BinOp,
left: &Expr,
right: &Expr,
adjoint: &Expr,
) -> Result<Vec<AdjointContrib>, AutodiffError> {
match op {
BinOp::Add => Ok(vec![
AdjointContrib {
child: left.clone(),
adjoint: adjoint.clone(),
},
AdjointContrib {
child: right.clone(),
adjoint: adjoint.clone(),
},
]),
BinOp::Sub => Ok(vec![
AdjointContrib {
child: left.clone(),
adjoint: adjoint.clone(),
},
AdjointContrib {
child: right.clone(),
adjoint: Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(adjoint.clone()),
},
},
]),
BinOp::Mul => Ok(vec![
AdjointContrib {
child: left.clone(),
adjoint: Expr::mul(adjoint.clone(), right.clone()),
},
AdjointContrib {
child: right.clone(),
adjoint: Expr::mul(adjoint.clone(), left.clone()),
},
]),
BinOp::Div => Ok(vec![
AdjointContrib {
child: left.clone(),
adjoint: Expr::div(adjoint.clone(), right.clone()),
},
AdjointContrib {
child: right.clone(),
adjoint: Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(Expr::div(
Expr::mul(adjoint.clone(), left.clone()),
Expr::mul(right.clone(), right.clone()),
)),
},
},
]),
BinOp::Min => Ok(vec![
AdjointContrib {
child: left.clone(),
adjoint: Expr::Select {
cond: Box::new(Expr::lt(left.clone(), right.clone())),
true_val: Box::new(adjoint.clone()),
false_val: Box::new(Expr::f32(0.0)),
},
},
AdjointContrib {
child: right.clone(),
adjoint: Expr::Select {
cond: Box::new(Expr::le(left.clone(), right.clone())),
true_val: Box::new(Expr::f32(0.0)),
false_val: Box::new(adjoint.clone()),
},
},
]),
BinOp::Max => Ok(vec![
AdjointContrib {
child: left.clone(),
adjoint: Expr::Select {
cond: Box::new(Expr::gt(left.clone(), right.clone())),
true_val: Box::new(adjoint.clone()),
false_val: Box::new(Expr::f32(0.0)),
},
},
AdjointContrib {
child: right.clone(),
adjoint: Expr::Select {
cond: Box::new(Expr::gt(left.clone(), right.clone())),
true_val: Box::new(Expr::f32(0.0)),
false_val: Box::new(adjoint.clone()),
},
},
]),
BinOp::Mod
| BinOp::BitAnd
| BinOp::BitOr
| BinOp::BitXor
| BinOp::Shl
| BinOp::Shr
| BinOp::WrappingAdd
| BinOp::WrappingSub
| BinOp::Eq
| BinOp::Ne
| BinOp::Lt
| BinOp::Gt
| BinOp::Le
| BinOp::Ge
| BinOp::And
| BinOp::Or
| BinOp::AbsDiff
| BinOp::SaturatingAdd
| BinOp::SaturatingSub
| BinOp::SaturatingMul
| BinOp::Shuffle
| BinOp::Ballot
| BinOp::WaveReduce
| BinOp::WaveBroadcast
| BinOp::RotateLeft
| BinOp::RotateRight
| BinOp::Opaque(_)
| _ => Err(AutodiffError::NotDifferentiable {
op: format!("BinOp::{op:?}"),
fix: "replace with a differentiable equivalent or gate behind a stop-gradient barrier"
.into(),
}),
}
}
pub fn unop_adjoint(
op: &UnOp,
operand: &Expr,
adjoint: &Expr,
) -> Result<AdjointContrib, AutodiffError> {
let dx = match op {
UnOp::Negate => Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(adjoint.clone()),
},
UnOp::Exp => Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Exp,
operand: Box::new(operand.clone()),
},
),
UnOp::Log => Expr::div(adjoint.clone(), operand.clone()),
UnOp::Sqrt => Expr::div(
adjoint.clone(),
Expr::mul(
Expr::f32(2.0),
Expr::UnOp {
op: UnOp::Sqrt,
operand: Box::new(operand.clone()),
},
),
),
UnOp::Tanh => {
let t = Expr::UnOp {
op: UnOp::Tanh,
operand: Box::new(operand.clone()),
};
Expr::mul(
adjoint.clone(),
Expr::sub(Expr::f32(1.0), Expr::mul(t.clone(), t)),
)
}
UnOp::Sin => Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Cos,
operand: Box::new(operand.clone()),
},
),
UnOp::Cos => Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Sin,
operand: Box::new(operand.clone()),
},
)),
},
UnOp::Abs => Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Sign,
operand: Box::new(operand.clone()),
},
),
UnOp::Exp2 => Expr::mul(
Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Exp2,
operand: Box::new(operand.clone()),
},
),
Expr::f32(core::f32::consts::LN_2),
),
UnOp::Log2 => Expr::div(
adjoint.clone(),
Expr::mul(operand.clone(), Expr::f32(core::f32::consts::LN_2)),
),
UnOp::Tan => {
let t = Expr::UnOp {
op: UnOp::Tan,
operand: Box::new(operand.clone()),
};
Expr::mul(
adjoint.clone(),
Expr::add(Expr::f32(1.0), Expr::mul(t.clone(), t)),
)
}
UnOp::Sinh => Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Cosh,
operand: Box::new(operand.clone()),
},
),
UnOp::Cosh => Expr::mul(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Sinh,
operand: Box::new(operand.clone()),
},
),
UnOp::Asin => Expr::div(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Sqrt,
operand: Box::new(Expr::sub(
Expr::f32(1.0),
Expr::mul(operand.clone(), operand.clone()),
)),
},
),
UnOp::Acos => Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(Expr::div(
adjoint.clone(),
Expr::UnOp {
op: UnOp::Sqrt,
operand: Box::new(Expr::sub(
Expr::f32(1.0),
Expr::mul(operand.clone(), operand.clone()),
)),
},
)),
},
UnOp::Atan => Expr::div(
adjoint.clone(),
Expr::add(Expr::f32(1.0), Expr::mul(operand.clone(), operand.clone())),
),
UnOp::InverseSqrt => {
let sqrt_x = Expr::UnOp {
op: UnOp::Sqrt,
operand: Box::new(operand.clone()),
};
Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(Expr::div(
adjoint.clone(),
Expr::mul(Expr::f32(2.0), Expr::mul(operand.clone(), sqrt_x)),
)),
}
}
UnOp::Reciprocal => Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(Expr::div(
adjoint.clone(),
Expr::mul(operand.clone(), operand.clone()),
)),
},
UnOp::BitNot
| UnOp::LogicalNot
| UnOp::Popcount
| UnOp::Clz
| UnOp::Ctz
| UnOp::ReverseBits
| UnOp::Floor
| UnOp::Ceil
| UnOp::Round
| UnOp::Trunc
| UnOp::Sign
| UnOp::IsNan
| UnOp::IsInf
| UnOp::IsFinite
| UnOp::Unpack4Low
| UnOp::Unpack4High
| UnOp::Unpack8Low
| UnOp::Unpack8High
| UnOp::Opaque(_)
| _ => {
return Err(AutodiffError::NotDifferentiable {
op: format!("UnOp::{op:?}"),
fix: "replace with a differentiable equivalent or gate behind a stop-gradient barrier".into(),
});
}
};
Ok(AdjointContrib {
child: operand.clone(),
adjoint: dx,
})
}
pub fn fma_adjoints(a: &Expr, b: &Expr, c: &Expr, adjoint: &Expr) -> Vec<AdjointContrib> {
vec![
AdjointContrib {
child: a.clone(),
adjoint: Expr::mul(adjoint.clone(), b.clone()),
},
AdjointContrib {
child: b.clone(),
adjoint: Expr::mul(adjoint.clone(), a.clone()),
},
AdjointContrib {
child: c.clone(),
adjoint: adjoint.clone(),
},
]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BinOp, Expr, UnOp};
#[test]
fn test_binop_add_adjoint() {
let l = Expr::var("l");
let r = Expr::var("r");
let adj = Expr::var("adj");
let contribs = binop_adjoints(BinOp::Add, &l, &r, &adj).unwrap();
assert_eq!(contribs.len(), 2);
assert_eq!(contribs[0].child, l);
assert_eq!(contribs[0].adjoint, adj);
assert_eq!(contribs[1].child, r);
assert_eq!(contribs[1].adjoint, adj);
}
#[test]
fn test_binop_mul_adjoint() {
let l = Expr::var("l");
let r = Expr::var("r");
let adj = Expr::var("adj");
let contribs = binop_adjoints(BinOp::Mul, &l, &r, &adj).unwrap();
assert_eq!(contribs.len(), 2);
assert_eq!(contribs[0].adjoint, Expr::mul(adj.clone(), r));
assert_eq!(contribs[1].adjoint, Expr::mul(adj, l));
}
#[test]
fn test_binop_not_differentiable() {
let l = Expr::var("l");
let r = Expr::var("r");
let adj = Expr::var("adj");
let result = binop_adjoints(BinOp::BitAnd, &l, &r, &adj);
assert!(result.is_err());
}
#[test]
fn test_unop_negate_adjoint() {
let op = Expr::var("op");
let adj = Expr::var("adj");
let contrib = unop_adjoint(&UnOp::Negate, &op, &adj).unwrap();
assert_eq!(contrib.child, op);
assert!(matches!(
contrib.adjoint,
Expr::UnOp {
op: UnOp::Negate,
..
}
));
}
#[test]
fn test_unop_exp_adjoint() {
let op = Expr::var("op");
let adj = Expr::var("adj");
let contrib = unop_adjoint(&UnOp::Exp, &op, &adj).unwrap();
assert!(matches!(
contrib.adjoint,
Expr::BinOp { op: BinOp::Mul, .. }
));
}
#[test]
fn test_unop_not_differentiable() {
let op = Expr::var("op");
let adj = Expr::var("adj");
let result = unop_adjoint(&UnOp::Floor, &op, &adj);
assert!(result.is_err());
}
#[test]
fn test_fma_adjoint() {
let a = Expr::var("a");
let b = Expr::var("b");
let c = Expr::var("c");
let adj = Expr::var("adj");
let contribs = fma_adjoints(&a, &b, &c, &adj);
assert_eq!(contribs.len(), 3);
assert_eq!(contribs[0].adjoint, Expr::mul(adj.clone(), b));
assert_eq!(contribs[1].adjoint, Expr::mul(adj.clone(), a));
assert_eq!(contribs[2].adjoint, adj);
}
}