Skip to main content

alkahest_cas/transform/
ztransform.rs

1//! Symbolic (unilateral) Z-transform `Z{a[n]}(z)` and inverse `Z⁻¹{A(z)}(n)`.
2//!
3//! # Forward transform
4//!
5//! [`z_transform`] computes the *unilateral* Z-transform
6//!
7//! ```text
8//!   Z{a[n]}(z) = Σ_{n≥0} a[n] z^{−n}.
9//! ```
10//!
11//! It is a rule/table-based structural recursion over `a[n]` (mirroring
12//! [`crate::transform::laplace`]):
13//!
14//! | `a[n]`                 | `Z{a}(z)`                              | rule              |
15//! |------------------------|----------------------------------------|-------------------|
16//! | `c` (const)            | `c·z/(z−1)`                             | constant          |
17//! | `n`                    | `z/(z−1)²`                              | ramp              |
18//! | `n²`                   | `z(z+1)/(z−1)³`                         | quadratic ramp    |
19//! | `aⁿ`                   | `z/(z−a)`                               | geometric         |
20//! | `n·aⁿ`                 | `a z/(z−a)²`                            | scaled-diff geom. |
21//! | `sin(ω n)`             | `z sin(ω) / (z² − 2 z cos(ω) + 1)`      | sine              |
22//! | `cos(ω n)`             | `z(z − cos(ω)) / (z² − 2 z cos(ω) + 1)` | cosine            |
23//! | `α·a[n] + β·b[n]`      | `α A(z) + β B(z)`                       | linearity         |
24//! | `aⁿ·x[n]`              | `X(z/a)`                                | scaling theorem   |
25//! | `n·x[n]`               | `−z·dX/dz`                              | differentiation   |
26//!
27//! The unilateral shift theorems are exposed separately (they operate on the
28//! *symbol* `X = Z{x}` plus initial values, since `x` itself is an unknown
29//! sequence — exactly as [`crate::transform::laplace::laplace_derivative_rule`]
30//! does for the derivative rule):
31//!
32//! - [`z_shift_delay`]: `x[n−k] ↦ z^{−k} X(z)` (zero initial conditions assumed
33//!   for the "missing" samples `x[−1], …, x[−k]`).
34//! - [`z_shift_advance`]: the *unilateral* advance
35//!   `x[n+1] ↦ z·X(z) − z·x[0]`, needed to translate difference equations
36//!   `a[n+1] = a[n] + a[n−1]` (etc.) into algebraic equations in `Z{a}`.
37//!
38//! # Inverse transform
39//!
40//! [`inverse_z_transform`] inverts a **rational** `X(z)` by writing
41//! `X(z)/z` in partial fractions (via [`crate::poly::apart`]), multiplying each
42//! term back by `z`, and mapping the resulting `z/(z−a)^k` shapes through the
43//! inverse table:
44//!
45//! | term in `X(z)`           | `Z⁻¹` term (`n ≥ 0`)                  |
46//! |---------------------------|----------------------------------------|
47//! | `A·z/(z−a)`               | `A·aⁿ`                                  |
48//! | `A·z/(z−a)²`              | `A·n·aⁿ⁻¹`  (rewritten as `(A/a)·n·aⁿ`) |
49//! | `A·z/(z−1)`               | `A` (constant)                          |
50//! | `(P z² + Q z)/(z² + b z + c)` (`b² − 4c < 0`) | `rⁿ(A cos θn + B sin θn)`, **real** |
51//!
52//! The last row covers an **irreducible quadratic** denominator — a
53//! complex-conjugate pole pair `r e^{±iθ}` with `r = √c`, `θ = acos(−b/2√c)` —
54//! and emits the **real** damped sinusoid (no imaginary unit in the output;
55//! the `i² = −1` collapse happens inside the derivation, not the result).  For
56//! example `X(z) = z/(z² − z + 1)` inverts to `(2/√3)·sin(π n / 3)`, which the
57//! forward table round-trips back to `z/(z² − z + 1)`.
58//!
59//! Higher-order repeated poles `(z−a)^k`, `k ≥ 3`, *repeated* complex poles
60//! (`k ≥ 2`), and quadratic denominators with **non-negative discriminant**
61//! (real, possibly surd, roots — e.g. the Fibonacci denominator `z² − z − 1`,
62//! discriminant `5`) remain declined (outside the table — see
63//! [`ZTransformError`]).  Such surd-root cases factor only over an algebraic
64//! extension and have no rational-coefficient closed form here.
65//!
66//! # Caveats
67//!
68//! Both directions are **formal**: this is the *unilateral* transform with no
69//! region-of-convergence tracked, matching [`crate::transform::laplace`]'s
70//! `noconds=True`-style convention. Unrecognised forms return
71//! [`ZTransformError::NoRule`] (forward) / [`ZTransformError::NotInvertible`]
72//! (inverse) rather than guessing.
73//!
74//! ## Declined table entries
75//!
76//! The planning document additionally lists `binomial(n+k−1, k−1)·aⁿ` (negative
77//! binomial / generalized geometric series) and the Kronecker delta
78//! `δ[n−k] ↦ z^{−k}`. Alkahest has no `binomial(·,·)` or discrete-delta
79//! expression primitive, so both are **out of scope** for the
80//! expression-pattern table here: there is nothing in the kernel's expression
81//! algebra that would match `δ[n−k]` (it is not `DiracDelta`, which is the
82//! *continuous* impulse used by [`crate::transform::laplace`], and a discrete
83//! Kronecker delta is a different object). Adding either would require a new
84//! primitive (out of scope for an additive, non-primitive-registry change).
85
86use crate::kernel::{ExprData, ExprId, ExprPool};
87
88/// Errors from the Z-transform routines.
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub enum ZTransformError {
91    /// No forward rule matched `a[n]` (E-TRANSFORM-101).
92    NoRule(String),
93    /// The inverse-transform input is not a form the table can invert
94    /// (E-TRANSFORM-102).
95    NotInvertible(String),
96    /// The frequency variable `z` and discrete-index variable `n` must be
97    /// distinct symbols (E-TRANSFORM-103).
98    SameVariable,
99}
100
101impl std::fmt::Display for ZTransformError {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            ZTransformError::NoRule(m) => {
105                write!(f, "z_transform: no rule for {m} [E-TRANSFORM-101]")
106            }
107            ZTransformError::NotInvertible(m) => write!(
108                f,
109                "inverse_z_transform: cannot invert {m} [E-TRANSFORM-102]"
110            ),
111            ZTransformError::SameVariable => write!(
112                f,
113                "z_transform: index and frequency variables must differ [E-TRANSFORM-103]"
114            ),
115        }
116    }
117}
118
119impl std::error::Error for ZTransformError {}
120
121// ===========================================================================
122// Small helpers (mirroring transform::laplace)
123// ===========================================================================
124
125fn is_free_of(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
126    crate::integrate::risch::poly_rde::is_free_of_var(expr, var, pool)
127}
128
129/// Extract `(a, b)` from `a·var + b` with `a, b` free of `var`. `None` if not
130/// affine in `var`.
131fn as_affine(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
132    if expr == var {
133        return Some((pool.integer(1_i32), pool.integer(0_i32)));
134    }
135    if is_free_of(expr, var, pool) {
136        return Some((pool.integer(0_i32), expr));
137    }
138    match pool.get(expr) {
139        ExprData::Mul(_) => {
140            let (a, b) = as_affine_term(expr, var, pool)?;
141            if b == pool.integer(0_i32) {
142                Some((a, pool.integer(0_i32)))
143            } else {
144                None
145            }
146        }
147        ExprData::Add(args) => {
148            let mut a_acc: Vec<ExprId> = Vec::new();
149            let mut b_acc: Vec<ExprId> = Vec::new();
150            for arg in args {
151                if is_free_of(arg, var, pool) {
152                    b_acc.push(arg);
153                } else {
154                    let (a, b) = as_affine_term(arg, var, pool)?;
155                    if b != pool.integer(0_i32) {
156                        return None;
157                    }
158                    a_acc.push(a);
159                }
160            }
161            let a = match a_acc.len() {
162                0 => pool.integer(0_i32),
163                1 => a_acc[0],
164                _ => pool.add(a_acc),
165            };
166            let b = match b_acc.len() {
167                0 => pool.integer(0_i32),
168                1 => b_acc[0],
169                _ => pool.add(b_acc),
170            };
171            Some((a, b))
172        }
173        _ => None,
174    }
175}
176
177fn as_affine_term(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
178    if expr == var {
179        return Some((pool.integer(1_i32), pool.integer(0_i32)));
180    }
181    if is_free_of(expr, var, pool) {
182        return Some((pool.integer(0_i32), expr));
183    }
184    if let ExprData::Mul(args) = pool.get(expr) {
185        let pos = args.iter().position(|&a| a == var)?;
186        let others: Vec<ExprId> = args
187            .iter()
188            .enumerate()
189            .filter(|&(i, _)| i != pos)
190            .map(|(_, &a)| a)
191            .collect();
192        if others.iter().all(|&o| is_free_of(o, var, pool)) {
193            let coeff = match others.len() {
194                0 => pool.integer(1_i32),
195                1 => others[0],
196                _ => pool.mul(others),
197            };
198            return Some((coeff, pool.integer(0_i32)));
199        }
200    }
201    None
202}
203
204fn simp(expr: ExprId, pool: &ExprPool) -> ExprId {
205    crate::simplify::simplify(expr, pool).value
206}
207
208fn neg(expr: ExprId, pool: &ExprPool) -> ExprId {
209    pool.mul(vec![pool.integer(-1_i32), expr])
210}
211
212fn recip(expr: ExprId, pool: &ExprPool) -> ExprId {
213    pool.pow(expr, pool.integer(-1_i32))
214}
215
216/// Substitute every occurrence of `from` with `to` in `expr`.
217fn subs_one(expr: ExprId, from: ExprId, to: ExprId, pool: &ExprPool) -> ExprId {
218    let mut map = std::collections::HashMap::new();
219    map.insert(from, to);
220    crate::kernel::subs(expr, &map, pool)
221}
222
223/// Remove the factor at `idx` from `factors`, returning the product of the
224/// rest (or `1` if none remain).
225fn remove_index(factors: &[ExprId], idx: usize, pool: &ExprPool) -> ExprId {
226    let rest: Vec<ExprId> = factors
227        .iter()
228        .enumerate()
229        .filter(|&(i, _)| i != idx)
230        .map(|(_, &f)| f)
231        .collect();
232    match rest.len() {
233        0 => pool.integer(1_i32),
234        1 => rest[0],
235        _ => pool.mul(rest),
236    }
237}
238
239/// Non-negative integer value of an exponent `ExprId`, if it is one.
240fn nonneg_int_exp(exp: ExprId, pool: &ExprPool) -> Option<u64> {
241    if let ExprData::Integer(n) = pool.get(exp) {
242        let n = n.0;
243        if n >= 0 {
244            return n.to_u64();
245        }
246    }
247    None
248}
249
250// ===========================================================================
251// Forward transform
252// ===========================================================================
253
254const MAX_DEPTH: usize = 32;
255
256/// Compute the unilateral Z-transform `Z{a[n]}(z) = Σ_{n≥0} a[n] z^{−n}`.
257///
258/// `n` is the discrete-index variable, `z` the transform variable; both must
259/// be distinct symbols. This is a *formal* transform — see the
260/// [module docs](self) for the rule table, caveats, and declines.
261///
262/// # Errors
263///
264/// - [`ZTransformError::SameVariable`] if `n == z`.
265/// - [`ZTransformError::NoRule`] if no table rule matches `a[n]`.
266///
267/// # Examples
268///
269/// ```
270/// use alkahest_cas::kernel::{Domain, ExprPool};
271/// use alkahest_cas::simplify::simplify;
272/// use alkahest_cas::transform::z_transform;
273///
274/// let pool = ExprPool::new();
275/// let n = pool.symbol("n", Domain::Real);
276/// let z = pool.symbol("z", Domain::Real);
277/// // Z{1}(z) = z/(z-1)
278/// let one = pool.integer(1_i32);
279/// let big_x = z_transform(one, n, z, &pool).unwrap();
280/// let expected = pool.mul(vec![
281///     z,
282///     pool.pow(pool.add(vec![z, pool.integer(-1_i32)]), pool.integer(-1_i32)),
283/// ]);
284/// assert_eq!(
285///     pool.display(big_x).to_string(),
286///     pool.display(simplify(expected, &pool).value).to_string()
287/// );
288/// ```
289pub fn z_transform(
290    a: ExprId,
291    n: ExprId,
292    z: ExprId,
293    pool: &ExprPool,
294) -> Result<ExprId, ZTransformError> {
295    if n == z {
296        return Err(ZTransformError::SameVariable);
297    }
298    let out = z_inner(a, n, z, pool, 0)?;
299    Ok(simp(out, pool))
300}
301
302fn z_inner(
303    a: ExprId,
304    n: ExprId,
305    z: ExprId,
306    pool: &ExprPool,
307    depth: usize,
308) -> Result<ExprId, ZTransformError> {
309    if depth > MAX_DEPTH {
310        return Err(ZTransformError::NoRule("recursion depth exceeded".into()));
311    }
312
313    // Constant (free of n): Z{c} = c·z/(z−1).
314    if is_free_of(a, n, pool) {
315        return Ok(pool.mul(vec![a, geometric_transform(pool.integer(1_i32), z, pool)]));
316    }
317
318    // Bare n: Z{n} = z/(z−1)².
319    if a == n {
320        return Ok(ramp_transform(z, pool));
321    }
322
323    match pool.get(a) {
324        // Linearity over sums.
325        ExprData::Add(args) => {
326            let mut terms = Vec::with_capacity(args.len());
327            for arg in args {
328                terms.push(z_inner(arg, n, z, pool, depth + 1)?);
329            }
330            Ok(pool.add(terms))
331        }
332
333        // Products: split off the n-free scalar (linearity), then dispatch the
334        // remaining n-dependent factor through the structural product rules.
335        ExprData::Mul(args) => z_mul(&args, n, z, pool, depth),
336
337        // n^2 (other integer powers of n are not in the table).
338        ExprData::Pow { base, exp } if base == n => {
339            if nonneg_int_exp(exp, pool) == Some(2) {
340                Ok(quadratic_ramp_transform(z, pool))
341            } else {
342                Err(ZTransformError::NoRule(format!(
343                    "n^e (only n and n^2 are tabulated): {}",
344                    pool.display(a)
345                )))
346            }
347        }
348
349        // a^n.
350        ExprData::Pow { base, exp } if exp == n => {
351            if is_free_of(base, n, pool) {
352                Ok(geometric_transform(base, z, pool))
353            } else {
354                Err(ZTransformError::NoRule(format!(
355                    "base^n with base depending on n: {}",
356                    pool.display(a)
357                )))
358            }
359        }
360
361        ExprData::Func { name, args } if args.len() == 1 => z_func(&name, args[0], n, z, pool),
362
363        _ => Err(ZTransformError::NoRule(pool.display(a).to_string())),
364    }
365}
366
367/// `Z{a^n}(z) = z/(z − a)`.
368fn geometric_transform(base: ExprId, z: ExprId, pool: &ExprPool) -> ExprId {
369    let denom = pool.add(vec![z, neg(base, pool)]);
370    pool.mul(vec![z, recip(denom, pool)])
371}
372
373/// `Z{n}(z) = z/(z − 1)²`.
374fn ramp_transform(z: ExprId, pool: &ExprPool) -> ExprId {
375    let denom = pool.pow(pool.add(vec![z, pool.integer(-1_i32)]), pool.integer(2_i32));
376    pool.mul(vec![z, recip(denom, pool)])
377}
378
379/// `Z{n²}(z) = z(z + 1) / (z − 1)³`.
380fn quadratic_ramp_transform(z: ExprId, pool: &ExprPool) -> ExprId {
381    let numer = pool.mul(vec![z, pool.add(vec![z, pool.integer(1_i32)])]);
382    let denom = pool.pow(pool.add(vec![z, pool.integer(-1_i32)]), pool.integer(3_i32));
383    pool.mul(vec![numer, recip(denom, pool)])
384}
385
386/// Z-transform of a product `∏ args`.
387fn z_mul(
388    args: &[ExprId],
389    n: ExprId,
390    z: ExprId,
391    pool: &ExprPool,
392    depth: usize,
393) -> Result<ExprId, ZTransformError> {
394    // Pull out the constant (n-free) scalar prefactor.
395    let (consts, rest): (Vec<ExprId>, Vec<ExprId>) =
396        args.iter().partition(|&&a| is_free_of(a, n, pool));
397    let scalar = match consts.len() {
398        0 => None,
399        1 => Some(consts[0]),
400        _ => Some(pool.mul(consts.clone())),
401    };
402
403    let inner = match rest.len() {
404        0 => {
405            let c = scalar.unwrap_or_else(|| pool.integer(1_i32));
406            return Ok(pool.mul(vec![c, geometric_transform(pool.integer(1_i32), z, pool)]));
407        }
408        1 => rest[0],
409        _ => pool.mul(rest.clone()),
410    };
411
412    let transformed = z_product_body(inner, n, z, pool, depth)?;
413    Ok(match scalar {
414        Some(c) => pool.mul(vec![c, transformed]),
415        None => transformed,
416    })
417}
418
419/// Transform an n-dependent product with no constant scalar factor, applying
420/// the structural product theorems (scaling `aⁿ·x[n]`, differentiation
421/// `n·x[n]`).
422fn z_product_body(
423    body: ExprId,
424    n: ExprId,
425    z: ExprId,
426    pool: &ExprPool,
427    depth: usize,
428) -> Result<ExprId, ZTransformError> {
429    let factors: Vec<ExprId> = match pool.get(body) {
430        ExprData::Mul(a) => a,
431        _ => vec![body],
432    };
433
434    // (1) aⁿ · x[n]  →  X(z/a)   [scaling theorem]
435    for (i, &fac) in factors.iter().enumerate() {
436        if let Some(a) = match_geometric(fac, n, pool) {
437            let rest = remove_index(&factors, i, pool);
438            let x_transform = z_inner(rest, n, z, pool, depth + 1)?;
439            let z_over_a = simp(pool.mul(vec![z, recip(a, pool)]), pool);
440            return Ok(subs_one(x_transform, z, z_over_a, pool));
441        }
442    }
443
444    // (2) n · x[n]  →  −z · dX/dz   [differentiation theorem]
445    for (i, &fac) in factors.iter().enumerate() {
446        if let Some(k) = match_n_power(fac, n, pool) {
447            let rest = remove_index(&factors, i, pool);
448            let mut x_transform = z_inner(rest, n, z, pool, depth + 1)?;
449            for _ in 0..k {
450                let dxdz = crate::diff::diff(x_transform, z, pool)
451                    .map_err(|_| ZTransformError::NoRule("differentiation theorem failed".into()))?
452                    .value;
453                x_transform = simp(pool.mul(vec![pool.integer(-1_i32), z, dxdz]), pool);
454            }
455            return Ok(x_transform);
456        }
457    }
458
459    // No product theorem applied. If `body` was not actually a product (a lone
460    // factor, e.g. a bare `cos(ω n)` whose n-free scalar was already peeled off
461    // by `z_mul`), fall back to the structural table. This cannot recurse
462    // forever: `z_inner` only re-enters `z_product_body` for a `Mul`, and here
463    // `body` is not a `Mul`.
464    if !matches!(pool.get(body), ExprData::Mul(_)) {
465        return z_inner(body, n, z, pool, depth + 1);
466    }
467
468    Err(ZTransformError::NoRule(pool.display(body).to_string()))
469}
470
471/// If `fac` is `aⁿ` (with `a` free of `n`, `a` not `±1`/trivial), return `a`.
472fn match_geometric(fac: ExprId, n: ExprId, pool: &ExprPool) -> Option<ExprId> {
473    if let ExprData::Pow { base, exp } = pool.get(fac) {
474        if exp == n && is_free_of(base, n, pool) {
475            return Some(base);
476        }
477    }
478    None
479}
480
481/// If `fac` is `n^k` with `k ∈ ℤ₊`, or bare `n`, return `k`.
482fn match_n_power(fac: ExprId, n: ExprId, pool: &ExprPool) -> Option<u64> {
483    if fac == n {
484        return Some(1);
485    }
486    if let ExprData::Pow { base, exp } = pool.get(fac) {
487        if base == n {
488            return nonneg_int_exp(exp, pool).filter(|&k| k >= 1);
489        }
490    }
491    None
492}
493
494/// Single-argument primitive functions: sin/cos.
495fn z_func(
496    name: &str,
497    arg: ExprId,
498    n: ExprId,
499    z: ExprId,
500    pool: &ExprPool,
501) -> Result<ExprId, ZTransformError> {
502    if matches!(name, "sin" | "cos") {
503        let (omega, off) = as_affine(arg, n, pool).ok_or_else(|| {
504            ZTransformError::NoRule(format!(
505                "{name} of non-affine argument: {}",
506                pool.display(arg)
507            ))
508        })?;
509        if off != pool.integer(0_i32) || omega == pool.integer(0_i32) {
510            return Err(ZTransformError::NoRule(format!(
511                "{name}(ω n): argument must be a nonzero multiple of n"
512            )));
513        }
514        let cos_w = pool.func("cos", vec![omega]);
515        let sin_w = pool.func("sin", vec![omega]);
516        let z2 = pool.pow(z, pool.integer(2_i32));
517        let two_z_cos = pool.mul(vec![pool.integer(2_i32), z, cos_w]);
518        // z² − 2z·cos(ω) + 1
519        let denom = pool.add(vec![z2, neg(two_z_cos, pool), pool.integer(1_i32)]);
520        return Ok(match name {
521            // sin(ωn) ↦ z·sin(ω) / (z² − 2z·cos(ω) + 1)
522            "sin" => {
523                let numer = pool.mul(vec![z, sin_w]);
524                pool.mul(vec![numer, recip(denom, pool)])
525            }
526            // cos(ωn) ↦ z(z − cos(ω)) / (z² − 2z·cos(ω) + 1)
527            "cos" => {
528                let z_minus_cos = pool.add(vec![z, neg(cos_w, pool)]);
529                let numer = pool.mul(vec![z, z_minus_cos]);
530                pool.mul(vec![numer, recip(denom, pool)])
531            }
532            _ => unreachable!(),
533        });
534    }
535
536    Err(ZTransformError::NoRule(format!("{name}(...)")))
537}
538
539// ===========================================================================
540// Shift theorems (for the difference-equation workflow)
541// ===========================================================================
542
543/// The Z-transform of the **delay** `x[n − k]` (`k ≥ 1`, zero initial
544/// conditions for `x[−1], …, x[−k]`) in terms of `X = Z{x}(z)`:
545///
546/// ```text
547///   Z{x[n − k]}(z) = z^{−k} X(z).
548/// ```
549///
550/// Because `x` is an *unknown* sequence, this operates on the placeholder
551/// `x_transform = X(z)` rather than a concrete `x[n]`. Mirrors
552/// [`crate::transform::laplace::laplace_derivative_rule`] for the ODE
553/// workflow.
554pub fn z_shift_delay(x_transform: ExprId, z: ExprId, k: u32, pool: &ExprPool) -> ExprId {
555    if k == 0 {
556        return x_transform;
557    }
558    let z_neg_k = pool.pow(z, pool.integer(-(k as i64)));
559    simp(pool.mul(vec![z_neg_k, x_transform]), pool)
560}
561
562/// The Z-transform of the unilateral **advance** `x[n + 1]` in terms of
563/// `X = Z{x}(z)` and the initial value `x[0]`:
564///
565/// ```text
566///   Z{x[n + 1]}(z) = z·X(z) − z·x[0].
567/// ```
568///
569/// More generally, the `order`-th advance `x[n + order]` is obtained by
570/// repeated application of this rule:
571///
572/// ```text
573///   Z{x[n + m]}(z) = z^m X(z) − Σ_{k=0}^{m−1} z^{m−k} x[k].
574/// ```
575///
576/// `initial_values[k]` must be `x[k]` for `k = 0, …, order − 1`. Missing
577/// trailing initial values default to `0`.
578pub fn z_shift_advance(
579    x_transform: ExprId,
580    z: ExprId,
581    order: u32,
582    initial_values: &[ExprId],
583    pool: &ExprPool,
584) -> ExprId {
585    // z^order X(z)
586    let z_m = pool.pow(z, pool.integer(order as i64));
587    let mut terms = vec![pool.mul(vec![z_m, x_transform])];
588    // − Σ_{k=0}^{order−1} z^{order−k} x[k]
589    for k in 0..order {
590        let xk = initial_values
591            .get(k as usize)
592            .copied()
593            .unwrap_or_else(|| pool.integer(0_i32));
594        if xk == pool.integer(0_i32) {
595            continue;
596        }
597        let power = (order - k) as i64;
598        let z_pow = pool.pow(z, pool.integer(power));
599        terms.push(pool.mul(vec![pool.integer(-1_i32), z_pow, xk]));
600    }
601    simp(pool.add(terms), pool)
602}
603
604// ===========================================================================
605// Inverse transform
606// ===========================================================================
607
608/// Compute the inverse Z-transform `Z⁻¹{X(z)}(n)` for a **rational** `X(z)`.
609///
610/// Strategy: write `X(z)/z` in partial fractions (via [`crate::poly::apart`]),
611/// multiply each term back by `z` (giving `z/(z−a)^k`-shaped terms), then map
612/// each through the inverse table. See the [module docs](self) for the table
613/// and caveats.
614///
615/// # Errors
616///
617/// - [`ZTransformError::SameVariable`] if `z == n`.
618/// - [`ZTransformError::NotInvertible`] for non-rational `X`, or a
619///   denominator factor outside the linear-pole table (repeated pole order
620///   `≥ 3`, or irreducible quadratic).
621pub fn inverse_z_transform(
622    big_x: ExprId,
623    z: ExprId,
624    n: ExprId,
625    pool: &ExprPool,
626) -> Result<ExprId, ZTransformError> {
627    if z == n {
628        return Err(ZTransformError::SameVariable);
629    }
630
631    // X(z)/z, partial-fractioned in z.
632    let x_over_z = simp(pool.mul(vec![big_x, recip(z, pool)]), pool);
633    let pf = crate::poly::apart(x_over_z, z, pool)
634        .map_err(|e| ZTransformError::NotInvertible(format!("apart failed: {e}")))?;
635
636    let pf_terms: Vec<ExprId> = match pool.get(pf) {
637        ExprData::Add(args) => args,
638        _ => vec![pf],
639    };
640
641    let mut out = Vec::with_capacity(pf_terms.len());
642    for term in pf_terms {
643        // Multiply this X(z)/z term back by z.
644        let term_z = simp(pool.mul(vec![term, z]), pool);
645        out.push(invert_term(term_z, z, n, pool)?);
646    }
647    Ok(simp(pool.add(out), pool))
648}
649
650/// Invert a single term `A·z^p·(z−a)^{−k}` (after re-multiplying the
651/// `apart(X(z)/z)` term by `z`); the table only covers `p == 1`.
652fn invert_term(
653    term: ExprId,
654    z: ExprId,
655    n: ExprId,
656    pool: &ExprPool,
657) -> Result<ExprId, ZTransformError> {
658    let (numer, base, k) = split_rational_term(term, pool)
659        .ok_or_else(|| ZTransformError::NotInvertible(pool.display(term).to_string()))?;
660
661    // A constant term (k == 0): Z⁻¹{c} would be `c·δ[n]`, which has no
662    // expression-level representation here (see module docs on the
663    // Kronecker delta) — decline rather than fabricate.
664    if k == 0 {
665        return Err(ZTransformError::NotInvertible(format!(
666            "constant term {} (Kronecker delta δ[n] — no discrete-impulse primitive)",
667            pool.display(term)
668        )));
669    }
670
671    // Irreducible quadratic denominator (complex-conjugate poles) → real
672    // damped sinusoid `rⁿ(A cos θn + B sin θn)`.  Handle this before the
673    // linear-pole numerator shape check, since here the numerator is a genuine
674    // degree-≤2 polynomial in z (e.g. `P z² + Q z`), not `A·zᵖ`.
675    if poly_degree(base, z, pool) == Some(2) {
676        return invert_quadratic_pole(numer, base, k, z, n, pool);
677    }
678
679    let (coeff, p) = split_z_power(numer, z, pool).ok_or_else(|| {
680        ZTransformError::NotInvertible(format!(
681            "linear-pole numerator not of the form A·z^p: {}",
682            pool.display(numer)
683        ))
684    })?;
685    if p != 1 {
686        return Err(ZTransformError::NotInvertible(format!(
687            "numerator power of z ({p}) not in the table (expected A·z)"
688        )));
689    }
690
691    match poly_degree(base, z, pool) {
692        Some(1) => invert_linear_pole(coeff, base, k, z, n, pool),
693        Some(d) => Err(ZTransformError::NotInvertible(format!(
694            "denominator factor of degree {d} (only linear poles are tabulated): {}",
695            pool.display(base)
696        ))),
697        None => Err(ZTransformError::NotInvertible(
698            pool.display(base).to_string(),
699        )),
700    }
701}
702
703/// Split `numer = coeff · z^p` with `coeff` free of `z` and `p ≥ 0` an
704/// integer. Returns `None` if `numer` is not of this shape.
705fn split_z_power(numer: ExprId, z: ExprId, pool: &ExprPool) -> Option<(ExprId, u64)> {
706    if numer == z {
707        return Some((pool.integer(1_i32), 1));
708    }
709    if is_free_of(numer, z, pool) {
710        return Some((numer, 0));
711    }
712    match pool.get(numer) {
713        ExprData::Pow { base, exp } if base == z => {
714            nonneg_int_exp(exp, pool).map(|p| (pool.integer(1_i32), p))
715        }
716        ExprData::Mul(args) => {
717            let mut coeff_parts: Vec<ExprId> = Vec::new();
718            let mut p = 0u64;
719            for a in args {
720                if a == z {
721                    p += 1;
722                    continue;
723                }
724                if let ExprData::Pow { base, exp } = pool.get(a) {
725                    if base == z {
726                        p += nonneg_int_exp(exp, pool)?;
727                        continue;
728                    }
729                }
730                if !is_free_of(a, z, pool) {
731                    return None;
732                }
733                coeff_parts.push(a);
734            }
735            let coeff = match coeff_parts.len() {
736                0 => pool.integer(1_i32),
737                1 => coeff_parts[0],
738                _ => pool.mul(coeff_parts),
739            };
740            Some((coeff, p))
741        }
742        _ => None,
743    }
744}
745
746/// Decompose a term into `(numerator, denom_base, k)` with
747/// `term = numerator · denom_base^{−k}` and `k ≥ 0`, `numerator` free of any
748/// negative power of `z`.
749fn split_rational_term(term: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId, u64)> {
750    let factors: Vec<ExprId> = match pool.get(term) {
751        ExprData::Mul(a) => a,
752        _ => vec![term],
753    };
754    let mut numer_parts: Vec<ExprId> = Vec::new();
755    let mut base: Option<ExprId> = None;
756    let mut k: u64 = 0;
757
758    for &fac in &factors {
759        if let ExprData::Pow { base: b, exp } = pool.get(fac) {
760            if let ExprData::Integer(e) = pool.get(exp) {
761                let ev = e.0;
762                if ev < 0 {
763                    if base.is_some() && base != Some(b) {
764                        return None;
765                    }
766                    base = Some(b);
767                    k = (-ev).to_u64()?;
768                    continue;
769                }
770            }
771        }
772        numer_parts.push(fac);
773    }
774
775    let numer = match numer_parts.len() {
776        0 => pool.integer(1_i32),
777        1 => numer_parts[0],
778        _ => pool.mul(numer_parts),
779    };
780    match base {
781        Some(b) => Some((numer, b, k)),
782        None => Some((numer, pool.integer(1_i32), 0)),
783    }
784}
785
786/// Degree of `base` as a polynomial in `z` (handles `z`, `z ± c` forms via
787/// structural inspection). Returns `None` if not obviously polynomial.
788fn poly_degree(base: ExprId, z: ExprId, pool: &ExprPool) -> Option<u64> {
789    if base == z {
790        return Some(1);
791    }
792    match pool.get(base) {
793        ExprData::Add(args) => {
794            let mut deg = 0u64;
795            for a in args {
796                deg = deg.max(monomial_degree(a, z, pool)?);
797            }
798            Some(deg)
799        }
800        ExprData::Pow { .. } | ExprData::Mul(_) => monomial_degree(base, z, pool),
801        _ if is_free_of(base, z, pool) => Some(0),
802        _ => None,
803    }
804}
805
806fn monomial_degree(term: ExprId, z: ExprId, pool: &ExprPool) -> Option<u64> {
807    if term == z {
808        return Some(1);
809    }
810    if is_free_of(term, z, pool) {
811        return Some(0);
812    }
813    match pool.get(term) {
814        ExprData::Pow { base, exp } if base == z => nonneg_int_exp(exp, pool),
815        ExprData::Mul(args) => {
816            let mut deg = 0u64;
817            for a in args {
818                deg += monomial_degree(a, z, pool)?;
819            }
820            Some(deg)
821        }
822        _ => None,
823    }
824}
825
826/// Invert `A·z·(z−a)^{−k}`:
827///
828/// - `k == 1`: `Z⁻¹{A·z/(z−a)} = A·aⁿ` (for `a == 1` this is the constant `A`).
829/// - `k == 2`: `Z⁻¹{A·z/(z−a)²} = (A/a)·n·aⁿ` (for `a ≠ 0`); for `a == 0` the
830///   term is `A·z^{-1}`, which is declined (anti-causal / improper for the
831///   unilateral table).
832fn invert_linear_pole(
833    numer: ExprId,
834    base: ExprId,
835    k: u64,
836    z: ExprId,
837    n: ExprId,
838    pool: &ExprPool,
839) -> Result<ExprId, ZTransformError> {
840    // `numer` is the coefficient `A` (free of `z` by construction of
841    // `split_z_power`). base = z − a (monic). Extract a from (coeff·z + b): a = −b/coeff,
842    // require coeff = 1.
843    let (coeff, b) = as_affine(base, z, pool)
844        .ok_or_else(|| ZTransformError::NotInvertible(pool.display(base).to_string()))?;
845    if coeff != pool.integer(1_i32) {
846        return Err(ZTransformError::NotInvertible(
847            "non-monic linear denominator".into(),
848        ));
849    }
850    let a = simp(neg(b, pool), pool); // a = −b
851
852    match k {
853        1 => {
854            if a == pool.integer(0_i32) {
855                // A·z/z = A·z^0 → constant A·δ-like term at n=0 only; decline
856                // (see module docs on Kronecker delta).
857                return Err(ZTransformError::NotInvertible(
858                    "A·z/z term reduces to a Kronecker delta (no discrete-impulse primitive)"
859                        .into(),
860                ));
861            }
862            if a == pool.integer(1_i32) {
863                // A·z/(z−1) → A (constant sequence)
864                return Ok(numer);
865            }
866            // A·aⁿ
867            let a_pow_n = pool.pow(a, n);
868            Ok(pool.mul(vec![numer, a_pow_n]))
869        }
870        2 => {
871            if a == pool.integer(0_i32) {
872                return Err(ZTransformError::NotInvertible(
873                    "A·z/z² term has no causal-sequence inverse in the table".into(),
874                ));
875            }
876            // (A/a)·n·aⁿ
877            let coeff = simp(pool.mul(vec![numer, recip(a, pool)]), pool);
878            let a_pow_n = pool.pow(a, n);
879            Ok(pool.mul(vec![coeff, n, a_pow_n]))
880        }
881        _ => Err(ZTransformError::NotInvertible(format!(
882            "repeated linear pole of order {k} (only k = 1, 2 are tabulated)"
883        ))),
884    }
885}
886
887/// Invert a term `numer · (z² + b z + c)^{−k}` whose denominator is an
888/// *irreducible* quadratic (complex-conjugate poles `r e^{±iθ}`).  Produces the
889/// **real** damped sinusoid
890///
891/// ```text
892///   Z⁻¹{·}(n) = rⁿ (A cos(θ n) + B sin(θ n)),
893/// ```
894///
895/// with `r = √c`, `θ = acos(−b / 2√c)`, and no imaginary unit in the output.
896///
897/// The denominator must be monic with `k == 1` (a single quadratic factor); the
898/// discriminant `b² − 4c` must be a literal *negative* rational (genuine
899/// complex pair).  Repeated complex poles (`k ≥ 2`), non-monic denominators,
900/// real-surd roots (non-negative discriminant), or a non-literal discriminant
901/// are declined.
902fn invert_quadratic_pole(
903    numer: ExprId,
904    base: ExprId,
905    k: u64,
906    z: ExprId,
907    n: ExprId,
908    pool: &ExprPool,
909) -> Result<ExprId, ZTransformError> {
910    if k != 1 {
911        return Err(ZTransformError::NotInvertible(format!(
912            "repeated complex-conjugate pole of order {k} (only k = 1 is tabulated)"
913        )));
914    }
915
916    // base = z² + b z + c (monic, coefficients free of z).
917    let (b, c) = monic_quadratic_coeffs(base, z, pool).ok_or_else(|| {
918        ZTransformError::NotInvertible(format!(
919            "non-monic / non-quadratic denominator: {}",
920            pool.display(base)
921        ))
922    })?;
923
924    // Discriminant must be a literal negative rational (true complex pair).
925    // A non-negative discriminant means real (possibly surd) roots — declined
926    // (e.g. the Fibonacci denominator z² − z − 1 has discriminant 5 > 0).
927    let b2 = pool.pow(b, pool.integer(2_i32));
928    let four_c = pool.mul(vec![pool.integer(4_i32), c]);
929    let disc = simp(pool.add(vec![b2, neg(four_c, pool)]), pool);
930    match literal_rational(disc, pool) {
931        Some(d) if d < 0 => {}
932        Some(_) => {
933            return Err(ZTransformError::NotInvertible(format!(
934                "real-root quadratic denominator (discriminant ≥ 0): {}",
935                pool.display(base)
936            )));
937        }
938        None => {
939            return Err(ZTransformError::NotInvertible(format!(
940                "quadratic denominator with non-literal discriminant: {}",
941                pool.display(base)
942            )));
943        }
944    }
945
946    // numer = P z² + Q z (no constant term: every apart(X/z) term is re-scaled
947    // by z, so the lowest power is z¹).  Reject anything outside that shape.
948    let (p_coeff, q_coeff) = quadratic_numer_pq(numer, z, pool).ok_or_else(|| {
949        ZTransformError::NotInvertible(format!(
950            "complex-pole numerator not of the form P·z² + Q·z: {}",
951            pool.display(numer)
952        ))
953    })?;
954
955    // r = √c, cosθ = −b / (2r), sinθ = √(1 − cos²θ), θ = acos(cosθ).
956    let half = pool.rational(1_i32, 2_i32);
957    let r = simp(pool.pow(c, half), pool);
958    let two_r = pool.mul(vec![pool.integer(2_i32), r]);
959    let cos_theta = simp(pool.mul(vec![neg(b, pool), recip(two_r, pool)]), pool);
960    // sinθ = (1 − cos²θ)^{1/2}  (θ ∈ (0, π) so sinθ > 0).
961    let cos2 = pool.pow(cos_theta, pool.integer(2_i32));
962    let sin_theta = simp(
963        pool.pow(pool.add(vec![pool.integer(1_i32), neg(cos2, pool)]), half),
964        pool,
965    );
966    let theta = pool.func("acos", vec![cos_theta]);
967    let theta_n = simp(pool.mul(vec![theta, n]), pool);
968
969    // Match P z² + Q z = A·z(z − r cosθ) + B·z·r sinθ:
970    //   A = P,  B = (Q + A r cosθ) / (r sinθ).
971    let a_amp = p_coeff;
972    let r_cos = pool.mul(vec![r, cos_theta]);
973    let r_sin = pool.mul(vec![r, sin_theta]);
974    let b_amp = simp(
975        pool.mul(vec![
976            pool.add(vec![q_coeff, pool.mul(vec![a_amp, r_cos])]),
977            recip(r_sin, pool),
978        ]),
979        pool,
980    );
981
982    // rⁿ (A cos(θn) + B sin(θn)).
983    let r_pow_n = pool.pow(r, n);
984    let cos_term = pool.mul(vec![a_amp, pool.func("cos", vec![theta_n])]);
985    let sin_term = pool.mul(vec![b_amp, pool.func("sin", vec![theta_n])]);
986    let combo = pool.add(vec![cos_term, sin_term]);
987    Ok(simp(pool.mul(vec![r_pow_n, combo]), pool))
988}
989
990/// Extract `(b, c)` from a monic quadratic `z² + b z + c` (coefficients free of
991/// `z`).  Returns `None` for a non-monic leading coefficient or a missing/extra
992/// degree.
993fn monic_quadratic_coeffs(base: ExprId, z: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
994    let z2 = pool.pow(z, pool.integer(2_i32));
995    let terms: Vec<ExprId> = match pool.get(base) {
996        ExprData::Add(a) => a,
997        _ => vec![base],
998    };
999    let mut a2: Option<ExprId> = None; // coeff of z²
1000    let mut b1_parts: Vec<ExprId> = Vec::new(); // coeffs of z
1001    let mut c0_parts: Vec<ExprId> = Vec::new(); // constant
1002    for term in terms {
1003        if is_free_of(term, z, pool) {
1004            c0_parts.push(term);
1005            continue;
1006        }
1007        if let Some(coeff) = monomial_coeff(term, z2, z, pool) {
1008            if a2.is_some() {
1009                return None;
1010            }
1011            a2 = Some(coeff);
1012            continue;
1013        }
1014        if let Some(coeff) = monomial_coeff(term, z, z, pool) {
1015            b1_parts.push(coeff);
1016            continue;
1017        }
1018        return None;
1019    }
1020    // Leading coefficient must be 1 (monic).
1021    if simp(a2?, pool) != pool.integer(1_i32) {
1022        return None;
1023    }
1024    let b = match b1_parts.len() {
1025        0 => pool.integer(0_i32),
1026        1 => b1_parts[0],
1027        _ => pool.add(b1_parts),
1028    };
1029    let c = match c0_parts.len() {
1030        0 => pool.integer(0_i32),
1031        1 => c0_parts[0],
1032        _ => pool.add(c0_parts),
1033    };
1034    Some((b, c))
1035}
1036
1037/// Coefficient of `power` (e.g. `z` or `z²`) in a single (non-Add) `term`,
1038/// requiring every other factor to be free of `var`; `1` for the bare power.
1039fn monomial_coeff(term: ExprId, power: ExprId, var: ExprId, pool: &ExprPool) -> Option<ExprId> {
1040    if term == power {
1041        return Some(pool.integer(1_i32));
1042    }
1043    if let ExprData::Mul(args) = pool.get(term) {
1044        let pos = args.iter().position(|&m| m == power)?;
1045        let others: Vec<ExprId> = args
1046            .iter()
1047            .enumerate()
1048            .filter(|&(i, _)| i != pos)
1049            .map(|(_, &m)| m)
1050            .collect();
1051        if others.iter().all(|&o| is_free_of(o, var, pool)) {
1052            return Some(match others.len() {
1053                0 => pool.integer(1_i32),
1054                1 => others[0],
1055                _ => pool.mul(others),
1056            });
1057        }
1058    }
1059    None
1060}
1061
1062/// Extract `(P, Q)` from a numerator `P·z² + Q·z` (no constant or higher term);
1063/// `None` otherwise.
1064fn quadratic_numer_pq(numer: ExprId, z: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
1065    let z2 = pool.pow(z, pool.integer(2_i32));
1066    let terms: Vec<ExprId> = match pool.get(numer) {
1067        ExprData::Add(a) => a,
1068        _ => vec![numer],
1069    };
1070    let mut p_parts: Vec<ExprId> = Vec::new();
1071    let mut q_parts: Vec<ExprId> = Vec::new();
1072    for term in terms {
1073        if let Some(coeff) = monomial_coeff(term, z2, z, pool) {
1074            p_parts.push(coeff);
1075            continue;
1076        }
1077        if let Some(coeff) = monomial_coeff(term, z, z, pool) {
1078            q_parts.push(coeff);
1079            continue;
1080        }
1081        return None; // constant or higher-degree numerator term
1082    }
1083    let p = match p_parts.len() {
1084        0 => pool.integer(0_i32),
1085        1 => p_parts[0],
1086        _ => pool.add(p_parts),
1087    };
1088    let q = match q_parts.len() {
1089        0 => pool.integer(0_i32),
1090        1 => q_parts[0],
1091        _ => pool.add(q_parts),
1092    };
1093    Some((p, q))
1094}
1095
1096/// If `expr` is a literal rational (integer or ratio), return it.
1097fn literal_rational(expr: ExprId, pool: &ExprPool) -> Option<rug::Rational> {
1098    match pool.get(expr) {
1099        ExprData::Integer(n) => Some(rug::Rational::from(n.0.clone())),
1100        ExprData::Rational(r) => Some(r.0.clone()),
1101        _ => None,
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests;