Skip to main content

alkahest_cas/sum/
product.rs

1//! Symbolic discrete products (∏) over ℚ(k) with ℤ-linear factorisation (V2-22).
2//!
3//! \(\prod_{k=m}^{n} q(k)\) for \(q\) a rational whose numerator/denominator split into
4//! linear factors over ℤ telescopes via \(\sum \Delta\logΓ(k+r) = \logΓ(n+r+1)-\logΓ(m+r)\)
5//! and integer leading-coefficient powers \(a^{e(n-m+1)}\).
6
7use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
8use crate::flint::{integer::FlintInteger, FlintPoly};
9use crate::kernel::{ExprData, ExprId, ExprPool};
10use crate::matrix::normal_form::RatUniPoly;
11use crate::poly::factor::UniPolyFactorization;
12use crate::poly::UniPoly;
13use crate::simplify::engine::simplify;
14use crate::sum::ratfunc::RatFunc;
15use rug::{Integer, Rational};
16use std::fmt;
17
18fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
19    simplify(e, pool).value
20}
21
22/// Errors raised by discrete product evaluation.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum ProductError {
25    /// Term is not a supported rational function of the index.
26    NotRationalTerm(String),
27    /// ℤ-factorisation failed.
28    Factorization,
29    /// An irreducible ℤ-factor has degree > 1.
30    NonLinearFactor,
31    /// Bound substitution failed (mirrors summation).
32    BoundSubstitution(String),
33}
34
35impl fmt::Display for ProductError {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            ProductError::NotRationalTerm(s) => write!(f, "product: unsupported term shape: {s}"),
39            ProductError::Factorization => write!(f, "product: polynomial factorisation failed"),
40            ProductError::NonLinearFactor => {
41                write!(
42                    f,
43                    "product: term has a non-linear irreducible factor over ℤ"
44                )
45            }
46            ProductError::BoundSubstitution(s) => write!(f, "product: bound substitution: {s}"),
47        }
48    }
49}
50
51impl std::error::Error for ProductError {}
52
53impl crate::errors::AlkahestError for ProductError {
54    fn code(&self) -> &'static str {
55        match self {
56            ProductError::NotRationalTerm(_) => "E-PROD-001",
57            ProductError::Factorization => "E-PROD-002",
58            ProductError::NonLinearFactor => "E-PROD-003",
59            ProductError::BoundSubstitution(_) => "E-PROD-004",
60        }
61    }
62
63    fn remediation(&self) -> Option<&'static str> {
64        Some("supported: ∏ q(k) for q ∈ ℚ(k) factoring into ℤ-linear terms; no irreducible quadratics in k")
65    }
66}
67
68fn rational_to_expr(pool: &ExprPool, r: &Rational) -> ExprId {
69    let n = r.numer().clone();
70    let d = r.denom().clone();
71    if d == 1 {
72        pool.integer(n)
73    } else {
74        pool.rational(n, d)
75    }
76}
77
78fn ratuni_poly_to_univ(p: &RatUniPoly, var: ExprId) -> Result<UniPoly, ProductError> {
79    if p.is_zero() {
80        return Ok(UniPoly::zero(var));
81    }
82    let mut lcm = Integer::from(1u32);
83    for c in &p.coeffs {
84        if !c.is_zero() {
85            lcm = lcm.lcm(&c.denom().clone());
86        }
87    }
88    let scale = Rational::from(&lcm);
89    let mut max_i = p.coeffs.len().saturating_sub(1);
90    let mut rug_coeffs = vec![Integer::from(0); max_i + 1];
91    for (i, c) in p.coeffs.iter().enumerate() {
92        if c.is_zero() {
93            continue;
94        }
95        let scaled = c.clone() * scale.clone();
96        if *scaled.denom() != 1 {
97            return Err(ProductError::NotRationalTerm(
98                "could not clear denominators".into(),
99            ));
100        }
101        rug_coeffs[i] = scaled.numer().clone();
102        max_i = max_i.max(i);
103    }
104    rug_coeffs.truncate(max_i + 1);
105    let coeffs: Vec<FlintInteger> = rug_coeffs.iter().map(FlintInteger::from_rug).collect();
106    let mut fp = FlintPoly::new();
107    for (i, ci) in coeffs.iter().enumerate() {
108        if !ci.to_rug().is_zero() {
109            fp.set_coeff_flint(i, ci);
110        }
111    }
112    Ok(UniPoly { var, coeffs: fp })
113}
114
115fn expr_to_ratfunc(term: ExprId, k: ExprId, pool: &ExprPool) -> Result<RatFunc, ProductError> {
116    let term = simp(pool, term);
117    if term == k {
118        return Ok(RatFunc {
119            num: RatUniPoly::x(),
120            den: RatUniPoly::one(),
121        }
122        .normalize());
123    }
124    match pool.get(term).clone() {
125        ExprData::Integer(n) => Ok(RatFunc::scalar(Rational::from(&n.0))),
126        ExprData::Rational(br) => Ok(RatFunc::scalar(br.0.clone())),
127        ExprData::Symbol { name, .. } => {
128            if term == k {
129                Ok(RatFunc {
130                    num: RatUniPoly::x(),
131                    den: RatUniPoly::one(),
132                }
133                .normalize())
134            } else {
135                Err(ProductError::NotRationalTerm(format!(
136                    "free symbol `{name}` — term must be unary rational in k",
137                )))
138            }
139        }
140        ExprData::Add(_) => {
141            let p = UniPoly::from_symbolic_clear_denoms(term, k, pool).map_err(|e| {
142                ProductError::NotRationalTerm(format!("polynomial expected in k: {e}"))
143            })?;
144            let coeffs: Vec<Rational> = p.coefficients().into_iter().map(Rational::from).collect();
145            Ok(RatFunc::from_poly(RatUniPoly { coeffs }.trim()).normalize())
146        }
147        ExprData::Pow { base, exp } => {
148            let e_i = match pool.get(exp) {
149                ExprData::Integer(n) => n
150                    .0
151                    .to_i32()
152                    .ok_or_else(|| ProductError::NotRationalTerm("exponent out of range".into()))?,
153                _ => {
154                    return Err(ProductError::NotRationalTerm(
155                        "non-constant exponent".into(),
156                    ))
157                }
158            };
159            let base_rf = expr_to_ratfunc(base, k, pool)?;
160            if e_i >= 0 {
161                let ee = u32::try_from(e_i)
162                    .map_err(|_| ProductError::NotRationalTerm("exponent overflow".into()))?;
163                let mut acc = RatFunc::one();
164                for _ in 0..ee {
165                    acc = acc.mul_ratfunc(&base_rf);
166                }
167                Ok(acc.normalize())
168            } else {
169                let inv = base_rf
170                    .inv()
171                    .ok_or_else(|| ProductError::NotRationalTerm("invert zero".into()))?;
172                let ee =
173                    u32::try_from(-e_i).map_err(|_| ProductError::NotRationalTerm("exp".into()))?;
174                let mut acc = RatFunc::one();
175                for _ in 0..ee {
176                    acc = acc.mul_ratfunc(&inv);
177                }
178                Ok(acc.normalize())
179            }
180        }
181        ExprData::Mul(args) => {
182            let mut acc = RatFunc::one();
183            for &a in &args {
184                acc = acc.mul_ratfunc(&expr_to_ratfunc(a, k, pool)?);
185            }
186            Ok(acc.normalize())
187        }
188        _ => Err(ProductError::NotRationalTerm(
189            "expression is not a rational function of k with integer poly factors".into(),
190        )),
191    }
192}
193
194fn factor_univ(p: &UniPoly) -> Result<UniPolyFactorization, ProductError> {
195    p.factor_z().map_err(|_| ProductError::Factorization)
196}
197
198/// ∏ fac over one side of a rational (numerator or denominator).
199fn definite_side_from_factorization(
200    pool: &ExprPool,
201    fac: &UniPolyFactorization,
202    lo: ExprId,
203    hi: ExprId,
204    delta_n: ExprId,
205) -> Result<ExprId, ProductError> {
206    let mut parts: Vec<ExprId> = Vec::new();
207    let u = &fac.unit;
208    if u.to_i32() == Some(-1) {
209        parts.push(pool.pow(pool.integer(-1_i32), delta_n));
210    } else if u.to_i32() != Some(1) {
211        parts.push(pool.pow(pool.integer(u.clone()), delta_n));
212    }
213
214    for (fact, ee) in &fac.factors {
215        let expo = *ee as i64;
216        let d = fact.degree().max(0) as usize;
217        match d {
218            0 => {
219                let cz = match fact.coefficients().first() {
220                    Some(c) => c.clone(),
221                    None => Integer::from(1),
222                };
223                if cz == 1 {
224                    continue;
225                }
226                if cz == -1 {
227                    if expo.rem_euclid(2) != 0 {
228                        parts.push(pool.pow(pool.integer(-1_i32), delta_n));
229                    }
230                    continue;
231                }
232                let exp_e = pool.integer(expo);
233                parts.push(pool.pow(
234                    pool.integer(cz.clone()),
235                    simp(pool, pool.mul(vec![delta_n, exp_e])),
236                ));
237            }
238            1 => {
239                let coeffs = fact.coefficients();
240                let aa = coeffs.get(1).cloned().unwrap_or_else(|| Integer::from(0));
241                let bb = coeffs.first().cloned().unwrap_or_else(|| Integer::from(0));
242                if aa == 0 {
243                    return Err(ProductError::NotRationalTerm("degenerate linear".into()));
244                }
245                let c_rat = Rational::from((bb, aa.clone()));
246                let one = Rational::from(1);
247                let hi_shift = rational_to_expr(pool, &(one.clone() + c_rat.clone()));
248                let lo_shift = rational_to_expr(pool, &c_rat);
249                let lead_exp = simp(pool, pool.mul(vec![delta_n, pool.integer(expo)]));
250                let gh = pool.func("gamma", vec![simp(pool, pool.add(vec![hi, hi_shift]))]);
251                let gl = pool.func("gamma", vec![simp(pool, pool.add(vec![lo, lo_shift]))]);
252                let ratio = simp(pool, pool.mul(vec![gh, pool.pow(gl, pool.integer(-1_i32))]));
253                parts.push(pool.pow(pool.integer(aa.clone()), lead_exp));
254                if expo != 0 {
255                    parts.push(pool.pow(ratio, pool.integer(expo)));
256                }
257            }
258            _ => return Err(ProductError::NonLinearFactor),
259        }
260    }
261
262    match parts.len() {
263        0 => Ok(pool.integer(1_i32)),
264        1 => Ok(simp(pool, parts[0])),
265        _ => Ok(simp(pool, pool.mul(parts))),
266    }
267}
268
269/// Indefinite multiplicative antiderivative for one polynomial side.
270fn indefinite_side_from_factorization(
271    pool: &ExprPool,
272    fac: &UniPolyFactorization,
273    k: ExprId,
274) -> Result<ExprId, ProductError> {
275    let mut parts: Vec<ExprId> = Vec::new();
276    let u = &fac.unit;
277    if u.to_i32() == Some(-1) {
278        parts.push(pool.pow(pool.integer(-1_i32), k));
279    } else if u.to_i32() != Some(1) {
280        parts.push(pool.pow(pool.integer(u.clone()), k));
281    }
282
283    for (fact, ee) in &fac.factors {
284        let expo = *ee as i64;
285        let d = fact.degree().max(0) as usize;
286        match d {
287            0 => {
288                let cz = match fact.coefficients().first() {
289                    Some(c) => c.clone(),
290                    None => Integer::from(1),
291                };
292                if cz == 1 {
293                    continue;
294                }
295                if cz == -1 {
296                    if expo.rem_euclid(2) != 0 {
297                        parts.push(pool.pow(pool.integer(-1_i32), k));
298                    }
299                    continue;
300                }
301                let exp_e = pool.integer(expo);
302                parts.push(pool.pow(
303                    pool.integer(cz.clone()),
304                    simp(pool, pool.mul(vec![k, exp_e])),
305                ));
306            }
307            1 => {
308                let coeffs = fact.coefficients();
309                let aa = coeffs.get(1).cloned().unwrap_or_else(|| Integer::from(0));
310                let bb = coeffs.first().cloned().unwrap_or_else(|| Integer::from(0));
311                if aa == 0 {
312                    return Err(ProductError::NotRationalTerm("degenerate linear".into()));
313                }
314                let c_rat = Rational::from((bb, aa.clone()));
315                let lo_shift = rational_to_expr(pool, &c_rat);
316                let gamma_k = pool.func("gamma", vec![simp(pool, pool.add(vec![k, lo_shift]))]);
317                let lead_exp_k = simp(pool, pool.mul(vec![k, pool.integer(expo)]));
318                parts.push(pool.pow(pool.integer(aa), lead_exp_k));
319                parts.push(pool.pow(gamma_k, pool.integer(expo)));
320            }
321            _ => return Err(ProductError::NonLinearFactor),
322        }
323    }
324
325    match parts.len() {
326        0 => Ok(pool.integer(1_i32)),
327        1 => Ok(simp(pool, parts[0])),
328        _ => Ok(simp(pool, pool.mul(parts))),
329    }
330}
331
332/// ∏_{k=lo}^{hi} term(k); inclusive `[lo, hi]`.
333pub fn product_definite(
334    term: ExprId,
335    k: ExprId,
336    lo: ExprId,
337    hi: ExprId,
338    pool: &ExprPool,
339) -> Result<DerivedExpr<ExprId>, ProductError> {
340    let rf = expr_to_ratfunc(term, k, pool)?;
341    if rf.num.is_zero() {
342        let z = simp(pool, pool.integer(0_i32));
343        let mut log = DerivationLog::new();
344        log.push(RewriteStep::simple("product_definite_zero", term, z));
345        return Ok(DerivedExpr::with_log(z, log));
346    }
347
348    let univ_n = ratuni_poly_to_univ(&rf.num, k)?;
349    let univ_d = ratuni_poly_to_univ(&rf.den, k)?;
350    let fac_n = factor_univ(&univ_n)?;
351    let fac_d = factor_univ(&univ_d)?;
352
353    let one = pool.integer(1_i32);
354    let delta_n = simp(
355        pool,
356        pool.add(vec![hi, pool.mul(vec![lo, pool.integer(-1)]), one]),
357    );
358
359    let top = definite_side_from_factorization(pool, &fac_n, lo, hi, delta_n)?;
360    let bot = definite_side_from_factorization(pool, &fac_d, lo, hi, delta_n)?;
361    let q = simp(
362        pool,
363        pool.mul(vec![top, pool.pow(bot, pool.integer(-1_i32))]),
364    );
365
366    let mut log = DerivationLog::new();
367    log.push(RewriteStep::simple("product_definite", term, q));
368    Ok(DerivedExpr::with_log(q, log))
369}
370
371/// Witness `Z(k)` with \(Z(k+1)/Z(k)=term(k)\) (after canonical simplification).
372pub fn product_indefinite(
373    term: ExprId,
374    k: ExprId,
375    pool: &ExprPool,
376) -> Result<DerivedExpr<ExprId>, ProductError> {
377    let rf = expr_to_ratfunc(term, k, pool)?;
378    if rf.num.is_zero() {
379        return Err(ProductError::NotRationalTerm(
380            "indefinite product of zero unsupported".into(),
381        ));
382    }
383    let fac_n = factor_univ(&ratuni_poly_to_univ(&rf.num, k)?)?;
384    let fac_d = factor_univ(&ratuni_poly_to_univ(&rf.den, k)?)?;
385
386    let top = indefinite_side_from_factorization(pool, &fac_n, k)?;
387    let bot = indefinite_side_from_factorization(pool, &fac_d, k)?;
388
389    let q = simp(
390        pool,
391        pool.mul(vec![top, pool.pow(bot, pool.integer(-1_i32))]),
392    );
393
394    let mut log = DerivationLog::new();
395    log.push(RewriteStep::simple("product_indefinite", term, q));
396    Ok(DerivedExpr::with_log(q, log))
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::jit::eval_interp;
403    use crate::kernel::Domain;
404    use rug::Float;
405    use std::collections::HashMap;
406
407    fn gamma64(x: f64) -> f64 {
408        Float::with_val(53, x).gamma().to_f64()
409    }
410
411    fn eval_g(expr: ExprId, env: &HashMap<ExprId, f64>, pool: &ExprPool) -> Option<f64> {
412        match pool.get(expr).clone() {
413            ExprData::Func { name, args } if name == "gamma" && args.len() == 1 => {
414                Some(gamma64(eval_g(args[0], env, pool)?))
415            }
416            ExprData::Add(args) => {
417                let mut s = 0.0f64;
418                for &a in &args {
419                    s += eval_g(a, env, pool)?;
420                }
421                Some(s)
422            }
423            ExprData::Mul(args) => {
424                let mut p = 1.0f64;
425                for a in args {
426                    p *= eval_g(a, env, pool)?;
427                }
428                Some(p)
429            }
430            ExprData::Pow { base, exp } => {
431                Some(eval_g(base, env, pool)?.powf(eval_interp(exp, env, pool)?))
432            }
433            _ => eval_interp(expr, env, pool),
434        }
435    }
436
437    #[test]
438    fn product_linear_k_matches_factorial_gamma() {
439        let pool = ExprPool::new();
440        let k = pool.symbol("k", Domain::Real);
441        let n = pool.symbol("n", Domain::Real);
442        let lo = pool.integer(1_i32);
443        let p = product_definite(k, k, lo, n, &pool).expect("prod");
444        let want = simp(
445            &pool,
446            pool.func(
447                "gamma",
448                vec![simp(&pool, pool.add(vec![n, pool.integer(1)]))],
449            ),
450        );
451        for ni in 2..14 {
452            let mut env = HashMap::new();
453            env.insert(n, ni as f64);
454            let pv = eval_g(p.value, &env, &pool).unwrap();
455            let wv = eval_g(want, &env, &pool).unwrap();
456            assert!(
457                (pv - wv).abs() < 1e-6 * wv.abs().max(1.0),
458                "n={ni}: pv={pv} wv={wv}"
459            );
460        }
461    }
462
463    #[test]
464    fn wallis_partial_product_ratios() {
465        let pool = ExprPool::new();
466        let k = pool.symbol("k", Domain::Real);
467        let n = pool.symbol("n", Domain::Real);
468        let two = pool.integer(2_i32);
469        let km1 = simp(&pool, pool.add(vec![k, pool.integer(-1)]));
470        let kp1 = simp(&pool, pool.add(vec![k, pool.integer(1)]));
471        let k2 = simp(&pool, pool.pow(k, pool.integer(2)));
472        let term = simp(
473            &pool,
474            pool.mul(vec![
475                simp(&pool, pool.mul(vec![km1, kp1])),
476                pool.pow(k2, pool.integer(-1)),
477            ]),
478        );
479
480        let p = product_definite(term, k, two, n, &pool).expect("wallis");
481        for ni in 3..36 {
482            let mut env = HashMap::new();
483            env.insert(n, ni as f64);
484            let pv = eval_g(p.value, &env, &pool).unwrap();
485            let want = (ni + 1) as f64 / (2.0 * ni as f64);
486            assert!(
487                (pv - want).abs() < 1e-5 * want.max(1.0),
488                "n={}: got {}",
489                ni,
490                pv
491            );
492        }
493    }
494}