1use crate::kernel::{ExprId, ExprPool};
4use crate::simplify::engine::simplify;
5use rug::Rational;
6use std::fmt;
7
8fn simp(pool: &ExprPool, e: ExprId) -> ExprId {
9 simplify(e, pool).value
10}
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum LinearRecurrenceError {
15 OrderUnsupported(usize),
16 InitialLength { expected: usize, got: usize },
17}
18
19impl fmt::Display for LinearRecurrenceError {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 match self {
22 LinearRecurrenceError::OrderUnsupported(o) => {
23 write!(f, "recurrence order {o} is not supported (max 2)")
24 }
25 LinearRecurrenceError::InitialLength { expected, got } => {
26 write!(f, "expected {expected} initial value(s), got {got}")
27 }
28 }
29 }
30}
31
32impl std::error::Error for LinearRecurrenceError {}
33
34impl crate::errors::AlkahestError for LinearRecurrenceError {
35 fn code(&self) -> &'static str {
36 match self {
37 LinearRecurrenceError::OrderUnsupported(_) => "E-REC-001",
38 LinearRecurrenceError::InitialLength { .. } => "E-REC-002",
39 }
40 }
41
42 fn remediation(&self) -> Option<&'static str> {
43 Some("use order 1 or 2 with rational coefficients; initials must match order")
44 }
45}
46
47#[derive(Debug, Clone)]
48pub struct RecurrenceSolution {
49 pub n: ExprId,
50 pub closed_form: ExprId,
51}
52
53pub fn solve_linear_recurrence_homogeneous(
55 pool: &ExprPool,
56 n: ExprId,
57 coeffs: &[Rational],
58 initials: &[ExprId],
59) -> Result<RecurrenceSolution, LinearRecurrenceError> {
60 if coeffs.len() < 2 {
61 return Err(LinearRecurrenceError::OrderUnsupported(0));
62 }
63 let d = coeffs.len() - 1;
64 match d {
65 1 => solve_order1(pool, n, coeffs, initials),
66 2 => solve_order2(pool, n, coeffs, initials),
67 _ => Err(LinearRecurrenceError::OrderUnsupported(d)),
68 }
69}
70
71fn rational_atom(pool: &ExprPool, r: &Rational) -> ExprId {
72 let numer = r.numer().clone();
73 let denom = r.denom().clone();
74 if denom == 1 {
75 pool.integer(numer)
76 } else {
77 pool.rational(numer, denom)
78 }
79}
80
81fn expr_div(pool: &ExprPool, num: ExprId, den: ExprId) -> ExprId {
82 pool.mul(vec![num, pool.pow(den, pool.integer(-1_i32))])
83}
84
85fn sqrt_disc_expr(pool: &ExprPool, disc: &Rational) -> ExprId {
86 let num = disc.numer().clone();
87 let den = disc.denom().clone();
88 let prod = num * den.clone();
89 let inner = pool.integer(prod);
90 let sqrt_e = pool.func("sqrt", vec![inner]);
91 let den_e = pool.integer(den);
92 expr_div(pool, sqrt_e, den_e)
93}
94
95fn solve_order1(
96 pool: &ExprPool,
97 n: ExprId,
98 coeffs: &[Rational],
99 initials: &[ExprId],
100) -> Result<RecurrenceSolution, LinearRecurrenceError> {
101 if coeffs.len() != 2 {
102 return Err(LinearRecurrenceError::OrderUnsupported(1));
103 }
104 if initials.len() != 1 {
105 return Err(LinearRecurrenceError::InitialLength {
106 expected: 1,
107 got: initials.len(),
108 });
109 }
110 let r = (Rational::from(0) - coeffs[0].clone()) / coeffs[1].clone();
111 let r_expr = rational_atom(pool, &r);
112 let closed = simp(pool, pool.mul(vec![initials[0], pool.pow(r_expr, n)]));
113 Ok(RecurrenceSolution {
114 n,
115 closed_form: closed,
116 })
117}
118
119fn solve_order2(
120 pool: &ExprPool,
121 n: ExprId,
122 coeffs: &[Rational],
123 initials: &[ExprId],
124) -> Result<RecurrenceSolution, LinearRecurrenceError> {
125 if coeffs.len() != 3 {
126 return Err(LinearRecurrenceError::OrderUnsupported(2));
127 }
128 if initials.len() != 2 {
129 return Err(LinearRecurrenceError::InitialLength {
130 expected: 2,
131 got: initials.len(),
132 });
133 }
134 let c0 = &coeffs[0];
135 let c1 = &coeffs[1];
136 let c2 = &coeffs[2];
137 if c2.is_zero() {
138 return Err(LinearRecurrenceError::OrderUnsupported(2));
139 }
140
141 let b = c1.clone() / c2.clone();
142 let c = c0.clone() / c2.clone();
143 let disc = b.clone() * b.clone() - Rational::from(4) * c.clone();
144 if disc < 0 {
145 return Err(LinearRecurrenceError::OrderUnsupported(2));
146 }
147
148 let sqrt_e = sqrt_disc_expr(pool, &disc);
149 let neg_b = rational_atom(pool, &(-b.clone()));
150 let half = rational_atom(pool, &Rational::from((1, 2)));
151
152 let inner1 = simp(pool, pool.add(vec![neg_b, sqrt_e]));
153 let r1 = simp(pool, pool.mul(vec![half, inner1]));
154 let inner2 = simp(
155 pool,
156 pool.add(vec![neg_b, pool.mul(vec![sqrt_e, pool.integer(-1_i32)])]),
157 );
158 let r2 = simp(pool, pool.mul(vec![half, inner2]));
159
160 let denom_e = simp(
161 pool,
162 pool.add(vec![r1, pool.mul(vec![r2, pool.integer(-1_i32)])]),
163 );
164
165 let r2_u0 = simp(pool, pool.mul(vec![initials[0], r2]));
166 let num_a = simp(
167 pool,
168 pool.add(vec![
169 initials[1],
170 pool.mul(vec![r2_u0, pool.integer(-1_i32)]),
171 ]),
172 );
173
174 let r1_u0 = simp(pool, pool.mul(vec![initials[0], r1]));
175 let num_b = simp(
176 pool,
177 pool.add(vec![
178 r1_u0,
179 pool.mul(vec![initials[1], pool.integer(-1_i32)]),
180 ]),
181 );
182
183 let big_a = expr_div(pool, num_a, denom_e);
184 let big_b = expr_div(pool, num_b, denom_e);
185
186 let closed = simp(
187 pool,
188 pool.add(vec![
189 simp(pool, pool.mul(vec![big_a, pool.pow(r1, n)])),
190 simp(pool, pool.mul(vec![big_b, pool.pow(r2, n)])),
191 ]),
192 );
193
194 Ok(RecurrenceSolution {
195 n,
196 closed_form: closed,
197 })
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::jit::eval_interp;
204 use crate::kernel::Domain;
205 use std::collections::HashMap;
206
207 #[test]
208 fn fibonacci_numeric_check() {
209 let pool = ExprPool::new();
210 let n_sym = pool.symbol("n", Domain::Real);
211 let coeffs = vec![Rational::from(-1), Rational::from(-1), Rational::from(1)];
212 let initials = vec![pool.integer(0_i32), pool.integer(1_i32)];
213 let sol =
214 solve_linear_recurrence_homogeneous(&pool, n_sym, &coeffs, &initials).expect("solve");
215
216 let mut fib = vec![0_i64, 1_i64];
217 for _ in 2..=12 {
218 let l = fib.len();
219 fib.push(fib[l - 1] + fib[l - 2]);
220 }
221
222 for (ni, &expected) in fib.iter().enumerate() {
223 let mut env = HashMap::new();
224 env.insert(n_sym, ni as f64);
225 let v = eval_interp(sol.closed_form, &env, &pool).expect("eval");
226 assert!((v - expected as f64).abs() < 1e-6);
227 }
228 }
229}