1use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
2use crate::kernel::{ExprData, ExprId, ExprPool};
3use crate::poly::UniPoly;
4use crate::simplify::engine::simplify;
5use std::fmt;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum DiffError {
13 UnknownFunction(String),
15 NonIntegerExponent,
17 ForwardUnknownFunction(String),
19 ForwardNonIntegerExponent,
21}
22
23impl fmt::Display for DiffError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 DiffError::UnknownFunction(name) => {
27 write!(f, "cannot differentiate unknown function '{name}'")
28 }
29 DiffError::NonIntegerExponent => {
30 write!(f, "cannot differentiate power with non-integer exponent")
31 }
32 DiffError::ForwardUnknownFunction(name) => {
33 write!(f, "diff_forward: unknown function '{name}'")
34 }
35 DiffError::ForwardNonIntegerExponent => {
36 write!(f, "diff_forward: non-integer exponent")
37 }
38 }
39 }
40}
41
42impl std::error::Error for DiffError {}
43
44impl crate::errors::AlkahestError for DiffError {
45 fn code(&self) -> &'static str {
46 match self {
47 DiffError::UnknownFunction(_) => "E-DIFF-001",
48 DiffError::NonIntegerExponent => "E-DIFF-002",
49 DiffError::ForwardUnknownFunction(_) => "E-DIFF-003",
50 DiffError::ForwardNonIntegerExponent => "E-DIFF-004",
51 }
52 }
53
54 fn remediation(&self) -> Option<&'static str> {
55 match self {
56 DiffError::UnknownFunction(_) => Some(
57 "register the function in PrimitiveRegistry, or use diff_forward with a custom rule",
58 ),
59 DiffError::NonIntegerExponent => Some(
60 "symbolic exponents require the chain rule; use diff_forward for non-integer powers",
61 ),
62 DiffError::ForwardUnknownFunction(_) => Some(
63 "register the function in PrimitiveRegistry with diff_forward implemented",
64 ),
65 DiffError::ForwardNonIntegerExponent => Some(
66 "substitute concrete values first; diff_forward requires integer exponents",
67 ),
68 }
69 }
70}
71
72pub fn diff(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DerivedExpr<ExprId>, DiffError> {
81 let result = diff_raw(expr, var, pool)?;
82 Ok(result.and_then(|v| simplify(v, pool)))
83}
84
85#[inline]
90fn diff_poly_try_univariate_fastpath(
91 expr: ExprId,
92 var: ExprId,
93 pool: &ExprPool,
94) -> Option<DerivedExpr<ExprId>> {
95 if matches!(
97 pool.get(expr),
98 ExprData::Symbol { .. } | ExprData::Integer(_) | ExprData::Rational(_) | ExprData::Float(_)
99 ) {
100 return None;
101 }
102 let poly = UniPoly::from_symbolic(expr, var, pool).ok()?;
103 let der = poly.derivative();
104 let result = der.to_symbolic_expr(pool);
105 let mut log = DerivationLog::new();
106 log.push(RewriteStep::simple("diff_univariate_poly", expr, result));
107 Some(DerivedExpr::with_log(result, log))
108}
109
110fn diff_raw(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<DerivedExpr<ExprId>, DiffError> {
111 if let Some(hit) = diff_poly_try_univariate_fastpath(expr, var, pool) {
112 return Ok(hit);
113 }
114
115 enum Node {
118 IdentVar,
119 Const,
120 Add(Vec<ExprId>),
121 Mul(Vec<ExprId>),
122 Pow {
123 base: ExprId,
124 exp: ExprId,
125 },
126 Func {
127 name: String,
128 args: Vec<ExprId>,
129 },
130 Piecewise {
131 branches: Vec<(ExprId, ExprId)>,
132 default: ExprId,
133 },
134 }
135
136 let node = pool.with(expr, |data| match data {
137 ExprData::Symbol { .. } if expr == var => Node::IdentVar,
138 ExprData::Symbol { .. }
139 | ExprData::Integer(_)
140 | ExprData::Rational(_)
141 | ExprData::Float(_) => Node::Const,
142 ExprData::Add(args) => Node::Add(args.clone()),
143 ExprData::Mul(args) => Node::Mul(args.clone()),
144 ExprData::Pow { base, exp } => Node::Pow {
145 base: *base,
146 exp: *exp,
147 },
148 ExprData::Func { name, args } => Node::Func {
149 name: name.clone(),
150 args: args.clone(),
151 },
152 ExprData::Piecewise { branches, default } => Node::Piecewise {
153 branches: branches.clone(),
154 default: *default,
155 },
156 ExprData::Predicate { .. } => Node::Const,
158 ExprData::Forall { .. } | ExprData::Exists { .. } => Node::Const,
159 ExprData::BigO(_) => Node::Const,
160 });
161
162 match node {
163 Node::IdentVar => {
165 let one = pool.integer(1_i32);
166 Ok(DerivedExpr::with_step(
167 one,
168 RewriteStep::simple("diff_identity", expr, one),
169 ))
170 }
171 Node::Const => {
173 let zero = pool.integer(0_i32);
174 Ok(DerivedExpr::with_step(
175 zero,
176 RewriteStep::simple("diff_const", expr, zero),
177 ))
178 }
179 Node::Add(args) => {
181 let mut log = DerivationLog::new();
182 let mut dargs: Vec<ExprId> = Vec::with_capacity(args.len());
183 for a in args {
184 let da = diff_raw(a, var, pool)?;
185 log = log.merge(da.log);
186 dargs.push(da.value);
187 }
188 let sum = pool.add(dargs);
189 log.push(RewriteStep::simple("sum_rule", expr, sum));
190 Ok(DerivedExpr::with_log(sum, log))
191 }
192 Node::Mul(args) => {
194 let mut log = DerivationLog::new();
195 let dargs: Vec<DerivedExpr<ExprId>> = args
196 .iter()
197 .map(|&a| diff_raw(a, var, pool))
198 .collect::<Result<_, _>>()?;
199 for da in &dargs {
200 log = log.merge(da.log.clone());
201 }
202 let mut terms: Vec<ExprId> = Vec::with_capacity(args.len());
203 for (i, da) in dargs.iter().enumerate() {
204 let di = da.value;
205 let rest: Vec<ExprId> = args
206 .iter()
207 .enumerate()
208 .filter(|&(j, _)| j != i)
209 .map(|(_, &a)| a)
210 .collect();
211 let term = if rest.is_empty() {
212 di
213 } else if rest.len() == 1 {
214 pool.mul(vec![di, rest[0]])
215 } else {
216 let prod = pool.mul(rest);
217 pool.mul(vec![di, prod])
218 };
219 terms.push(term);
220 }
221 let result = match terms.len() {
222 0 => pool.integer(0_i32),
223 1 => terms[0],
224 _ => pool.add(terms),
225 };
226 log.push(RewriteStep::simple("product_rule", expr, result));
227 Ok(DerivedExpr::with_log(result, log))
228 }
229 Node::Pow { base, exp } => {
231 let n = pool
233 .with(exp, |data| match data {
234 ExprData::Integer(n) => Some(n.0.clone()),
235 _ => None,
236 })
237 .ok_or(DiffError::NonIntegerExponent)?;
238
239 if n == 0 {
241 let zero = pool.integer(0_i32);
242 let mut log = DerivationLog::new();
243 log.push(RewriteStep::simple("power_rule_n0", expr, zero));
244 return Ok(DerivedExpr::with_log(zero, log));
245 }
246 if n == 1 {
248 let mut result = diff_raw(base, var, pool)?;
249 result
250 .log
251 .push(RewriteStep::simple("power_rule_n1", expr, result.value));
252 return Ok(result);
253 }
254
255 let mut log = DerivationLog::new();
256 let df = diff_raw(base, var, pool)?;
257 log = log.merge(df.log);
258 let n_id = pool.integer(n.clone());
259 let n_minus_1 = pool.integer(n - 1);
260 let base_pow = pool.pow(base, n_minus_1);
261 let result = pool.mul(vec![n_id, base_pow, df.value]);
262 log.push(RewriteStep::simple("power_rule", expr, result));
263 Ok(DerivedExpr::with_log(result, log))
264 }
265 Node::Func { name, args } if args.len() == 1 => {
267 let f = args[0];
268 let mut log = DerivationLog::new();
269 let df = diff_raw(f, var, pool)?;
270 log = log.merge(df.log);
271 let result = match name.as_str() {
272 "sin" => {
273 let cos_f = pool.func("cos", vec![f]);
274 let r = pool.mul(vec![cos_f, df.value]);
275 log.push(RewriteStep::simple("diff_sin", expr, r));
276 r
277 }
278 "cos" => {
279 let sin_f = pool.func("sin", vec![f]);
280 let neg_one = pool.integer(-1_i32);
281 let r = pool.mul(vec![neg_one, sin_f, df.value]);
282 log.push(RewriteStep::simple("diff_cos", expr, r));
283 r
284 }
285 "exp" => {
286 let exp_f = pool.func("exp", vec![f]);
287 let r = pool.mul(vec![exp_f, df.value]);
288 log.push(RewriteStep::simple("diff_exp", expr, r));
289 r
290 }
291 "log" => {
292 let f_inv = pool.pow(f, pool.integer(-1_i32));
293 let r = pool.mul(vec![df.value, f_inv]);
294 log.push(RewriteStep::simple("diff_log", expr, r));
295 r
296 }
297 "sqrt" => {
298 let sqrt_f = pool.func("sqrt", vec![f]);
299 let two_sqrt = pool.mul(vec![pool.integer(2_i32), sqrt_f]);
300 let denom_inv = pool.pow(two_sqrt, pool.integer(-1_i32));
301 let r = pool.mul(vec![df.value, denom_inv]);
302 log.push(RewriteStep::simple("diff_sqrt", expr, r));
303 r
304 }
305 other => {
306 let reg = crate::primitive::PrimitiveRegistry::default_registry();
308 if let Some(d) = reg.diff_forward(other, &[f], var, pool) {
309 log.push(RewriteStep::simple("diff_primitive_registry", expr, d));
310 d
311 } else {
312 return Err(DiffError::UnknownFunction(other.to_string()));
313 }
314 }
315 };
316 Ok(DerivedExpr::with_log(result, log))
317 }
318 Node::Func { name, .. } => Err(DiffError::UnknownFunction(name)),
319 Node::Piecewise { branches, default } => {
322 let mut log = DerivationLog::new();
323 let mut new_branches = Vec::with_capacity(branches.len());
324 for (cond, val) in branches {
325 let dval = diff_raw(val, var, pool)?;
326 log = log.merge(dval.log);
327 new_branches.push((cond, dval.value));
328 }
329 let ddefault = diff_raw(default, var, pool)?;
330 log = log.merge(ddefault.log);
331 let result = pool.piecewise(new_branches, ddefault.value);
332 log.push(RewriteStep::simple("diff_piecewise", expr, result));
333 Ok(DerivedExpr::with_log(result, log))
334 }
335 }
336}
337
338#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::kernel::{Domain, ExprPool};
346 use crate::poly::UniPoly;
347
348 fn p() -> ExprPool {
349 ExprPool::new()
350 }
351
352 #[test]
353 fn diff_constant() {
354 let pool = p();
355 let x = pool.symbol("x", Domain::Real);
356 let r = diff(pool.integer(5_i32), x, &pool).unwrap();
357 assert_eq!(r.value, pool.integer(0_i32));
358 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_const"));
359 }
360
361 #[test]
362 fn diff_identity() {
363 let pool = p();
364 let x = pool.symbol("x", Domain::Real);
365 let r = diff(x, x, &pool).unwrap();
366 assert_eq!(r.value, pool.integer(1_i32));
367 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_identity"));
368 }
369
370 #[test]
371 fn diff_other_variable() {
372 let pool = p();
373 let x = pool.symbol("x", Domain::Real);
374 let y = pool.symbol("y", Domain::Real);
375 let r = diff(y, x, &pool).unwrap();
376 assert_eq!(r.value, pool.integer(0_i32));
377 }
378
379 #[test]
380 fn diff_linear() {
381 let pool = p();
383 let x = pool.symbol("x", Domain::Real);
384 let expr = pool.mul(vec![pool.integer(3_i32), x]);
385 let r = diff(expr, x, &pool).unwrap();
386 assert_eq!(r.value, pool.integer(3_i32));
387 }
388
389 #[test]
390 fn diff_quadratic() {
391 let pool = p();
393 let x = pool.symbol("x", Domain::Real);
394 let r = diff(pool.pow(x, pool.integer(2_i32)), x, &pool).unwrap();
395 let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
396 assert_eq!(poly.coefficients_i64(), vec![0, 2]);
397 }
398
399 #[test]
400 fn diff_cubic() {
401 let pool = p();
403 let x = pool.symbol("x", Domain::Real);
404 let r = diff(pool.pow(x, pool.integer(3_i32)), x, &pool).unwrap();
405 let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
406 assert_eq!(poly.coefficients_i64(), vec![0, 0, 3]);
407 }
408
409 #[test]
410 fn diff_polynomial() {
411 let pool = p();
413 let x = pool.symbol("x", Domain::Real);
414 let expr = pool.add(vec![
415 pool.pow(x, pool.integer(3_i32)),
416 pool.mul(vec![pool.integer(2_i32), pool.pow(x, pool.integer(2_i32))]),
417 x,
418 pool.integer(1_i32),
419 ]);
420 let r = diff(expr, x, &pool).unwrap();
421 let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
422 assert_eq!(poly.coefficients_i64(), vec![1, 4, 3]);
423 }
424
425 #[test]
426 fn diff_sum_rule_logged() {
427 let pool = p();
428 let x = pool.symbol("x", Domain::Real);
429 let y = pool.symbol("y", Domain::Real);
430 let r = diff(pool.add(vec![x, y]), x, &pool).unwrap();
431 assert_eq!(r.value, pool.integer(1_i32));
432 assert!(r.log.steps().iter().any(|s| s.rule_name == "sum_rule"));
433 }
434
435 #[test]
436 fn diff_product_rule_logged() {
437 let pool = p();
438 let x = pool.symbol("x", Domain::Real);
439 let y = pool.symbol("y", Domain::Real);
440 let r = diff(pool.mul(vec![x, y]), x, &pool).unwrap();
441 assert_eq!(r.value, y);
442 assert!(r.log.steps().iter().any(|s| s.rule_name == "product_rule"));
443 }
444
445 #[test]
446 fn diff_sin() {
447 let pool = p();
448 let x = pool.symbol("x", Domain::Real);
449 let r = diff(pool.func("sin", vec![x]), x, &pool).unwrap();
450 assert_eq!(r.value, pool.func("cos", vec![x]));
451 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_sin"));
452 }
453
454 #[test]
455 fn diff_cos() {
456 let pool = p();
457 let x = pool.symbol("x", Domain::Real);
458 let r = diff(pool.func("cos", vec![x]), x, &pool).unwrap();
459 let sin_x = pool.func("sin", vec![x]);
461 let neg_one = pool.integer(-1_i32);
462 match pool.get(r.value) {
463 ExprData::Mul(ref args) => {
464 assert_eq!(args.len(), 2);
465 assert!(args.contains(&neg_one) && args.contains(&sin_x));
466 }
467 _ => panic!("expected Mul, got {:?}", pool.display(r.value)),
468 }
469 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_cos"));
470 }
471
472 #[test]
473 fn diff_exp() {
474 let pool = p();
475 let x = pool.symbol("x", Domain::Real);
476 let exp_x = pool.func("exp", vec![x]);
477 let r = diff(exp_x, x, &pool).unwrap();
478 assert_eq!(r.value, exp_x);
479 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_exp"));
480 }
481
482 #[test]
483 fn diff_log() {
484 let pool = p();
486 let x = pool.symbol("x", Domain::Real);
487 let r = diff(pool.func("log", vec![x]), x, &pool).unwrap();
488 assert_eq!(r.value, pool.pow(x, pool.integer(-1_i32)));
489 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_log"));
490 }
491
492 #[test]
493 fn diff_chain_rule_sin() {
494 let pool = p();
496 let x = pool.symbol("x", Domain::Real);
497 let r = diff(
498 pool.func("sin", vec![pool.pow(x, pool.integer(2_i32))]),
499 x,
500 &pool,
501 )
502 .unwrap();
503 assert_ne!(r.value, pool.integer(0_i32));
504 assert!(r.log.steps().iter().any(|s| s.rule_name == "diff_sin"));
505 assert!(r
506 .log
507 .steps()
508 .iter()
509 .any(|s| s.rule_name == "diff_univariate_poly"));
510 }
511
512 #[test]
513 fn diff_pow_n0() {
514 let pool = p();
516 let x = pool.symbol("x", Domain::Real);
517 let expr = pool.pow(x, pool.integer(0_i32));
518 let r = diff(expr, x, &pool).unwrap();
519 assert_eq!(r.value, pool.integer(0_i32));
520 assert!(r
521 .log
522 .steps()
523 .iter()
524 .any(|s| s.rule_name == "diff_univariate_poly"));
525 }
526
527 #[test]
528 fn diff_pow_n1() {
529 let pool = p();
531 let x = pool.symbol("x", Domain::Real);
532 let expr = pool.pow(x, pool.integer(1_i32));
533 let r = diff(expr, x, &pool).unwrap();
534 assert_eq!(r.value, pool.integer(1_i32));
535 assert!(r
536 .log
537 .steps()
538 .iter()
539 .any(|s| s.rule_name == "diff_univariate_poly"));
540 }
541
542 #[test]
543 fn diff_unknown_function_error() {
544 let pool = p();
545 let x = pool.symbol("x", Domain::Real);
546 let err = diff(pool.func("zeta", vec![x]), x, &pool);
547 assert!(matches!(err, Err(DiffError::UnknownFunction(_))));
548 }
549
550 #[test]
551 fn diff_non_integer_exponent_error() {
552 let pool = p();
553 let x = pool.symbol("x", Domain::Real);
554 let y = pool.symbol("y", Domain::Real);
555 let err = diff(pool.pow(x, y), x, &pool);
556 assert!(matches!(err, Err(DiffError::NonIntegerExponent)));
557 }
558
559 #[test]
560 fn diff_balanced_geom_series_univariate_fastpath() {
561 fn balanced_sum(pool: &ExprPool, terms: &[ExprId]) -> ExprId {
562 match terms.len() {
563 0 => pool.integer(0_i32),
564 1 => terms[0],
565 _ => {
566 let mid = terms.len() / 2;
567 pool.add(vec![
568 balanced_sum(pool, &terms[..mid]),
569 balanced_sum(pool, &terms[mid..]),
570 ])
571 }
572 }
573 }
574 let pool = p();
575 let x = pool.symbol("x", Domain::Real);
576 let n = 80i32;
577 let mut terms = vec![pool.integer(1_i32)];
578 for k in 1..=n {
579 terms.push(pool.pow(x, pool.integer(k)));
580 }
581 let expr = balanced_sum(&pool, &terms);
582 let r = diff(expr, x, &pool).unwrap();
583 assert!(
584 r.log
585 .steps()
586 .iter()
587 .any(|s| s.rule_name == "diff_univariate_poly"),
588 "expected dense ℤ-poly fast-path for balanced sum"
589 );
590 let poly = UniPoly::from_symbolic(r.value, x, &pool).unwrap();
591 assert_eq!(poly.degree(), i64::from(n) - 1);
592 let coeffs = poly.coefficients_i64();
593 assert_eq!(coeffs.first().copied(), Some(1));
594 assert_eq!(coeffs.last().copied(), Some(n as i64));
595 }
596
597 #[test]
598 fn diff_log_has_both_diff_and_simplify_steps() {
599 let pool = p();
600 let x = pool.symbol("x", Domain::Real);
601 let y = pool.symbol("y", Domain::Real);
602 let expr = pool.add(vec![
603 pool.pow(x, pool.integer(2_i32)),
604 y,
605 pool.integer(0_i32),
606 ]);
607 let r = diff(expr, x, &pool).unwrap();
608 let rules: Vec<&str> = r.log.steps().iter().map(|s| s.rule_name).collect();
609 assert!(
610 rules.contains(&"sum_rule"),
611 "should have sum_rule: {rules:?}"
612 );
613 assert!(
614 rules.contains(&"diff_univariate_poly"),
615 "x² term differentiates via ℤ-polynomial fast-path: {rules:?}"
616 );
617 assert!(rules.len() > 1, "log should have multiple steps: {rules:?}");
618 }
619}