Skip to main content

alkahest_cas/sum/
rsolve.rs

1//! Linear difference equations with constant coefficients (V2-18).
2//!
3//! `rsolve` accepts an equation linear in `seq_name(n + offset)` applications with
4//! integer offsets, rational multipliers, and a right-hand side polynomial in `n`.
5
6#![allow(clippy::needless_range_loop)]
7
8use crate::kernel::subs::subs;
9use crate::kernel::{ExprData, ExprId, ExprPool};
10use crate::matrix::normal_form::RatUniPoly;
11use crate::poly::unipoly::UniPoly;
12use crate::simplify::engine::simplify;
13use rug::{Integer, Rational};
14use std::collections::{BTreeMap, HashMap};
15use std::fmt;
16
17fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
18    simplify(e, pool).value
19}
20
21/// True when ``r`` is multiplicative unity (checks canonical `numer/denom`; also matches `±k/±k`).
22#[inline]
23fn rational_eq_one(r: &Rational) -> bool {
24    !r.is_zero() && r.numer() == r.denom()
25}
26
27/// Errors from [`rsolve`].
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum RsolveError {
30    /// Equation shape is not a supported linear recurrence in `seq_name`.
31    NotLinearRecurrence(String),
32    /// Non-constant-coefficient factor (e.g. `n*f(n)`).
33    NonlinearTerm,
34    /// Right-hand side is not a polynomial in `n`.
35    NonPolynomialRhs(String),
36    /// Order or characteristic factorization outside the supported fragment.
37    Unsupported(String),
38    /// Initial values do not fix the constants (singular or wrong count).
39    InitialMismatch(String),
40}
41
42impl fmt::Display for RsolveError {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            RsolveError::NotLinearRecurrence(s) => write!(f, "rsolve: {s}"),
46            RsolveError::NonlinearTerm => write!(f, "rsolve: nonlinear term in sequence variable"),
47            RsolveError::NonPolynomialRhs(s) => write!(f, "rsolve: non-polynomial rhs: {s}"),
48            RsolveError::Unsupported(s) => write!(f, "rsolve: unsupported: {s}"),
49            RsolveError::InitialMismatch(s) => write!(f, "rsolve: initial values: {s}"),
50        }
51    }
52}
53
54impl std::error::Error for RsolveError {}
55
56impl crate::errors::AlkahestError for RsolveError {
57    fn code(&self) -> &'static str {
58        match self {
59            RsolveError::NotLinearRecurrence(_) => "E-RSOLVE-001",
60            RsolveError::NonlinearTerm => "E-RSOLVE-002",
61            RsolveError::NonPolynomialRhs(_) => "E-RSOLVE-003",
62            RsolveError::Unsupported(_) => "E-RSOLVE-004",
63            RsolveError::InitialMismatch(_) => "E-RSOLVE-005",
64        }
65    }
66
67    fn remediation(&self) -> Option<&'static str> {
68        Some(
69            "use pool.func(name, [n + integer]) for shifts; keep coefficients rational and rhs polynomial in n",
70        )
71    }
72}
73
74fn rational_atom(pool: &ExprPool, r: &Rational) -> ExprId {
75    let numer = r.numer().clone();
76    let denom = r.denom().clone();
77    if denom == 1 {
78        pool.integer(numer)
79    } else {
80        pool.rational(numer, denom)
81    }
82}
83
84fn expr_div(pool: &ExprPool, num: ExprId, den: ExprId) -> ExprId {
85    pool.mul(vec![num, pool.pow(den, pool.integer(-1_i32))])
86}
87
88fn flatten_add(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
89    match pool.get(expr) {
90        ExprData::Add(args) => args.iter().flat_map(|&x| flatten_add(x, pool)).collect(),
91        _ => vec![expr],
92    }
93}
94
95fn flatten_mul(expr: ExprId, pool: &ExprPool) -> Vec<ExprId> {
96    match pool.get(expr) {
97        ExprData::Mul(args) => args.iter().flat_map(|&x| flatten_mul(x, pool)).collect(),
98        _ => vec![expr],
99    }
100}
101
102fn linear_in_sym(expr: ExprId, sym: ExprId, pool: &ExprPool) -> Option<(Rational, Rational)> {
103    let e = simp(pool, expr);
104    if e == sym {
105        return Some((Rational::from(1), Rational::from(0)));
106    }
107    match pool.get(e) {
108        ExprData::Integer(n) => Some((Rational::from(0), Rational::from(n.0.clone()))),
109        ExprData::Rational(r) => Some((Rational::from(0), r.0.clone())),
110        ExprData::Add(args) => {
111            let mut a = Rational::from(0);
112            let mut b = Rational::from(0);
113            for t in args {
114                if t == sym {
115                    a += Rational::from(1);
116                } else if let Some((a0, b0)) = linear_in_sym(t, sym, pool) {
117                    a += a0;
118                    b += b0;
119                } else {
120                    return None;
121                }
122            }
123            Some((a, b))
124        }
125        ExprData::Mul(args) => {
126            if args.len() == 2 && args[0] == sym {
127                match pool.get(args[1]) {
128                    ExprData::Integer(n) => Some((Rational::from(n.0.clone()), Rational::from(0))),
129                    ExprData::Rational(r) => Some((r.0.clone(), Rational::from(0))),
130                    _ => None,
131                }
132            } else if args.len() == 2 && args[1] == sym {
133                match pool.get(args[0]) {
134                    ExprData::Integer(n) => Some((Rational::from(n.0.clone()), Rational::from(0))),
135                    ExprData::Rational(r) => Some((r.0.clone(), Rational::from(0))),
136                    _ => None,
137                }
138            } else {
139                None
140            }
141        }
142        ExprData::Pow { base, exp } => {
143            if base == sym {
144                match pool.get(exp) {
145                    ExprData::Integer(n) if n.0 == 1 => {
146                        Some((Rational::from(1), Rational::from(0)))
147                    }
148                    _ => None,
149                }
150            } else {
151                None
152            }
153        }
154        _ => None,
155    }
156}
157
158fn offset_in_n(arg: ExprId, n: ExprId, pool: &ExprPool) -> Result<i64, RsolveError> {
159    let (coef, c) = linear_in_sym(arg, n, pool).ok_or_else(|| {
160        RsolveError::NotLinearRecurrence(
161            "sequence index must be an affine integer shift of the recurrence variable".into(),
162        )
163    })?;
164    if coef != 1 {
165        return Err(RsolveError::NotLinearRecurrence(
166            "recurrence variable must appear with coefficient 1 in each index".into(),
167        ));
168    }
169    let num = c.numer();
170    let den = c.denom();
171    if num.clone() % den.clone() == 0 {
172        let q = Integer::from(num / den);
173        Ok(q.to_i64().unwrap_or(i64::MIN))
174    } else {
175        Err(RsolveError::NotLinearRecurrence(
176            "index shift must be an integer".into(),
177        ))
178    }
179}
180
181fn contains_seq(expr: ExprId, seq_name: &str, pool: &ExprPool) -> bool {
182    match pool.get(expr) {
183        ExprData::Func { name, args } => {
184            if name == seq_name {
185                return true;
186            }
187            args.iter().any(|&a| contains_seq(a, seq_name, pool))
188        }
189        ExprData::Add(xs) => xs.iter().any(|&a| contains_seq(a, seq_name, pool)),
190        ExprData::Mul(xs) => xs.iter().any(|&a| contains_seq(a, seq_name, pool)),
191        ExprData::Pow { base, exp } => {
192            contains_seq(base, seq_name, pool) || contains_seq(exp, seq_name, pool)
193        }
194        _ => false,
195    }
196}
197
198enum Peeled {
199    Seq { coeff: Rational, offset: i64 },
200    Other(ExprId),
201}
202
203fn peel_addend(
204    term: ExprId,
205    seq_name: &str,
206    n: ExprId,
207    pool: &ExprPool,
208) -> Result<Peeled, RsolveError> {
209    let factors = flatten_mul(term, pool);
210    let mut rat = Rational::from(1);
211    let mut seq_off: Option<i64> = None;
212    let mut rest: Vec<ExprId> = Vec::new();
213
214    for g in factors {
215        match pool.get(g) {
216            ExprData::Integer(nn) => {
217                rat *= Rational::from(nn.0.clone());
218            }
219            ExprData::Rational(rr) => {
220                rat *= rr.0.clone();
221            }
222            ExprData::Func { name, args } if name == seq_name => {
223                if args.len() != 1 {
224                    return Err(RsolveError::NotLinearRecurrence(
225                        "sequence applications must have exactly one index argument".into(),
226                    ));
227                }
228                if seq_off.is_some() {
229                    return Err(RsolveError::NonlinearTerm);
230                }
231                seq_off = Some(offset_in_n(args[0], n, pool)?);
232            }
233            _ => rest.push(g),
234        }
235    }
236
237    match (seq_off, rest.is_empty()) {
238        (Some(o), true) => Ok(Peeled::Seq {
239            coeff: rat,
240            offset: o,
241        }),
242        (None, _) => {
243            let rhs = if rest.is_empty() {
244                rational_atom(pool, &rat)
245            } else if rest.len() == 1 {
246                if rat == 1 {
247                    rest[0]
248                } else {
249                    simp(pool, pool.mul(vec![rational_atom(pool, &rat), rest[0]]))
250                }
251            } else {
252                let mut v = rest;
253                if rat != 1 {
254                    v.insert(0, rational_atom(pool, &rat));
255                }
256                simp(pool, pool.mul(v))
257            };
258            Ok(Peeled::Other(rhs))
259        }
260        (Some(_), false) => Err(RsolveError::NonlinearTerm),
261    }
262}
263
264/// Returns `a[k]` for `∑_{k=0}^d a[k] f(n-k) = rhs` and polynomial `rhs`.
265fn extract_recurrence(
266    equation: ExprId,
267    seq_name: &str,
268    n: ExprId,
269    pool: &ExprPool,
270) -> Result<(Vec<Rational>, RatUniPoly), RsolveError> {
271    let zero = simp(pool, equation);
272    let parts = flatten_add(zero, pool);
273    let mut by_shift: BTreeMap<i64, Rational> = BTreeMap::new();
274    let mut rhs_terms: Vec<ExprId> = Vec::new();
275
276    for p in parts {
277        match peel_addend(p, seq_name, n, pool)? {
278            Peeled::Seq { coeff, offset } => {
279                *by_shift.entry(offset).or_insert(Rational::from(0)) += coeff;
280            }
281            Peeled::Other(e) => rhs_terms.push(e),
282        }
283    }
284
285    if by_shift.is_empty() {
286        return Err(RsolveError::NotLinearRecurrence(
287            "no sequence term in equation".into(),
288        ));
289    }
290
291    let max_o = *by_shift.keys().max().unwrap();
292    let mut shifts: BTreeMap<i64, Rational> = BTreeMap::new();
293    for (&o, c) in &by_shift {
294        let lag = max_o - o;
295        *shifts.entry(lag).or_insert(Rational::from(0)) += c;
296    }
297
298    let d = *shifts.keys().max().unwrap() as usize;
299    let mut a = vec![Rational::from(0); d + 1];
300    for (&k, v) in &shifts {
301        a[k as usize] = v.clone();
302    }
303
304    if a[0] == 0 {
305        return Err(RsolveError::NotLinearRecurrence(
306            "leading coefficient of f(n) vanishes after normalization".into(),
307        ));
308    }
309
310    let rhs_expr = if rhs_terms.is_empty() {
311        pool.integer(0_i32)
312    } else {
313        let s = simp(pool, pool.add(rhs_terms));
314        simp(pool, pool.mul(vec![s, pool.integer(-1_i32)]))
315    };
316
317    if contains_seq(rhs_expr, seq_name, pool) {
318        return Err(RsolveError::NotLinearRecurrence(
319            "right-hand side still references the sequence".into(),
320        ));
321    }
322
323    let rhs_poly = match UniPoly::from_symbolic_clear_denoms(rhs_expr, n, pool) {
324        Ok(p) => {
325            let cs: Vec<Rational> = p.coefficients().into_iter().map(Rational::from).collect();
326            RatUniPoly { coeffs: cs }.trim()
327        }
328        Err(e) => {
329            return Err(RsolveError::NonPolynomialRhs(e.to_string()));
330        }
331    };
332
333    Ok((a, rhs_poly))
334}
335
336fn binom(n: u32, k: u32) -> Integer {
337    if k > n {
338        return Integer::from(0);
339    }
340    let mut acc = Integer::from(1);
341    for i in 0..k {
342        acc *= Integer::from(n - i);
343        acc /= Integer::from(i + 1);
344    }
345    acc
346}
347
348/// `(n - m)^deg` as a polynomial in `n` (ascending coefficients).
349fn shift_x_sub_m(deg: u32, m: i64) -> RatUniPoly {
350    if deg == 0 {
351        return RatUniPoly::one();
352    }
353    let mut coeffs = vec![Rational::from(0); (deg + 1) as usize];
354    let mm = Rational::from(m);
355    for k in 0..=deg {
356        let mut term = Rational::from(binom(deg, k));
357        if (deg - k) % 2 == 1 {
358            term = -term;
359        }
360        for _ in 0..(deg - k) {
361            term *= mm.clone();
362        }
363        coeffs[k as usize] = term;
364    }
365    RatUniPoly { coeffs }.trim()
366}
367
368/// `L[p] = p(n) - r·p(n-1)` on polynomials in `n`, with `minus_r = -r` matching
369/// `f(n) + r f(n-1)` normalization... Here `r_forward` is the root of `x - r = 0` in
370/// `f(n) = r f(n-1) + h` i.e. apply `p - r * sub(p,n-1)`.
371fn poly_apply_order1_shift(r: &Rational, p: &RatUniPoly) -> RatUniPoly {
372    let mut out = RatUniPoly::zero();
373    for (deg, c) in p.coeffs.iter().enumerate() {
374        if c.is_zero() {
375            continue;
376        }
377        let mut mon = vec![Rational::from(0); deg + 1];
378        mon[deg] = c.clone();
379        let n_poly = RatUniPoly { coeffs: mon }.trim();
380        let shifted = shift_x_sub_m(deg as u32, 1);
381        let sub = &RatUniPoly::constant(r.clone()) * &shifted;
382        out = &out + &(&n_poly - &sub);
383    }
384    out.trim()
385}
386
387fn poly_apply_order2(a0: &Rational, a1: &Rational, a2: &Rational, p: &RatUniPoly) -> RatUniPoly {
388    let mut out = RatUniPoly::zero();
389    for (deg, coeff) in p.coeffs.iter().enumerate() {
390        if coeff.is_zero() {
391            continue;
392        }
393        let mut mon = vec![Rational::from(0); deg + 1];
394        mon[deg] = coeff.clone();
395        let n_poly = RatUniPoly { coeffs: mon }.trim();
396        let p1 = shift_x_sub_m(deg as u32, 1);
397        let p2 = shift_x_sub_m(deg as u32, 2);
398        let term = &(&(&RatUniPoly::constant(a0.clone()) * &n_poly)
399            + &(&RatUniPoly::constant(a1.clone()) * &p1))
400            + &(&RatUniPoly::constant(a2.clone()) * &p2);
401        out = &out + &term;
402    }
403    out.trim()
404}
405
406fn mono_n(j: usize) -> RatUniPoly {
407    let mut c = vec![Rational::from(0); j + 1];
408    c[j] = Rational::from(1);
409    RatUniPoly { coeffs: c }.trim()
410}
411
412fn solve_rational_linear_system(
413    mut a: Vec<Vec<Rational>>,
414    mut b: Vec<Rational>,
415) -> Option<Vec<Rational>> {
416    let n = b.len();
417    debug_assert_eq!(a.len(), n);
418    for col in 0..n {
419        let mut pivot = None;
420        for row in col..n {
421            if !a[row][col].is_zero() {
422                pivot = Some(row);
423                break;
424            }
425        }
426        let pr = pivot?;
427        if pr != col {
428            a.swap(col, pr);
429            b.swap(col, pr);
430        }
431        let div = a[col][col].clone();
432        if div.is_zero() {
433            return None;
434        }
435        let inv = Rational::from(1) / div.clone();
436        for j in col..n {
437            a[col][j] *= inv.clone();
438        }
439        b[col] *= inv;
440        for row in 0..n {
441            if row == col {
442                continue;
443            }
444            let f = a[row][col].clone();
445            if f.is_zero() {
446                continue;
447            }
448            for j in col..n {
449                let sub = f.clone() * a[col][j].clone();
450                a[row][j] -= sub;
451            }
452            let bcol = b[col].clone();
453            b[row] -= f * bcol;
454        }
455    }
456    Some(b)
457}
458
459fn undetermined_order1(r: &Rational, h: &RatUniPoly) -> Option<RatUniPoly> {
460    let dh = h.degree().max(0) as usize;
461    let start_deg = if rational_eq_one(r) { 1 } else { 0 };
462    for bump in 0..24 {
463        let hi_deg = (dh + bump + usize::from(rational_eq_one(r))).max(start_deg);
464        if hi_deg > 40 {
465            break;
466        }
467        let u = hi_deg.saturating_sub(start_deg) + 1;
468        let mut mat = vec![vec![Rational::from(0); u]; u];
469        let mut rhs = vec![Rational::from(0); u];
470        for row in 0..u {
471            for j in 0..u {
472                let pow = start_deg + j;
473                let basis = mono_n(pow);
474                let applied = poly_apply_order1_shift(r, &basis);
475                mat[row][j] = applied
476                    .coeffs
477                    .get(row)
478                    .cloned()
479                    .unwrap_or_else(|| Rational::from(0));
480            }
481            rhs[row] = h
482                .coeffs
483                .get(row)
484                .cloned()
485                .unwrap_or_else(|| Rational::from(0));
486        }
487        if let Some(x) = solve_rational_linear_system(mat, rhs) {
488            let mut coeffs = vec![Rational::from(0); hi_deg + 1];
489            for (j, coeff) in x.into_iter().enumerate() {
490                coeffs[start_deg + j] = coeff;
491            }
492            return Some(RatUniPoly { coeffs }.trim());
493        }
494    }
495    None
496}
497
498fn undetermined_order2(
499    a0: &Rational,
500    a1: &Rational,
501    a2: &Rational,
502    h: &RatUniPoly,
503) -> Option<RatUniPoly> {
504    let dh = h.degree().max(0) as usize;
505    for bump in 0..24 {
506        let trial_deg = (dh + 4 + bump).min(42);
507        let u = trial_deg + 1;
508        let mut mat = vec![vec![Rational::from(0); u]; u];
509        let mut rhs = vec![Rational::from(0); u];
510        for row in 0..u {
511            for j in 0..u {
512                let basis = mono_n(j);
513                let applied = poly_apply_order2(a0, a1, a2, &basis);
514                mat[row][j] = applied
515                    .coeffs
516                    .get(row)
517                    .cloned()
518                    .unwrap_or_else(|| Rational::from(0));
519            }
520            rhs[row] = h
521                .coeffs
522                .get(row)
523                .cloned()
524                .unwrap_or_else(|| Rational::from(0));
525        }
526        if let Some(x) = solve_rational_linear_system(mat, rhs) {
527            return Some(RatUniPoly { coeffs: x }.trim());
528        }
529    }
530    None
531}
532
533fn rat_poly_to_expr(pool: &ExprPool, n_sym: ExprId, p: &RatUniPoly) -> ExprId {
534    let mut terms: Vec<ExprId> = Vec::new();
535    for (deg, coeff) in p.coeffs.iter().enumerate() {
536        if coeff.is_zero() {
537            continue;
538        }
539        let coeff_q = coeff.clone();
540        let numer = coeff_q.numer();
541        let denom = coeff_q.denom();
542        let coeff_expr = if *denom == 1 {
543            pool.integer(numer.clone())
544        } else {
545            pool.rational(numer.clone(), denom.clone())
546        };
547        let pow_id = if deg == 0 {
548            coeff_expr
549        } else if deg == 1 {
550            pool.mul(vec![coeff_expr, n_sym])
551        } else {
552            pool.mul(vec![coeff_expr, pool.pow(n_sym, pool.integer(deg as i64))])
553        };
554        terms.push(pow_id);
555    }
556    match terms.len() {
557        0 => pool.integer(0_i32),
558        1 => terms[0],
559        _ => pool.add(terms),
560    }
561}
562
563fn sqrt_disc_expr(pool: &ExprPool, disc: &Rational) -> ExprId {
564    let num = disc.numer().clone();
565    let den = disc.denom().clone();
566    let prod = num * den.clone();
567    let inner = pool.integer(prod);
568    let sqrt_e = pool.func("sqrt", vec![inner]);
569    let den_e = pool.integer(den);
570    expr_div(pool, sqrt_e, den_e)
571}
572
573/// `a[0] r^d + … + a[d]` in ascending powers of `r`.
574fn char_poly_asc(a: &[Rational]) -> RatUniPoly {
575    let d = a.len() - 1;
576    let mut v = vec![Rational::from(0); d + 1];
577    for i in 0..=d {
578        v[i] = a[d - i].clone();
579    }
580    RatUniPoly { coeffs: v }.trim()
581}
582
583fn horner_rat(p: &RatUniPoly, x: &Rational) -> Rational {
584    let mut acc = Rational::from(0);
585    for c in p.coeffs.iter().rev() {
586        acc = acc * x.clone() + c.clone();
587    }
588    acc
589}
590
591fn divisors_int(mut n: Integer) -> Vec<Integer> {
592    if n < 0 {
593        n = -n;
594    }
595    if n == 0 {
596        return vec![Integer::from(1)];
597    }
598    let mut out = vec![Integer::from(1)];
599    let mut i = Integer::from(2);
600    while i.clone() * i.clone() <= n {
601        if n.clone() % i.clone() == 0 {
602            let mut pws = vec![Integer::from(1)];
603            let mut nn = n.clone();
604            while nn.clone() % i.clone() == 0 {
605                let last = pws.last().unwrap().clone();
606                pws.push(last * i.clone());
607                nn /= i.clone();
608            }
609            n = nn;
610            let old = out.clone();
611            out.clear();
612            for base in old {
613                for pw in &pws {
614                    out.push(base.clone() * pw);
615                }
616            }
617        }
618        i += 1;
619    }
620    if n > 1 {
621        let old = out.clone();
622        out.clear();
623        for base in old {
624            out.push(base.clone());
625            out.push(base * n.clone());
626        }
627    }
628    out.sort();
629    out.dedup();
630    out
631}
632
633fn peel_rational_root(p: &RatUniPoly) -> Option<Rational> {
634    if p.is_zero() {
635        return None;
636    }
637    let mut z: Vec<Integer> = Vec::new();
638    let mut lcm_den = Integer::from(1);
639    for c in &p.coeffs {
640        lcm_den = lcm_den.lcm(&c.denom().clone());
641    }
642    for c in &p.coeffs {
643        let d = c.denom().clone();
644        let scale = lcm_den.clone() / d;
645        z.push(scale * c.numer().clone());
646    }
647    let lc = z.last().cloned().unwrap_or_else(|| Integer::from(0));
648    let c0 = z.first().cloned().unwrap_or_else(|| Integer::from(0));
649    if lc.is_zero() {
650        return Some(Rational::from(0));
651    }
652    let mut try_vals: Vec<Rational> = Vec::new();
653    for pd in divisors_int(lc.clone()) {
654        for qd in divisors_int(c0.clone()) {
655            try_vals.push(Rational::from((pd.clone(), qd.clone())));
656            try_vals.push(-Rational::from((pd.clone(), qd)));
657        }
658    }
659    try_vals.sort_by(|x, y| x.partial_cmp(y).unwrap());
660    try_vals.dedup();
661    try_vals
662        .into_iter()
663        .find(|r| !p.coeffs.is_empty() && horner_rat(p, r).is_zero())
664}
665
666fn div_linear_factor(p: RatUniPoly, root: &Rational) -> RatUniPoly {
667    let r = root.clone();
668    let lin = RatUniPoly {
669        coeffs: vec![-r, Rational::from(1)],
670    }
671    .trim();
672    let (q, rem) = RatUniPoly::div_rem(&p, &lin);
673    debug_assert!(rem.is_zero());
674    q
675}
676
677fn factor_char_polynomial(mut p: RatUniPoly) -> Result<Vec<(Rational, usize)>, RsolveError> {
678    let mut roots: Vec<(Rational, usize)> = Vec::new();
679    let mut guard = 0usize;
680    while p.degree() > 0 && guard < 64 {
681        guard += 1;
682        let Some(r0) = peel_rational_root(&p) else {
683            break;
684        };
685        let mut m = 0usize;
686        while p.degree() > 0 && horner_rat(&p, &r0).is_zero() {
687            p = div_linear_factor(p, &r0);
688            m += 1;
689        }
690        roots.push((r0, m));
691    }
692    match p.degree() {
693        -1 | 0 => Ok(roots),
694        1 => {
695            let c0 = p.coeffs[0].clone();
696            let c1 = p.coeffs[1].clone();
697            if c1 == 0 {
698                return Err(RsolveError::Unsupported("degenerate characteristic".into()));
699            }
700            roots.push((-c0 / c1, 1));
701            Ok(roots)
702        }
703        2 => {
704            let c0 = p.coeffs[0].clone();
705            let c1 = p.coeffs[1].clone();
706            let c2 = p.coeffs[2].clone();
707            if c2 == 0 {
708                return Err(RsolveError::Unsupported(
709                    "characteristic degree mismatch".into(),
710                ));
711            }
712            let disc = c1.clone() * c1.clone() - Rational::from(4) * c2.clone() * c0.clone();
713            if disc == 0 {
714                let r = -c1 / (Rational::from(2) * c2.clone());
715                roots.push((r, 2));
716            } else if disc > 0 {
717                let disc_numer = disc.numer().clone();
718                let disc_denom = disc.denom().clone();
719                let (sn, rem_n) = disc_numer.sqrt_rem(Integer::new());
720                let (sd, rem_d) = disc_denom.sqrt_rem(Integer::new());
721                if rem_n != 0 || rem_d != 0 {
722                    return Err(RsolveError::Unsupported(
723                        "order-3+ with irreducible quadratic characteristic tail".into(),
724                    ));
725                }
726                let sqrt_d = Rational::from((sn, sd));
727                let r1 = (-c1.clone() + sqrt_d.clone()) / (Rational::from(2) * c2.clone());
728                let r2 = (-c1 - sqrt_d) / (Rational::from(2) * c2.clone());
729                roots.push((r1, 1));
730                roots.push((r2, 1));
731            } else {
732                return Err(RsolveError::Unsupported(
733                    "complex characteristic roots".into(),
734                ));
735            }
736            Ok(roots)
737        }
738        d => Err(RsolveError::Unsupported(format!(
739            "characteristic leftover degree {d}"
740        ))),
741    }
742}
743
744fn hom_solution_from_roots(
745    pool: &ExprPool,
746    n_sym: ExprId,
747    root_facts: &[(Rational, usize)],
748    c_syms: &[ExprId],
749) -> Result<ExprId, RsolveError> {
750    let mut terms: Vec<ExprId> = Vec::new();
751    let mut idx = 0;
752    for (r, mult) in root_facts {
753        let re = rational_atom(pool, r);
754        for p in 0..*mult {
755            if idx >= c_syms.len() {
756                return Err(RsolveError::Unsupported(
757                    "internal: not enough constant symbols".into(),
758                ));
759            }
760            let basis = if p == 0 {
761                simp(pool, pool.pow(re, n_sym))
762            } else {
763                let np = pool.pow(n_sym, pool.integer(p as i64));
764                simp(pool, pool.mul(vec![np, pool.pow(re, n_sym)]))
765            };
766            terms.push(simp(pool, pool.mul(vec![c_syms[idx], basis])));
767            idx += 1;
768        }
769    }
770    if idx != c_syms.len() {
771        return Err(RsolveError::Unsupported(
772            "internal: constant count mismatch".into(),
773        ));
774    }
775    match terms.len() {
776        0 => Ok(pool.integer(0_i32)),
777        1 => Ok(terms[0]),
778        _ => Ok(simp(pool, pool.add(terms))),
779    }
780}
781
782fn order2_r_exprs(pool: &ExprPool, a_rec: &[Rational]) -> Result<(ExprId, ExprId), RsolveError> {
783    let p = char_poly_asc(a_rec);
784    if p.degree() != 2 {
785        return Err(RsolveError::Unsupported(
786            "expected order-2 characteristic".into(),
787        ));
788    }
789    let p0 = p.coeffs[0].clone();
790    let p1 = p.coeffs[1].clone();
791    let p2 = p.coeffs[2].clone();
792    if p2 == 0 {
793        return Err(RsolveError::Unsupported("degenerate characteristic".into()));
794    }
795    let b = p1 / p2.clone();
796    let c = p0 / p2.clone();
797    let disc = b.clone() * b.clone() - Rational::from(4) * c.clone();
798    if disc < 0 {
799        return Err(RsolveError::Unsupported("complex roots".into()));
800    }
801    let sqrt_e = sqrt_disc_expr(pool, &disc);
802    let neg_b = rational_atom(pool, &(-b.clone()));
803    let half = rational_atom(pool, &Rational::from((1, 2)));
804    let inner1 = simp(pool, pool.add(vec![neg_b, sqrt_e]));
805    let r1 = simp(pool, pool.mul(vec![half, inner1]));
806    let inner2 = simp(
807        pool,
808        pool.add(vec![neg_b, pool.mul(vec![sqrt_e, pool.integer(-1_i32)])]),
809    );
810    let r2 = simp(pool, pool.mul(vec![half, inner2]));
811    Ok((r1, r2))
812}
813
814fn fresh_constants(pool: &ExprPool, k: usize) -> Vec<ExprId> {
815    (0..k)
816        .map(|i: usize| pool.symbol(format!("C{}", i), crate::kernel::Domain::Real))
817        .collect()
818}
819
820fn subs_n_int(pool: &ExprPool, expr: ExprId, n_sym: ExprId, ni: i64) -> ExprId {
821    let mut m = HashMap::new();
822    m.insert(n_sym, pool.integer(ni));
823    simp(pool, subs(expr, &m, pool))
824}
825
826#[allow(clippy::too_many_arguments)]
827fn apply_init(
828    pool: &ExprPool,
829    general: ExprId,
830    n_sym: ExprId,
831    c_syms: &[ExprId],
832    initials: &BTreeMap<i64, ExprId>,
833    d: usize,
834    a: &[Rational],
835    particular: ExprId,
836) -> Result<ExprId, RsolveError> {
837    if initials.len() != d {
838        return Err(RsolveError::InitialMismatch(format!(
839            "need {d} initial values for order {d}, got {}",
840            initials.len()
841        )));
842    }
843
844    if d == 1 {
845        let (&n0, v0) = initials.first_key_value().unwrap();
846        let r = (-a[1].clone()) / a[0].clone();
847        let r_e = rational_atom(pool, &r);
848        let p0 = subs_n_int(pool, particular, n_sym, n0);
849        let rpow = simp(pool, pool.pow(r_e, pool.integer(n0)));
850        let rhs = simp(
851            pool,
852            pool.add(vec![*v0, pool.mul(vec![p0, pool.integer(-1_i32)])]),
853        );
854        let c0v = expr_div(pool, rhs, rpow);
855        let mut m = HashMap::new();
856        m.insert(c_syms[0], c0v);
857        return Ok(simp(pool, subs(general, &m, pool)));
858    }
859
860    if d == 2 {
861        let keys: Vec<i64> = initials.keys().copied().collect();
862        if keys.len() != 2 {
863            return Err(RsolveError::InitialMismatch("need two integers".into()));
864        }
865        let (n0, n1) = (keys[0], keys[1]);
866        let (r1_e, r2_e) = order2_r_exprs(pool, a)?;
867        let v0 = *initials.get(&n0).unwrap();
868        let v1 = *initials.get(&n1).unwrap();
869        let p0 = subs_n_int(pool, particular, n_sym, n0);
870        let p1 = subs_n_int(pool, particular, n_sym, n1);
871        let v0p = simp(
872            pool,
873            pool.add(vec![v0, pool.mul(vec![p0, pool.integer(-1_i32)])]),
874        );
875        let v1p = simp(
876            pool,
877            pool.add(vec![v1, pool.mul(vec![p1, pool.integer(-1_i32)])]),
878        );
879        let a00 = simp(pool, pool.pow(r1_e, pool.integer(n0)));
880        let b00 = simp(pool, pool.pow(r2_e, pool.integer(n0)));
881        let a10 = simp(pool, pool.pow(r1_e, pool.integer(n1)));
882        let b10 = simp(pool, pool.pow(r2_e, pool.integer(n1)));
883        let det = simp(
884            pool,
885            pool.add(vec![
886                pool.mul(vec![a00, b10]),
887                pool.mul(vec![a10, b00, pool.integer(-1_i32)]),
888            ]),
889        );
890        let num_c0 = simp(
891            pool,
892            pool.add(vec![
893                pool.mul(vec![v0p, b10]),
894                pool.mul(vec![v1p, b00, pool.integer(-1_i32)]),
895            ]),
896        );
897        let num_c1 = simp(
898            pool,
899            pool.add(vec![
900                pool.mul(vec![a00, v1p]),
901                pool.mul(vec![a10, v0p, pool.integer(-1_i32)]),
902            ]),
903        );
904        let c0v = expr_div(pool, num_c0, det);
905        let c1v = expr_div(pool, num_c1, det);
906        let mut m = HashMap::new();
907        m.insert(c_syms[0], c0v);
908        m.insert(c_syms[1], c1v);
909        return Ok(simp(pool, subs(general, &m, pool)));
910    }
911
912    Err(RsolveError::InitialMismatch(
913        "initial values for order > 2 not implemented".into(),
914    ))
915}
916
917/// Solve a linear recurrence coded as `equation == 0` in `n`:
918///
919/// Each sequence term is `pool.func(seq_name, [n + integer])`.  The general
920/// solution introduces symbols `C0`, `C1`, …; pass `initials` to eliminate them.
921pub fn rsolve(
922    pool: &ExprPool,
923    equation: ExprId,
924    n: ExprId,
925    seq_name: &str,
926    initials: Option<&BTreeMap<i64, ExprId>>,
927) -> Result<ExprId, RsolveError> {
928    let (a, rhs_p) = extract_recurrence(equation, seq_name, n, pool)?;
929    let d = a.len() - 1;
930
931    let a0_lead = a[0].clone();
932    let hom_norm: Vec<Rational> = a.iter().map(|x| x.clone() / a0_lead.clone()).collect();
933    let rhs_norm = {
934        let inv = Rational::from(1) / a0_lead.clone();
935        RatUniPoly {
936            coeffs: rhs_p
937                .coeffs
938                .iter()
939                .map(|c| c.clone() * inv.clone())
940                .collect(),
941        }
942        .trim()
943    };
944
945    let particular_p = if rhs_norm.is_zero() {
946        RatUniPoly::zero()
947    } else if d == 1 {
948        let r = -hom_norm[1].clone();
949        undetermined_order1(&r, &rhs_norm).ok_or_else(|| {
950            RsolveError::Unsupported("particular solution (order 1) failed".into())
951        })?
952    } else if d == 2 {
953        undetermined_order2(&hom_norm[0], &hom_norm[1], &hom_norm[2], &rhs_norm).ok_or_else(
954            || RsolveError::Unsupported("particular solution (order 2) failed".into()),
955        )?
956    } else {
957        if !rhs_norm.is_zero() {
958            return Err(RsolveError::Unsupported(
959                "non-homogeneous order > 2 is not implemented".into(),
960            ));
961        }
962        RatUniPoly::zero()
963    };
964
965    let particular_e = if particular_p.is_zero() {
966        pool.integer(0_i32)
967    } else {
968        rat_poly_to_expr(pool, n, &particular_p)
969    };
970
971    let (hom_e, c_syms): (ExprId, Vec<ExprId>) = match d {
972        1 => {
973            let r = -hom_norm[1].clone();
974            let re = rational_atom(pool, &r);
975            let c0 = pool.symbol("C0", crate::kernel::Domain::Real);
976            let h = simp(pool, pool.mul(vec![c0, pool.pow(re, n)]));
977            (h, vec![c0])
978        }
979        2 => {
980            let c0 = pool.symbol("C0", crate::kernel::Domain::Real);
981            let c1 = pool.symbol("C1", crate::kernel::Domain::Real);
982            let (r1, r2) = order2_r_exprs(pool, &a)?;
983            let h = simp(
984                pool,
985                pool.add(vec![
986                    simp(pool, pool.mul(vec![c0, pool.pow(r1, n)])),
987                    simp(pool, pool.mul(vec![c1, pool.pow(r2, n)])),
988                ]),
989            );
990            (h, vec![c0, c1])
991        }
992        _ => {
993            let facts = factor_char_polynomial(char_poly_asc(&a))?;
994            let nconst: usize = facts.iter().map(|(_, m)| *m).sum();
995            let cs = fresh_constants(pool, nconst);
996            let h = hom_solution_from_roots(pool, n, &facts, &cs)?;
997            (h, cs)
998        }
999    };
1000
1001    let general = simp(pool, pool.add(vec![hom_e, particular_e]));
1002
1003    if let Some(init) = initials {
1004        apply_init(pool, general, n, &c_syms, init, d, &a, particular_e)
1005    } else {
1006        Ok(general)
1007    }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013    use crate::jit::eval_interp;
1014    use crate::kernel::Domain;
1015    use rug::Rational;
1016    use std::collections::HashMap;
1017
1018    fn has_sym(expr: ExprId, name: &str, pool: &ExprPool) -> bool {
1019        match pool.get(expr) {
1020            ExprData::Symbol { name: n, .. } => n == name,
1021            ExprData::Add(xs) => xs.iter().any(|&x| has_sym(x, name, pool)),
1022            ExprData::Mul(xs) => xs.iter().any(|&x| has_sym(x, name, pool)),
1023            ExprData::Pow { base, exp } => has_sym(base, name, pool) || has_sym(exp, name, pool),
1024            ExprData::Func { args, .. } => args.iter().any(|&a| has_sym(a, name, pool)),
1025            _ => false,
1026        }
1027    }
1028
1029    #[test]
1030    fn arithmetic_progression_general() {
1031        let pool = ExprPool::new();
1032        let n = pool.symbol("n", Domain::Real);
1033        let f = |args: Vec<ExprId>| pool.func("f", args);
1034        let eq = simp(
1035            &pool,
1036            pool.add(vec![
1037                f(vec![n]),
1038                pool.mul(vec![
1039                    f(vec![pool.add(vec![n, pool.integer(-1_i32)])]),
1040                    pool.integer(-1_i32),
1041                ]),
1042                pool.integer(-1_i32),
1043            ]),
1044        );
1045        let sol = rsolve(&pool, eq, n, "f", None).expect("rsolve");
1046        assert!(has_sym(sol, "C0", &pool));
1047    }
1048
1049    #[test]
1050    fn fibonacci_numeric_with_init() {
1051        use crate::sum::recurrence::solve_linear_recurrence_homogeneous;
1052        let pool = ExprPool::new();
1053        let n = pool.symbol("n", Domain::Real);
1054        let f = |args: Vec<ExprId>| pool.func("f", args);
1055        let eq = simp(
1056            &pool,
1057            pool.add(vec![
1058                f(vec![n]),
1059                pool.mul(vec![
1060                    f(vec![pool.add(vec![n, pool.integer(-1_i32)])]),
1061                    pool.integer(-1_i32),
1062                ]),
1063                pool.mul(vec![
1064                    f(vec![pool.add(vec![n, pool.integer(-2_i32)])]),
1065                    pool.integer(-1_i32),
1066                ]),
1067            ]),
1068        );
1069        let mut init = BTreeMap::new();
1070        init.insert(0, pool.integer(0));
1071        init.insert(1, pool.integer(1));
1072        let sol = rsolve(&pool, eq, n, "f", Some(&init)).expect("rsolve");
1073
1074        let ref_sol = solve_linear_recurrence_homogeneous(
1075            &pool,
1076            n,
1077            &[Rational::from(-1), Rational::from(-1), Rational::from(1)],
1078            &[pool.integer(0), pool.integer(1)],
1079        )
1080        .expect("ref");
1081
1082        for ni in 0..=12 {
1083            let mut env = HashMap::new();
1084            env.insert(n, ni as f64);
1085            let v = eval_interp(sol, &env, &pool).expect("eval");
1086            let vr = eval_interp(ref_sol.closed_form, &env, &pool).expect("eref");
1087            assert!((v - vr).abs() < 1e-4, "n={ni} rsolve={v} ref={vr}",);
1088        }
1089    }
1090}