Skip to main content

alkahest_cas/sum/
mod.rs

1//! Creative telescoping / Zeilberger-style symbolic summation (V2-10).
2//!
3//! Gosper indefinite summation for hypergeometric terms — ratios `F(k+1)/F(k)`
4//! that reduce to rational functions of `k`.  Includes constant-coefficient
5//! homogeneous recurrence solving (order ≤ 2), explicit [`rsolve`] for linear
6//! difference equations (V2-18), and optional WZ pair verification.
7
8mod expr_ratio;
9mod gosper;
10mod poly_aux;
11mod product;
12mod ratfunc;
13mod recurrence;
14mod rsolve;
15
16pub use expr_ratio::hypergeom_ratio;
17pub use gosper::{gosper_certificate, gosper_normal_form};
18pub use product::{product_definite, product_indefinite, ProductError};
19pub use ratfunc::RatFunc;
20pub use recurrence::{
21    solve_linear_recurrence_homogeneous, LinearRecurrenceError, RecurrenceSolution,
22};
23pub use rsolve::{rsolve, RsolveError};
24
25use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
26use crate::kernel::subs::subs;
27use crate::kernel::{ExprId, ExprPool};
28use crate::matrix::normal_form::RatUniPoly;
29use crate::simplify::engine::simplify;
30use std::collections::HashMap;
31use std::fmt;
32
33fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
34    simplify(e, pool).value
35}
36
37/// Errors from symbolic summation.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum SumError {
40    /// Term is not hypergeometric or ratio extraction failed.
41    NotHypergeometric(String),
42    /// Gosper's algorithm does not apply (no rational certificate).
43    NotGosperSummable,
44    /// Difference-variable substitution failed building bounds.
45    BoundSubstitution(String),
46}
47
48impl fmt::Display for SumError {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            SumError::NotHypergeometric(s) => write!(f, "sum: not hypergeometric: {s}"),
52            SumError::NotGosperSummable => write!(f, "sum: term is not Gosper-summable"),
53            SumError::BoundSubstitution(s) => write!(f, "sum: bound substitution: {s}"),
54        }
55    }
56}
57
58impl std::error::Error for SumError {}
59
60impl crate::errors::AlkahestError for SumError {
61    fn code(&self) -> &'static str {
62        match self {
63            SumError::NotHypergeometric(_) => "E-SUM-001",
64            SumError::NotGosperSummable => "E-SUM-002",
65            SumError::BoundSubstitution(_) => "E-SUM-003",
66        }
67    }
68
69    fn remediation(&self) -> Option<&'static str> {
70        Some(
71            "supported indefinite sums are hypergeometric terms built from polynomials in k, products, and gamma(linear(k)); Zeilberger automation is partial — use verify_wz_pair for certificates",
72        )
73    }
74}
75
76fn rat_poly_to_expr(pool: &ExprPool, k: ExprId, p: &RatUniPoly) -> ExprId {
77    let mut terms: Vec<ExprId> = Vec::new();
78    for (deg, coeff) in p.coeffs.iter().enumerate() {
79        if coeff.is_zero() {
80            continue;
81        }
82        let coeff_q = coeff.clone();
83        let numer = coeff_q.numer();
84        let denom = coeff_q.denom();
85        let coeff_expr = if *denom == 1 {
86            pool.integer(numer.clone())
87        } else {
88            pool.rational(numer.clone(), denom.clone())
89        };
90        let pow_id = if deg == 0 {
91            coeff_expr
92        } else if deg == 1 {
93            pool.mul(vec![coeff_expr, k])
94        } else {
95            pool.mul(vec![coeff_expr, pool.pow(k, pool.integer(deg as i64))])
96        };
97        terms.push(pow_id);
98    }
99    match terms.len() {
100        0 => pool.integer(0_i32),
101        1 => terms[0],
102        _ => pool.add(terms),
103    }
104}
105
106fn ratfunc_to_expr(pool: &ExprPool, k: ExprId, r: &RatFunc) -> ExprId {
107    let num_e = rat_poly_to_expr(pool, k, &r.num);
108    if r.den.is_zero() || r.den.degree() == 0 && r.den.coeffs.is_empty() {
109        return num_e;
110    }
111    let den_e = rat_poly_to_expr(pool, k, &r.den);
112    pool.mul(vec![num_e, pool.pow(den_e, pool.integer(-1_i32))])
113}
114
115/// Indefinite Gosper sum: find `G(k)` with `G(k+1)-G(k)=term` when `term` is hypergeometric in `k`.
116pub fn sum_indefinite(
117    term: ExprId,
118    k: ExprId,
119    pool: &ExprPool,
120) -> Result<DerivedExpr<ExprId>, SumError> {
121    let ratio = hypergeom_ratio(term, k, pool)?;
122    let cert = gosper_certificate(&ratio).ok_or(SumError::NotGosperSummable)?;
123    let cert_e = ratfunc_to_expr(pool, k, &cert);
124    let g = simp(pool, pool.mul(vec![term, cert_e]));
125    let mut log = DerivationLog::new();
126    log.push(RewriteStep::simple("gosper_indefinite", term, g));
127    Ok(DerivedExpr::with_log(g, log))
128}
129
130/// Definite sum `∑_{k=lo}^{hi} term(k)` when Gosper applies (upper bound inclusive).
131pub fn sum_definite(
132    term: ExprId,
133    k: ExprId,
134    lo: ExprId,
135    hi: ExprId,
136    pool: &ExprPool,
137) -> Result<DerivedExpr<ExprId>, SumError> {
138    let ind = sum_indefinite(term, k, pool)?;
139    let g = ind.value;
140    let one = pool.integer(1_i32);
141    let hi_p1 = simp(pool, pool.add(vec![hi, one]));
142
143    let mut m_upper = HashMap::new();
144    m_upper.insert(k, hi_p1);
145    let upper = simp(pool, subs(g, &m_upper, pool));
146
147    let mut m_lower = HashMap::new();
148    m_lower.insert(k, lo);
149    let lower = simp(pool, subs(g, &m_lower, pool));
150
151    let diff = simp(
152        pool,
153        pool.add(vec![upper, pool.mul(vec![lower, pool.integer(-1_i32)])]),
154    );
155    let mut log = DerivationLog::new();
156    log.push(RewriteStep::simple("gosper_definite_telescope", term, diff));
157    Ok(DerivedExpr::with_log(diff, log))
158}
159
160/// Witness `(F, G)` for Zeilberger/WZ-style telescoping in `k`:
161/// checks `F(n+1,k)-F(n,k) = G(n,k+1)-G(n,k)` after clearing denominators by cross-multiplication.
162///
163/// Requires `n`, `k` distinct symbols. Uses [`simplify`] and structural equality; dense normalization
164/// for general `binom`/`gamma` identities is not guaranteed without extra rewrite rules.
165#[derive(Clone, Debug)]
166pub struct WzPair {
167    pub f: ExprId,
168    pub g: ExprId,
169}
170
171pub fn verify_wz_pair(pair: &WzPair, n: ExprId, k: ExprId, pool: &ExprPool) -> bool {
172    let k1 = simp(pool, pool.add(vec![k, pool.integer(1_i32)]));
173    let n1 = simp(pool, pool.add(vec![n, pool.integer(1_i32)]));
174
175    let mut mn = HashMap::new();
176    mn.insert(n, n1);
177    let f_n1_k = simp(pool, subs(pair.f, &mn, pool));
178
179    let lhs = simp(
180        pool,
181        pool.add(vec![f_n1_k, pool.mul(vec![pair.f, pool.integer(-1_i32)])]),
182    );
183
184    let mut mk = HashMap::new();
185    mk.insert(k, k1);
186    let g_n_k1 = simp(pool, subs(pair.g, &mk, pool));
187
188    let rhs = simp(
189        pool,
190        pool.add(vec![g_n_k1, pool.mul(vec![pair.g, pool.integer(-1_i32)])]),
191    );
192
193    lhs == rhs
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::jit::eval_interp;
200    use crate::kernel::ExprId;
201    use crate::kernel::{Domain, ExprData};
202    use std::collections::HashMap;
203
204    fn eval_with_gamma(expr: ExprId, env: &HashMap<ExprId, f64>, pool: &ExprPool) -> Option<f64> {
205        match pool.get(expr) {
206            ExprData::Func { name, args } if name == "gamma" && args.len() == 1 => {
207                let x = eval_with_gamma(args[0], env, pool)?;
208                Some(rug::Float::with_val(53, x).gamma().to_f64())
209            }
210            ExprData::Add(args) => {
211                let mut sum = 0.0f64;
212                for &a in &args {
213                    sum += eval_with_gamma(a, env, pool)?;
214                }
215                Some(sum)
216            }
217            ExprData::Mul(args) => {
218                let mut prod = 1.0f64;
219                for &a in &args {
220                    prod *= eval_with_gamma(a, env, pool)?;
221                }
222                Some(prod)
223            }
224            ExprData::Pow { base, exp } => {
225                Some(eval_with_gamma(base, env, pool)?.powf(eval_with_gamma(exp, env, pool)?))
226            }
227            _ => eval_interp(expr, env, pool),
228        }
229    }
230
231    #[test]
232    fn indefinite_k_gamma_k_plus_1() {
233        let pool = ExprPool::new();
234        let k = pool.symbol("k", Domain::Real);
235        let gkp1 = pool.func("gamma", vec![pool.add(vec![k, pool.integer(1_i32)])]);
236        let term = simp(&pool, pool.mul(vec![k, gkp1]));
237        let r = sum_indefinite(term, k, &pool).expect("gosper");
238        assert!(pool.with(r.value, |d| matches!(
239            d,
240            ExprData::Func { .. } | ExprData::Mul(_)
241        )));
242    }
243
244    #[test]
245    fn definite_sum_kfactorial_telescope() {
246        let pool = ExprPool::new();
247        let k = pool.symbol("k", Domain::Real);
248        let n = pool.symbol("n", Domain::Real);
249        let zero = pool.integer(0_i32);
250        let gkp1 = pool.func("gamma", vec![pool.add(vec![k, pool.integer(1_i32)])]);
251        let term = simp(&pool, pool.mul(vec![k, gkp1]));
252        let s = sum_definite(term, k, zero, n, &pool).expect("definite");
253        let expected = simp(
254            &pool,
255            pool.add(vec![
256                pool.func("gamma", vec![pool.add(vec![n, pool.integer(2_i32)])]),
257                pool.integer(-1_i32),
258            ]),
259        );
260        for ni in 0..=8 {
261            let mut env = HashMap::new();
262            env.insert(n, ni as f64);
263            let sv = eval_with_gamma(s.value, &env, &pool).expect("sum eval");
264            let ev = eval_with_gamma(expected, &env, &pool).expect("expected eval");
265            assert!(
266                (sv - ev).abs() < 1e-5 * ev.abs().max(1.0),
267                "n={ni}: got {sv} want {ev}"
268            );
269        }
270    }
271
272    #[test]
273    fn wz_pair_zero_is_certificate() {
274        let pool = ExprPool::new();
275        let n = pool.symbol("n", Domain::Real);
276        let k = pool.symbol("k", Domain::Real);
277        let z = pool.integer(0_i32);
278        let pair = WzPair { f: z, g: z };
279        assert!(verify_wz_pair(&pair, n, k, &pool));
280    }
281}