Skip to main content

alkahest_cas/sum/
recurrence.rs

1//! Homogeneous linear recurrences with constant coefficients (order ≤ 2).
2
3use crate::kernel::{ExprId, ExprPool};
4use crate::simplify::engine::simplify;
5use rug::Rational;
6use std::fmt;
7
8fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
9    simplify(e, pool).value
10}
11
12/// Errors from [`solve_linear_recurrence_homogeneous`].
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum LinearRecurrenceError {
15    OrderUnsupported(usize),
16    InitialLength { expected: usize, got: usize },
17}
18
19impl fmt::Display for LinearRecurrenceError {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            LinearRecurrenceError::OrderUnsupported(o) => {
23                write!(f, "recurrence order {o} is not supported (max 2)")
24            }
25            LinearRecurrenceError::InitialLength { expected, got } => {
26                write!(f, "expected {expected} initial value(s), got {got}")
27            }
28        }
29    }
30}
31
32impl std::error::Error for LinearRecurrenceError {}
33
34impl crate::errors::AlkahestError for LinearRecurrenceError {
35    fn code(&self) -> &'static str {
36        match self {
37            LinearRecurrenceError::OrderUnsupported(_) => "E-REC-001",
38            LinearRecurrenceError::InitialLength { .. } => "E-REC-002",
39        }
40    }
41
42    fn remediation(&self) -> Option<&'static str> {
43        Some("use order 1 or 2 with rational coefficients; initials must match order")
44    }
45}
46
47#[derive(Debug, Clone)]
48pub struct RecurrenceSolution {
49    pub n: ExprId,
50    pub closed_form: ExprId,
51}
52
53/// Solve `∑_{i=0}^d c_i · f(n+i) = 0` with rational coefficients (`c_d ≠ 0`).
54pub fn solve_linear_recurrence_homogeneous(
55    pool: &ExprPool,
56    n: ExprId,
57    coeffs: &[Rational],
58    initials: &[ExprId],
59) -> Result<RecurrenceSolution, LinearRecurrenceError> {
60    if coeffs.len() < 2 {
61        return Err(LinearRecurrenceError::OrderUnsupported(0));
62    }
63    let d = coeffs.len() - 1;
64    match d {
65        1 => solve_order1(pool, n, coeffs, initials),
66        2 => solve_order2(pool, n, coeffs, initials),
67        _ => Err(LinearRecurrenceError::OrderUnsupported(d)),
68    }
69}
70
71fn rational_atom(pool: &ExprPool, r: &Rational) -> ExprId {
72    let numer = r.numer().clone();
73    let denom = r.denom().clone();
74    if denom == 1 {
75        pool.integer(numer)
76    } else {
77        pool.rational(numer, denom)
78    }
79}
80
81fn expr_div(pool: &ExprPool, num: ExprId, den: ExprId) -> ExprId {
82    pool.mul(vec![num, pool.pow(den, pool.integer(-1_i32))])
83}
84
85fn sqrt_disc_expr(pool: &ExprPool, disc: &Rational) -> ExprId {
86    let num = disc.numer().clone();
87    let den = disc.denom().clone();
88    let prod = num * den.clone();
89    let inner = pool.integer(prod);
90    let sqrt_e = pool.func("sqrt", vec![inner]);
91    let den_e = pool.integer(den);
92    expr_div(pool, sqrt_e, den_e)
93}
94
95fn solve_order1(
96    pool: &ExprPool,
97    n: ExprId,
98    coeffs: &[Rational],
99    initials: &[ExprId],
100) -> Result<RecurrenceSolution, LinearRecurrenceError> {
101    if coeffs.len() != 2 {
102        return Err(LinearRecurrenceError::OrderUnsupported(1));
103    }
104    if initials.len() != 1 {
105        return Err(LinearRecurrenceError::InitialLength {
106            expected: 1,
107            got: initials.len(),
108        });
109    }
110    let r = (Rational::from(0) - coeffs[0].clone()) / coeffs[1].clone();
111    let r_expr = rational_atom(pool, &r);
112    let closed = simp(pool, pool.mul(vec![initials[0], pool.pow(r_expr, n)]));
113    Ok(RecurrenceSolution {
114        n,
115        closed_form: closed,
116    })
117}
118
119fn solve_order2(
120    pool: &ExprPool,
121    n: ExprId,
122    coeffs: &[Rational],
123    initials: &[ExprId],
124) -> Result<RecurrenceSolution, LinearRecurrenceError> {
125    if coeffs.len() != 3 {
126        return Err(LinearRecurrenceError::OrderUnsupported(2));
127    }
128    if initials.len() != 2 {
129        return Err(LinearRecurrenceError::InitialLength {
130            expected: 2,
131            got: initials.len(),
132        });
133    }
134    let c0 = &coeffs[0];
135    let c1 = &coeffs[1];
136    let c2 = &coeffs[2];
137    if c2.is_zero() {
138        return Err(LinearRecurrenceError::OrderUnsupported(2));
139    }
140
141    let b = c1.clone() / c2.clone();
142    let c = c0.clone() / c2.clone();
143    let disc = b.clone() * b.clone() - Rational::from(4) * c.clone();
144    if disc < 0 {
145        return Err(LinearRecurrenceError::OrderUnsupported(2));
146    }
147
148    let sqrt_e = sqrt_disc_expr(pool, &disc);
149    let neg_b = rational_atom(pool, &(-b.clone()));
150    let half = rational_atom(pool, &Rational::from((1, 2)));
151
152    let inner1 = simp(pool, pool.add(vec![neg_b, sqrt_e]));
153    let r1 = simp(pool, pool.mul(vec![half, inner1]));
154    let inner2 = simp(
155        pool,
156        pool.add(vec![neg_b, pool.mul(vec![sqrt_e, pool.integer(-1_i32)])]),
157    );
158    let r2 = simp(pool, pool.mul(vec![half, inner2]));
159
160    let denom_e = simp(
161        pool,
162        pool.add(vec![r1, pool.mul(vec![r2, pool.integer(-1_i32)])]),
163    );
164
165    let r2_u0 = simp(pool, pool.mul(vec![initials[0], r2]));
166    let num_a = simp(
167        pool,
168        pool.add(vec![
169            initials[1],
170            pool.mul(vec![r2_u0, pool.integer(-1_i32)]),
171        ]),
172    );
173
174    let r1_u0 = simp(pool, pool.mul(vec![initials[0], r1]));
175    let num_b = simp(
176        pool,
177        pool.add(vec![
178            r1_u0,
179            pool.mul(vec![initials[1], pool.integer(-1_i32)]),
180        ]),
181    );
182
183    let big_a = expr_div(pool, num_a, denom_e);
184    let big_b = expr_div(pool, num_b, denom_e);
185
186    let closed = simp(
187        pool,
188        pool.add(vec![
189            simp(pool, pool.mul(vec![big_a, pool.pow(r1, n)])),
190            simp(pool, pool.mul(vec![big_b, pool.pow(r2, n)])),
191        ]),
192    );
193
194    Ok(RecurrenceSolution {
195        n,
196        closed_form: closed,
197    })
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::jit::eval_interp;
204    use crate::kernel::Domain;
205    use std::collections::HashMap;
206
207    #[test]
208    fn fibonacci_numeric_check() {
209        let pool = ExprPool::new();
210        let n_sym = pool.symbol("n", Domain::Real);
211        let coeffs = vec![Rational::from(-1), Rational::from(-1), Rational::from(1)];
212        let initials = vec![pool.integer(0_i32), pool.integer(1_i32)];
213        let sol =
214            solve_linear_recurrence_homogeneous(&pool, n_sym, &coeffs, &initials).expect("solve");
215
216        let mut fib = vec![0_i64, 1_i64];
217        for _ in 2..=12 {
218            let l = fib.len();
219            fib.push(fib[l - 1] + fib[l - 2]);
220        }
221
222        for (ni, &expected) in fib.iter().enumerate() {
223            let mut env = HashMap::new();
224            env.insert(n_sym, ni as f64);
225            let v = eval_interp(sol.closed_form, &env, &pool).expect("eval");
226            assert!((v - expected as f64).abs() < 1e-6);
227        }
228    }
229}