1use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
8use crate::flint::{integer::FlintInteger, FlintPoly};
9use crate::kernel::{ExprData, ExprId, ExprPool};
10use crate::matrix::normal_form::RatUniPoly;
11use crate::poly::factor::UniPolyFactorization;
12use crate::poly::UniPoly;
13use crate::simplify::engine::simplify;
14use crate::sum::ratfunc::RatFunc;
15use rug::{Integer, Rational};
16use std::fmt;
17
18fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
19 simplify(e, pool).value
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum ProductError {
25 NotRationalTerm(String),
27 Factorization,
29 NonLinearFactor,
31 BoundSubstitution(String),
33}
34
35impl fmt::Display for ProductError {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 ProductError::NotRationalTerm(s) => write!(f, "product: unsupported term shape: {s}"),
39 ProductError::Factorization => write!(f, "product: polynomial factorisation failed"),
40 ProductError::NonLinearFactor => {
41 write!(
42 f,
43 "product: term has a non-linear irreducible factor over ℤ"
44 )
45 }
46 ProductError::BoundSubstitution(s) => write!(f, "product: bound substitution: {s}"),
47 }
48 }
49}
50
51impl std::error::Error for ProductError {}
52
53impl crate::errors::AlkahestError for ProductError {
54 fn code(&self) -> &'static str {
55 match self {
56 ProductError::NotRationalTerm(_) => "E-PROD-001",
57 ProductError::Factorization => "E-PROD-002",
58 ProductError::NonLinearFactor => "E-PROD-003",
59 ProductError::BoundSubstitution(_) => "E-PROD-004",
60 }
61 }
62
63 fn remediation(&self) -> Option<&'static str> {
64 Some("supported: ∏ q(k) for q ∈ ℚ(k) factoring into ℤ-linear terms; no irreducible quadratics in k")
65 }
66}
67
68fn rational_to_expr(pool: &ExprPool, r: &Rational) -> ExprId {
69 let n = r.numer().clone();
70 let d = r.denom().clone();
71 if d == 1 {
72 pool.integer(n)
73 } else {
74 pool.rational(n, d)
75 }
76}
77
78fn ratuni_poly_to_univ(p: &RatUniPoly, var: ExprId) -> Result<UniPoly, ProductError> {
79 if p.is_zero() {
80 return Ok(UniPoly::zero(var));
81 }
82 let mut lcm = Integer::from(1u32);
83 for c in &p.coeffs {
84 if !c.is_zero() {
85 lcm = lcm.lcm(&c.denom().clone());
86 }
87 }
88 let scale = Rational::from(&lcm);
89 let mut max_i = p.coeffs.len().saturating_sub(1);
90 let mut rug_coeffs = vec![Integer::from(0); max_i + 1];
91 for (i, c) in p.coeffs.iter().enumerate() {
92 if c.is_zero() {
93 continue;
94 }
95 let scaled = c.clone() * scale.clone();
96 if *scaled.denom() != 1 {
97 return Err(ProductError::NotRationalTerm(
98 "could not clear denominators".into(),
99 ));
100 }
101 rug_coeffs[i] = scaled.numer().clone();
102 max_i = max_i.max(i);
103 }
104 rug_coeffs.truncate(max_i + 1);
105 let coeffs: Vec<FlintInteger> = rug_coeffs.iter().map(FlintInteger::from_rug).collect();
106 let mut fp = FlintPoly::new();
107 for (i, ci) in coeffs.iter().enumerate() {
108 if !ci.to_rug().is_zero() {
109 fp.set_coeff_flint(i, ci);
110 }
111 }
112 Ok(UniPoly { var, coeffs: fp })
113}
114
115fn expr_to_ratfunc(term: ExprId, k: ExprId, pool: &ExprPool) -> Result<RatFunc, ProductError> {
116 let term = simp(pool, term);
117 if term == k {
118 return Ok(RatFunc {
119 num: RatUniPoly::x(),
120 den: RatUniPoly::one(),
121 }
122 .normalize());
123 }
124 match pool.get(term).clone() {
125 ExprData::Integer(n) => Ok(RatFunc::scalar(Rational::from(&n.0))),
126 ExprData::Rational(br) => Ok(RatFunc::scalar(br.0.clone())),
127 ExprData::Symbol { name, .. } => {
128 if term == k {
129 Ok(RatFunc {
130 num: RatUniPoly::x(),
131 den: RatUniPoly::one(),
132 }
133 .normalize())
134 } else {
135 Err(ProductError::NotRationalTerm(format!(
136 "free symbol `{name}` — term must be unary rational in k",
137 )))
138 }
139 }
140 ExprData::Add(_) => {
141 let p = UniPoly::from_symbolic_clear_denoms(term, k, pool).map_err(|e| {
142 ProductError::NotRationalTerm(format!("polynomial expected in k: {e}"))
143 })?;
144 let coeffs: Vec<Rational> = p.coefficients().into_iter().map(Rational::from).collect();
145 Ok(RatFunc::from_poly(RatUniPoly { coeffs }.trim()).normalize())
146 }
147 ExprData::Pow { base, exp } => {
148 let e_i = match pool.get(exp) {
149 ExprData::Integer(n) => n
150 .0
151 .to_i32()
152 .ok_or_else(|| ProductError::NotRationalTerm("exponent out of range".into()))?,
153 _ => {
154 return Err(ProductError::NotRationalTerm(
155 "non-constant exponent".into(),
156 ))
157 }
158 };
159 let base_rf = expr_to_ratfunc(base, k, pool)?;
160 if e_i >= 0 {
161 let ee = u32::try_from(e_i)
162 .map_err(|_| ProductError::NotRationalTerm("exponent overflow".into()))?;
163 let mut acc = RatFunc::one();
164 for _ in 0..ee {
165 acc = acc.mul_ratfunc(&base_rf);
166 }
167 Ok(acc.normalize())
168 } else {
169 let inv = base_rf
170 .inv()
171 .ok_or_else(|| ProductError::NotRationalTerm("invert zero".into()))?;
172 let ee =
173 u32::try_from(-e_i).map_err(|_| ProductError::NotRationalTerm("exp".into()))?;
174 let mut acc = RatFunc::one();
175 for _ in 0..ee {
176 acc = acc.mul_ratfunc(&inv);
177 }
178 Ok(acc.normalize())
179 }
180 }
181 ExprData::Mul(args) => {
182 let mut acc = RatFunc::one();
183 for &a in &args {
184 acc = acc.mul_ratfunc(&expr_to_ratfunc(a, k, pool)?);
185 }
186 Ok(acc.normalize())
187 }
188 _ => Err(ProductError::NotRationalTerm(
189 "expression is not a rational function of k with integer poly factors".into(),
190 )),
191 }
192}
193
194fn factor_univ(p: &UniPoly) -> Result<UniPolyFactorization, ProductError> {
195 p.factor_z().map_err(|_| ProductError::Factorization)
196}
197
198fn definite_side_from_factorization(
200 pool: &ExprPool,
201 fac: &UniPolyFactorization,
202 lo: ExprId,
203 hi: ExprId,
204 delta_n: ExprId,
205) -> Result<ExprId, ProductError> {
206 let mut parts: Vec<ExprId> = Vec::new();
207 let u = &fac.unit;
208 if u.to_i32() == Some(-1) {
209 parts.push(pool.pow(pool.integer(-1_i32), delta_n));
210 } else if u.to_i32() != Some(1) {
211 parts.push(pool.pow(pool.integer(u.clone()), delta_n));
212 }
213
214 for (fact, ee) in &fac.factors {
215 let expo = *ee as i64;
216 let d = fact.degree().max(0) as usize;
217 match d {
218 0 => {
219 let cz = match fact.coefficients().first() {
220 Some(c) => c.clone(),
221 None => Integer::from(1),
222 };
223 if cz == 1 {
224 continue;
225 }
226 if cz == -1 {
227 if expo.rem_euclid(2) != 0 {
228 parts.push(pool.pow(pool.integer(-1_i32), delta_n));
229 }
230 continue;
231 }
232 let exp_e = pool.integer(expo);
233 parts.push(pool.pow(
234 pool.integer(cz.clone()),
235 simp(pool, pool.mul(vec![delta_n, exp_e])),
236 ));
237 }
238 1 => {
239 let coeffs = fact.coefficients();
240 let aa = coeffs.get(1).cloned().unwrap_or_else(|| Integer::from(0));
241 let bb = coeffs.first().cloned().unwrap_or_else(|| Integer::from(0));
242 if aa == 0 {
243 return Err(ProductError::NotRationalTerm("degenerate linear".into()));
244 }
245 let c_rat = Rational::from((bb, aa.clone()));
246 let one = Rational::from(1);
247 let hi_shift = rational_to_expr(pool, &(one.clone() + c_rat.clone()));
248 let lo_shift = rational_to_expr(pool, &c_rat);
249 let lead_exp = simp(pool, pool.mul(vec![delta_n, pool.integer(expo)]));
250 let gh = pool.func("gamma", vec![simp(pool, pool.add(vec![hi, hi_shift]))]);
251 let gl = pool.func("gamma", vec![simp(pool, pool.add(vec![lo, lo_shift]))]);
252 let ratio = simp(pool, pool.mul(vec![gh, pool.pow(gl, pool.integer(-1_i32))]));
253 parts.push(pool.pow(pool.integer(aa.clone()), lead_exp));
254 if expo != 0 {
255 parts.push(pool.pow(ratio, pool.integer(expo)));
256 }
257 }
258 _ => return Err(ProductError::NonLinearFactor),
259 }
260 }
261
262 match parts.len() {
263 0 => Ok(pool.integer(1_i32)),
264 1 => Ok(simp(pool, parts[0])),
265 _ => Ok(simp(pool, pool.mul(parts))),
266 }
267}
268
269fn indefinite_side_from_factorization(
271 pool: &ExprPool,
272 fac: &UniPolyFactorization,
273 k: ExprId,
274) -> Result<ExprId, ProductError> {
275 let mut parts: Vec<ExprId> = Vec::new();
276 let u = &fac.unit;
277 if u.to_i32() == Some(-1) {
278 parts.push(pool.pow(pool.integer(-1_i32), k));
279 } else if u.to_i32() != Some(1) {
280 parts.push(pool.pow(pool.integer(u.clone()), k));
281 }
282
283 for (fact, ee) in &fac.factors {
284 let expo = *ee as i64;
285 let d = fact.degree().max(0) as usize;
286 match d {
287 0 => {
288 let cz = match fact.coefficients().first() {
289 Some(c) => c.clone(),
290 None => Integer::from(1),
291 };
292 if cz == 1 {
293 continue;
294 }
295 if cz == -1 {
296 if expo.rem_euclid(2) != 0 {
297 parts.push(pool.pow(pool.integer(-1_i32), k));
298 }
299 continue;
300 }
301 let exp_e = pool.integer(expo);
302 parts.push(pool.pow(
303 pool.integer(cz.clone()),
304 simp(pool, pool.mul(vec![k, exp_e])),
305 ));
306 }
307 1 => {
308 let coeffs = fact.coefficients();
309 let aa = coeffs.get(1).cloned().unwrap_or_else(|| Integer::from(0));
310 let bb = coeffs.first().cloned().unwrap_or_else(|| Integer::from(0));
311 if aa == 0 {
312 return Err(ProductError::NotRationalTerm("degenerate linear".into()));
313 }
314 let c_rat = Rational::from((bb, aa.clone()));
315 let lo_shift = rational_to_expr(pool, &c_rat);
316 let gamma_k = pool.func("gamma", vec![simp(pool, pool.add(vec![k, lo_shift]))]);
317 let lead_exp_k = simp(pool, pool.mul(vec![k, pool.integer(expo)]));
318 parts.push(pool.pow(pool.integer(aa), lead_exp_k));
319 parts.push(pool.pow(gamma_k, pool.integer(expo)));
320 }
321 _ => return Err(ProductError::NonLinearFactor),
322 }
323 }
324
325 match parts.len() {
326 0 => Ok(pool.integer(1_i32)),
327 1 => Ok(simp(pool, parts[0])),
328 _ => Ok(simp(pool, pool.mul(parts))),
329 }
330}
331
332pub fn product_definite(
334 term: ExprId,
335 k: ExprId,
336 lo: ExprId,
337 hi: ExprId,
338 pool: &ExprPool,
339) -> Result<DerivedExpr<ExprId>, ProductError> {
340 let rf = expr_to_ratfunc(term, k, pool)?;
341 if rf.num.is_zero() {
342 let z = simp(pool, pool.integer(0_i32));
343 let mut log = DerivationLog::new();
344 log.push(RewriteStep::simple("product_definite_zero", term, z));
345 return Ok(DerivedExpr::with_log(z, log));
346 }
347
348 let univ_n = ratuni_poly_to_univ(&rf.num, k)?;
349 let univ_d = ratuni_poly_to_univ(&rf.den, k)?;
350 let fac_n = factor_univ(&univ_n)?;
351 let fac_d = factor_univ(&univ_d)?;
352
353 let one = pool.integer(1_i32);
354 let delta_n = simp(
355 pool,
356 pool.add(vec![hi, pool.mul(vec![lo, pool.integer(-1)]), one]),
357 );
358
359 let top = definite_side_from_factorization(pool, &fac_n, lo, hi, delta_n)?;
360 let bot = definite_side_from_factorization(pool, &fac_d, lo, hi, delta_n)?;
361 let q = simp(
362 pool,
363 pool.mul(vec![top, pool.pow(bot, pool.integer(-1_i32))]),
364 );
365
366 let mut log = DerivationLog::new();
367 log.push(RewriteStep::simple("product_definite", term, q));
368 Ok(DerivedExpr::with_log(q, log))
369}
370
371pub fn product_indefinite(
373 term: ExprId,
374 k: ExprId,
375 pool: &ExprPool,
376) -> Result<DerivedExpr<ExprId>, ProductError> {
377 let rf = expr_to_ratfunc(term, k, pool)?;
378 if rf.num.is_zero() {
379 return Err(ProductError::NotRationalTerm(
380 "indefinite product of zero unsupported".into(),
381 ));
382 }
383 let fac_n = factor_univ(&ratuni_poly_to_univ(&rf.num, k)?)?;
384 let fac_d = factor_univ(&ratuni_poly_to_univ(&rf.den, k)?)?;
385
386 let top = indefinite_side_from_factorization(pool, &fac_n, k)?;
387 let bot = indefinite_side_from_factorization(pool, &fac_d, k)?;
388
389 let q = simp(
390 pool,
391 pool.mul(vec![top, pool.pow(bot, pool.integer(-1_i32))]),
392 );
393
394 let mut log = DerivationLog::new();
395 log.push(RewriteStep::simple("product_indefinite", term, q));
396 Ok(DerivedExpr::with_log(q, log))
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::jit::eval_interp;
403 use crate::kernel::Domain;
404 use rug::Float;
405 use std::collections::HashMap;
406
407 fn gamma64(x: f64) -> f64 {
408 Float::with_val(53, x).gamma().to_f64()
409 }
410
411 fn eval_g(expr: ExprId, env: &HashMap<ExprId, f64>, pool: &ExprPool) -> Option<f64> {
412 match pool.get(expr).clone() {
413 ExprData::Func { name, args } if name == "gamma" && args.len() == 1 => {
414 Some(gamma64(eval_g(args[0], env, pool)?))
415 }
416 ExprData::Add(args) => {
417 let mut s = 0.0f64;
418 for &a in &args {
419 s += eval_g(a, env, pool)?;
420 }
421 Some(s)
422 }
423 ExprData::Mul(args) => {
424 let mut p = 1.0f64;
425 for a in args {
426 p *= eval_g(a, env, pool)?;
427 }
428 Some(p)
429 }
430 ExprData::Pow { base, exp } => {
431 Some(eval_g(base, env, pool)?.powf(eval_interp(exp, env, pool)?))
432 }
433 _ => eval_interp(expr, env, pool),
434 }
435 }
436
437 #[test]
438 fn product_linear_k_matches_factorial_gamma() {
439 let pool = ExprPool::new();
440 let k = pool.symbol("k", Domain::Real);
441 let n = pool.symbol("n", Domain::Real);
442 let lo = pool.integer(1_i32);
443 let p = product_definite(k, k, lo, n, &pool).expect("prod");
444 let want = simp(
445 &pool,
446 pool.func(
447 "gamma",
448 vec![simp(&pool, pool.add(vec![n, pool.integer(1)]))],
449 ),
450 );
451 for ni in 2..14 {
452 let mut env = HashMap::new();
453 env.insert(n, ni as f64);
454 let pv = eval_g(p.value, &env, &pool).unwrap();
455 let wv = eval_g(want, &env, &pool).unwrap();
456 assert!(
457 (pv - wv).abs() < 1e-6 * wv.abs().max(1.0),
458 "n={ni}: pv={pv} wv={wv}"
459 );
460 }
461 }
462
463 #[test]
464 fn wallis_partial_product_ratios() {
465 let pool = ExprPool::new();
466 let k = pool.symbol("k", Domain::Real);
467 let n = pool.symbol("n", Domain::Real);
468 let two = pool.integer(2_i32);
469 let km1 = simp(&pool, pool.add(vec![k, pool.integer(-1)]));
470 let kp1 = simp(&pool, pool.add(vec![k, pool.integer(1)]));
471 let k2 = simp(&pool, pool.pow(k, pool.integer(2)));
472 let term = simp(
473 &pool,
474 pool.mul(vec![
475 simp(&pool, pool.mul(vec![km1, kp1])),
476 pool.pow(k2, pool.integer(-1)),
477 ]),
478 );
479
480 let p = product_definite(term, k, two, n, &pool).expect("wallis");
481 for ni in 3..36 {
482 let mut env = HashMap::new();
483 env.insert(n, ni as f64);
484 let pv = eval_g(p.value, &env, &pool).unwrap();
485 let want = (ni + 1) as f64 / (2.0 * ni as f64);
486 assert!(
487 (pv - want).abs() < 1e-5 * want.max(1.0),
488 "n={}: got {}",
489 ni,
490 pv
491 );
492 }
493 }
494}