Skip to main content

alkahest_cas/diff/
diff_impl.rs

1use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
2use crate::kernel::{ExprData, ExprId, ExprPool};
3use crate::poly::UniPoly;
4use crate::simplify::engine::simplify;
5use std::fmt;
6
7// ---------------------------------------------------------------------------
8// Error type
9// ---------------------------------------------------------------------------
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum DiffError {
13    /// An unknown function was encountered; differentiation is not defined.
14    UnknownFunction(String),
15    /// A `Pow` node whose exponent is not a constant integer.
16    NonIntegerExponent,
17    /// Forward-mode: unknown function (folded from the former `ForwardDiffError`).
18    ForwardUnknownFunction(String),
19    /// Forward-mode: non-integer exponent (folded from the former `ForwardDiffError`).
20    ForwardNonIntegerExponent,
21}
22
23impl fmt::Display for DiffError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            DiffError::UnknownFunction(name) => {
27                write!(f, "cannot differentiate unknown function '{name}'")
28            }
29            DiffError::NonIntegerExponent => {
30                write!(f, "cannot differentiate power with non-integer exponent")
31            }
32            DiffError::ForwardUnknownFunction(name) => {
33                write!(f, "diff_forward: unknown function '{name}'")
34            }
35            DiffError::ForwardNonIntegerExponent => {
36                write!(f, "diff_forward: non-integer exponent")
37            }
38        }
39    }
40}
41
42impl std::error::Error for DiffError {}
43
44impl crate::errors::AlkahestError for DiffError {
45    fn code(&self) -> &'static str {
46        match self {
47            DiffError::UnknownFunction(_) => "E-DIFF-001",
48            DiffError::NonIntegerExponent => "E-DIFF-002",
49            DiffError::ForwardUnknownFunction(_) => "E-DIFF-003",
50            DiffError::ForwardNonIntegerExponent => "E-DIFF-004",
51        }
52    }
53
54    fn remediation(&self) -> Option<&'static str> {
55        match self {
56            DiffError::UnknownFunction(_) => Some(
57                "register the function in PrimitiveRegistry, or use diff_forward with a custom rule",
58            ),
59            DiffError::NonIntegerExponent => Some(
60                "symbolic exponents require the chain rule; use diff_forward for non-integer powers",
61            ),
62            DiffError::ForwardUnknownFunction(_) => Some(
63                "register the function in PrimitiveRegistry with diff_forward implemented",
64            ),
65            DiffError::ForwardNonIntegerExponent => Some(
66                "substitute concrete values first; diff_forward requires integer exponents",
67            ),
68        }
69    }
70}
71
72// ---------------------------------------------------------------------------
73// Public entry point
74// ---------------------------------------------------------------------------
75
76/// Symbolically differentiate `expr` with respect to `var`.
77///
78/// The returned log records every rule applied, including post-differentiation
79/// simplification steps appended at the end.
80pub fn diff(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DerivedExpr<ExprId>, DiffError> {
81    let result = diff_raw(expr, var, pool)?;
82    Ok(result.and_then(|v| simplify(v, pool)))
83}
84
85// ---------------------------------------------------------------------------
86// Core recursive differentiation (no simplification)
87// ---------------------------------------------------------------------------
88
89#[inline]
90fn diff_poly_try_univariate_fastpath(
91    expr: ExprId,
92    var: ExprId,
93    pool: &ExprPool,
94) -> Option<DerivedExpr<ExprId>> {
95    // Skip atoms so simple cases keep their dedicated log rules (`diff_identity`, `diff_const`, …).
96    if matches!(
97        pool.get(expr),
98        ExprData::Symbol { .. } | ExprData::Integer(_) | ExprData::Rational(_) | ExprData::Float(_)
99    ) {
100        return None;
101    }
102    let poly = UniPoly::from_symbolic(expr, var, pool).ok()?;
103    let der = poly.derivative();
104    let result = der.to_symbolic_expr(pool);
105    let mut log = DerivationLog::new();
106    log.push(RewriteStep::simple("diff_univariate_poly", expr, result));
107    Some(DerivedExpr::with_log(result, log))
108}
109
110fn diff_raw(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DerivedExpr<ExprId>, DiffError> {
111    if let Some(hit) = diff_poly_try_univariate_fastpath(expr, var, pool) {
112        return Ok(hit);
113    }
114
115    // Extract only what we need from the pool in a single lock acquisition,
116    // then release the lock before any recursive diff_raw calls.
117    enum Node {
118        IdentVar,
119        Const,
120        Add(Vec<ExprId>),
121        Mul(Vec<ExprId>),
122        Pow {
123            base: ExprId,
124            exp: ExprId,
125        },
126        Func {
127            name: String,
128            args: Vec<ExprId>,
129        },
130        Piecewise {
131            branches: Vec<(ExprId, ExprId)>,
132            default: ExprId,
133        },
134    }
135
136    let node = pool.with(expr, |data| match data {
137        ExprData::Symbol { .. } if expr == var => Node::IdentVar,
138        ExprData::Symbol { .. }
139        | ExprData::Integer(_)
140        | ExprData::Rational(_)
141        | ExprData::Float(_) => Node::Const,
142        ExprData::Add(args) => Node::Add(args.clone()),
143        ExprData::Mul(args) => Node::Mul(args.clone()),
144        ExprData::Pow { base, exp } => Node::Pow {
145            base: *base,
146            exp: *exp,
147        },
148        ExprData::Func { name, args } => Node::Func {
149            name: name.clone(),
150            args: args.clone(),
151        },
152        ExprData::Piecewise { branches, default } => Node::Piecewise {
153            branches: branches.clone(),
154            default: *default,
155        },
156        // Predicates have no algebraic derivative.
157        ExprData::Predicate { .. } => Node::Const,
158        ExprData::Forall { .. } | ExprData::Exists { .. } => Node::Const,
159        ExprData::BigO(_) => Node::Const,
160    });
161
162    match node {
163        // d/dx x = 1
164        Node::IdentVar => {
165            let one = pool.integer(1_i32);
166            Ok(DerivedExpr::with_step(
167                one,
168                RewriteStep::simple("diff_identity", expr, one),
169            ))
170        }
171        // d/dx c = 0  (any atom that is not the target variable)
172        Node::Const => {
173            let zero = pool.integer(0_i32);
174            Ok(DerivedExpr::with_step(
175                zero,
176                RewriteStep::simple("diff_const", expr, zero),
177            ))
178        }
179        // Sum rule: d/dx (f₁ + f₂ + …) = f₁' + f₂' + …
180        Node::Add(args) => {
181            let mut log = DerivationLog::new();
182            let mut dargs: Vec<ExprId> = Vec::with_capacity(args.len());
183            for a in args {
184                let da = diff_raw(a, var, pool)?;
185                log = log.merge(da.log);
186                dargs.push(da.value);
187            }
188            let sum = pool.add(dargs);
189            log.push(RewriteStep::simple("sum_rule", expr, sum));
190            Ok(DerivedExpr::with_log(sum, log))
191        }
192        // Product rule (n-ary Leibniz): d/dx (∏ᵢ fᵢ) = Σᵢ (fᵢ' · ∏_{j≠i} fⱼ)
193        Node::Mul(args) => {
194            let mut log = DerivationLog::new();
195            let dargs: Vec<DerivedExpr<ExprId>> = args
196                .iter()
197                .map(|&a| diff_raw(a, var, pool))
198                .collect::<Result<_, _>>()?;
199            for da in &dargs {
200                log = log.merge(da.log.clone());
201            }
202            let mut terms: Vec<ExprId> = Vec::with_capacity(args.len());
203            for (i, da) in dargs.iter().enumerate() {
204                let di = da.value;
205                let rest: Vec<ExprId> = args
206                    .iter()
207                    .enumerate()
208                    .filter(|&(j, _)| j != i)
209                    .map(|(_, &a)| a)
210                    .collect();
211                let term = if rest.is_empty() {
212                    di
213                } else if rest.len() == 1 {
214                    pool.mul(vec![di, rest[0]])
215                } else {
216                    let prod = pool.mul(rest);
217                    pool.mul(vec![di, prod])
218                };
219                terms.push(term);
220            }
221            let result = match terms.len() {
222                0 => pool.integer(0_i32),
223                1 => terms[0],
224                _ => pool.add(terms),
225            };
226            log.push(RewriteStep::simple("product_rule", expr, result));
227            Ok(DerivedExpr::with_log(result, log))
228        }
229        // Power rule (integer exponent): d/dx f^n = n · f^(n-1) · f'
230        Node::Pow { base, exp } => {
231            // Read the exponent without holding the pool lock during recursion.
232            let n = pool
233                .with(exp, |data| match data {
234                    ExprData::Integer(n) => Some(n.0.clone()),
235                    _ => None,
236                })
237                .ok_or(DiffError::NonIntegerExponent)?;
238
239            // Special case n=0: d/dx f^0 = 0
240            if n == 0 {
241                let zero = pool.integer(0_i32);
242                let mut log = DerivationLog::new();
243                log.push(RewriteStep::simple("power_rule_n0", expr, zero));
244                return Ok(DerivedExpr::with_log(zero, log));
245            }
246            // Special case n=1: d/dx f^1 = f'
247            if n == 1 {
248                let mut result = diff_raw(base, var, pool)?;
249                result
250                    .log
251                    .push(RewriteStep::simple("power_rule_n1", expr, result.value));
252                return Ok(result);
253            }
254
255            let mut log = DerivationLog::new();
256            let df = diff_raw(base, var, pool)?;
257            log = log.merge(df.log);
258            let n_id = pool.integer(n.clone());
259            let n_minus_1 = pool.integer(n - 1);
260            let base_pow = pool.pow(base, n_minus_1);
261            let result = pool.mul(vec![n_id, base_pow, df.value]);
262            log.push(RewriteStep::simple("power_rule", expr, result));
263            Ok(DerivedExpr::with_log(result, log))
264        }
265        // Chain rules for single-argument named functions
266        Node::Func { name, args } if args.len() == 1 => {
267            let f = args[0];
268            let mut log = DerivationLog::new();
269            let df = diff_raw(f, var, pool)?;
270            log = log.merge(df.log);
271            let result = match name.as_str() {
272                "sin" => {
273                    let cos_f = pool.func("cos", vec![f]);
274                    let r = pool.mul(vec![cos_f, df.value]);
275                    log.push(RewriteStep::simple("diff_sin", expr, r));
276                    r
277                }
278                "cos" => {
279                    let sin_f = pool.func("sin", vec![f]);
280                    let neg_one = pool.integer(-1_i32);
281                    let r = pool.mul(vec![neg_one, sin_f, df.value]);
282                    log.push(RewriteStep::simple("diff_cos", expr, r));
283                    r
284                }
285                "exp" => {
286                    let exp_f = pool.func("exp", vec![f]);
287                    let r = pool.mul(vec![exp_f, df.value]);
288                    log.push(RewriteStep::simple("diff_exp", expr, r));
289                    r
290                }
291                "log" => {
292                    let f_inv = pool.pow(f, pool.integer(-1_i32));
293                    let r = pool.mul(vec![df.value, f_inv]);
294                    log.push(RewriteStep::simple("diff_log", expr, r));
295                    r
296                }
297                "sqrt" => {
298                    let sqrt_f = pool.func("sqrt", vec![f]);
299                    let two_sqrt = pool.mul(vec![pool.integer(2_i32), sqrt_f]);
300                    let denom_inv = pool.pow(two_sqrt, pool.integer(-1_i32));
301                    let r = pool.mul(vec![df.value, denom_inv]);
302                    log.push(RewriteStep::simple("diff_sqrt", expr, r));
303                    r
304                }
305                other => {
306                    // Fall back to PrimitiveRegistry for V1-12 primitives
307                    let reg = crate::primitive::PrimitiveRegistry::default_registry();
308                    if let Some(d) = reg.diff_forward(other, &[f], var, pool) {
309                        log.push(RewriteStep::simple("diff_primitive_registry", expr, d));
310                        d
311                    } else {
312                        return Err(DiffError::UnknownFunction(other.to_string()));
313                    }
314                }
315            };
316            Ok(DerivedExpr::with_log(result, log))
317        }
318        Node::Func { name, .. } => Err(DiffError::UnknownFunction(name)),
319        // PA-9: Piecewise diff distributes into branches.
320        // d/dx Piecewise([(c₁,v₁), …], d) = Piecewise([(c₁, d/dx v₁), …], d/dx d)
321        Node::Piecewise { branches, default } => {
322            let mut log = DerivationLog::new();
323            let mut new_branches = Vec::with_capacity(branches.len());
324            for (cond, val) in branches {
325                let dval = diff_raw(val, var, pool)?;
326                log = log.merge(dval.log);
327                new_branches.push((cond, dval.value));
328            }
329            let ddefault = diff_raw(default, var, pool)?;
330            log = log.merge(ddefault.log);
331            let result = pool.piecewise(new_branches, ddefault.value);
332            log.push(RewriteStep::simple("diff_piecewise", expr, result));
333            Ok(DerivedExpr::with_log(result, log))
334        }
335    }
336}
337
338// ---------------------------------------------------------------------------
339// Unit tests
340// ---------------------------------------------------------------------------
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::kernel::{Domain, ExprPool};
346    use crate::poly::UniPoly;
347
348    fn p() -> ExprPool {
349        ExprPool::new()
350    }
351
352    #[test]
353    fn diff_constant() {
354        let pool = p();
355        let x = pool.symbol("x", Domain::Real);
356        let r = diff(pool.integer(5_i32), x, &pool).unwrap();
357        assert_eq!(r.value, pool.integer(0_i32));
358        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_const"));
359    }
360
361    #[test]
362    fn diff_identity() {
363        let pool = p();
364        let x = pool.symbol("x", Domain::Real);
365        let r = diff(x, x, &pool).unwrap();
366        assert_eq!(r.value, pool.integer(1_i32));
367        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_identity"));
368    }
369
370    #[test]
371    fn diff_other_variable() {
372        let pool = p();
373        let x = pool.symbol("x", Domain::Real);
374        let y = pool.symbol("y", Domain::Real);
375        let r = diff(y, x, &pool).unwrap();
376        assert_eq!(r.value, pool.integer(0_i32));
377    }
378
379    #[test]
380    fn diff_linear() {
381        // d/dx (3x) = 3
382        let pool = p();
383        let x = pool.symbol("x", Domain::Real);
384        let expr = pool.mul(vec![pool.integer(3_i32), x]);
385        let r = diff(expr, x, &pool).unwrap();
386        assert_eq!(r.value, pool.integer(3_i32));
387    }
388
389    #[test]
390    fn diff_quadratic() {
391        // d/dx x² = 2x
392        let pool = p();
393        let x = pool.symbol("x", Domain::Real);
394        let r = diff(pool.pow(x, pool.integer(2_i32)), x, &pool).unwrap();
395        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
396        assert_eq!(poly.coefficients_i64(), vec![0, 2]);
397    }
398
399    #[test]
400    fn diff_cubic() {
401        // d/dx x³ = 3x²
402        let pool = p();
403        let x = pool.symbol("x", Domain::Real);
404        let r = diff(pool.pow(x, pool.integer(3_i32)), x, &pool).unwrap();
405        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
406        assert_eq!(poly.coefficients_i64(), vec![0, 0, 3]);
407    }
408
409    #[test]
410    fn diff_polynomial() {
411        // d/dx (x³ + 2x² + x + 1) = 3x² + 4x + 1
412        let pool = p();
413        let x = pool.symbol("x", Domain::Real);
414        let expr = pool.add(vec![
415            pool.pow(x, pool.integer(3_i32)),
416            pool.mul(vec![pool.integer(2_i32), pool.pow(x, pool.integer(2_i32))]),
417            x,
418            pool.integer(1_i32),
419        ]);
420        let r = diff(expr, x, &pool).unwrap();
421        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
422        assert_eq!(poly.coefficients_i64(), vec![1, 4, 3]);
423    }
424
425    #[test]
426    fn diff_sum_rule_logged() {
427        let pool = p();
428        let x = pool.symbol("x", Domain::Real);
429        let y = pool.symbol("y", Domain::Real);
430        let r = diff(pool.add(vec![x, y]), x, &pool).unwrap();
431        assert_eq!(r.value, pool.integer(1_i32));
432        assert!(r.log.steps().iter().any(|s| s.rule_name == "sum_rule"));
433    }
434
435    #[test]
436    fn diff_product_rule_logged() {
437        let pool = p();
438        let x = pool.symbol("x", Domain::Real);
439        let y = pool.symbol("y", Domain::Real);
440        let r = diff(pool.mul(vec![x, y]), x, &pool).unwrap();
441        assert_eq!(r.value, y);
442        assert!(r.log.steps().iter().any(|s| s.rule_name == "product_rule"));
443    }
444
445    #[test]
446    fn diff_sin() {
447        let pool = p();
448        let x = pool.symbol("x", Domain::Real);
449        let r = diff(pool.func("sin", vec![x]), x, &pool).unwrap();
450        assert_eq!(r.value, pool.func("cos", vec![x]));
451        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_sin"));
452    }
453
454    #[test]
455    fn diff_cos() {
456        let pool = p();
457        let x = pool.symbol("x", Domain::Real);
458        let r = diff(pool.func("cos", vec![x]), x, &pool).unwrap();
459        // d/dx cos(x) = -sin(x) = Mul([-1, sin(x)]) in canonical arg order
460        let sin_x = pool.func("sin", vec![x]);
461        let neg_one = pool.integer(-1_i32);
462        match pool.get(r.value) {
463            ExprData::Mul(ref args) => {
464                assert_eq!(args.len(), 2);
465                assert!(args.contains(&neg_one) && args.contains(&sin_x));
466            }
467            _ => panic!("expected Mul, got {:?}", pool.display(r.value)),
468        }
469        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_cos"));
470    }
471
472    #[test]
473    fn diff_exp() {
474        let pool = p();
475        let x = pool.symbol("x", Domain::Real);
476        let exp_x = pool.func("exp", vec![x]);
477        let r = diff(exp_x, x, &pool).unwrap();
478        assert_eq!(r.value, exp_x);
479        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_exp"));
480    }
481
482    #[test]
483    fn diff_log() {
484        // d/dx log(x) = x^(-1)
485        let pool = p();
486        let x = pool.symbol("x", Domain::Real);
487        let r = diff(pool.func("log", vec![x]), x, &pool).unwrap();
488        assert_eq!(r.value, pool.pow(x, pool.integer(-1_i32)));
489        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_log"));
490    }
491
492    #[test]
493    fn diff_chain_rule_sin() {
494        // d/dx sin(x²): cos inner uses ℤ-polynomial fast-path for x² → 2x (not the granular power_rule).
495        let pool = p();
496        let x = pool.symbol("x", Domain::Real);
497        let r = diff(
498            pool.func("sin", vec![pool.pow(x, pool.integer(2_i32))]),
499            x,
500            &pool,
501        )
502        .unwrap();
503        assert_ne!(r.value, pool.integer(0_i32));
504        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_sin"));
505        assert!(r
506            .log
507            .steps()
508            .iter()
509            .any(|s| s.rule_name == "diff_univariate_poly"));
510    }
511
512    #[test]
513    fn diff_pow_n0() {
514        // d/dx f^0 = 0 — x^0 ≅ 1 is read as a ℤ-poly constant, so the dense derivative path applies.
515        let pool = p();
516        let x = pool.symbol("x", Domain::Real);
517        let expr = pool.pow(x, pool.integer(0_i32));
518        let r = diff(expr, x, &pool).unwrap();
519        assert_eq!(r.value, pool.integer(0_i32));
520        assert!(r
521            .log
522            .steps()
523            .iter()
524            .any(|s| s.rule_name == "diff_univariate_poly"));
525    }
526
527    #[test]
528    fn diff_pow_n1() {
529        // d/dx x^1 — same fast-path as other pure ℤ-polynomials.
530        let pool = p();
531        let x = pool.symbol("x", Domain::Real);
532        let expr = pool.pow(x, pool.integer(1_i32));
533        let r = diff(expr, x, &pool).unwrap();
534        assert_eq!(r.value, pool.integer(1_i32));
535        assert!(r
536            .log
537            .steps()
538            .iter()
539            .any(|s| s.rule_name == "diff_univariate_poly"));
540    }
541
542    #[test]
543    fn diff_unknown_function_error() {
544        let pool = p();
545        let x = pool.symbol("x", Domain::Real);
546        let err = diff(pool.func("zeta", vec![x]), x, &pool);
547        assert!(matches!(err, Err(DiffError::UnknownFunction(_))));
548    }
549
550    #[test]
551    fn diff_non_integer_exponent_error() {
552        let pool = p();
553        let x = pool.symbol("x", Domain::Real);
554        let y = pool.symbol("y", Domain::Real);
555        let err = diff(pool.pow(x, y), x, &pool);
556        assert!(matches!(err, Err(DiffError::NonIntegerExponent)));
557    }
558
559    #[test]
560    fn diff_balanced_geom_series_univariate_fastpath() {
561        fn balanced_sum(pool: &ExprPool, terms: &[ExprId]) -> ExprId {
562            match terms.len() {
563                0 => pool.integer(0_i32),
564                1 => terms[0],
565                _ => {
566                    let mid = terms.len() / 2;
567                    pool.add(vec![
568                        balanced_sum(pool, &terms[..mid]),
569                        balanced_sum(pool, &terms[mid..]),
570                    ])
571                }
572            }
573        }
574        let pool = p();
575        let x = pool.symbol("x", Domain::Real);
576        let n = 80i32;
577        let mut terms = vec![pool.integer(1_i32)];
578        for k in 1..=n {
579            terms.push(pool.pow(x, pool.integer(k)));
580        }
581        let expr = balanced_sum(&pool, &terms);
582        let r = diff(expr, x, &pool).unwrap();
583        assert!(
584            r.log
585                .steps()
586                .iter()
587                .any(|s| s.rule_name == "diff_univariate_poly"),
588            "expected dense ℤ-poly fast-path for balanced sum"
589        );
590        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
591        assert_eq!(poly.degree(), i64::from(n) - 1);
592        let coeffs = poly.coefficients_i64();
593        assert_eq!(coeffs.first().copied(), Some(1));
594        assert_eq!(coeffs.last().copied(), Some(n as i64));
595    }
596
597    #[test]
598    fn diff_log_has_both_diff_and_simplify_steps() {
599        let pool = p();
600        let x = pool.symbol("x", Domain::Real);
601        let y = pool.symbol("y", Domain::Real);
602        let expr = pool.add(vec![
603            pool.pow(x, pool.integer(2_i32)),
604            y,
605            pool.integer(0_i32),
606        ]);
607        let r = diff(expr, x, &pool).unwrap();
608        let rules: Vec<&str> = r.log.steps().iter().map(|s| s.rule_name).collect();
609        assert!(
610            rules.contains(&"sum_rule"),
611            "should have sum_rule: {rules:?}"
612        );
613        assert!(
614            rules.contains(&"diff_univariate_poly"),
615            "x² term differentiates via ℤ-polynomial fast-path: {rules:?}"
616        );
617        assert!(rules.len() > 1, "log should have multiple steps: {rules:?}");
618    }
619}