1use crate::diff::{diff, DiffError};
4use crate::flint::FlintPoly;
5use crate::kernel::{subs, Domain, ExprData, ExprId, ExprPool};
6use crate::poly::{RationalFunction, UniPoly};
7use crate::simplify::simplify;
8use std::collections::HashMap;
9use std::fmt;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub struct Series(pub ExprId);
18
19impl Series {
20 pub fn expr(self) -> ExprId {
21 self.0
22 }
23}
24
25#[derive(Debug)]
26pub enum SeriesError {
27 Diff(DiffError),
29 InvalidOrder,
31}
32
33impl fmt::Display for SeriesError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 SeriesError::Diff(e) => write!(f, "{e}"),
37 SeriesError::InvalidOrder => write!(f, "series order must be >= 1"),
38 }
39 }
40}
41
42impl std::error::Error for SeriesError {
43 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44 match self {
45 SeriesError::Diff(e) => Some(e),
46 SeriesError::InvalidOrder => None,
47 }
48 }
49}
50
51impl crate::errors::AlkahestError for SeriesError {
52 fn code(&self) -> &'static str {
53 match self {
54 SeriesError::Diff(_) => "E-SERIES-001",
55 SeriesError::InvalidOrder => "E-SERIES-002",
56 }
57 }
58
59 fn remediation(&self) -> Option<&'static str> {
60 match self {
61 SeriesError::Diff(_) => {
62 Some("ensure all functions are registered primitives with differentiation rules")
63 }
64 SeriesError::InvalidOrder => Some("pass order >= 1 (exclusive truncation degree in x)"),
65 }
66 }
67}
68
69impl From<DiffError> for SeriesError {
70 fn from(e: DiffError) -> Self {
71 SeriesError::Diff(e)
72 }
73}
74
75pub fn series(
91 expr: ExprId,
92 var: ExprId,
93 point: ExprId,
94 order: u32,
95 pool: &ExprPool,
96) -> Result<Series, SeriesError> {
97 let LocalExpansion {
98 valuation,
99 coeffs,
100 h_expr,
101 } = local_expansion(expr, var, point, order, pool)?;
102
103 Ok(assemble_series(&coeffs, valuation, h_expr, order, pool))
104}
105
106#[derive(Clone, Debug)]
114pub(crate) struct LocalExpansion {
115 pub valuation: i32,
116 pub coeffs: Vec<ExprId>,
117 pub h_expr: ExprId,
118}
119
120pub(crate) fn local_expansion(
121 expr: ExprId,
122 var: ExprId,
123 point: ExprId,
124 order: u32,
125 pool: &ExprPool,
126) -> Result<LocalExpansion, SeriesError> {
127 if order == 0 {
128 return Err(SeriesError::InvalidOrder);
129 }
130
131 let xi = pool.symbol("__sxp", Domain::Real);
132 let mut map = HashMap::new();
133 map.insert(var, pool.add(vec![point, xi]));
134 let shifted = subs(expr, &map, pool);
135
136 let h_expr = expansion_increment(pool, var, point);
137
138 expansion_matched_laurent(shifted, xi, h_expr, order, pool)
139}
140
141fn factorial_u32(n: u32) -> rug::Integer {
142 let mut r = rug::Integer::from(1);
143 for i in 2..=n {
144 r *= i;
145 }
146 r
147}
148
149fn expansion_increment(pool: &ExprPool, var: ExprId, point: ExprId) -> ExprId {
150 match pool.get(point) {
151 ExprData::Integer(n) if n.0 == 0 => var,
152 _ => pool.add(vec![var, pool.mul(vec![pool.integer(-1_i32), point])]),
153 }
154}
155
156fn laurent_big_o_pow(valuation: i32, order: u32) -> i64 {
157 if valuation < 0 {
158 1
159 } else {
160 order as i64
161 }
162}
163
164fn is_structural_zero(id: ExprId, pool: &ExprPool) -> bool {
165 matches!(pool.get(id), ExprData::Integer(n) if n.0 == 0)
166}
167
168fn collect_atom_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
169 match pool.get(expr) {
170 ExprData::Pow { base, exp } => {
171 let n = pool.with(exp, |d| match d {
172 ExprData::Integer(i) => Some(i.0.clone()),
173 _ => None,
174 })?;
175 if n > 0 {
176 Some((vec![expr], vec![]))
177 } else if n < 0 {
178 let mag = (-n).to_u32()?;
179 let pos_exp = pool.integer(mag as i64);
180 Some((vec![], vec![pool.pow(base, pos_exp)]))
181 } else {
182 Some((vec![pool.integer(1_i32)], vec![]))
183 }
184 }
185 ExprData::Integer(_)
186 | ExprData::Rational(_)
187 | ExprData::Float(_)
188 | ExprData::Symbol { .. }
189 | ExprData::Func { .. } => Some((vec![expr], vec![])),
190 ExprData::Add(_)
191 | ExprData::Mul(_)
192 | ExprData::Piecewise { .. }
193 | ExprData::Predicate { .. }
194 | ExprData::Forall { .. }
195 | ExprData::Exists { .. }
196 | ExprData::BigO(_) => None,
197 }
198}
199
200fn collect_term_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
201 match pool.get(expr) {
202 ExprData::Mul(args) => {
203 let mut nums = Vec::new();
204 let mut dens = Vec::new();
205 for &a in &args {
206 let (n, d) = collect_atom_factors(a, pool)?;
207 nums.extend(n);
208 dens.extend(d);
209 }
210 Some((nums, dens))
211 }
212 _ => collect_atom_factors(expr, pool),
213 }
214}
215
216fn product_sorted(pool: &ExprPool, factors: Vec<ExprId>) -> ExprId {
217 match factors.len() {
218 0 => pool.integer(1_i32),
219 1 => factors[0],
220 _ => pool.mul(factors),
221 }
222}
223
224fn unipoly_valuation(p: &UniPoly) -> Option<u32> {
225 for (i, c) in p.coefficients().into_iter().enumerate() {
226 if c != 0 {
227 return Some(i as u32);
228 }
229 }
230 None
231}
232
233fn unipoly_strip_low(p: &UniPoly, k: u32) -> UniPoly {
234 let coeffs: Vec<rug::Integer> = p.coefficients().into_iter().skip(k as usize).collect();
235 UniPoly {
236 var: p.var,
237 coeffs: FlintPoly::from_rug_coefficients(&coeffs),
238 }
239}
240
241fn taylor_coefficients(
242 mut cur: ExprId,
243 xi: ExprId,
244 num: u32,
245 pool: &ExprPool,
246) -> Result<Vec<ExprId>, SeriesError> {
247 let mut mapping = HashMap::new();
248 mapping.insert(xi, pool.integer(0_i32));
249 let mut out = Vec::with_capacity(num as usize);
250 for k in 0..num {
251 let ev = subs(cur, &mapping, pool);
252 let simp = simplify(ev, pool).value;
253 let fc = factorial_u32(k);
254 let inv_fact = pool.rational(rug::Integer::from(1), fc);
255 let coeff = simplify(pool.mul(vec![simp, inv_fact]), pool).value;
256 out.push(coeff);
257 if k + 1 < num {
258 cur = diff(cur, xi, pool)?.value;
259 }
260 }
261 Ok(out)
262}
263
264fn assemble_series(
265 coeffs: &[ExprId],
266 valuation: i32,
267 h_expr: ExprId,
268 order: u32,
269 pool: &ExprPool,
270) -> Series {
271 let mut terms = Vec::new();
272 for (k, coeff) in coeffs.iter().enumerate() {
273 if is_structural_zero(*coeff, pool) {
274 continue;
275 }
276 let exp = valuation + k as i32;
277 let pow_term = if exp == 0 {
278 pool.integer(1_i32)
279 } else if exp == 1 {
280 h_expr
281 } else {
282 pool.pow(h_expr, pool.integer(exp as i64))
283 };
284 terms.push(pool.mul(vec![*coeff, pow_term]));
285 }
286 let big_o_pow = laurent_big_o_pow(valuation, order);
287 let o_term = pool.big_o(pool.pow(h_expr, pool.integer(big_o_pow)));
288 terms.push(o_term);
289 Series(pool.add(terms))
290}
291
292fn expansion_matched_laurent(
293 shifted: ExprId,
294 xi: ExprId,
295 h_expr: ExprId,
296 order: u32,
297 pool: &ExprPool,
298) -> Result<LocalExpansion, SeriesError> {
299 let (nums, dens) = match collect_term_factors(shifted, pool) {
300 Some(p) => p,
301 None => {
302 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
303 return Ok(LocalExpansion {
304 valuation: 0,
305 coeffs,
306 h_expr,
307 });
308 }
309 };
310
311 let n_expr = product_sorted(pool, nums);
312 let d_expr = product_sorted(pool, dens);
313
314 let rf = match RationalFunction::from_symbolic(n_expr, d_expr, vec![xi], pool) {
315 Ok(r) => r,
316 Err(_) => {
317 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
318 return Ok(LocalExpansion {
319 valuation: 0,
320 coeffs,
321 h_expr,
322 });
323 }
324 };
325
326 if rf.numer.is_zero() {
327 return Ok(LocalExpansion {
328 valuation: 0,
329 coeffs: vec![pool.integer(0_i32)],
330 h_expr,
331 });
332 }
333
334 let n_uni = match UniPoly::from_symbolic(rf.numer.to_expr(pool), xi, pool) {
335 Ok(u) => u,
336 Err(_) => {
337 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
338 return Ok(LocalExpansion {
339 valuation: 0,
340 coeffs,
341 h_expr,
342 });
343 }
344 };
345 let d_uni = match UniPoly::from_symbolic(rf.denom.to_expr(pool), xi, pool) {
346 Ok(u) => u,
347 Err(_) => {
348 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
349 return Ok(LocalExpansion {
350 valuation: 0,
351 coeffs,
352 h_expr,
353 });
354 }
355 };
356
357 let vn = match unipoly_valuation(&n_uni) {
358 Some(v) => v,
359 None => {
360 return Ok(LocalExpansion {
361 valuation: 0,
362 coeffs: vec![pool.integer(0_i32)],
363 h_expr,
364 });
365 }
366 };
367 let vd = match unipoly_valuation(&d_uni) {
368 Some(v) => v,
369 None => {
370 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
371 return Ok(LocalExpansion {
372 valuation: 0,
373 coeffs,
374 h_expr,
375 });
376 }
377 };
378
379 let valuation = vn as i32 - vd as i32;
380 let n0 = unipoly_strip_low(&n_uni, vn);
381 let d0 = unipoly_strip_low(&d_uni, vd);
382
383 let d0c = d0.coefficients();
384 if d0c.is_empty() || d0c[0] == 0 {
385 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
386 return Ok(LocalExpansion {
387 valuation: 0,
388 coeffs,
389 h_expr,
390 });
391 }
392
393 let n0_e = n0.to_symbolic_expr(pool);
394 let d0_e = d0.to_symbolic_expr(pool);
395 let inv_d = pool.pow(d0_e, pool.integer(-1_i32));
396 let g = simplify(pool.mul(vec![n0_e, inv_d]), pool).value;
397
398 let num_taylor: u32 = if valuation < 0 {
399 order
400 } else {
401 (order as i32 - valuation).max(0) as u32
402 };
403
404 if num_taylor == 0 {
405 return Ok(LocalExpansion {
406 valuation,
407 coeffs: Vec::new(),
408 h_expr,
409 });
410 }
411
412 let coeffs = taylor_coefficients(g, xi, num_taylor, pool)?;
413 Ok(LocalExpansion {
414 valuation,
415 coeffs,
416 h_expr,
417 })
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::kernel::{Domain, ExprData};
424
425 fn contains_big_o(id: ExprId, pool: &ExprPool) -> bool {
426 match pool.get(id) {
427 ExprData::BigO(_) => true,
428 ExprData::Add(xs) | ExprData::Mul(xs) => xs.iter().any(|e| contains_big_o(*e, pool)),
429 ExprData::Pow { base, exp } => contains_big_o(base, pool) || contains_big_o(exp, pool),
430 ExprData::Func { args, .. } => args.iter().any(|e| contains_big_o(*e, pool)),
431 _ => false,
432 }
433 }
434
435 #[test]
436 fn series_cos_about_zero_has_big_o() {
437 let p = ExprPool::new();
438 let x = p.symbol("x", Domain::Real);
439 let z = p.integer(0);
440 let cx = p.func("cos", vec![x]);
441 let s = series(cx, x, z, 6, &p).unwrap();
442 assert!(contains_big_o(s.expr(), &p));
443 }
444
445 #[test]
446 fn series_inv_x_laurent_has_big_o() {
447 let p = ExprPool::new();
448 let x = p.symbol("x", Domain::Real);
449 let z = p.integer(0);
450 let ix = p.pow(x, p.integer(-1));
451 let s = series(ix, x, z, 4, &p).unwrap();
452 assert!(contains_big_o(s.expr(), &p));
453 }
454}