Skip to main content

alkahest_cas/diff/
diff_impl.rs

1use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
2use crate::kernel::{ExprData, ExprId, ExprPool};
3use crate::poly::UniPoly;
4use crate::simplify::engine::simplify;
5use std::collections::HashMap;
6use std::fmt;
7
8/// Build a canonical constant node for a rational `r`: an `Integer` when `r` is
9/// integer-valued, otherwise a `Rational`.  Shared by all three diff modes for
10/// the constant-exponent power rule.
11pub(crate) fn const_node(pool: &ExprPool, r: rug::Rational) -> ExprId {
12    if *r.denom() == 1 {
13        pool.integer(r.numer().clone())
14    } else {
15        let (n, d) = r.into_numer_denom();
16        pool.rational(n, d)
17    }
18}
19
20// ---------------------------------------------------------------------------
21// Error type
22// ---------------------------------------------------------------------------
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum DiffError {
26    /// An unknown function was encountered; differentiation is not defined.
27    UnknownFunction(String),
28    /// A `Pow` node whose exponent is not a constant integer.
29    NonIntegerExponent,
30    /// Forward-mode: unknown function (folded from the former `ForwardDiffError`).
31    ForwardUnknownFunction(String),
32    /// Forward-mode: non-integer exponent (folded from the former `ForwardDiffError`).
33    ForwardNonIntegerExponent,
34}
35
36impl fmt::Display for DiffError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            DiffError::UnknownFunction(name) => {
40                write!(f, "cannot differentiate unknown function '{name}'")
41            }
42            DiffError::NonIntegerExponent => {
43                write!(f, "cannot differentiate power with non-integer exponent")
44            }
45            DiffError::ForwardUnknownFunction(name) => {
46                write!(f, "diff_forward: unknown function '{name}'")
47            }
48            DiffError::ForwardNonIntegerExponent => {
49                write!(f, "diff_forward: non-integer exponent")
50            }
51        }
52    }
53}
54
55impl std::error::Error for DiffError {}
56
57impl crate::errors::AlkahestError for DiffError {
58    fn code(&self) -> &'static str {
59        match self {
60            DiffError::UnknownFunction(_) => "E-DIFF-001",
61            DiffError::NonIntegerExponent => "E-DIFF-002",
62            DiffError::ForwardUnknownFunction(_) => "E-DIFF-003",
63            DiffError::ForwardNonIntegerExponent => "E-DIFF-004",
64        }
65    }
66
67    fn remediation(&self) -> Option<&'static str> {
68        match self {
69            DiffError::UnknownFunction(_) => Some(
70                "register the function in PrimitiveRegistry, or use diff_forward with a custom rule",
71            ),
72            DiffError::NonIntegerExponent => Some(
73                "symbolic exponents require the chain rule; use diff_forward for non-integer powers",
74            ),
75            DiffError::ForwardUnknownFunction(_) => Some(
76                "register the function in PrimitiveRegistry with diff_forward implemented",
77            ),
78            DiffError::ForwardNonIntegerExponent => Some(
79                "substitute concrete values first; diff_forward requires integer exponents",
80            ),
81        }
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Public entry point
87// ---------------------------------------------------------------------------
88
89/// Symbolically differentiate `expr` with respect to `var`.
90///
91/// The returned log records every rule applied, including post-differentiation
92/// simplification steps appended at the end.
93pub fn diff(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DerivedExpr<ExprId>, DiffError> {
94    // One memo table per top-level diff call.  Maps ExprId → derivative ExprId
95    // so shared subexpressions are differentiated at most once.
96    let mut memo: HashMap<ExprId, ExprId> = HashMap::new();
97    let result = diff_raw(expr, var, pool, &mut memo)?;
98    Ok(result.and_then(|v| simplify(v, pool)))
99}
100
101// ---------------------------------------------------------------------------
102// Core recursive differentiation (no simplification)
103// ---------------------------------------------------------------------------
104
105#[inline]
106fn diff_poly_try_univariate_fastpath(
107    expr: ExprId,
108    var: ExprId,
109    pool: &ExprPool,
110) -> Option<DerivedExpr<ExprId>> {
111    // Skip atoms so simple cases keep their dedicated log rules (`diff_identity`, `diff_const`, …).
112    if matches!(
113        pool.get(expr),
114        ExprData::Symbol { .. } | ExprData::Integer(_) | ExprData::Rational(_) | ExprData::Float(_)
115    ) {
116        return None;
117    }
118    let poly = UniPoly::from_symbolic(expr, var, pool).ok()?;
119    let der = poly.derivative();
120    let result = der.to_symbolic_expr(pool);
121    let mut log = DerivationLog::new();
122    log.push(RewriteStep::simple("diff_univariate_poly", expr, result));
123    Some(DerivedExpr::with_log(result, log))
124}
125
126/// Memoised differentiation worker.
127///
128/// `memo` maps `ExprId → ExprId` (derivative value).  Shared subexpressions
129/// are differentiated once; subsequent occurrences return the cached result
130/// with an empty derivation log to avoid duplicate log entries.
131fn diff_raw(
132    expr: ExprId,
133    var: ExprId,
134    pool: &ExprPool,
135    memo: &mut HashMap<ExprId, ExprId>,
136) -> Result<DerivedExpr<ExprId>, DiffError> {
137    // Return cached derivative for shared subexpressions.
138    if let Some(&cached) = memo.get(&expr) {
139        return Ok(DerivedExpr::new(cached));
140    }
141
142    if let Some(hit) = diff_poly_try_univariate_fastpath(expr, var, pool) {
143        memo.insert(expr, hit.value);
144        return Ok(hit);
145    }
146
147    // Extract only what we need from the pool in a single lock acquisition,
148    // then release the lock before any recursive diff_raw calls.
149    enum Node {
150        IdentVar,
151        Const,
152        Add(Vec<ExprId>),
153        Mul(Vec<ExprId>),
154        Pow {
155            base: ExprId,
156            exp: ExprId,
157        },
158        Func {
159            name: String,
160            args: Vec<ExprId>,
161        },
162        Piecewise {
163            branches: Vec<(ExprId, ExprId)>,
164            default: ExprId,
165        },
166        RootSum {
167            poly: ExprId,
168            rvar: ExprId,
169            body: ExprId,
170        },
171    }
172
173    let node = pool.with(expr, |data| match data {
174        ExprData::Symbol { .. } if expr == var => Node::IdentVar,
175        ExprData::Symbol { .. }
176        | ExprData::Integer(_)
177        | ExprData::Rational(_)
178        | ExprData::Float(_) => Node::Const,
179        ExprData::Add(args) => Node::Add(args.clone()),
180        ExprData::Mul(args) => Node::Mul(args.clone()),
181        ExprData::Pow { base, exp } => Node::Pow {
182            base: *base,
183            exp: *exp,
184        },
185        ExprData::Func { name, args } => Node::Func {
186            name: name.clone(),
187            args: args.clone(),
188        },
189        ExprData::Piecewise { branches, default } => Node::Piecewise {
190            branches: branches.clone(),
191            default: *default,
192        },
193        // Predicates have no algebraic derivative.
194        ExprData::Predicate { .. } => Node::Const,
195        ExprData::Forall { .. } | ExprData::Exists { .. } => Node::Const,
196        ExprData::BigO(_) => Node::Const,
197        ExprData::RootSum { poly, var, body } => Node::RootSum {
198            poly: *poly,
199            rvar: *var,
200            body: *body,
201        },
202    });
203
204    match node {
205        // d/dx x = 1
206        Node::IdentVar => {
207            let one = pool.integer(1_i32);
208            memo.insert(expr, one);
209            Ok(DerivedExpr::with_step(
210                one,
211                RewriteStep::simple("diff_identity", expr, one),
212            ))
213        }
214        // d/dx c = 0  (any atom that is not the target variable)
215        Node::Const => {
216            let zero = pool.integer(0_i32);
217            memo.insert(expr, zero);
218            Ok(DerivedExpr::with_step(
219                zero,
220                RewriteStep::simple("diff_const", expr, zero),
221            ))
222        }
223        // Sum rule: d/dx (f₁ + f₂ + …) = f₁' + f₂' + …
224        Node::Add(args) => {
225            let mut log = DerivationLog::new();
226            let mut dargs: Vec<ExprId> = Vec::with_capacity(args.len());
227            for a in args {
228                let da = diff_raw(a, var, pool, memo)?;
229                log = log.merge(da.log);
230                dargs.push(da.value);
231            }
232            let sum = pool.add(dargs);
233            log.push(RewriteStep::simple("sum_rule", expr, sum));
234            let result = DerivedExpr::with_log(sum, log);
235            memo.insert(expr, result.value);
236            Ok(result)
237        }
238        // Product rule (n-ary Leibniz): d/dx (∏ᵢ fᵢ) = Σᵢ (fᵢ' · ∏_{j≠i} fⱼ)
239        Node::Mul(args) => {
240            let mut log = DerivationLog::new();
241            let dargs: Vec<DerivedExpr<ExprId>> = args
242                .iter()
243                .map(|&a| diff_raw(a, var, pool, memo))
244                .collect::<Result<_, _>>()?;
245            for da in &dargs {
246                log = log.merge(da.log.clone());
247            }
248            let mut terms: Vec<ExprId> = Vec::with_capacity(args.len());
249            for (i, da) in dargs.iter().enumerate() {
250                let di = da.value;
251                let rest: Vec<ExprId> = args
252                    .iter()
253                    .enumerate()
254                    .filter(|&(j, _)| j != i)
255                    .map(|(_, &a)| a)
256                    .collect();
257                let term = if rest.is_empty() {
258                    di
259                } else if rest.len() == 1 {
260                    pool.mul(vec![di, rest[0]])
261                } else {
262                    let prod = pool.mul(rest);
263                    pool.mul(vec![di, prod])
264                };
265                terms.push(term);
266            }
267            let result_id = match terms.len() {
268                0 => pool.integer(0_i32),
269                1 => terms[0],
270                _ => pool.add(terms),
271            };
272            log.push(RewriteStep::simple("product_rule", expr, result_id));
273            let result = DerivedExpr::with_log(result_id, log);
274            memo.insert(expr, result.value);
275            Ok(result)
276        }
277        // Power rule, constant exponent (integer or rational):
278        //   d/dx f^r = r · f^(r-1) · f'.
279        // A var-dependent / non-constant exponent (e.g. x^y, x^x) is a different
280        // rule (logarithmic differentiation) and remains unsupported.
281        Node::Pow { base, exp } => {
282            // Read the exponent without holding the pool lock during recursion.
283            let r = pool
284                .with(exp, |data| match data {
285                    ExprData::Integer(n) => Some(rug::Rational::from(n.0.clone())),
286                    ExprData::Rational(q) => Some(q.0.clone()),
287                    _ => None,
288                })
289                .ok_or(DiffError::NonIntegerExponent)?;
290
291            // Special case r=0: d/dx f^0 = 0
292            if r == 0 {
293                let zero = pool.integer(0_i32);
294                let mut log = DerivationLog::new();
295                log.push(RewriteStep::simple("power_rule_n0", expr, zero));
296                memo.insert(expr, zero);
297                return Ok(DerivedExpr::with_log(zero, log));
298            }
299            // Special case r=1: d/dx f^1 = f'
300            if r == 1 {
301                let mut result = diff_raw(base, var, pool, memo)?;
302                result
303                    .log
304                    .push(RewriteStep::simple("power_rule_n1", expr, result.value));
305                memo.insert(expr, result.value);
306                return Ok(result);
307            }
308
309            let mut log = DerivationLog::new();
310            let df = diff_raw(base, var, pool, memo)?;
311            log = log.merge(df.log);
312            let r_id = const_node(pool, r.clone());
313            let r_minus_1 = const_node(pool, r - 1);
314            let base_pow = pool.pow(base, r_minus_1);
315            let result_id = pool.mul(vec![r_id, base_pow, df.value]);
316            log.push(RewriteStep::simple("power_rule", expr, result_id));
317            memo.insert(expr, result_id);
318            Ok(DerivedExpr::with_log(result_id, log))
319        }
320        // Chain rules for single-argument named functions
321        Node::Func { name, args } if args.len() == 1 => {
322            let f = args[0];
323            let mut log = DerivationLog::new();
324            let df = diff_raw(f, var, pool, memo)?;
325            log = log.merge(df.log);
326            let result = match name.as_str() {
327                "sin" => {
328                    let cos_f = pool.func("cos", vec![f]);
329                    let r = pool.mul(vec![cos_f, df.value]);
330                    log.push(RewriteStep::simple("diff_sin", expr, r));
331                    r
332                }
333                "cos" => {
334                    let sin_f = pool.func("sin", vec![f]);
335                    let neg_one = pool.integer(-1_i32);
336                    let r = pool.mul(vec![neg_one, sin_f, df.value]);
337                    log.push(RewriteStep::simple("diff_cos", expr, r));
338                    r
339                }
340                "exp" => {
341                    let exp_f = pool.func("exp", vec![f]);
342                    let r = pool.mul(vec![exp_f, df.value]);
343                    log.push(RewriteStep::simple("diff_exp", expr, r));
344                    r
345                }
346                "log" => {
347                    let f_inv = pool.pow(f, pool.integer(-1_i32));
348                    let r = pool.mul(vec![df.value, f_inv]);
349                    log.push(RewriteStep::simple("diff_log", expr, r));
350                    r
351                }
352                "sqrt" => {
353                    let sqrt_f = pool.func("sqrt", vec![f]);
354                    let two_sqrt = pool.mul(vec![pool.integer(2_i32), sqrt_f]);
355                    let denom_inv = pool.pow(two_sqrt, pool.integer(-1_i32));
356                    let r = pool.mul(vec![df.value, denom_inv]);
357                    log.push(RewriteStep::simple("diff_sqrt", expr, r));
358                    r
359                }
360                other => {
361                    // Fall back to PrimitiveRegistry for V1-12 primitives
362                    let reg = crate::primitive::PrimitiveRegistry::default_registry();
363                    if let Some(d) = reg.diff_forward(other, &[f], var, pool) {
364                        log.push(RewriteStep::simple("diff_primitive_registry", expr, d));
365                        d
366                    } else {
367                        return Err(DiffError::UnknownFunction(other.to_string()));
368                    }
369                }
370            };
371            memo.insert(expr, result);
372            Ok(DerivedExpr::with_log(result, log))
373        }
374        // Multi-argument named functions: route through the PrimitiveRegistry.
375        // The primitive's `diff_forward` computes each argument's derivative
376        // internally (via `crate::diff::diff`) and returns the total chain-rule
377        // derivative, so we only need to dispatch — no per-argument recursion
378        // here, which avoids double-counting.  We still touch each argument via
379        // `diff_raw` so the shared-subexpression memo stays consistent.
380        Node::Func { name, args } => {
381            let mut log = DerivationLog::new();
382            for &a in &args {
383                let da = diff_raw(a, var, pool, memo)?;
384                log = log.merge(da.log);
385            }
386            let reg = crate::primitive::PrimitiveRegistry::default_registry();
387            if let Some(d) = reg.diff_forward(&name, &args, var, pool) {
388                log.push(RewriteStep::simple("diff_primitive_registry", expr, d));
389                memo.insert(expr, d);
390                Ok(DerivedExpr::with_log(d, log))
391            } else {
392                Err(DiffError::UnknownFunction(name))
393            }
394        }
395        // PA-9: Piecewise diff distributes into branches.
396        // d/dx Piecewise([(c₁,v₁), …], d) = Piecewise([(c₁, d/dx v₁), …], d/dx d)
397        Node::Piecewise { branches, default } => {
398            let mut log = DerivationLog::new();
399            let mut new_branches = Vec::with_capacity(branches.len());
400            for (cond, val) in branches {
401                let dval = diff_raw(val, var, pool, memo)?;
402                log = log.merge(dval.log);
403                new_branches.push((cond, dval.value));
404            }
405            let ddefault = diff_raw(default, var, pool, memo)?;
406            log = log.merge(ddefault.log);
407            let result = pool.piecewise(new_branches, ddefault.value);
408            log.push(RewriteStep::simple("diff_piecewise", expr, result));
409            memo.insert(expr, result);
410            Ok(DerivedExpr::with_log(result, log))
411        }
412        // d/dx Σ_{c:P(c)=0} body(c,x) = Σ_{c:P(c)=0} ∂body/∂x.
413        // The root `c` (rvar) is constant in `x`; `poly` is free of `x`.
414        Node::RootSum { poly, rvar, body } => {
415            let dbody = diff_raw(body, var, pool, memo)?;
416            let result = pool.root_sum(poly, rvar, dbody.value);
417            let mut log = dbody.log;
418            log.push(RewriteStep::simple("diff_root_sum", expr, result));
419            memo.insert(expr, result);
420            Ok(DerivedExpr::with_log(result, log))
421        }
422    }
423}
424
425// ---------------------------------------------------------------------------
426// Unit tests
427// ---------------------------------------------------------------------------
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use crate::kernel::{Domain, ExprPool};
433    use crate::poly::UniPoly;
434
435    fn p() -> ExprPool {
436        ExprPool::new()
437    }
438
439    #[test]
440    fn diff_constant() {
441        let pool = p();
442        let x = pool.symbol("x", Domain::Real);
443        let r = diff(pool.integer(5_i32), x, &pool).unwrap();
444        assert_eq!(r.value, pool.integer(0_i32));
445        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_const"));
446    }
447
448    #[test]
449    fn diff_identity() {
450        let pool = p();
451        let x = pool.symbol("x", Domain::Real);
452        let r = diff(x, x, &pool).unwrap();
453        assert_eq!(r.value, pool.integer(1_i32));
454        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_identity"));
455    }
456
457    #[test]
458    fn diff_other_variable() {
459        let pool = p();
460        let x = pool.symbol("x", Domain::Real);
461        let y = pool.symbol("y", Domain::Real);
462        let r = diff(y, x, &pool).unwrap();
463        assert_eq!(r.value, pool.integer(0_i32));
464    }
465
466    #[test]
467    fn diff_linear() {
468        // d/dx (3x) = 3
469        let pool = p();
470        let x = pool.symbol("x", Domain::Real);
471        let expr = pool.mul(vec![pool.integer(3_i32), x]);
472        let r = diff(expr, x, &pool).unwrap();
473        assert_eq!(r.value, pool.integer(3_i32));
474    }
475
476    #[test]
477    fn diff_quadratic() {
478        // d/dx x² = 2x
479        let pool = p();
480        let x = pool.symbol("x", Domain::Real);
481        let r = diff(pool.pow(x, pool.integer(2_i32)), x, &pool).unwrap();
482        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
483        assert_eq!(poly.coefficients_i64(), vec![0, 2]);
484    }
485
486    #[test]
487    fn diff_cubic() {
488        // d/dx x³ = 3x²
489        let pool = p();
490        let x = pool.symbol("x", Domain::Real);
491        let r = diff(pool.pow(x, pool.integer(3_i32)), x, &pool).unwrap();
492        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
493        assert_eq!(poly.coefficients_i64(), vec![0, 0, 3]);
494    }
495
496    #[test]
497    fn diff_polynomial() {
498        // d/dx (x³ + 2x² + x + 1) = 3x² + 4x + 1
499        let pool = p();
500        let x = pool.symbol("x", Domain::Real);
501        let expr = pool.add(vec![
502            pool.pow(x, pool.integer(3_i32)),
503            pool.mul(vec![pool.integer(2_i32), pool.pow(x, pool.integer(2_i32))]),
504            x,
505            pool.integer(1_i32),
506        ]);
507        let r = diff(expr, x, &pool).unwrap();
508        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
509        assert_eq!(poly.coefficients_i64(), vec![1, 4, 3]);
510    }
511
512    #[test]
513    fn diff_sum_rule_logged() {
514        let pool = p();
515        let x = pool.symbol("x", Domain::Real);
516        let y = pool.symbol("y", Domain::Real);
517        let r = diff(pool.add(vec![x, y]), x, &pool).unwrap();
518        assert_eq!(r.value, pool.integer(1_i32));
519        assert!(r.log.steps().iter().any(|s| s.rule_name == "sum_rule"));
520    }
521
522    #[test]
523    fn diff_product_rule_logged() {
524        let pool = p();
525        let x = pool.symbol("x", Domain::Real);
526        let y = pool.symbol("y", Domain::Real);
527        let r = diff(pool.mul(vec![x, y]), x, &pool).unwrap();
528        assert_eq!(r.value, y);
529        assert!(r.log.steps().iter().any(|s| s.rule_name == "product_rule"));
530    }
531
532    #[test]
533    fn diff_sin() {
534        let pool = p();
535        let x = pool.symbol("x", Domain::Real);
536        let r = diff(pool.func("sin", vec![x]), x, &pool).unwrap();
537        assert_eq!(r.value, pool.func("cos", vec![x]));
538        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_sin"));
539    }
540
541    #[test]
542    fn diff_cos() {
543        let pool = p();
544        let x = pool.symbol("x", Domain::Real);
545        let r = diff(pool.func("cos", vec![x]), x, &pool).unwrap();
546        // d/dx cos(x) = -sin(x) = Mul([-1, sin(x)]) in canonical arg order
547        let sin_x = pool.func("sin", vec![x]);
548        let neg_one = pool.integer(-1_i32);
549        match pool.get(r.value) {
550            ExprData::Mul(ref args) => {
551                assert_eq!(args.len(), 2);
552                assert!(args.contains(&neg_one) && args.contains(&sin_x));
553            }
554            _ => panic!("expected Mul, got {:?}", pool.display(r.value)),
555        }
556        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_cos"));
557    }
558
559    #[test]
560    fn diff_exp() {
561        let pool = p();
562        let x = pool.symbol("x", Domain::Real);
563        let exp_x = pool.func("exp", vec![x]);
564        let r = diff(exp_x, x, &pool).unwrap();
565        assert_eq!(r.value, exp_x);
566        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_exp"));
567    }
568
569    #[test]
570    fn diff_log() {
571        // d/dx log(x) = x^(-1)
572        let pool = p();
573        let x = pool.symbol("x", Domain::Real);
574        let r = diff(pool.func("log", vec![x]), x, &pool).unwrap();
575        assert_eq!(r.value, pool.pow(x, pool.integer(-1_i32)));
576        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_log"));
577    }
578
579    #[test]
580    fn diff_chain_rule_sin() {
581        // d/dx sin(x²): cos inner uses ℤ-polynomial fast-path for x² → 2x (not the granular power_rule).
582        let pool = p();
583        let x = pool.symbol("x", Domain::Real);
584        let r = diff(
585            pool.func("sin", vec![pool.pow(x, pool.integer(2_i32))]),
586            x,
587            &pool,
588        )
589        .unwrap();
590        assert_ne!(r.value, pool.integer(0_i32));
591        assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_sin"));
592        assert!(r
593            .log
594            .steps()
595            .iter()
596            .any(|s| s.rule_name == "diff_univariate_poly"));
597    }
598
599    #[test]
600    fn diff_pow_n0() {
601        // d/dx f^0 = 0 — x^0 ≅ 1 is read as a ℤ-poly constant, so the dense derivative path applies.
602        let pool = p();
603        let x = pool.symbol("x", Domain::Real);
604        let expr = pool.pow(x, pool.integer(0_i32));
605        let r = diff(expr, x, &pool).unwrap();
606        assert_eq!(r.value, pool.integer(0_i32));
607        assert!(r
608            .log
609            .steps()
610            .iter()
611            .any(|s| s.rule_name == "diff_univariate_poly"));
612    }
613
614    #[test]
615    fn diff_pow_n1() {
616        // d/dx x^1 — same fast-path as other pure ℤ-polynomials.
617        let pool = p();
618        let x = pool.symbol("x", Domain::Real);
619        let expr = pool.pow(x, pool.integer(1_i32));
620        let r = diff(expr, x, &pool).unwrap();
621        assert_eq!(r.value, pool.integer(1_i32));
622        assert!(r
623            .log
624            .steps()
625            .iter()
626            .any(|s| s.rule_name == "diff_univariate_poly"));
627    }
628
629    #[test]
630    fn diff_unknown_function_error() {
631        let pool = p();
632        let x = pool.symbol("x", Domain::Real);
633        let err = diff(pool.func("zeta", vec![x]), x, &pool);
634        assert!(matches!(err, Err(DiffError::UnknownFunction(_))));
635    }
636
637    #[test]
638    fn diff_non_integer_exponent_error() {
639        let pool = p();
640        let x = pool.symbol("x", Domain::Real);
641        let y = pool.symbol("y", Domain::Real);
642        // A *var-dependent* exponent still needs logarithmic differentiation and
643        // remains unsupported.
644        let err = diff(pool.pow(x, y), x, &pool);
645        assert!(matches!(err, Err(DiffError::NonIntegerExponent)));
646    }
647
648    #[test]
649    fn diff_fractional_power() {
650        let pool = p();
651        let x = pool.symbol("x", Domain::Real);
652        // d/dx x^{1/2} = (1/2) x^{-1/2}.
653        let half = pool.pow(x, pool.rational(1_i32, 2_i32));
654        let d = diff(half, x, &pool).unwrap();
655        let expected = pool.mul(vec![
656            pool.rational(1_i32, 2_i32),
657            pool.pow(x, pool.rational(-1_i32, 2_i32)),
658        ]);
659        assert_eq!(
660            simplify(d.value, &pool).value,
661            simplify(expected, &pool).value
662        );
663
664        // d/dx x^{2/3} = (2/3) x^{-1/3}.
665        let two_thirds = pool.pow(x, pool.rational(2_i32, 3_i32));
666        let d = diff(two_thirds, x, &pool).unwrap();
667        let expected = pool.mul(vec![
668            pool.rational(2_i32, 3_i32),
669            pool.pow(x, pool.rational(-1_i32, 3_i32)),
670        ]);
671        assert_eq!(
672            simplify(d.value, &pool).value,
673            simplify(expected, &pool).value
674        );
675    }
676
677    #[test]
678    fn diff_fractional_power_chain_rule() {
679        let pool = p();
680        let x = pool.symbol("x", Domain::Real);
681        // d/dx (x²+1)^{3/2} = (3/2)(x²+1)^{1/2}·2x = 3x·(x²+1)^{1/2}.
682        let base = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
683        let expr = pool.pow(base, pool.rational(3_i32, 2_i32));
684        let d = diff(expr, x, &pool).unwrap();
685        let expected = pool.mul(vec![
686            pool.integer(3_i32),
687            x,
688            pool.pow(base, pool.rational(1_i32, 2_i32)),
689        ]);
690        assert_eq!(
691            simplify(d.value, &pool).value,
692            simplify(expected, &pool).value
693        );
694    }
695
696    #[test]
697    fn diff_balanced_geom_series_univariate_fastpath() {
698        fn balanced_sum(pool: &ExprPool, terms: &[ExprId]) -> ExprId {
699            match terms.len() {
700                0 => pool.integer(0_i32),
701                1 => terms[0],
702                _ => {
703                    let mid = terms.len() / 2;
704                    pool.add(vec![
705                        balanced_sum(pool, &terms[..mid]),
706                        balanced_sum(pool, &terms[mid..]),
707                    ])
708                }
709            }
710        }
711        let pool = p();
712        let x = pool.symbol("x", Domain::Real);
713        let n = 80i32;
714        let mut terms = vec![pool.integer(1_i32)];
715        for k in 1..=n {
716            terms.push(pool.pow(x, pool.integer(k)));
717        }
718        let expr = balanced_sum(&pool, &terms);
719        let r = diff(expr, x, &pool).unwrap();
720        assert!(
721            r.log
722                .steps()
723                .iter()
724                .any(|s| s.rule_name == "diff_univariate_poly"),
725            "expected dense ℤ-poly fast-path for balanced sum"
726        );
727        let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
728        assert_eq!(poly.degree(), i64::from(n) - 1);
729        let coeffs = poly.coefficients_i64();
730        assert_eq!(coeffs.first().copied(), Some(1));
731        assert_eq!(coeffs.last().copied(), Some(n as i64));
732    }
733
734    #[test]
735    fn diff_log_has_both_diff_and_simplify_steps() {
736        let pool = p();
737        let x = pool.symbol("x", Domain::Real);
738        let y = pool.symbol("y", Domain::Real);
739        let expr = pool.add(vec![
740            pool.pow(x, pool.integer(2_i32)),
741            y,
742            pool.integer(0_i32),
743        ]);
744        let r = diff(expr, x, &pool).unwrap();
745        let rules: Vec<&str> = r.log.steps().iter().map(|s| s.rule_name).collect();
746        assert!(
747            rules.contains(&"sum_rule"),
748            "should have sum_rule: {rules:?}"
749        );
750        assert!(
751            rules.contains(&"diff_univariate_poly"),
752            "x² term differentiates via ℤ-polynomial fast-path: {rules:?}"
753        );
754        assert!(rules.len() > 1, "log should have multiple steps: {rules:?}");
755    }
756}