1#![allow(clippy::needless_range_loop)]
7
8use crate::kernel::subs::subs;
9use crate::kernel::{ExprData, ExprId, ExprPool};
10use crate::matrix::normal_form::RatUniPoly;
11use crate::poly::unipoly::UniPoly;
12use crate::simplify::engine::simplify;
13use rug::{Integer, Rational};
14use std::collections::{BTreeMap, HashMap};
15use std::fmt;
16
17fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
18 simplify(e, pool).value
19}
20
21#[inline]
23fn rational_eq_one(r: &Rational) -> bool {
24 !r.is_zero() && r.numer() == r.denom()
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum RsolveError {
30 NotLinearRecurrence(String),
32 NonlinearTerm,
34 NonPolynomialRhs(String),
36 Unsupported(String),
38 InitialMismatch(String),
40}
41
42impl fmt::Display for RsolveError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 RsolveError::NotLinearRecurrence(s) => write!(f, "rsolve: {s}"),
46 RsolveError::NonlinearTerm => write!(f, "rsolve: nonlinear term in sequence variable"),
47 RsolveError::NonPolynomialRhs(s) => write!(f, "rsolve: non-polynomial rhs: {s}"),
48 RsolveError::Unsupported(s) => write!(f, "rsolve: unsupported: {s}"),
49 RsolveError::InitialMismatch(s) => write!(f, "rsolve: initial values: {s}"),
50 }
51 }
52}
53
54impl std::error::Error for RsolveError {}
55
56impl crate::errors::AlkahestError for RsolveError {
57 fn code(&self) -> &'static str {
58 match self {
59 RsolveError::NotLinearRecurrence(_) => "E-RSOLVE-001",
60 RsolveError::NonlinearTerm => "E-RSOLVE-002",
61 RsolveError::NonPolynomialRhs(_) => "E-RSOLVE-003",
62 RsolveError::Unsupported(_) => "E-RSOLVE-004",
63 RsolveError::InitialMismatch(_) => "E-RSOLVE-005",
64 }
65 }
66
67 fn remediation(&self) -> Option<&'static str> {
68 Some(
69 "use pool.func(name, [n + integer]) for shifts; keep coefficients rational and rhs polynomial in n",
70 )
71 }
72}
73
74fn rational_atom(pool: &ExprPool, r: &Rational) -> ExprId {
75 let numer = r.numer().clone();
76 let denom = r.denom().clone();
77 if denom == 1 {
78 pool.integer(numer)
79 } else {
80 pool.rational(numer, denom)
81 }
82}
83
84fn expr_div(pool: &ExprPool, num: ExprId, den: ExprId) -> ExprId {
85 pool.mul(vec![num, pool.pow(den, pool.integer(-1_i32))])
86}
87
88fn flatten_add(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
89 match pool.get(expr) {
90 ExprData::Add(args) => args.iter().flat_map(|&x| flatten_add(x, pool)).collect(),
91 _ => vec![expr],
92 }
93}
94
95fn flatten_mul(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
96 match pool.get(expr) {
97 ExprData::Mul(args) => args.iter().flat_map(|&x| flatten_mul(x, pool)).collect(),
98 _ => vec![expr],
99 }
100}
101
102fn linear_in_sym(expr: ExprId, sym: ExprId, pool: &ExprPool) -> Option<(Rational, Rational)> {
103 let e = simp(pool, expr);
104 if e == sym {
105 return Some((Rational::from(1), Rational::from(0)));
106 }
107 match pool.get(e) {
108 ExprData::Integer(n) => Some((Rational::from(0), Rational::from(n.0.clone()))),
109 ExprData::Rational(r) => Some((Rational::from(0), r.0.clone())),
110 ExprData::Add(args) => {
111 let mut a = Rational::from(0);
112 let mut b = Rational::from(0);
113 for t in args {
114 if t == sym {
115 a += Rational::from(1);
116 } else if let Some((a0, b0)) = linear_in_sym(t, sym, pool) {
117 a += a0;
118 b += b0;
119 } else {
120 return None;
121 }
122 }
123 Some((a, b))
124 }
125 ExprData::Mul(args) => {
126 if args.len() == 2 && args[0] == sym {
127 match pool.get(args[1]) {
128 ExprData::Integer(n) => Some((Rational::from(n.0.clone()), Rational::from(0))),
129 ExprData::Rational(r) => Some((r.0.clone(), Rational::from(0))),
130 _ => None,
131 }
132 } else if args.len() == 2 && args[1] == sym {
133 match pool.get(args[0]) {
134 ExprData::Integer(n) => Some((Rational::from(n.0.clone()), Rational::from(0))),
135 ExprData::Rational(r) => Some((r.0.clone(), Rational::from(0))),
136 _ => None,
137 }
138 } else {
139 None
140 }
141 }
142 ExprData::Pow { base, exp } => {
143 if base == sym {
144 match pool.get(exp) {
145 ExprData::Integer(n) if n.0 == 1 => {
146 Some((Rational::from(1), Rational::from(0)))
147 }
148 _ => None,
149 }
150 } else {
151 None
152 }
153 }
154 _ => None,
155 }
156}
157
158fn offset_in_n(arg: ExprId, n: ExprId, pool: &ExprPool) -> Result<i64, RsolveError> {
159 let (coef, c) = linear_in_sym(arg, n, pool).ok_or_else(|| {
160 RsolveError::NotLinearRecurrence(
161 "sequence index must be an affine integer shift of the recurrence variable".into(),
162 )
163 })?;
164 if coef != 1 {
165 return Err(RsolveError::NotLinearRecurrence(
166 "recurrence variable must appear with coefficient 1 in each index".into(),
167 ));
168 }
169 let num = c.numer();
170 let den = c.denom();
171 if num.clone() % den.clone() == 0 {
172 let q = Integer::from(num / den);
173 Ok(q.to_i64().unwrap_or(i64::MIN))
174 } else {
175 Err(RsolveError::NotLinearRecurrence(
176 "index shift must be an integer".into(),
177 ))
178 }
179}
180
181fn contains_seq(expr: ExprId, seq_name: &str, pool: &ExprPool) -> bool {
182 match pool.get(expr) {
183 ExprData::Func { name, args } => {
184 if name == seq_name {
185 return true;
186 }
187 args.iter().any(|&a| contains_seq(a, seq_name, pool))
188 }
189 ExprData::Add(xs) => xs.iter().any(|&a| contains_seq(a, seq_name, pool)),
190 ExprData::Mul(xs) => xs.iter().any(|&a| contains_seq(a, seq_name, pool)),
191 ExprData::Pow { base, exp } => {
192 contains_seq(base, seq_name, pool) || contains_seq(exp, seq_name, pool)
193 }
194 _ => false,
195 }
196}
197
198enum Peeled {
199 Seq { coeff: Rational, offset: i64 },
200 Other(ExprId),
201}
202
203fn peel_addend(
204 term: ExprId,
205 seq_name: &str,
206 n: ExprId,
207 pool: &ExprPool,
208) -> Result<Peeled, RsolveError> {
209 let factors = flatten_mul(term, pool);
210 let mut rat = Rational::from(1);
211 let mut seq_off: Option<i64> = None;
212 let mut rest: Vec<ExprId> = Vec::new();
213
214 for g in factors {
215 match pool.get(g) {
216 ExprData::Integer(nn) => {
217 rat *= Rational::from(nn.0.clone());
218 }
219 ExprData::Rational(rr) => {
220 rat *= rr.0.clone();
221 }
222 ExprData::Func { name, args } if name == seq_name => {
223 if args.len() != 1 {
224 return Err(RsolveError::NotLinearRecurrence(
225 "sequence applications must have exactly one index argument".into(),
226 ));
227 }
228 if seq_off.is_some() {
229 return Err(RsolveError::NonlinearTerm);
230 }
231 seq_off = Some(offset_in_n(args[0], n, pool)?);
232 }
233 _ => rest.push(g),
234 }
235 }
236
237 match (seq_off, rest.is_empty()) {
238 (Some(o), true) => Ok(Peeled::Seq {
239 coeff: rat,
240 offset: o,
241 }),
242 (None, _) => {
243 let rhs = if rest.is_empty() {
244 rational_atom(pool, &rat)
245 } else if rest.len() == 1 {
246 if rat == 1 {
247 rest[0]
248 } else {
249 simp(pool, pool.mul(vec![rational_atom(pool, &rat), rest[0]]))
250 }
251 } else {
252 let mut v = rest;
253 if rat != 1 {
254 v.insert(0, rational_atom(pool, &rat));
255 }
256 simp(pool, pool.mul(v))
257 };
258 Ok(Peeled::Other(rhs))
259 }
260 (Some(_), false) => Err(RsolveError::NonlinearTerm),
261 }
262}
263
264fn extract_recurrence(
266 equation: ExprId,
267 seq_name: &str,
268 n: ExprId,
269 pool: &ExprPool,
270) -> Result<(Vec<Rational>, RatUniPoly), RsolveError> {
271 let zero = simp(pool, equation);
272 let parts = flatten_add(zero, pool);
273 let mut by_shift: BTreeMap<i64, Rational> = BTreeMap::new();
274 let mut rhs_terms: Vec<ExprId> = Vec::new();
275
276 for p in parts {
277 match peel_addend(p, seq_name, n, pool)? {
278 Peeled::Seq { coeff, offset } => {
279 *by_shift.entry(offset).or_insert(Rational::from(0)) += coeff;
280 }
281 Peeled::Other(e) => rhs_terms.push(e),
282 }
283 }
284
285 if by_shift.is_empty() {
286 return Err(RsolveError::NotLinearRecurrence(
287 "no sequence term in equation".into(),
288 ));
289 }
290
291 let max_o = *by_shift.keys().max().unwrap();
292 let mut shifts: BTreeMap<i64, Rational> = BTreeMap::new();
293 for (&o, c) in &by_shift {
294 let lag = max_o - o;
295 *shifts.entry(lag).or_insert(Rational::from(0)) += c;
296 }
297
298 let d = *shifts.keys().max().unwrap() as usize;
299 let mut a = vec![Rational::from(0); d + 1];
300 for (&k, v) in &shifts {
301 a[k as usize] = v.clone();
302 }
303
304 if a[0] == 0 {
305 return Err(RsolveError::NotLinearRecurrence(
306 "leading coefficient of f(n) vanishes after normalization".into(),
307 ));
308 }
309
310 let rhs_expr = if rhs_terms.is_empty() {
311 pool.integer(0_i32)
312 } else {
313 let s = simp(pool, pool.add(rhs_terms));
314 simp(pool, pool.mul(vec![s, pool.integer(-1_i32)]))
315 };
316
317 if contains_seq(rhs_expr, seq_name, pool) {
318 return Err(RsolveError::NotLinearRecurrence(
319 "right-hand side still references the sequence".into(),
320 ));
321 }
322
323 let rhs_poly = match UniPoly::from_symbolic_clear_denoms(rhs_expr, n, pool) {
324 Ok(p) => {
325 let cs: Vec<Rational> = p.coefficients().into_iter().map(Rational::from).collect();
326 RatUniPoly { coeffs: cs }.trim()
327 }
328 Err(e) => {
329 return Err(RsolveError::NonPolynomialRhs(e.to_string()));
330 }
331 };
332
333 Ok((a, rhs_poly))
334}
335
336fn binom(n: u32, k: u32) -> Integer {
337 if k > n {
338 return Integer::from(0);
339 }
340 let mut acc = Integer::from(1);
341 for i in 0..k {
342 acc *= Integer::from(n - i);
343 acc /= Integer::from(i + 1);
344 }
345 acc
346}
347
348fn shift_x_sub_m(deg: u32, m: i64) -> RatUniPoly {
350 if deg == 0 {
351 return RatUniPoly::one();
352 }
353 let mut coeffs = vec![Rational::from(0); (deg + 1) as usize];
354 let mm = Rational::from(m);
355 for k in 0..=deg {
356 let mut term = Rational::from(binom(deg, k));
357 if (deg - k) % 2 == 1 {
358 term = -term;
359 }
360 for _ in 0..(deg - k) {
361 term *= mm.clone();
362 }
363 coeffs[k as usize] = term;
364 }
365 RatUniPoly { coeffs }.trim()
366}
367
368fn poly_apply_order1_shift(r: &Rational, p: &RatUniPoly) -> RatUniPoly {
372 let mut out = RatUniPoly::zero();
373 for (deg, c) in p.coeffs.iter().enumerate() {
374 if c.is_zero() {
375 continue;
376 }
377 let mut mon = vec![Rational::from(0); deg + 1];
378 mon[deg] = c.clone();
379 let n_poly = RatUniPoly { coeffs: mon }.trim();
380 let shifted = shift_x_sub_m(deg as u32, 1);
381 let sub = &RatUniPoly::constant(r.clone()) * &shifted;
382 out = &out + &(&n_poly - &sub);
383 }
384 out.trim()
385}
386
387fn poly_apply_order2(a0: &Rational, a1: &Rational, a2: &Rational, p: &RatUniPoly) -> RatUniPoly {
388 let mut out = RatUniPoly::zero();
389 for (deg, coeff) in p.coeffs.iter().enumerate() {
390 if coeff.is_zero() {
391 continue;
392 }
393 let mut mon = vec![Rational::from(0); deg + 1];
394 mon[deg] = coeff.clone();
395 let n_poly = RatUniPoly { coeffs: mon }.trim();
396 let p1 = shift_x_sub_m(deg as u32, 1);
397 let p2 = shift_x_sub_m(deg as u32, 2);
398 let term = &(&(&RatUniPoly::constant(a0.clone()) * &n_poly)
399 + &(&RatUniPoly::constant(a1.clone()) * &p1))
400 + &(&RatUniPoly::constant(a2.clone()) * &p2);
401 out = &out + &term;
402 }
403 out.trim()
404}
405
406fn mono_n(j: usize) -> RatUniPoly {
407 let mut c = vec![Rational::from(0); j + 1];
408 c[j] = Rational::from(1);
409 RatUniPoly { coeffs: c }.trim()
410}
411
412fn solve_rational_linear_system(
413 mut a: Vec<Vec<Rational>>,
414 mut b: Vec<Rational>,
415) -> Option<Vec<Rational>> {
416 let n = b.len();
417 debug_assert_eq!(a.len(), n);
418 for col in 0..n {
419 let mut pivot = None;
420 for row in col..n {
421 if !a[row][col].is_zero() {
422 pivot = Some(row);
423 break;
424 }
425 }
426 let pr = pivot?;
427 if pr != col {
428 a.swap(col, pr);
429 b.swap(col, pr);
430 }
431 let div = a[col][col].clone();
432 if div.is_zero() {
433 return None;
434 }
435 let inv = Rational::from(1) / div.clone();
436 for j in col..n {
437 a[col][j] *= inv.clone();
438 }
439 b[col] *= inv;
440 for row in 0..n {
441 if row == col {
442 continue;
443 }
444 let f = a[row][col].clone();
445 if f.is_zero() {
446 continue;
447 }
448 for j in col..n {
449 let sub = f.clone() * a[col][j].clone();
450 a[row][j] -= sub;
451 }
452 let bcol = b[col].clone();
453 b[row] -= f * bcol;
454 }
455 }
456 Some(b)
457}
458
459fn undetermined_order1(r: &Rational, h: &RatUniPoly) -> Option<RatUniPoly> {
460 let dh = h.degree().max(0) as usize;
461 let start_deg = if rational_eq_one(r) { 1 } else { 0 };
462 for bump in 0..24 {
463 let hi_deg = (dh + bump + usize::from(rational_eq_one(r))).max(start_deg);
464 if hi_deg > 40 {
465 break;
466 }
467 let u = hi_deg.saturating_sub(start_deg) + 1;
468 let mut mat = vec![vec![Rational::from(0); u]; u];
469 let mut rhs = vec![Rational::from(0); u];
470 for row in 0..u {
471 for j in 0..u {
472 let pow = start_deg + j;
473 let basis = mono_n(pow);
474 let applied = poly_apply_order1_shift(r, &basis);
475 mat[row][j] = applied
476 .coeffs
477 .get(row)
478 .cloned()
479 .unwrap_or_else(|| Rational::from(0));
480 }
481 rhs[row] = h
482 .coeffs
483 .get(row)
484 .cloned()
485 .unwrap_or_else(|| Rational::from(0));
486 }
487 if let Some(x) = solve_rational_linear_system(mat, rhs) {
488 let mut coeffs = vec![Rational::from(0); hi_deg + 1];
489 for (j, coeff) in x.into_iter().enumerate() {
490 coeffs[start_deg + j] = coeff;
491 }
492 return Some(RatUniPoly { coeffs }.trim());
493 }
494 }
495 None
496}
497
498fn undetermined_order2(
499 a0: &Rational,
500 a1: &Rational,
501 a2: &Rational,
502 h: &RatUniPoly,
503) -> Option<RatUniPoly> {
504 let dh = h.degree().max(0) as usize;
505 for bump in 0..24 {
506 let trial_deg = (dh + 4 + bump).min(42);
507 let u = trial_deg + 1;
508 let mut mat = vec![vec![Rational::from(0); u]; u];
509 let mut rhs = vec![Rational::from(0); u];
510 for row in 0..u {
511 for j in 0..u {
512 let basis = mono_n(j);
513 let applied = poly_apply_order2(a0, a1, a2, &basis);
514 mat[row][j] = applied
515 .coeffs
516 .get(row)
517 .cloned()
518 .unwrap_or_else(|| Rational::from(0));
519 }
520 rhs[row] = h
521 .coeffs
522 .get(row)
523 .cloned()
524 .unwrap_or_else(|| Rational::from(0));
525 }
526 if let Some(x) = solve_rational_linear_system(mat, rhs) {
527 return Some(RatUniPoly { coeffs: x }.trim());
528 }
529 }
530 None
531}
532
533fn rat_poly_to_expr(pool: &ExprPool, n_sym: ExprId, p: &RatUniPoly) -> ExprId {
534 let mut terms: Vec<ExprId> = Vec::new();
535 for (deg, coeff) in p.coeffs.iter().enumerate() {
536 if coeff.is_zero() {
537 continue;
538 }
539 let coeff_q = coeff.clone();
540 let numer = coeff_q.numer();
541 let denom = coeff_q.denom();
542 let coeff_expr = if *denom == 1 {
543 pool.integer(numer.clone())
544 } else {
545 pool.rational(numer.clone(), denom.clone())
546 };
547 let pow_id = if deg == 0 {
548 coeff_expr
549 } else if deg == 1 {
550 pool.mul(vec![coeff_expr, n_sym])
551 } else {
552 pool.mul(vec![coeff_expr, pool.pow(n_sym, pool.integer(deg as i64))])
553 };
554 terms.push(pow_id);
555 }
556 match terms.len() {
557 0 => pool.integer(0_i32),
558 1 => terms[0],
559 _ => pool.add(terms),
560 }
561}
562
563fn sqrt_disc_expr(pool: &ExprPool, disc: &Rational) -> ExprId {
564 let num = disc.numer().clone();
565 let den = disc.denom().clone();
566 let prod = num * den.clone();
567 let inner = pool.integer(prod);
568 let sqrt_e = pool.func("sqrt", vec![inner]);
569 let den_e = pool.integer(den);
570 expr_div(pool, sqrt_e, den_e)
571}
572
573fn char_poly_asc(a: &[Rational]) -> RatUniPoly {
575 let d = a.len() - 1;
576 let mut v = vec![Rational::from(0); d + 1];
577 for i in 0..=d {
578 v[i] = a[d - i].clone();
579 }
580 RatUniPoly { coeffs: v }.trim()
581}
582
583fn horner_rat(p: &RatUniPoly, x: &Rational) -> Rational {
584 let mut acc = Rational::from(0);
585 for c in p.coeffs.iter().rev() {
586 acc = acc * x.clone() + c.clone();
587 }
588 acc
589}
590
591fn divisors_int(mut n: Integer) -> Vec<Integer> {
592 if n < 0 {
593 n = -n;
594 }
595 if n == 0 {
596 return vec![Integer::from(1)];
597 }
598 let mut out = vec![Integer::from(1)];
599 let mut i = Integer::from(2);
600 while i.clone() * i.clone() <= n {
601 if n.clone() % i.clone() == 0 {
602 let mut pws = vec![Integer::from(1)];
603 let mut nn = n.clone();
604 while nn.clone() % i.clone() == 0 {
605 let last = pws.last().unwrap().clone();
606 pws.push(last * i.clone());
607 nn /= i.clone();
608 }
609 n = nn;
610 let old = out.clone();
611 out.clear();
612 for base in old {
613 for pw in &pws {
614 out.push(base.clone() * pw);
615 }
616 }
617 }
618 i += 1;
619 }
620 if n > 1 {
621 let old = out.clone();
622 out.clear();
623 for base in old {
624 out.push(base.clone());
625 out.push(base * n.clone());
626 }
627 }
628 out.sort();
629 out.dedup();
630 out
631}
632
633fn peel_rational_root(p: &RatUniPoly) -> Option<Rational> {
634 if p.is_zero() {
635 return None;
636 }
637 let mut z: Vec<Integer> = Vec::new();
638 let mut lcm_den = Integer::from(1);
639 for c in &p.coeffs {
640 lcm_den = lcm_den.lcm(&c.denom().clone());
641 }
642 for c in &p.coeffs {
643 let d = c.denom().clone();
644 let scale = lcm_den.clone() / d;
645 z.push(scale * c.numer().clone());
646 }
647 let lc = z.last().cloned().unwrap_or_else(|| Integer::from(0));
648 let c0 = z.first().cloned().unwrap_or_else(|| Integer::from(0));
649 if lc.is_zero() {
650 return Some(Rational::from(0));
651 }
652 let mut try_vals: Vec<Rational> = Vec::new();
653 for pd in divisors_int(lc.clone()) {
654 for qd in divisors_int(c0.clone()) {
655 try_vals.push(Rational::from((pd.clone(), qd.clone())));
656 try_vals.push(-Rational::from((pd.clone(), qd)));
657 }
658 }
659 try_vals.sort_by(|x, y| x.partial_cmp(y).unwrap());
660 try_vals.dedup();
661 try_vals
662 .into_iter()
663 .find(|r| !p.coeffs.is_empty() && horner_rat(p, r).is_zero())
664}
665
666fn div_linear_factor(p: RatUniPoly, root: &Rational) -> RatUniPoly {
667 let r = root.clone();
668 let lin = RatUniPoly {
669 coeffs: vec![-r, Rational::from(1)],
670 }
671 .trim();
672 let (q, rem) = RatUniPoly::div_rem(&p, &lin);
673 debug_assert!(rem.is_zero());
674 q
675}
676
677fn factor_char_polynomial(mut p: RatUniPoly) -> Result<Vec<(Rational, usize)>, RsolveError> {
678 let mut roots: Vec<(Rational, usize)> = Vec::new();
679 let mut guard = 0usize;
680 while p.degree() > 0 && guard < 64 {
681 guard += 1;
682 let Some(r0) = peel_rational_root(&p) else {
683 break;
684 };
685 let mut m = 0usize;
686 while p.degree() > 0 && horner_rat(&p, &r0).is_zero() {
687 p = div_linear_factor(p, &r0);
688 m += 1;
689 }
690 roots.push((r0, m));
691 }
692 match p.degree() {
693 -1 | 0 => Ok(roots),
694 1 => {
695 let c0 = p.coeffs[0].clone();
696 let c1 = p.coeffs[1].clone();
697 if c1 == 0 {
698 return Err(RsolveError::Unsupported("degenerate characteristic".into()));
699 }
700 roots.push((-c0 / c1, 1));
701 Ok(roots)
702 }
703 2 => {
704 let c0 = p.coeffs[0].clone();
705 let c1 = p.coeffs[1].clone();
706 let c2 = p.coeffs[2].clone();
707 if c2 == 0 {
708 return Err(RsolveError::Unsupported(
709 "characteristic degree mismatch".into(),
710 ));
711 }
712 let disc = c1.clone() * c1.clone() - Rational::from(4) * c2.clone() * c0.clone();
713 if disc == 0 {
714 let r = -c1 / (Rational::from(2) * c2.clone());
715 roots.push((r, 2));
716 } else if disc > 0 {
717 let disc_numer = disc.numer().clone();
718 let disc_denom = disc.denom().clone();
719 let (sn, rem_n) = disc_numer.sqrt_rem(Integer::new());
720 let (sd, rem_d) = disc_denom.sqrt_rem(Integer::new());
721 if rem_n != 0 || rem_d != 0 {
722 return Err(RsolveError::Unsupported(
723 "order-3+ with irreducible quadratic characteristic tail".into(),
724 ));
725 }
726 let sqrt_d = Rational::from((sn, sd));
727 let r1 = (-c1.clone() + sqrt_d.clone()) / (Rational::from(2) * c2.clone());
728 let r2 = (-c1 - sqrt_d) / (Rational::from(2) * c2.clone());
729 roots.push((r1, 1));
730 roots.push((r2, 1));
731 } else {
732 return Err(RsolveError::Unsupported(
733 "complex characteristic roots".into(),
734 ));
735 }
736 Ok(roots)
737 }
738 d => Err(RsolveError::Unsupported(format!(
739 "characteristic leftover degree {d}"
740 ))),
741 }
742}
743
744fn hom_solution_from_roots(
745 pool: &ExprPool,
746 n_sym: ExprId,
747 root_facts: &[(Rational, usize)],
748 c_syms: &[ExprId],
749) -> Result<ExprId, RsolveError> {
750 let mut terms: Vec<ExprId> = Vec::new();
751 let mut idx = 0;
752 for (r, mult) in root_facts {
753 let re = rational_atom(pool, r);
754 for p in 0..*mult {
755 if idx >= c_syms.len() {
756 return Err(RsolveError::Unsupported(
757 "internal: not enough constant symbols".into(),
758 ));
759 }
760 let basis = if p == 0 {
761 simp(pool, pool.pow(re, n_sym))
762 } else {
763 let np = pool.pow(n_sym, pool.integer(p as i64));
764 simp(pool, pool.mul(vec![np, pool.pow(re, n_sym)]))
765 };
766 terms.push(simp(pool, pool.mul(vec![c_syms[idx], basis])));
767 idx += 1;
768 }
769 }
770 if idx != c_syms.len() {
771 return Err(RsolveError::Unsupported(
772 "internal: constant count mismatch".into(),
773 ));
774 }
775 match terms.len() {
776 0 => Ok(pool.integer(0_i32)),
777 1 => Ok(terms[0]),
778 _ => Ok(simp(pool, pool.add(terms))),
779 }
780}
781
782fn order2_r_exprs(pool: &ExprPool, a_rec: &[Rational]) -> Result<(ExprId, ExprId), RsolveError> {
783 let p = char_poly_asc(a_rec);
784 if p.degree() != 2 {
785 return Err(RsolveError::Unsupported(
786 "expected order-2 characteristic".into(),
787 ));
788 }
789 let p0 = p.coeffs[0].clone();
790 let p1 = p.coeffs[1].clone();
791 let p2 = p.coeffs[2].clone();
792 if p2 == 0 {
793 return Err(RsolveError::Unsupported("degenerate characteristic".into()));
794 }
795 let b = p1 / p2.clone();
796 let c = p0 / p2.clone();
797 let disc = b.clone() * b.clone() - Rational::from(4) * c.clone();
798 if disc < 0 {
799 return Err(RsolveError::Unsupported("complex roots".into()));
800 }
801 let sqrt_e = sqrt_disc_expr(pool, &disc);
802 let neg_b = rational_atom(pool, &(-b.clone()));
803 let half = rational_atom(pool, &Rational::from((1, 2)));
804 let inner1 = simp(pool, pool.add(vec![neg_b, sqrt_e]));
805 let r1 = simp(pool, pool.mul(vec![half, inner1]));
806 let inner2 = simp(
807 pool,
808 pool.add(vec![neg_b, pool.mul(vec![sqrt_e, pool.integer(-1_i32)])]),
809 );
810 let r2 = simp(pool, pool.mul(vec![half, inner2]));
811 Ok((r1, r2))
812}
813
814fn fresh_constants(pool: &ExprPool, k: usize) -> Vec<ExprId> {
815 (0..k)
816 .map(|i: usize| pool.symbol(format!("C{}", i), crate::kernel::Domain::Real))
817 .collect()
818}
819
820fn subs_n_int(pool: &ExprPool, expr: ExprId, n_sym: ExprId, ni: i64) -> ExprId {
821 let mut m = HashMap::new();
822 m.insert(n_sym, pool.integer(ni));
823 simp(pool, subs(expr, &m, pool))
824}
825
826#[allow(clippy::too_many_arguments)]
827fn apply_init(
828 pool: &ExprPool,
829 general: ExprId,
830 n_sym: ExprId,
831 c_syms: &[ExprId],
832 initials: &BTreeMap<i64, ExprId>,
833 d: usize,
834 a: &[Rational],
835 particular: ExprId,
836) -> Result<ExprId, RsolveError> {
837 if initials.len() != d {
838 return Err(RsolveError::InitialMismatch(format!(
839 "need {d} initial values for order {d}, got {}",
840 initials.len()
841 )));
842 }
843
844 if d == 1 {
845 let (&n0, v0) = initials.first_key_value().unwrap();
846 let r = (-a[1].clone()) / a[0].clone();
847 let r_e = rational_atom(pool, &r);
848 let p0 = subs_n_int(pool, particular, n_sym, n0);
849 let rpow = simp(pool, pool.pow(r_e, pool.integer(n0)));
850 let rhs = simp(
851 pool,
852 pool.add(vec![*v0, pool.mul(vec![p0, pool.integer(-1_i32)])]),
853 );
854 let c0v = expr_div(pool, rhs, rpow);
855 let mut m = HashMap::new();
856 m.insert(c_syms[0], c0v);
857 return Ok(simp(pool, subs(general, &m, pool)));
858 }
859
860 if d == 2 {
861 let keys: Vec<i64> = initials.keys().copied().collect();
862 if keys.len() != 2 {
863 return Err(RsolveError::InitialMismatch("need two integers".into()));
864 }
865 let (n0, n1) = (keys[0], keys[1]);
866 let (r1_e, r2_e) = order2_r_exprs(pool, a)?;
867 let v0 = *initials.get(&n0).unwrap();
868 let v1 = *initials.get(&n1).unwrap();
869 let p0 = subs_n_int(pool, particular, n_sym, n0);
870 let p1 = subs_n_int(pool, particular, n_sym, n1);
871 let v0p = simp(
872 pool,
873 pool.add(vec![v0, pool.mul(vec![p0, pool.integer(-1_i32)])]),
874 );
875 let v1p = simp(
876 pool,
877 pool.add(vec![v1, pool.mul(vec![p1, pool.integer(-1_i32)])]),
878 );
879 let a00 = simp(pool, pool.pow(r1_e, pool.integer(n0)));
880 let b00 = simp(pool, pool.pow(r2_e, pool.integer(n0)));
881 let a10 = simp(pool, pool.pow(r1_e, pool.integer(n1)));
882 let b10 = simp(pool, pool.pow(r2_e, pool.integer(n1)));
883 let det = simp(
884 pool,
885 pool.add(vec![
886 pool.mul(vec![a00, b10]),
887 pool.mul(vec![a10, b00, pool.integer(-1_i32)]),
888 ]),
889 );
890 let num_c0 = simp(
891 pool,
892 pool.add(vec![
893 pool.mul(vec![v0p, b10]),
894 pool.mul(vec![v1p, b00, pool.integer(-1_i32)]),
895 ]),
896 );
897 let num_c1 = simp(
898 pool,
899 pool.add(vec![
900 pool.mul(vec![a00, v1p]),
901 pool.mul(vec![a10, v0p, pool.integer(-1_i32)]),
902 ]),
903 );
904 let c0v = expr_div(pool, num_c0, det);
905 let c1v = expr_div(pool, num_c1, det);
906 let mut m = HashMap::new();
907 m.insert(c_syms[0], c0v);
908 m.insert(c_syms[1], c1v);
909 return Ok(simp(pool, subs(general, &m, pool)));
910 }
911
912 Err(RsolveError::InitialMismatch(
913 "initial values for order > 2 not implemented".into(),
914 ))
915}
916
917pub fn rsolve(
922 pool: &ExprPool,
923 equation: ExprId,
924 n: ExprId,
925 seq_name: &str,
926 initials: Option<&BTreeMap<i64, ExprId>>,
927) -> Result<ExprId, RsolveError> {
928 let (a, rhs_p) = extract_recurrence(equation, seq_name, n, pool)?;
929 let d = a.len() - 1;
930
931 let a0_lead = a[0].clone();
932 let hom_norm: Vec<Rational> = a.iter().map(|x| x.clone() / a0_lead.clone()).collect();
933 let rhs_norm = {
934 let inv = Rational::from(1) / a0_lead.clone();
935 RatUniPoly {
936 coeffs: rhs_p
937 .coeffs
938 .iter()
939 .map(|c| c.clone() * inv.clone())
940 .collect(),
941 }
942 .trim()
943 };
944
945 let particular_p = if rhs_norm.is_zero() {
946 RatUniPoly::zero()
947 } else if d == 1 {
948 let r = -hom_norm[1].clone();
949 undetermined_order1(&r, &rhs_norm).ok_or_else(|| {
950 RsolveError::Unsupported("particular solution (order 1) failed".into())
951 })?
952 } else if d == 2 {
953 undetermined_order2(&hom_norm[0], &hom_norm[1], &hom_norm[2], &rhs_norm).ok_or_else(
954 || RsolveError::Unsupported("particular solution (order 2) failed".into()),
955 )?
956 } else {
957 if !rhs_norm.is_zero() {
958 return Err(RsolveError::Unsupported(
959 "non-homogeneous order > 2 is not implemented".into(),
960 ));
961 }
962 RatUniPoly::zero()
963 };
964
965 let particular_e = if particular_p.is_zero() {
966 pool.integer(0_i32)
967 } else {
968 rat_poly_to_expr(pool, n, &particular_p)
969 };
970
971 let (hom_e, c_syms): (ExprId, Vec<ExprId>) = match d {
972 1 => {
973 let r = -hom_norm[1].clone();
974 let re = rational_atom(pool, &r);
975 let c0 = pool.symbol("C0", crate::kernel::Domain::Real);
976 let h = simp(pool, pool.mul(vec![c0, pool.pow(re, n)]));
977 (h, vec![c0])
978 }
979 2 => {
980 let c0 = pool.symbol("C0", crate::kernel::Domain::Real);
981 let c1 = pool.symbol("C1", crate::kernel::Domain::Real);
982 let (r1, r2) = order2_r_exprs(pool, &a)?;
983 let h = simp(
984 pool,
985 pool.add(vec![
986 simp(pool, pool.mul(vec![c0, pool.pow(r1, n)])),
987 simp(pool, pool.mul(vec![c1, pool.pow(r2, n)])),
988 ]),
989 );
990 (h, vec![c0, c1])
991 }
992 _ => {
993 let facts = factor_char_polynomial(char_poly_asc(&a))?;
994 let nconst: usize = facts.iter().map(|(_, m)| *m).sum();
995 let cs = fresh_constants(pool, nconst);
996 let h = hom_solution_from_roots(pool, n, &facts, &cs)?;
997 (h, cs)
998 }
999 };
1000
1001 let general = simp(pool, pool.add(vec![hom_e, particular_e]));
1002
1003 if let Some(init) = initials {
1004 apply_init(pool, general, n, &c_syms, init, d, &a, particular_e)
1005 } else {
1006 Ok(general)
1007 }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013 use crate::jit::eval_interp;
1014 use crate::kernel::Domain;
1015 use rug::Rational;
1016 use std::collections::HashMap;
1017
1018 fn has_sym(expr: ExprId, name: &str, pool: &ExprPool) -> bool {
1019 match pool.get(expr) {
1020 ExprData::Symbol { name: n, .. } => n == name,
1021 ExprData::Add(xs) => xs.iter().any(|&x| has_sym(x, name, pool)),
1022 ExprData::Mul(xs) => xs.iter().any(|&x| has_sym(x, name, pool)),
1023 ExprData::Pow { base, exp } => has_sym(base, name, pool) || has_sym(exp, name, pool),
1024 ExprData::Func { args, .. } => args.iter().any(|&a| has_sym(a, name, pool)),
1025 _ => false,
1026 }
1027 }
1028
1029 #[test]
1030 fn arithmetic_progression_general() {
1031 let pool = ExprPool::new();
1032 let n = pool.symbol("n", Domain::Real);
1033 let f = |args: Vec<ExprId>| pool.func("f", args);
1034 let eq = simp(
1035 &pool,
1036 pool.add(vec![
1037 f(vec![n]),
1038 pool.mul(vec![
1039 f(vec![pool.add(vec![n, pool.integer(-1_i32)])]),
1040 pool.integer(-1_i32),
1041 ]),
1042 pool.integer(-1_i32),
1043 ]),
1044 );
1045 let sol = rsolve(&pool, eq, n, "f", None).expect("rsolve");
1046 assert!(has_sym(sol, "C0", &pool));
1047 }
1048
1049 #[test]
1050 fn fibonacci_numeric_with_init() {
1051 use crate::sum::recurrence::solve_linear_recurrence_homogeneous;
1052 let pool = ExprPool::new();
1053 let n = pool.symbol("n", Domain::Real);
1054 let f = |args: Vec<ExprId>| pool.func("f", args);
1055 let eq = simp(
1056 &pool,
1057 pool.add(vec![
1058 f(vec![n]),
1059 pool.mul(vec![
1060 f(vec![pool.add(vec![n, pool.integer(-1_i32)])]),
1061 pool.integer(-1_i32),
1062 ]),
1063 pool.mul(vec![
1064 f(vec![pool.add(vec![n, pool.integer(-2_i32)])]),
1065 pool.integer(-1_i32),
1066 ]),
1067 ]),
1068 );
1069 let mut init = BTreeMap::new();
1070 init.insert(0, pool.integer(0));
1071 init.insert(1, pool.integer(1));
1072 let sol = rsolve(&pool, eq, n, "f", Some(&init)).expect("rsolve");
1073
1074 let ref_sol = solve_linear_recurrence_homogeneous(
1075 &pool,
1076 n,
1077 &[Rational::from(-1), Rational::from(-1), Rational::from(1)],
1078 &[pool.integer(0), pool.integer(1)],
1079 )
1080 .expect("ref");
1081
1082 for ni in 0..=12 {
1083 let mut env = HashMap::new();
1084 env.insert(n, ni as f64);
1085 let v = eval_interp(sol, &env, &pool).expect("eval");
1086 let vr = eval_interp(ref_sol.closed_form, &env, &pool).expect("eref");
1087 assert!((v - vr).abs() < 1e-4, "n={ni} rsolve={v} ref={vr}",);
1088 }
1089 }
1090}