Skip to main content

alkahest_cas/calculus/
series.rs

1//! Truncated Taylor / Laurent series with symbolic [`crate::kernel::ExprData::BigO`] remainder (V2-15).
2
3use crate::diff::{diff, DiffError};
4use crate::flint::FlintPoly;
5use crate::kernel::{subs, Domain, ExprData, ExprId, ExprPool};
6use crate::poly::{RationalFunction, UniPoly};
7use crate::simplify::simplify;
8use std::collections::HashMap;
9use std::fmt;
10
11// ---------------------------------------------------------------------------
12// Public types
13// ---------------------------------------------------------------------------
14
15/// Result of [`series`] — truncated expansion plus big-O bound as one [`ExprId`].
16#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub struct Series(pub ExprId);
18
19impl Series {
20    pub fn expr(self) -> ExprId {
21        self.0
22    }
23}
24
25#[derive(Debug)]
26pub enum SeriesError {
27    /// Differentiation failed while forming Taylor coefficients.
28    Diff(DiffError),
29    /// `order` must be positive.
30    InvalidOrder,
31}
32
33impl fmt::Display for SeriesError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            SeriesError::Diff(e) => write!(f, "{e}"),
37            SeriesError::InvalidOrder => write!(f, "series order must be >= 1"),
38        }
39    }
40}
41
42impl std::error::Error for SeriesError {
43    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
44        match self {
45            SeriesError::Diff(e) => Some(e),
46            SeriesError::InvalidOrder => None,
47        }
48    }
49}
50
51impl crate::errors::AlkahestError for SeriesError {
52    fn code(&self) -> &'static str {
53        match self {
54            SeriesError::Diff(_) => "E-SERIES-001",
55            SeriesError::InvalidOrder => "E-SERIES-002",
56        }
57    }
58
59    fn remediation(&self) -> Option<&'static str> {
60        match self {
61            SeriesError::Diff(_) => {
62                Some("ensure all functions are registered primitives with differentiation rules")
63            }
64            SeriesError::InvalidOrder => Some("pass order >= 1 (exclusive truncation degree in x)"),
65        }
66    }
67}
68
69impl From<DiffError> for SeriesError {
70    fn from(e: DiffError) -> Self {
71        SeriesError::Diff(e)
72    }
73}
74
75// ---------------------------------------------------------------------------
76// Entry point
77// ---------------------------------------------------------------------------
78
79/// Truncated Taylor or Laurent expansion of `expr` in `var` about `point`.
80///
81/// Let `h = var - point`. The returned expression has the shape
82/// `⋯ + O(h^k)` where `k = order` for analytic series (`valuation ≥ 0`), and
83/// `k = 1` when a polar term (`valuation < 0`) is present — matching the
84/// Laurent examples in the roadmap (`1/x` about `0` gives `x⁻¹ + O(x)`).
85///
86/// The `order` parameter matches the Taylor convention used in the roadmap:
87/// include powers `h^e` with `valuation ≤ e < order` when `valuation ≥ 0`, and
88/// when `valuation < 0` include the polar tail using `order` Taylor coefficients
89/// of the analytic factor `h^{-valuation} · f`.
90pub fn series(
91    expr: ExprId,
92    var: ExprId,
93    point: ExprId,
94    order: u32,
95    pool: &ExprPool,
96) -> Result<Series, SeriesError> {
97    let LocalExpansion {
98        valuation,
99        coeffs,
100        h_expr,
101    } = local_expansion(expr, var, point, order, pool)?;
102
103    Ok(assemble_series(&coeffs, valuation, h_expr, order, pool))
104}
105
106// ---------------------------------------------------------------------------
107// Internals
108// ---------------------------------------------------------------------------
109
110/// Local Laurent / Taylor data about `point`: `expr = ∑ᵢ coeffᵢ · h^{valuation+i}` up to truncation.
111///
112/// `h` is `var - point`, or bare `var` when `point` is the integer zero (matching [`series`]).
113#[derive(Clone, Debug)]
114pub(crate) struct LocalExpansion {
115    pub valuation: i32,
116    pub coeffs: Vec<ExprId>,
117    pub h_expr: ExprId,
118}
119
120pub(crate) fn local_expansion(
121    expr: ExprId,
122    var: ExprId,
123    point: ExprId,
124    order: u32,
125    pool: &ExprPool,
126) -> Result<LocalExpansion, SeriesError> {
127    if order == 0 {
128        return Err(SeriesError::InvalidOrder);
129    }
130
131    let xi = pool.symbol("__sxp", Domain::Real);
132    let mut map = HashMap::new();
133    map.insert(var, pool.add(vec![point, xi]));
134    let shifted = subs(expr, &map, pool);
135
136    let h_expr = expansion_increment(pool, var, point);
137
138    expansion_matched_laurent(shifted, xi, h_expr, order, pool)
139}
140
141fn factorial_u32(n: u32) -> rug::Integer {
142    let mut r = rug::Integer::from(1);
143    for i in 2..=n {
144        r *= i;
145    }
146    r
147}
148
149fn expansion_increment(pool: &ExprPool, var: ExprId, point: ExprId) -> ExprId {
150    match pool.get(point) {
151        ExprData::Integer(n) if n.0 == 0 => var,
152        _ => pool.add(vec![var, pool.mul(vec![pool.integer(-1_i32), point])]),
153    }
154}
155
156fn laurent_big_o_pow(valuation: i32, order: u32) -> i64 {
157    if valuation < 0 {
158        1
159    } else {
160        order as i64
161    }
162}
163
164fn is_structural_zero(id: ExprId, pool: &ExprPool) -> bool {
165    matches!(pool.get(id), ExprData::Integer(n) if n.0 == 0)
166}
167
168fn collect_atom_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
169    match pool.get(expr) {
170        ExprData::Pow { base, exp } => {
171            let n = pool.with(exp, |d| match d {
172                ExprData::Integer(i) => Some(i.0.clone()),
173                _ => None,
174            })?;
175            if n > 0 {
176                Some((vec![expr], vec![]))
177            } else if n < 0 {
178                let mag = (-n).to_u32()?;
179                let pos_exp = pool.integer(mag as i64);
180                Some((vec![], vec![pool.pow(base, pos_exp)]))
181            } else {
182                Some((vec![pool.integer(1_i32)], vec![]))
183            }
184        }
185        ExprData::Integer(_)
186        | ExprData::Rational(_)
187        | ExprData::Float(_)
188        | ExprData::Symbol { .. }
189        | ExprData::Func { .. } => Some((vec![expr], vec![])),
190        ExprData::Add(_)
191        | ExprData::Mul(_)
192        | ExprData::Piecewise { .. }
193        | ExprData::Predicate { .. }
194        | ExprData::Forall { .. }
195        | ExprData::Exists { .. }
196        | ExprData::RootSum { .. }
197        | ExprData::BigO(_) => None,
198    }
199}
200
201fn collect_term_factors(expr: ExprId, pool: &ExprPool) -> Option<(Vec<ExprId>, Vec<ExprId>)> {
202    match pool.get(expr) {
203        ExprData::Mul(args) => {
204            let mut nums = Vec::new();
205            let mut dens = Vec::new();
206            for &a in &args {
207                let (n, d) = collect_atom_factors(a, pool)?;
208                nums.extend(n);
209                dens.extend(d);
210            }
211            Some((nums, dens))
212        }
213        _ => collect_atom_factors(expr, pool),
214    }
215}
216
217fn product_sorted(pool: &ExprPool, factors: Vec<ExprId>) -> ExprId {
218    match factors.len() {
219        0 => pool.integer(1_i32),
220        1 => factors[0],
221        _ => pool.mul(factors),
222    }
223}
224
225fn unipoly_valuation(p: &UniPoly) -> Option<u32> {
226    for (i, c) in p.coefficients().into_iter().enumerate() {
227        if c != 0 {
228            return Some(i as u32);
229        }
230    }
231    None
232}
233
234fn unipoly_strip_low(p: &UniPoly, k: u32) -> UniPoly {
235    let coeffs: Vec<rug::Integer> = p.coefficients().into_iter().skip(k as usize).collect();
236    UniPoly {
237        var: p.var,
238        coeffs: FlintPoly::from_rug_coefficients(&coeffs),
239    }
240}
241
242fn taylor_coefficients(
243    mut cur: ExprId,
244    xi: ExprId,
245    num: u32,
246    pool: &ExprPool,
247) -> Result<Vec<ExprId>, SeriesError> {
248    let mut mapping = HashMap::new();
249    mapping.insert(xi, pool.integer(0_i32));
250    let mut out = Vec::with_capacity(num as usize);
251    for k in 0..num {
252        let ev = subs(cur, &mapping, pool);
253        let simp = simplify(ev, pool).value;
254        let fc = factorial_u32(k);
255        let inv_fact = pool.rational(rug::Integer::from(1), fc);
256        let coeff = simplify(pool.mul(vec![simp, inv_fact]), pool).value;
257        out.push(coeff);
258        if k + 1 < num {
259            cur = diff(cur, xi, pool)?.value;
260        }
261    }
262    Ok(out)
263}
264
265fn assemble_series(
266    coeffs: &[ExprId],
267    valuation: i32,
268    h_expr: ExprId,
269    order: u32,
270    pool: &ExprPool,
271) -> Series {
272    let mut terms = Vec::new();
273    for (k, coeff) in coeffs.iter().enumerate() {
274        if is_structural_zero(*coeff, pool) {
275            continue;
276        }
277        let exp = valuation + k as i32;
278        let pow_term = if exp == 0 {
279            pool.integer(1_i32)
280        } else if exp == 1 {
281            h_expr
282        } else {
283            pool.pow(h_expr, pool.integer(exp as i64))
284        };
285        terms.push(pool.mul(vec![*coeff, pow_term]));
286    }
287    let big_o_pow = laurent_big_o_pow(valuation, order);
288    let o_term = pool.big_o(pool.pow(h_expr, pool.integer(big_o_pow)));
289    terms.push(o_term);
290    Series(pool.add(terms))
291}
292
293fn expansion_matched_laurent(
294    shifted: ExprId,
295    xi: ExprId,
296    h_expr: ExprId,
297    order: u32,
298    pool: &ExprPool,
299) -> Result<LocalExpansion, SeriesError> {
300    let (nums, dens) = match collect_term_factors(shifted, pool) {
301        Some(p) => p,
302        None => {
303            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
304            return Ok(LocalExpansion {
305                valuation: 0,
306                coeffs,
307                h_expr,
308            });
309        }
310    };
311
312    let n_expr = product_sorted(pool, nums);
313    let d_expr = product_sorted(pool, dens);
314
315    let rf = match RationalFunction::from_symbolic(n_expr, d_expr, vec![xi], pool) {
316        Ok(r) => r,
317        Err(_) => {
318            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
319            return Ok(LocalExpansion {
320                valuation: 0,
321                coeffs,
322                h_expr,
323            });
324        }
325    };
326
327    if rf.numer.is_zero() {
328        return Ok(LocalExpansion {
329            valuation: 0,
330            coeffs: vec![pool.integer(0_i32)],
331            h_expr,
332        });
333    }
334
335    let n_uni = match UniPoly::from_symbolic(rf.numer.to_expr(pool), xi, pool) {
336        Ok(u) => u,
337        Err(_) => {
338            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
339            return Ok(LocalExpansion {
340                valuation: 0,
341                coeffs,
342                h_expr,
343            });
344        }
345    };
346    let d_uni = match UniPoly::from_symbolic(rf.denom.to_expr(pool), xi, pool) {
347        Ok(u) => u,
348        Err(_) => {
349            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
350            return Ok(LocalExpansion {
351                valuation: 0,
352                coeffs,
353                h_expr,
354            });
355        }
356    };
357
358    let vn = match unipoly_valuation(&n_uni) {
359        Some(v) => v,
360        None => {
361            return Ok(LocalExpansion {
362                valuation: 0,
363                coeffs: vec![pool.integer(0_i32)],
364                h_expr,
365            });
366        }
367    };
368    let vd = match unipoly_valuation(&d_uni) {
369        Some(v) => v,
370        None => {
371            let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
372            return Ok(LocalExpansion {
373                valuation: 0,
374                coeffs,
375                h_expr,
376            });
377        }
378    };
379
380    let valuation = vn as i32 - vd as i32;
381    let n0 = unipoly_strip_low(&n_uni, vn);
382    let d0 = unipoly_strip_low(&d_uni, vd);
383
384    let d0c = d0.coefficients();
385    if d0c.is_empty() || d0c[0] == 0 {
386        let coeffs = taylor_coefficients(shifted, xi, order, pool)?;
387        return Ok(LocalExpansion {
388            valuation: 0,
389            coeffs,
390            h_expr,
391        });
392    }
393
394    let n0_e = n0.to_symbolic_expr(pool);
395    let d0_e = d0.to_symbolic_expr(pool);
396    let inv_d = pool.pow(d0_e, pool.integer(-1_i32));
397    let g = simplify(pool.mul(vec![n0_e, inv_d]), pool).value;
398
399    let num_taylor: u32 = if valuation < 0 {
400        order
401    } else {
402        (order as i32 - valuation).max(0) as u32
403    };
404
405    if num_taylor == 0 {
406        return Ok(LocalExpansion {
407            valuation,
408            coeffs: Vec::new(),
409            h_expr,
410        });
411    }
412
413    let coeffs = taylor_coefficients(g, xi, num_taylor, pool)?;
414    Ok(LocalExpansion {
415        valuation,
416        coeffs,
417        h_expr,
418    })
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::kernel::{Domain, ExprData};
425
426    fn contains_big_o(id: ExprId, pool: &ExprPool) -> bool {
427        match pool.get(id) {
428            ExprData::BigO(_) => true,
429            ExprData::Add(xs) | ExprData::Mul(xs) => xs.iter().any(|e| contains_big_o(*e, pool)),
430            ExprData::Pow { base, exp } => contains_big_o(base, pool) || contains_big_o(exp, pool),
431            ExprData::Func { args, .. } => args.iter().any(|e| contains_big_o(*e, pool)),
432            _ => false,
433        }
434    }
435
436    #[test]
437    fn series_cos_about_zero_has_big_o() {
438        let p = ExprPool::new();
439        let x = p.symbol("x", Domain::Real);
440        let z = p.integer(0);
441        let cx = p.func("cos", vec![x]);
442        let s = series(cx, x, z, 6, &p).unwrap();
443        assert!(contains_big_o(s.expr(), &p));
444    }
445
446    #[test]
447    fn series_inv_x_laurent_has_big_o() {
448        let p = ExprPool::new();
449        let x = p.symbol("x", Domain::Real);
450        let z = p.integer(0);
451        let ix = p.pow(x, p.integer(-1));
452        let s = series(ix, x, z, 4, &p).unwrap();
453        assert!(contains_big_o(s.expr(), &p));
454    }
455}