1mod expr_ratio;
9mod gosper;
10mod poly_aux;
11mod product;
12mod ratfunc;
13mod recurrence;
14mod rsolve;
15
16pub use expr_ratio::hypergeom_ratio;
17pub use gosper::{gosper_certificate, gosper_normal_form};
18pub use product::{product_definite, product_indefinite, ProductError};
19pub use ratfunc::RatFunc;
20pub use recurrence::{
21 solve_linear_recurrence_homogeneous, LinearRecurrenceError, RecurrenceSolution,
22};
23pub use rsolve::{rsolve, RsolveError};
24
25use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
26use crate::kernel::subs::subs;
27use crate::kernel::{ExprId, ExprPool};
28use crate::matrix::normal_form::RatUniPoly;
29use crate::simplify::engine::simplify;
30use std::collections::HashMap;
31use std::fmt;
32
33fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
34 simplify(e, pool).value
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum SumError {
40 NotHypergeometric(String),
42 NotGosperSummable,
44 BoundSubstitution(String),
46}
47
48impl fmt::Display for SumError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 SumError::NotHypergeometric(s) => write!(f, "sum: not hypergeometric: {s}"),
52 SumError::NotGosperSummable => write!(f, "sum: term is not Gosper-summable"),
53 SumError::BoundSubstitution(s) => write!(f, "sum: bound substitution: {s}"),
54 }
55 }
56}
57
58impl std::error::Error for SumError {}
59
60impl crate::errors::AlkahestError for SumError {
61 fn code(&self) -> &'static str {
62 match self {
63 SumError::NotHypergeometric(_) => "E-SUM-001",
64 SumError::NotGosperSummable => "E-SUM-002",
65 SumError::BoundSubstitution(_) => "E-SUM-003",
66 }
67 }
68
69 fn remediation(&self) -> Option<&'static str> {
70 Some(
71 "supported indefinite sums are hypergeometric terms built from polynomials in k, products, and gamma(linear(k)); Zeilberger automation is partial — use verify_wz_pair for certificates",
72 )
73 }
74}
75
76fn rat_poly_to_expr(pool: &ExprPool, k: ExprId, p: &RatUniPoly) -> ExprId {
77 let mut terms: Vec<ExprId> = Vec::new();
78 for (deg, coeff) in p.coeffs.iter().enumerate() {
79 if coeff.is_zero() {
80 continue;
81 }
82 let coeff_q = coeff.clone();
83 let numer = coeff_q.numer();
84 let denom = coeff_q.denom();
85 let coeff_expr = if *denom == 1 {
86 pool.integer(numer.clone())
87 } else {
88 pool.rational(numer.clone(), denom.clone())
89 };
90 let pow_id = if deg == 0 {
91 coeff_expr
92 } else if deg == 1 {
93 pool.mul(vec![coeff_expr, k])
94 } else {
95 pool.mul(vec![coeff_expr, pool.pow(k, pool.integer(deg as i64))])
96 };
97 terms.push(pow_id);
98 }
99 match terms.len() {
100 0 => pool.integer(0_i32),
101 1 => terms[0],
102 _ => pool.add(terms),
103 }
104}
105
106fn ratfunc_to_expr(pool: &ExprPool, k: ExprId, r: &RatFunc) -> ExprId {
107 let num_e = rat_poly_to_expr(pool, k, &r.num);
108 if r.den.is_zero() || r.den.degree() == 0 && r.den.coeffs.is_empty() {
109 return num_e;
110 }
111 let den_e = rat_poly_to_expr(pool, k, &r.den);
112 pool.mul(vec![num_e, pool.pow(den_e, pool.integer(-1_i32))])
113}
114
115pub fn sum_indefinite(
117 term: ExprId,
118 k: ExprId,
119 pool: &ExprPool,
120) -> Result<DerivedExpr<ExprId>, SumError> {
121 let ratio = hypergeom_ratio(term, k, pool)?;
122 let cert = gosper_certificate(&ratio).ok_or(SumError::NotGosperSummable)?;
123 let cert_e = ratfunc_to_expr(pool, k, &cert);
124 let g = simp(pool, pool.mul(vec![term, cert_e]));
125 let mut log = DerivationLog::new();
126 log.push(RewriteStep::simple("gosper_indefinite", term, g));
127 Ok(DerivedExpr::with_log(g, log))
128}
129
130pub fn sum_definite(
132 term: ExprId,
133 k: ExprId,
134 lo: ExprId,
135 hi: ExprId,
136 pool: &ExprPool,
137) -> Result<DerivedExpr<ExprId>, SumError> {
138 let ind = sum_indefinite(term, k, pool)?;
139 let g = ind.value;
140 let one = pool.integer(1_i32);
141 let hi_p1 = simp(pool, pool.add(vec![hi, one]));
142
143 let mut m_upper = HashMap::new();
144 m_upper.insert(k, hi_p1);
145 let upper = simp(pool, subs(g, &m_upper, pool));
146
147 let mut m_lower = HashMap::new();
148 m_lower.insert(k, lo);
149 let lower = simp(pool, subs(g, &m_lower, pool));
150
151 let diff = simp(
152 pool,
153 pool.add(vec![upper, pool.mul(vec![lower, pool.integer(-1_i32)])]),
154 );
155 let mut log = DerivationLog::new();
156 log.push(RewriteStep::simple("gosper_definite_telescope", term, diff));
157 Ok(DerivedExpr::with_log(diff, log))
158}
159
160#[derive(Clone, Debug)]
166pub struct WzPair {
167 pub f: ExprId,
168 pub g: ExprId,
169}
170
171pub fn verify_wz_pair(pair: &WzPair, n: ExprId, k: ExprId, pool: &ExprPool) -> bool {
172 let k1 = simp(pool, pool.add(vec![k, pool.integer(1_i32)]));
173 let n1 = simp(pool, pool.add(vec![n, pool.integer(1_i32)]));
174
175 let mut mn = HashMap::new();
176 mn.insert(n, n1);
177 let f_n1_k = simp(pool, subs(pair.f, &mn, pool));
178
179 let lhs = simp(
180 pool,
181 pool.add(vec![f_n1_k, pool.mul(vec![pair.f, pool.integer(-1_i32)])]),
182 );
183
184 let mut mk = HashMap::new();
185 mk.insert(k, k1);
186 let g_n_k1 = simp(pool, subs(pair.g, &mk, pool));
187
188 let rhs = simp(
189 pool,
190 pool.add(vec![g_n_k1, pool.mul(vec![pair.g, pool.integer(-1_i32)])]),
191 );
192
193 lhs == rhs
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::jit::eval_interp;
200 use crate::kernel::ExprId;
201 use crate::kernel::{Domain, ExprData};
202 use std::collections::HashMap;
203
204 fn eval_with_gamma(expr: ExprId, env: &HashMap<ExprId, f64>, pool: &ExprPool) -> Option<f64> {
205 match pool.get(expr) {
206 ExprData::Func { name, args } if name == "gamma" && args.len() == 1 => {
207 let x = eval_with_gamma(args[0], env, pool)?;
208 Some(rug::Float::with_val(53, x).gamma().to_f64())
209 }
210 ExprData::Add(args) => {
211 let mut sum = 0.0f64;
212 for &a in &args {
213 sum += eval_with_gamma(a, env, pool)?;
214 }
215 Some(sum)
216 }
217 ExprData::Mul(args) => {
218 let mut prod = 1.0f64;
219 for &a in &args {
220 prod *= eval_with_gamma(a, env, pool)?;
221 }
222 Some(prod)
223 }
224 ExprData::Pow { base, exp } => {
225 Some(eval_with_gamma(base, env, pool)?.powf(eval_with_gamma(exp, env, pool)?))
226 }
227 _ => eval_interp(expr, env, pool),
228 }
229 }
230
231 #[test]
232 fn indefinite_k_gamma_k_plus_1() {
233 let pool = ExprPool::new();
234 let k = pool.symbol("k", Domain::Real);
235 let gkp1 = pool.func("gamma", vec![pool.add(vec![k, pool.integer(1_i32)])]);
236 let term = simp(&pool, pool.mul(vec![k, gkp1]));
237 let r = sum_indefinite(term, k, &pool).expect("gosper");
238 assert!(pool.with(r.value, |d| matches!(
239 d,
240 ExprData::Func { .. } | ExprData::Mul(_)
241 )));
242 }
243
244 #[test]
245 fn definite_sum_kfactorial_telescope() {
246 let pool = ExprPool::new();
247 let k = pool.symbol("k", Domain::Real);
248 let n = pool.symbol("n", Domain::Real);
249 let zero = pool.integer(0_i32);
250 let gkp1 = pool.func("gamma", vec![pool.add(vec![k, pool.integer(1_i32)])]);
251 let term = simp(&pool, pool.mul(vec![k, gkp1]));
252 let s = sum_definite(term, k, zero, n, &pool).expect("definite");
253 let expected = simp(
254 &pool,
255 pool.add(vec![
256 pool.func("gamma", vec![pool.add(vec![n, pool.integer(2_i32)])]),
257 pool.integer(-1_i32),
258 ]),
259 );
260 for ni in 0..=8 {
261 let mut env = HashMap::new();
262 env.insert(n, ni as f64);
263 let sv = eval_with_gamma(s.value, &env, &pool).expect("sum eval");
264 let ev = eval_with_gamma(expected, &env, &pool).expect("expected eval");
265 assert!(
266 (sv - ev).abs() < 1e-5 * ev.abs().max(1.0),
267 "n={ni}: got {sv} want {ev}"
268 );
269 }
270 }
271
272 #[test]
273 fn wz_pair_zero_is_certificate() {
274 let pool = ExprPool::new();
275 let n = pool.symbol("n", Domain::Real);
276 let k = pool.symbol("k", Domain::Real);
277 let z = pool.integer(0_i32);
278 let pair = WzPair { f: z, g: z };
279 assert!(verify_wz_pair(&pair, n, k, &pool));
280 }
281}