Skip to main content

alkahest_cas/calculus/
series.rs

1//! Truncated Taylor / Laurent series with symbolic [`crate::kernel::ExprData::BigO`] remainder (V2-15).
2
3use crate::diff::{diff, DiffError};
4use crate::flint::FlintPoly;
5use crate::kernel::{subs, Domain, ExprData, ExprId, ExprPool};
6use crate::poly::{RationalFunction, UniPoly};
7use crate::simplify::simplify;
8use std::collections::HashMap;
9use std::fmt;
10
11// ---------------------------------------------------------------------------
12// Public types
13// ---------------------------------------------------------------------------
14
15/// Result of [`series`] — truncated expansion plus big-O bound as one [`ExprId`].
16#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub struct Series(pub ExprId);
18
19impl Series {
20    pub fn expr(self) -> ExprId {
21        self.0
22    }
23}
24
25#[derive(Debug)]
26pub enum SeriesError {
27    /// Differentiation failed while forming Taylor coefficients.
28    Diff(DiffError),
29    /// `order` must be positive.
30    InvalidOrder,
31}
32
33impl fmt::Display for SeriesError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            SeriesError::Diff(e) => write!(f, "{e}"),
37            SeriesError::InvalidOrder => write!(f, "series order must be >= 1"),
38        }
39    }
40}
41
42impl std::error::Error for SeriesError {
43    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44        match self {
45            SeriesError::Diff(e) => Some(e),
46            SeriesError::InvalidOrder => None,
47        }
48    }
49}
50
51impl crate::errors::AlkahestError for SeriesError {
52    fn code(&self) -> &'static str {
53        match self {
54            SeriesError::Diff(_) => "E-SERIES-001",
55            SeriesError::InvalidOrder => "E-SERIES-002",
56        }
57    }
58
59    fn remediation(&self) -> Option<&'static str> {
60        match self {
61            SeriesError::Diff(_) => {
62                Some("ensure all functions are registered primitives with differentiation rules")
63            }
64            SeriesError::InvalidOrder => Some("pass order >= 1 (exclusive truncation degree in x)"),
65        }
66    }
67}
68
69impl From<DiffError> for SeriesError {
70    fn from(e: DiffError) -> Self {
71        SeriesError::Diff(e)
72    }
73}
74
75// ---------------------------------------------------------------------------
76// Entry point
77// ---------------------------------------------------------------------------
78
79/// Truncated Taylor or Laurent expansion of `expr` in `var` about `point`.
80///
81/// Let `h = var - point`. The returned expression has the shape
82/// `⋯ + O(h^k)` where `k = order` for analytic series (`valuation ≥ 0`), and
83/// `k = 1` when a polar term (`valuation < 0`) is present — matching the
84/// Laurent examples in the roadmap (`1/x` about `0` gives `x⁻¹ + O(x)`).
85///
86/// The `order` parameter matches the Taylor convention used in the roadmap:
87/// include powers `h^e` with `valuation ≤ e < order` when `valuation ≥ 0`, and
88/// when `valuation < 0` include the polar tail using `order` Taylor coefficients
89/// of the analytic factor `h^{-valuation} · f`.
90pub fn series(
91    expr: ExprId,
92    var: ExprId,
93    point: ExprId,
94    order: u32,
95    pool: &ExprPool,
96) -> Result<Series, SeriesError> {
97    let LocalExpansion {
98        valuation,
99        coeffs,
100        h_expr,
101    } = local_expansion(expr, var, point, order, pool)?;
102
103    Ok(assemble_series(&coeffs, valuation, h_expr, order, pool))
104}
105
106// ---------------------------------------------------------------------------
107// Internals
108// ---------------------------------------------------------------------------
109
110/// Local Laurent / Taylor data about `point`: `expr = ∑ᵢ coeffᵢ · h^{valuation+i}` up to truncation.
111///
112/// `h` is `var - point`, or bare `var` when `point` is the integer zero (matching [`series`]).
113#[derive(Clone, Debug)]
114pub(crate) struct LocalExpansion {
115    pub valuation: i32,
116    pub coeffs: Vec<ExprId>,
117    pub h_expr: ExprId,
118}
119
120pub(crate) fn local_expansion(
121    expr: ExprId,
122    var: ExprId,
123    point: ExprId,
124    order: u32,
125    pool: &ExprPool,
126) -> Result<LocalExpansion, SeriesError> {
127    if order == 0 {
128        return Err(SeriesError::InvalidOrder);
129    }
130
131    let xi = pool.symbol("__sxp", Domain::Real);
132    let mut map = HashMap::new();
133    map.insert(var, pool.add(vec![point, xi]));
134    let shifted = subs(expr, &map, pool);
135
136    let h_expr = expansion_increment(pool, var, point);
137
138    expansion_matched_laurent(shifted, xi, h_expr, order, pool)
139}
140
141fn factorial_u32(n: u32) -> rug::Integer {
142    let mut r = rug::Integer::from(1);
143    for i in 2..=n {
144        r *= i;
145    }
146    r
147}
148
149fn expansion_increment(pool: &ExprPool, var: ExprId, point: ExprId) -> ExprId {
150    match pool.get(point) {
151        ExprData::Integer(n) if n.0 == 0 => var,
152        _ => pool.add(vec![var, pool.mul(vec![pool.integer(-1_i32), point])]),
153    }
154}
155
156fn laurent_big_o_pow(valuation: i32, order: u32) -> i64 {
157    if valuation < 0 {
158        1
159    } else {
160        order as i64
161    }
162}
163
164fn is_structural_zero(id: ExprId, pool: &ExprPool) -> bool {
165    matches!(pool.get(id), ExprData::Integer(n) if n.0 == 0)
166}
167
168fn collect_atom_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
169    match pool.get(expr) {
170        ExprData::Pow { base, exp } => {
171            let n = pool.with(exp, |d| match d {
172                ExprData::Integer(i) => Some(i.0.clone()),
173                _ => None,
174            })?;
175            if n > 0 {
176                Some((vec![expr], vec![]))
177            } else if n < 0 {
178                let mag = (-n).to_u32()?;
179                let pos_exp = pool.integer(mag as i64);
180                Some((vec![], vec![pool.pow(base, pos_exp)]))
181            } else {
182                Some((vec![pool.integer(1_i32)], vec![]))
183            }
184        }
185        ExprData::Integer(_)
186        | ExprData::Rational(_)
187        | ExprData::Float(_)
188        | ExprData::Symbol { .. }
189        | ExprData::Func { .. } => Some((vec![expr], vec![])),
190        ExprData::Add(_)
191        | ExprData::Mul(_)
192        | ExprData::Piecewise { .. }
193        | ExprData::Predicate { .. }
194        | ExprData::Forall { .. }
195        | ExprData::Exists { .. }
196        | ExprData::BigO(_) => None,
197    }
198}
199
200fn collect_term_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
201    match pool.get(expr) {
202        ExprData::Mul(args) => {
203            let mut nums = Vec::new();
204            let mut dens = Vec::new();
205            for &a in &args {
206                let (n, d) = collect_atom_factors(a, pool)?;
207                nums.extend(n);
208                dens.extend(d);
209            }
210            Some((nums, dens))
211        }
212        _ => collect_atom_factors(expr, pool),
213    }
214}
215
216fn product_sorted(pool: &ExprPool, factors: Vec<ExprId>) -> ExprId {
217    match factors.len() {
218        0 => pool.integer(1_i32),
219        1 => factors[0],
220        _ => pool.mul(factors),
221    }
222}
223
224fn unipoly_valuation(p: &UniPoly) -> Option<u32> {
225    for (i, c) in p.coefficients().into_iter().enumerate() {
226        if c != 0 {
227            return Some(i as u32);
228        }
229    }
230    None
231}
232
233fn unipoly_strip_low(p: &UniPoly, k: u32) -> UniPoly {
234    let coeffs: Vec<rug::Integer> = p.coefficients().into_iter().skip(k as usize).collect();
235    UniPoly {
236        var: p.var,
237        coeffs: FlintPoly::from_rug_coefficients(&coeffs),
238    }
239}
240
241fn taylor_coefficients(
242    mut cur: ExprId,
243    xi: ExprId,
244    num: u32,
245    pool: &ExprPool,
246) -> Result<Vec<ExprId>, SeriesError> {
247    let mut mapping = HashMap::new();
248    mapping.insert(xi, pool.integer(0_i32));
249    let mut out = Vec::with_capacity(num as usize);
250    for k in 0..num {
251        let ev = subs(cur, &mapping, pool);
252        let simp = simplify(ev, pool).value;
253        let fc = factorial_u32(k);
254        let inv_fact = pool.rational(rug::Integer::from(1), fc);
255        let coeff = simplify(pool.mul(vec![simp, inv_fact]), pool).value;
256        out.push(coeff);
257        if k + 1 < num {
258            cur = diff(cur, xi, pool)?.value;
259        }
260    }
261    Ok(out)
262}
263
264fn assemble_series(
265    coeffs: &[ExprId],
266    valuation: i32,
267    h_expr: ExprId,
268    order: u32,
269    pool: &ExprPool,
270) -> Series {
271    let mut terms = Vec::new();
272    for (k, coeff) in coeffs.iter().enumerate() {
273        if is_structural_zero(*coeff, pool) {
274            continue;
275        }
276        let exp = valuation + k as i32;
277        let pow_term = if exp == 0 {
278            pool.integer(1_i32)
279        } else if exp == 1 {
280            h_expr
281        } else {
282            pool.pow(h_expr, pool.integer(exp as i64))
283        };
284        terms.push(pool.mul(vec![*coeff, pow_term]));
285    }
286    let big_o_pow = laurent_big_o_pow(valuation, order);
287    let o_term = pool.big_o(pool.pow(h_expr, pool.integer(big_o_pow)));
288    terms.push(o_term);
289    Series(pool.add(terms))
290}
291
292fn expansion_matched_laurent(
293    shifted: ExprId,
294    xi: ExprId,
295    h_expr: ExprId,
296    order: u32,
297    pool: &ExprPool,
298) -> Result<LocalExpansion, SeriesError> {
299    let (nums, dens) = match collect_term_factors(shifted, pool) {
300        Some(p) => p,
301        None => {
302            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
303            return Ok(LocalExpansion {
304                valuation: 0,
305                coeffs,
306                h_expr,
307            });
308        }
309    };
310
311    let n_expr = product_sorted(pool, nums);
312    let d_expr = product_sorted(pool, dens);
313
314    let rf = match RationalFunction::from_symbolic(n_expr, d_expr, vec![xi], pool) {
315        Ok(r) => r,
316        Err(_) => {
317            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
318            return Ok(LocalExpansion {
319                valuation: 0,
320                coeffs,
321                h_expr,
322            });
323        }
324    };
325
326    if rf.numer.is_zero() {
327        return Ok(LocalExpansion {
328            valuation: 0,
329            coeffs: vec![pool.integer(0_i32)],
330            h_expr,
331        });
332    }
333
334    let n_uni = match UniPoly::from_symbolic(rf.numer.to_expr(pool), xi, pool) {
335        Ok(u) => u,
336        Err(_) => {
337            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
338            return Ok(LocalExpansion {
339                valuation: 0,
340                coeffs,
341                h_expr,
342            });
343        }
344    };
345    let d_uni = match UniPoly::from_symbolic(rf.denom.to_expr(pool), xi, pool) {
346        Ok(u) => u,
347        Err(_) => {
348            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
349            return Ok(LocalExpansion {
350                valuation: 0,
351                coeffs,
352                h_expr,
353            });
354        }
355    };
356
357    let vn = match unipoly_valuation(&n_uni) {
358        Some(v) => v,
359        None => {
360            return Ok(LocalExpansion {
361                valuation: 0,
362                coeffs: vec![pool.integer(0_i32)],
363                h_expr,
364            });
365        }
366    };
367    let vd = match unipoly_valuation(&d_uni) {
368        Some(v) => v,
369        None => {
370            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
371            return Ok(LocalExpansion {
372                valuation: 0,
373                coeffs,
374                h_expr,
375            });
376        }
377    };
378
379    let valuation = vn as i32 - vd as i32;
380    let n0 = unipoly_strip_low(&n_uni, vn);
381    let d0 = unipoly_strip_low(&d_uni, vd);
382
383    let d0c = d0.coefficients();
384    if d0c.is_empty() || d0c[0] == 0 {
385        let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
386        return Ok(LocalExpansion {
387            valuation: 0,
388            coeffs,
389            h_expr,
390        });
391    }
392
393    let n0_e = n0.to_symbolic_expr(pool);
394    let d0_e = d0.to_symbolic_expr(pool);
395    let inv_d = pool.pow(d0_e, pool.integer(-1_i32));
396    let g = simplify(pool.mul(vec![n0_e, inv_d]), pool).value;
397
398    let num_taylor: u32 = if valuation < 0 {
399        order
400    } else {
401        (order as i32 - valuation).max(0) as u32
402    };
403
404    if num_taylor == 0 {
405        return Ok(LocalExpansion {
406            valuation,
407            coeffs: Vec::new(),
408            h_expr,
409        });
410    }
411
412    let coeffs = taylor_coefficients(g, xi, num_taylor, pool)?;
413    Ok(LocalExpansion {
414        valuation,
415        coeffs,
416        h_expr,
417    })
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::kernel::{Domain, ExprData};
424
425    fn contains_big_o(id: ExprId, pool: &ExprPool) -> bool {
426        match pool.get(id) {
427            ExprData::BigO(_) => true,
428            ExprData::Add(xs) | ExprData::Mul(xs) => xs.iter().any(|e| contains_big_o(*e, pool)),
429            ExprData::Pow { base, exp } => contains_big_o(base, pool) || contains_big_o(exp, pool),
430            ExprData::Func { args, .. } => args.iter().any(|e| contains_big_o(*e, pool)),
431            _ => false,
432        }
433    }
434
435    #[test]
436    fn series_cos_about_zero_has_big_o() {
437        let p = ExprPool::new();
438        let x = p.symbol("x", Domain::Real);
439        let z = p.integer(0);
440        let cx = p.func("cos", vec![x]);
441        let s = series(cx, x, z, 6, &p).unwrap();
442        assert!(contains_big_o(s.expr(), &p));
443    }
444
445    #[test]
446    fn series_inv_x_laurent_has_big_o() {
447        let p = ExprPool::new();
448        let x = p.symbol("x", Domain::Real);
449        let z = p.integer(0);
450        let ix = p.pow(x, p.integer(-1));
451        let s = series(ix, x, z, 4, &p).unwrap();
452        assert!(contains_big_o(s.expr(), &p));
453    }
454}