Skip to main content

alkahest_cas/diff/
forward.rs

1/// Forward-mode automatic differentiation via dual numbers.
2///
3/// A dual number `DualValue { value: T, tangent: T }` tracks both the primal
4/// value and its derivative simultaneously.  Evaluating an expression with
5/// `DualValue<ExprId>` inputs — setting the tangent of the variable of
6/// differentiation to `1` and all others to `0` — propagates the derivative
7/// through every operation automatically.
8///
9/// The result agrees with the symbolic differentiator on all expressions
10/// whose derivative is defined; property tests cross-validate both.
11use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
12use crate::diff::diff_impl::DiffError;
13use crate::kernel::{ExprData, ExprId, ExprPool};
14use crate::simplify::engine::simplify;
15
16// ---------------------------------------------------------------------------
17// Deprecated type alias — `ForwardDiffError` is now folded into `DiffError`
18// (variants `ForwardUnknownFunction` / `ForwardNonIntegerExponent`, codes
19// E-DIFF-003 / E-DIFF-004).  This alias keeps old `ForwardDiffError` names
20// compiling; it will be removed in the next major version.
21// ---------------------------------------------------------------------------
22
23#[deprecated(
24    since = "2.0.0",
25    note = "use DiffError::ForwardUnknownFunction / ForwardNonIntegerExponent instead"
26)]
27pub type ForwardDiffError = DiffError;
28
29// ---------------------------------------------------------------------------
30// DualValue
31// ---------------------------------------------------------------------------
32
33/// A dual number carrying a primal `value` and a first-order `tangent`.
34///
35/// Arithmetic on `DualValue` follows the dual-number algebra:
36/// - `(a + ε·da) + (b + ε·db) = (a+b) + ε·(da+db)`
37/// - `(a + ε·da) * (b + ε·db) = a·b + ε·(a·db + b·da)`
38#[derive(Clone, Debug)]
39pub struct DualValue {
40    pub value: ExprId,
41    pub tangent: ExprId,
42}
43
44impl DualValue {
45    fn new(value: ExprId, tangent: ExprId) -> Self {
46        DualValue { value, tangent }
47    }
48
49    fn constant(value: ExprId, pool: &ExprPool) -> Self {
50        let zero = pool.integer(0_i32);
51        DualValue::new(value, zero)
52    }
53
54    fn seed(value: ExprId, pool: &ExprPool) -> Self {
55        let one = pool.integer(1_i32);
56        DualValue::new(value, one)
57    }
58
59    fn add(self, rhs: Self, pool: &ExprPool) -> Self {
60        let value = pool.add(vec![self.value, rhs.value]);
61        let tangent = pool.add(vec![self.tangent, rhs.tangent]);
62        DualValue::new(value, tangent)
63    }
64
65    fn mul(self, rhs: Self, pool: &ExprPool) -> Self {
66        // (a·db + b·da)
67        let value = pool.mul(vec![self.value, rhs.value]);
68        let term1 = pool.mul(vec![self.value, rhs.tangent]);
69        let term2 = pool.mul(vec![rhs.value, self.tangent]);
70        let tangent = pool.add(vec![term1, term2]);
71        DualValue::new(value, tangent)
72    }
73
74    #[allow(dead_code)]
75    fn neg(self, pool: &ExprPool) -> Self {
76        let neg_one = pool.integer(-1_i32);
77        let value = pool.mul(vec![neg_one, self.value]);
78        let tangent = pool.mul(vec![neg_one, self.tangent]);
79        DualValue::new(value, tangent)
80    }
81
82    #[allow(dead_code)]
83    fn sub(self, rhs: Self, pool: &ExprPool) -> Self {
84        self.add(rhs.neg(pool), pool)
85    }
86
87    /// Division: d(a/b) = (b·da - a·db) / b²
88    #[allow(dead_code)]
89    fn div(self, rhs: Self, pool: &ExprPool) -> Self {
90        let value = pool.mul(vec![self.value, pool.pow(rhs.value, pool.integer(-1_i32))]);
91        let bda = pool.mul(vec![rhs.value, self.tangent]);
92        let adb = pool.mul(vec![self.value, rhs.tangent]);
93        let neg_one = pool.integer(-1_i32);
94        let numerator = pool.add(vec![bda, pool.mul(vec![neg_one, adb])]);
95        let b_sq = pool.pow(rhs.value, pool.integer(2_i32));
96        let tangent = pool.mul(vec![numerator, pool.pow(b_sq, pool.integer(-1_i32))]);
97        DualValue::new(value, tangent)
98    }
99
100    /// Power rule for integer exponent n: d(f^n) = n * f^(n-1) * f'
101    fn pow_int(self, n: rug::Integer, pool: &ExprPool) -> Self {
102        if n == 0 {
103            let one = pool.integer(1_i32);
104            return DualValue::new(one, pool.integer(0_i32));
105        }
106        if n == 1 {
107            return self;
108        }
109        let n_id = pool.integer(n.clone());
110        let n_minus_1 = pool.integer(n - 1);
111        let value = pool.pow(self.value, n_id);
112        let base_pow = pool.pow(self.value, n_minus_1);
113        let tangent = pool.mul(vec![n_id, base_pow, self.tangent]);
114        DualValue::new(value, tangent)
115    }
116
117    fn sin(self, pool: &ExprPool) -> Self {
118        // d/dx sin(f) = cos(f) * f'
119        let value = pool.func("sin", vec![self.value]);
120        let cos_f = pool.func("cos", vec![self.value]);
121        let tangent = pool.mul(vec![cos_f, self.tangent]);
122        DualValue::new(value, tangent)
123    }
124
125    fn cos(self, pool: &ExprPool) -> Self {
126        // d/dx cos(f) = -sin(f) * f'
127        let value = pool.func("cos", vec![self.value]);
128        let sin_f = pool.func("sin", vec![self.value]);
129        let neg_one = pool.integer(-1_i32);
130        let tangent = pool.mul(vec![neg_one, sin_f, self.tangent]);
131        DualValue::new(value, tangent)
132    }
133
134    fn exp(self, pool: &ExprPool) -> Self {
135        // d/dx exp(f) = exp(f) * f'
136        let value = pool.func("exp", vec![self.value]);
137        let tangent = pool.mul(vec![value, self.tangent]);
138        DualValue::new(value, tangent)
139    }
140
141    fn log(self, pool: &ExprPool) -> Self {
142        // d/dx log(f) = f' / f = f' * f^(-1)
143        let value = pool.func("log", vec![self.value]);
144        let f_inv = pool.pow(self.value, pool.integer(-1_i32));
145        let tangent = pool.mul(vec![self.tangent, f_inv]);
146        DualValue::new(value, tangent)
147    }
148
149    fn sqrt(self, pool: &ExprPool) -> Self {
150        // d/dx sqrt(f) = f' / (2 * sqrt(f))
151        let value = pool.func("sqrt", vec![self.value]);
152        let two_sqrt = pool.mul(vec![pool.integer(2_i32), value]);
153        let tangent = pool.mul(vec![self.tangent, pool.pow(two_sqrt, pool.integer(-1_i32))]);
154        DualValue::new(value, tangent)
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Core evaluation
160// ---------------------------------------------------------------------------
161
162fn eval_dual(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DualValue, DiffError> {
163    enum Node {
164        IsVar,
165        IsConst,
166        Add(Vec<ExprId>),
167        Mul(Vec<ExprId>),
168        Pow { base: ExprId, exp: ExprId },
169        Func { name: String, arg: ExprId },
170    }
171
172    let node = pool.with(expr, |data| match data {
173        ExprData::Symbol { .. } if expr == var => Node::IsVar,
174        ExprData::Symbol { .. }
175        | ExprData::Integer(_)
176        | ExprData::Rational(_)
177        | ExprData::Float(_) => Node::IsConst,
178        ExprData::Add(args) => Node::Add(args.clone()),
179        ExprData::Mul(args) => Node::Mul(args.clone()),
180        ExprData::Pow { base, exp } => Node::Pow {
181            base: *base,
182            exp: *exp,
183        },
184        ExprData::Func { name, args } if args.len() == 1 => Node::Func {
185            name: name.clone(),
186            arg: args[0],
187        },
188        ExprData::Func { name, .. } => Node::Func {
189            name: name.clone(),
190            arg: expr,
191        },
192        // PA-9: Piecewise and Predicate are treated as constants w.r.t. the
193        // variable being differentiated (predicates don't depend on x algebraically).
194        ExprData::Piecewise { .. } | ExprData::Predicate { .. } => Node::IsConst,
195        ExprData::Forall { .. } | ExprData::Exists { .. } => Node::IsConst,
196        ExprData::BigO(_) => Node::IsConst,
197    });
198
199    match node {
200        Node::IsVar => Ok(DualValue::seed(expr, pool)),
201        Node::IsConst => Ok(DualValue::constant(expr, pool)),
202        Node::Add(args) => {
203            let mut acc = DualValue::constant(pool.integer(0_i32), pool);
204            for a in args {
205                acc = acc.add(eval_dual(a, var, pool)?, pool);
206            }
207            Ok(acc)
208        }
209        Node::Mul(args) => {
210            let mut acc = DualValue::constant(pool.integer(1_i32), pool);
211            for a in args {
212                acc = acc.mul(eval_dual(a, var, pool)?, pool);
213            }
214            Ok(acc)
215        }
216        Node::Pow { base, exp } => {
217            let n = pool
218                .with(exp, |data| match data {
219                    ExprData::Integer(n) => Some(n.0.clone()),
220                    _ => None,
221                })
222                .ok_or(DiffError::ForwardNonIntegerExponent)?;
223            let b = eval_dual(base, var, pool)?;
224            Ok(b.pow_int(n, pool))
225        }
226        Node::Func { name, arg } => {
227            // Protect against the dummy self-referential node from multi-arg fns
228            if arg == expr {
229                return Err(DiffError::ForwardUnknownFunction(name));
230            }
231            let inner = eval_dual(arg, var, pool)?;
232            match name.as_str() {
233                "sin" => Ok(inner.sin(pool)),
234                "cos" => Ok(inner.cos(pool)),
235                "exp" => Ok(inner.exp(pool)),
236                "log" => Ok(inner.log(pool)),
237                "sqrt" => Ok(inner.sqrt(pool)),
238                other => Err(DiffError::ForwardUnknownFunction(other.to_string())),
239            }
240        }
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Public API
246// ---------------------------------------------------------------------------
247
248/// Differentiate `expr` with respect to `var` using forward-mode (dual-number)
249/// automatic differentiation.
250///
251/// Returns the derivative expression after applying the rule-based simplifier.
252/// The derivation log records a single `diff_forward` step.
253///
254/// # Agreement with symbolic diff
255///
256/// For any polynomial or rational-function expression, `diff_forward` and
257/// `diff` (symbolic) produce structurally equal results after simplification.
258/// Property tests in this module verify this on random polynomials.
259pub fn diff_forward(
260    expr: ExprId,
261    var: ExprId,
262    pool: &ExprPool,
263) -> Result<DerivedExpr<ExprId>, DiffError> {
264    let dual = eval_dual(expr, var, pool)?;
265    let tangent_raw = dual.tangent;
266
267    // Simplify the raw tangent
268    let simplified = simplify(tangent_raw, pool);
269
270    // Wrap in a derivation log
271    let mut log = DerivationLog::new();
272    log.push(RewriteStep::simple("diff_forward", expr, simplified.value));
273    let full_log = log.merge(simplified.log);
274    Ok(DerivedExpr::with_log(simplified.value, full_log))
275}
276
277// ---------------------------------------------------------------------------
278// Tests
279// ---------------------------------------------------------------------------
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::diff::diff as sym_diff;
285    use crate::kernel::{Domain, ExprPool};
286    use crate::poly::UniPoly;
287
288    fn p() -> ExprPool {
289        ExprPool::new()
290    }
291
292    #[test]
293    fn forward_diff_constant() {
294        let pool = p();
295        let x = pool.symbol("x", Domain::Real);
296        let r = diff_forward(pool.integer(5_i32), x, &pool).unwrap();
297        assert_eq!(r.value, pool.integer(0_i32));
298    }
299
300    #[test]
301    fn forward_diff_identity() {
302        let pool = p();
303        let x = pool.symbol("x", Domain::Real);
304        let r = diff_forward(x, x, &pool).unwrap();
305        assert_eq!(r.value, pool.integer(1_i32));
306    }
307
308    #[test]
309    fn forward_diff_other_var() {
310        let pool = p();
311        let x = pool.symbol("x", Domain::Real);
312        let y = pool.symbol("y", Domain::Real);
313        let r = diff_forward(y, x, &pool).unwrap();
314        assert_eq!(r.value, pool.integer(0_i32));
315    }
316
317    #[test]
318    fn forward_diff_linear() {
319        // d/dx (3x) = 3
320        let pool = p();
321        let x = pool.symbol("x", Domain::Real);
322        let expr = pool.mul(vec![pool.integer(3_i32), x]);
323        let r = diff_forward(expr, x, &pool).unwrap();
324        assert_eq!(r.value, pool.integer(3_i32));
325    }
326
327    #[test]
328    fn forward_diff_quadratic_agrees_with_symbolic() {
329        // d/dx x² via forward vs symbolic
330        let pool = p();
331        let x = pool.symbol("x", Domain::Real);
332        let expr = pool.pow(x, pool.integer(2_i32));
333        let fwd = diff_forward(expr, x, &pool).unwrap();
334        let sym = sym_diff(expr, x, &pool).unwrap();
335        // Both should give 2x
336        let fwd_poly = UniPoly::from_symbolic(fwd.value, x, &pool).unwrap();
337        let sym_poly = UniPoly::from_symbolic(sym.value, x, &pool).unwrap();
338        assert_eq!(fwd_poly.coefficients_i64(), sym_poly.coefficients_i64());
339    }
340
341    #[test]
342    fn forward_diff_cubic_agrees_with_symbolic() {
343        let pool = p();
344        let x = pool.symbol("x", Domain::Real);
345        let expr = pool.pow(x, pool.integer(3_i32));
346        let fwd = diff_forward(expr, x, &pool).unwrap().value;
347        let sym = sym_diff(expr, x, &pool).unwrap().value;
348        let fwd_poly = UniPoly::from_symbolic(fwd, x, &pool).unwrap();
349        let sym_poly = UniPoly::from_symbolic(sym, x, &pool).unwrap();
350        assert_eq!(fwd_poly.coefficients_i64(), sym_poly.coefficients_i64());
351    }
352
353    #[test]
354    fn forward_diff_sin() {
355        let pool = p();
356        let x = pool.symbol("x", Domain::Real);
357        let r = diff_forward(pool.func("sin", vec![x]), x, &pool).unwrap();
358        assert_eq!(r.value, pool.func("cos", vec![x]));
359    }
360
361    #[test]
362    fn forward_diff_exp() {
363        let pool = p();
364        let x = pool.symbol("x", Domain::Real);
365        let exp_x = pool.func("exp", vec![x]);
366        let r = diff_forward(exp_x, x, &pool).unwrap();
367        assert_eq!(r.value, exp_x);
368    }
369
370    #[test]
371    fn forward_diff_log() {
372        // d/dx log(x) = x^{-1}
373        let pool = p();
374        let x = pool.symbol("x", Domain::Real);
375        let r = diff_forward(pool.func("log", vec![x]), x, &pool).unwrap();
376        assert_eq!(r.value, pool.pow(x, pool.integer(-1_i32)));
377    }
378
379    #[test]
380    fn forward_diff_step_logged() {
381        let pool = p();
382        let x = pool.symbol("x", Domain::Real);
383        let r = diff_forward(x, x, &pool).unwrap();
384        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_forward"));
385    }
386}