1use crate::diff::{diff, DiffError};
4use crate::flint::FlintPoly;
5use crate::kernel::{subs, Domain, ExprData, ExprId, ExprPool};
6use crate::poly::{RationalFunction, UniPoly};
7use crate::simplify::simplify;
8use std::collections::HashMap;
9use std::fmt;
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub struct Series(pub ExprId);
18
19impl Series {
20 pub fn expr(self) -> ExprId {
21 self.0
22 }
23}
24
25#[derive(Debug)]
26pub enum SeriesError {
27 Diff(DiffError),
29 InvalidOrder,
31}
32
33impl fmt::Display for SeriesError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 SeriesError::Diff(e) => write!(f, "{e}"),
37 SeriesError::InvalidOrder => write!(f, "series order must be >= 1"),
38 }
39 }
40}
41
42impl std::error::Error for SeriesError {
43 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44 match self {
45 SeriesError::Diff(e) => Some(e),
46 SeriesError::InvalidOrder => None,
47 }
48 }
49}
50
51impl crate::errors::AlkahestError for SeriesError {
52 fn code(&self) -> &'static str {
53 match self {
54 SeriesError::Diff(_) => "E-SERIES-001",
55 SeriesError::InvalidOrder => "E-SERIES-002",
56 }
57 }
58
59 fn remediation(&self) -> Option<&'static str> {
60 match self {
61 SeriesError::Diff(_) => {
62 Some("ensure all functions are registered primitives with differentiation rules")
63 }
64 SeriesError::InvalidOrder => Some("pass order >= 1 (exclusive truncation degree in x)"),
65 }
66 }
67}
68
69impl From<DiffError> for SeriesError {
70 fn from(e: DiffError) -> Self {
71 SeriesError::Diff(e)
72 }
73}
74
75pub fn series(
91 expr: ExprId,
92 var: ExprId,
93 point: ExprId,
94 order: u32,
95 pool: &ExprPool,
96) -> Result<Series, SeriesError> {
97 let LocalExpansion {
98 valuation,
99 coeffs,
100 h_expr,
101 } = local_expansion(expr, var, point, order, pool)?;
102
103 Ok(assemble_series(&coeffs, valuation, h_expr, order, pool))
104}
105
106#[derive(Clone, Debug)]
114pub(crate) struct LocalExpansion {
115 pub valuation: i32,
116 pub coeffs: Vec<ExprId>,
117 pub h_expr: ExprId,
118}
119
120pub(crate) fn local_expansion(
121 expr: ExprId,
122 var: ExprId,
123 point: ExprId,
124 order: u32,
125 pool: &ExprPool,
126) -> Result<LocalExpansion, SeriesError> {
127 if order == 0 {
128 return Err(SeriesError::InvalidOrder);
129 }
130
131 let xi = pool.symbol("__sxp", Domain::Real);
132 let mut map = HashMap::new();
133 map.insert(var, pool.add(vec![point, xi]));
134 let shifted = subs(expr, &map, pool);
135
136 let h_expr = expansion_increment(pool, var, point);
137
138 expansion_matched_laurent(shifted, xi, h_expr, order, pool)
139}
140
141fn factorial_u32(n: u32) -> rug::Integer {
142 let mut r = rug::Integer::from(1);
143 for i in 2..=n {
144 r *= i;
145 }
146 r
147}
148
149fn expansion_increment(pool: &ExprPool, var: ExprId, point: ExprId) -> ExprId {
150 match pool.get(point) {
151 ExprData::Integer(n) if n.0 == 0 => var,
152 _ => pool.add(vec![var, pool.mul(vec![pool.integer(-1_i32), point])]),
153 }
154}
155
156fn laurent_big_o_pow(valuation: i32, order: u32) -> i64 {
157 if valuation < 0 {
158 1
159 } else {
160 order as i64
161 }
162}
163
164fn is_structural_zero(id: ExprId, pool: &ExprPool) -> bool {
165 matches!(pool.get(id), ExprData::Integer(n) if n.0 == 0)
166}
167
168fn collect_atom_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
169 match pool.get(expr) {
170 ExprData::Pow { base, exp } => {
171 let n = pool.with(exp, |d| match d {
172 ExprData::Integer(i) => Some(i.0.clone()),
173 _ => None,
174 })?;
175 if n > 0 {
176 Some((vec![expr], vec![]))
177 } else if n < 0 {
178 let mag = (-n).to_u32()?;
179 let pos_exp = pool.integer(mag as i64);
180 Some((vec![], vec![pool.pow(base, pos_exp)]))
181 } else {
182 Some((vec![pool.integer(1_i32)], vec![]))
183 }
184 }
185 ExprData::Integer(_)
186 | ExprData::Rational(_)
187 | ExprData::Float(_)
188 | ExprData::Symbol { .. }
189 | ExprData::Func { .. } => Some((vec![expr], vec![])),
190 ExprData::Add(_)
191 | ExprData::Mul(_)
192 | ExprData::Piecewise { .. }
193 | ExprData::Predicate { .. }
194 | ExprData::Forall { .. }
195 | ExprData::Exists { .. }
196 | ExprData::RootSum { .. }
197 | ExprData::BigO(_) => None,
198 }
199}
200
201fn collect_term_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
202 match pool.get(expr) {
203 ExprData::Mul(args) => {
204 let mut nums = Vec::new();
205 let mut dens = Vec::new();
206 for &a in &args {
207 let (n, d) = collect_atom_factors(a, pool)?;
208 nums.extend(n);
209 dens.extend(d);
210 }
211 Some((nums, dens))
212 }
213 _ => collect_atom_factors(expr, pool),
214 }
215}
216
217fn product_sorted(pool: &ExprPool, factors: Vec<ExprId>) -> ExprId {
218 match factors.len() {
219 0 => pool.integer(1_i32),
220 1 => factors[0],
221 _ => pool.mul(factors),
222 }
223}
224
225fn unipoly_valuation(p: &UniPoly) -> Option<u32> {
226 for (i, c) in p.coefficients().into_iter().enumerate() {
227 if c != 0 {
228 return Some(i as u32);
229 }
230 }
231 None
232}
233
234fn unipoly_strip_low(p: &UniPoly, k: u32) -> UniPoly {
235 let coeffs: Vec<rug::Integer> = p.coefficients().into_iter().skip(k as usize).collect();
236 UniPoly {
237 var: p.var,
238 coeffs: FlintPoly::from_rug_coefficients(&coeffs),
239 }
240}
241
242fn taylor_coefficients(
243 mut cur: ExprId,
244 xi: ExprId,
245 num: u32,
246 pool: &ExprPool,
247) -> Result<Vec<ExprId>, SeriesError> {
248 let mut mapping = HashMap::new();
249 mapping.insert(xi, pool.integer(0_i32));
250 let mut out = Vec::with_capacity(num as usize);
251 for k in 0..num {
252 let ev = subs(cur, &mapping, pool);
253 let simp = simplify(ev, pool).value;
254 let fc = factorial_u32(k);
255 let inv_fact = pool.rational(rug::Integer::from(1), fc);
256 let coeff = simplify(pool.mul(vec![simp, inv_fact]), pool).value;
257 out.push(coeff);
258 if k + 1 < num {
259 cur = diff(cur, xi, pool)?.value;
260 }
261 }
262 Ok(out)
263}
264
265fn assemble_series(
266 coeffs: &[ExprId],
267 valuation: i32,
268 h_expr: ExprId,
269 order: u32,
270 pool: &ExprPool,
271) -> Series {
272 let mut terms = Vec::new();
273 for (k, coeff) in coeffs.iter().enumerate() {
274 if is_structural_zero(*coeff, pool) {
275 continue;
276 }
277 let exp = valuation + k as i32;
278 let pow_term = if exp == 0 {
279 pool.integer(1_i32)
280 } else if exp == 1 {
281 h_expr
282 } else {
283 pool.pow(h_expr, pool.integer(exp as i64))
284 };
285 terms.push(pool.mul(vec![*coeff, pow_term]));
286 }
287 let big_o_pow = laurent_big_o_pow(valuation, order);
288 let o_term = pool.big_o(pool.pow(h_expr, pool.integer(big_o_pow)));
289 terms.push(o_term);
290 Series(pool.add(terms))
291}
292
293fn expansion_matched_laurent(
294 shifted: ExprId,
295 xi: ExprId,
296 h_expr: ExprId,
297 order: u32,
298 pool: &ExprPool,
299) -> Result<LocalExpansion, SeriesError> {
300 let (nums, dens) = match collect_term_factors(shifted, pool) {
301 Some(p) => p,
302 None => {
303 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
304 return Ok(LocalExpansion {
305 valuation: 0,
306 coeffs,
307 h_expr,
308 });
309 }
310 };
311
312 let n_expr = product_sorted(pool, nums);
313 let d_expr = product_sorted(pool, dens);
314
315 let rf = match RationalFunction::from_symbolic(n_expr, d_expr, vec![xi], pool) {
316 Ok(r) => r,
317 Err(_) => {
318 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
319 return Ok(LocalExpansion {
320 valuation: 0,
321 coeffs,
322 h_expr,
323 });
324 }
325 };
326
327 if rf.numer.is_zero() {
328 return Ok(LocalExpansion {
329 valuation: 0,
330 coeffs: vec![pool.integer(0_i32)],
331 h_expr,
332 });
333 }
334
335 let n_uni = match UniPoly::from_symbolic(rf.numer.to_expr(pool), xi, pool) {
336 Ok(u) => u,
337 Err(_) => {
338 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
339 return Ok(LocalExpansion {
340 valuation: 0,
341 coeffs,
342 h_expr,
343 });
344 }
345 };
346 let d_uni = match UniPoly::from_symbolic(rf.denom.to_expr(pool), xi, pool) {
347 Ok(u) => u,
348 Err(_) => {
349 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
350 return Ok(LocalExpansion {
351 valuation: 0,
352 coeffs,
353 h_expr,
354 });
355 }
356 };
357
358 let vn = match unipoly_valuation(&n_uni) {
359 Some(v) => v,
360 None => {
361 return Ok(LocalExpansion {
362 valuation: 0,
363 coeffs: vec![pool.integer(0_i32)],
364 h_expr,
365 });
366 }
367 };
368 let vd = match unipoly_valuation(&d_uni) {
369 Some(v) => v,
370 None => {
371 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
372 return Ok(LocalExpansion {
373 valuation: 0,
374 coeffs,
375 h_expr,
376 });
377 }
378 };
379
380 let valuation = vn as i32 - vd as i32;
381 let n0 = unipoly_strip_low(&n_uni, vn);
382 let d0 = unipoly_strip_low(&d_uni, vd);
383
384 let d0c = d0.coefficients();
385 if d0c.is_empty() || d0c[0] == 0 {
386 let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
387 return Ok(LocalExpansion {
388 valuation: 0,
389 coeffs,
390 h_expr,
391 });
392 }
393
394 let n0_e = n0.to_symbolic_expr(pool);
395 let d0_e = d0.to_symbolic_expr(pool);
396 let inv_d = pool.pow(d0_e, pool.integer(-1_i32));
397 let g = simplify(pool.mul(vec![n0_e, inv_d]), pool).value;
398
399 let num_taylor: u32 = if valuation < 0 {
400 order
401 } else {
402 (order as i32 - valuation).max(0) as u32
403 };
404
405 if num_taylor == 0 {
406 return Ok(LocalExpansion {
407 valuation,
408 coeffs: Vec::new(),
409 h_expr,
410 });
411 }
412
413 let coeffs = taylor_coefficients(g, xi, num_taylor, pool)?;
414 Ok(LocalExpansion {
415 valuation,
416 coeffs,
417 h_expr,
418 })
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::kernel::{Domain, ExprData};
425
426 fn contains_big_o(id: ExprId, pool: &ExprPool) -> bool {
427 match pool.get(id) {
428 ExprData::BigO(_) => true,
429 ExprData::Add(xs) | ExprData::Mul(xs) => xs.iter().any(|e| contains_big_o(*e, pool)),
430 ExprData::Pow { base, exp } => contains_big_o(base, pool) || contains_big_o(exp, pool),
431 ExprData::Func { args, .. } => args.iter().any(|e| contains_big_o(*e, pool)),
432 _ => false,
433 }
434 }
435
436 #[test]
437 fn series_cos_about_zero_has_big_o() {
438 let p = ExprPool::new();
439 let x = p.symbol("x", Domain::Real);
440 let z = p.integer(0);
441 let cx = p.func("cos", vec![x]);
442 let s = series(cx, x, z, 6, &p).unwrap();
443 assert!(contains_big_o(s.expr(), &p));
444 }
445
446 #[test]
447 fn series_inv_x_laurent_has_big_o() {
448 let p = ExprPool::new();
449 let x = p.symbol("x", Domain::Real);
450 let z = p.integer(0);
451 let ix = p.pow(x, p.integer(-1));
452 let s = series(ix, x, z, 4, &p).unwrap();
453 assert!(contains_big_o(s.expr(), &p));
454 }
455}