use super::*;
use crate::jit::eval_interp;
use crate::kernel::{Domain, ExprPool};
use std::collections::HashMap;
fn setup() -> (ExprPool, ExprId, ExprId) {
let pool = ExprPool::new();
let n = pool.symbol("n", Domain::Real);
let z = pool.symbol("z", Domain::Real);
(pool, n, z)
}
fn eval_at(expr: ExprId, var: ExprId, val: f64, pool: &ExprPool) -> Option<f64> {
let mut env = HashMap::new();
env.insert(var, val);
eval_interp(expr, &env, pool)
}
fn assert_numeric_eq(a: ExprId, b: ExprId, var: ExprId, samples: &[f64], pool: &ExprPool) {
for &x in samples {
let va = eval_at(a, var, x, pool);
let vb = eval_at(b, var, x, pool);
match (va, vb) {
(Some(va), Some(vb)) => assert!(
(va - vb).abs() < 1e-6 * (1.0 + va.abs() + vb.abs()),
"mismatch at {x}: {} = {va} vs {} = {vb}",
pool.display(a),
pool.display(b),
),
_ => panic!(
"could not numerically evaluate at {x}: {} / {}",
pool.display(a),
pool.display(b)
),
}
}
}
fn assert_numeric_eq_with(
a: ExprId,
b: ExprId,
var: ExprId,
samples: &[f64],
extra: ExprId,
extra_val: f64,
pool: &ExprPool,
) {
for &x in samples {
let mut env = HashMap::new();
env.insert(var, x);
env.insert(extra, extra_val);
let va = eval_interp(a, &env, pool);
let vb = eval_interp(b, &env, pool);
match (va, vb) {
(Some(va), Some(vb)) => assert!(
(va - vb).abs() < 1e-6 * (1.0 + va.abs() + vb.abs()),
"mismatch at {x}: {} = {va} vs {} = {vb}",
pool.display(a),
pool.display(b),
),
_ => panic!(
"could not numerically evaluate at {x}: {} / {}",
pool.display(a),
pool.display(b)
),
}
}
}
#[test]
fn forward_constant() {
let (pool, n, z) = setup();
let f = pool.integer(1_i32);
let got = z_transform(f, n, z, &pool).unwrap();
let want = pool.mul(vec![
z,
pool.pow(
pool.add(vec![z, pool.integer(-1_i32)]),
pool.integer(-1_i32),
),
]);
assert_numeric_eq(got, want, z, &[2.0, 3.0, 5.0], &pool);
}
#[test]
fn forward_constant_scaled() {
let (pool, n, z) = setup();
let f = pool.integer(5_i32);
let got = z_transform(f, n, z, &pool).unwrap();
let want = pool.mul(vec![
pool.integer(5_i32),
z,
pool.pow(
pool.add(vec![z, pool.integer(-1_i32)]),
pool.integer(-1_i32),
),
]);
assert_numeric_eq(got, want, z, &[2.0, 3.0, 5.0], &pool);
}
#[test]
fn forward_ramp_n() {
let (pool, n, z) = setup();
let got = z_transform(n, n, z, &pool).unwrap();
let want = pool.mul(vec![
z,
pool.pow(
pool.add(vec![z, pool.integer(-1_i32)]),
pool.integer(-2_i32),
),
]);
assert_numeric_eq(got, want, z, &[2.0, 3.0, 5.0], &pool);
}
#[test]
fn forward_n_squared() {
let (pool, n, z) = setup();
let f = pool.pow(n, pool.integer(2_i32));
let got = z_transform(f, n, z, &pool).unwrap();
let want = pool.mul(vec![
z,
pool.add(vec![z, pool.integer(1_i32)]),
pool.pow(
pool.add(vec![z, pool.integer(-1_i32)]),
pool.integer(-3_i32),
),
]);
assert_numeric_eq(got, want, z, &[2.0, 3.0, 5.0], &pool);
}
#[test]
fn forward_geometric() {
let (pool, n, z) = setup();
let a = pool.rational(1_i32, 2_i32);
let f = pool.pow(a, n);
let got = z_transform(f, n, z, &pool).unwrap();
let want = pool.mul(vec![
z,
pool.pow(pool.add(vec![z, neg(a, &pool)]), pool.integer(-1_i32)),
]);
assert_numeric_eq(got, want, z, &[3.0, 4.0, 5.0], &pool);
}
#[test]
fn forward_n_times_geometric() {
let (pool, n, z) = setup();
let a = pool.rational(1_i32, 2_i32);
let f = pool.mul(vec![n, pool.pow(a, n)]);
let got = z_transform(f, n, z, &pool).unwrap();
let want = pool.mul(vec![
a,
z,
pool.pow(pool.add(vec![z, neg(a, &pool)]), pool.integer(-2_i32)),
]);
assert_numeric_eq(got, want, z, &[3.0, 4.0, 5.0], &pool);
}
#[test]
fn forward_sin_cos() {
let (pool, n, z) = setup();
let omega = pool.rational(1_i32, 3_i32);
let sin_on = pool.func("sin", vec![pool.mul(vec![omega, n])]);
let got_sin = z_transform(sin_on, n, z, &pool).unwrap();
let cos_w = pool.func("cos", vec![omega]);
let sin_w = pool.func("sin", vec![omega]);
let z2 = pool.pow(z, pool.integer(2_i32));
let two_z_cos = pool.mul(vec![pool.integer(2_i32), z, cos_w]);
let denom = pool.add(vec![z2, neg(two_z_cos, &pool), pool.integer(1_i32)]);
let want_sin = pool.mul(vec![z, sin_w, pool.pow(denom, pool.integer(-1_i32))]);
assert_numeric_eq(got_sin, want_sin, z, &[3.0, 5.0, 7.0], &pool);
let cos_on = pool.func("cos", vec![pool.mul(vec![omega, n])]);
let got_cos = z_transform(cos_on, n, z, &pool).unwrap();
let z_minus_cos = pool.add(vec![z, neg(cos_w, &pool)]);
let want_cos = pool.mul(vec![z, z_minus_cos, pool.pow(denom, pool.integer(-1_i32))]);
assert_numeric_eq(got_cos, want_cos, z, &[3.0, 5.0, 7.0], &pool);
}
#[test]
fn linearity() {
let (pool, n, z) = setup();
let f = pool.add(vec![
pool.integer(2_i32),
pool.mul(vec![pool.integer(3_i32), n]),
]);
let got = z_transform(f, n, z, &pool).unwrap();
let z1 = z_transform(pool.integer(1_i32), n, z, &pool).unwrap();
let zn = z_transform(n, n, z, &pool).unwrap();
let want = pool.add(vec![
pool.mul(vec![pool.integer(2_i32), z1]),
pool.mul(vec![pool.integer(3_i32), zn]),
]);
assert_numeric_eq(got, want, z, &[3.0, 5.0, 7.0], &pool);
}
#[test]
fn scaling_theorem() {
let (pool, n, z) = setup();
let a = pool.rational(2_i32, 1_i32);
let f = pool.mul(vec![pool.pow(a, n), pool.integer(1_i32)]);
let got = z_transform(f, n, z, &pool).unwrap();
let want = pool.mul(vec![
z,
pool.pow(pool.add(vec![z, neg(a, &pool)]), pool.integer(-1_i32)),
]);
assert_numeric_eq(got, want, z, &[5.0, 7.0], &pool);
}
#[test]
fn shift_delay_theorem() {
let (pool, _n, z) = setup();
let big_x = pool.symbol("X", Domain::Real);
let got = z_shift_delay(big_x, z, 2, &pool);
let want = pool.mul(vec![pool.pow(z, pool.integer(-2_i32)), big_x]);
assert_numeric_eq_with(got, want, z, &[3.0, 5.0], big_x, 11.0, &pool);
}
#[test]
fn shift_advance_theorem() {
let (pool, _n, z) = setup();
let big_x = pool.symbol("X", Domain::Real);
let x0 = pool.integer(7_i32);
let got = z_shift_advance(big_x, z, 1, &[x0], &pool);
let want = pool.add(vec![
pool.mul(vec![z, big_x]),
pool.mul(vec![pool.integer(-1_i32), z, x0]),
]);
assert_numeric_eq_with(got, want, z, &[3.0, 5.0], big_x, 11.0, &pool);
}
#[test]
fn differentiation_theorem() {
let (pool, n, z) = setup();
let got = z_transform(n, n, z, &pool).unwrap();
let z1 = z_transform(pool.integer(1_i32), n, z, &pool).unwrap();
let dz1 = crate::diff::diff(z1, z, &pool).unwrap().value;
let want = simp(pool.mul(vec![pool.integer(-1_i32), z, dz1]), &pool);
assert_numeric_eq(got, want, z, &[3.0, 5.0], &pool);
}
#[test]
fn inverse_geometric() {
let (pool, n, z) = setup();
let a = pool.rational(1_i32, 2_i32);
let big_x = pool.mul(vec![
z,
pool.pow(pool.add(vec![z, neg(a, &pool)]), pool.integer(-1_i32)),
]);
let got = inverse_z_transform(big_x, z, n, &pool).unwrap();
let want = pool.pow(a, n);
assert_numeric_eq(got, want, n, &[0.0, 1.0, 2.0, 5.0, 10.0], &pool);
}
#[test]
fn inverse_constant() {
let (pool, n, z) = setup();
let big_x = pool.mul(vec![
pool.integer(5_i32),
z,
pool.pow(
pool.add(vec![z, pool.integer(-1_i32)]),
pool.integer(-1_i32),
),
]);
let got = inverse_z_transform(big_x, z, n, &pool).unwrap();
let want = pool.integer(5_i32);
assert_numeric_eq(got, want, n, &[0.0, 1.0, 2.0, 5.0], &pool);
}
#[test]
fn inverse_repeated_geometric() {
let (pool, n, z) = setup();
let a = pool.rational(1_i32, 2_i32);
let big_x = pool.mul(vec![
a,
z,
pool.pow(pool.add(vec![z, neg(a, &pool)]), pool.integer(-2_i32)),
]);
let got = inverse_z_transform(big_x, z, n, &pool).unwrap();
let want = pool.mul(vec![n, pool.pow(a, n)]);
assert_numeric_eq(got, want, n, &[0.0, 1.0, 2.0, 5.0, 8.0], &pool);
}
#[test]
fn inverse_round_trip_sum() {
let (pool, n, z) = setup();
let half = pool.rational(1_i32, 2_i32);
let f = pool.add(vec![
pool.integer(2_i32),
pool.mul(vec![pool.integer(3_i32), pool.pow(half, n)]),
]);
let big_x = z_transform(f, n, z, &pool).unwrap();
let got = inverse_z_transform(big_x, z, n, &pool).unwrap();
assert_numeric_eq(got, f, n, &[0.0, 1.0, 2.0, 4.0, 7.0], &pool);
}
#[test]
fn decline_non_table_function() {
let (pool, n, z) = setup();
let f = pool.func("tan", vec![n]);
let err = z_transform(f, n, z, &pool).unwrap_err();
assert!(matches!(err, ZTransformError::NoRule(_)));
}
#[test]
fn decline_same_variable() {
let (pool, n, _z) = setup();
let err = z_transform(pool.integer(1_i32), n, n, &pool).unwrap_err();
assert_eq!(err, ZTransformError::SameVariable);
let err2 = inverse_z_transform(pool.integer(1_i32), n, n, &pool).unwrap_err();
assert_eq!(err2, ZTransformError::SameVariable);
}
#[test]
fn decline_high_order_pole() {
let (pool, n, z) = setup();
let big_x = pool.mul(vec![
z,
pool.pow(
pool.add(vec![z, pool.integer(-1_i32)]),
pool.integer(-3_i32),
),
]);
let err = inverse_z_transform(big_x, z, n, &pool).unwrap_err();
assert!(matches!(err, ZTransformError::NotInvertible(_)));
}
#[test]
fn decline_real_surd_quadratic_inverse() {
let (pool, n, z) = setup();
let z2 = pool.pow(z, pool.integer(2_i32));
let denom = pool.add(vec![z2, neg(z, &pool), pool.integer(-1_i32)]);
let big_x = pool.mul(vec![z, pool.pow(denom, pool.integer(-1_i32))]);
let err = inverse_z_transform(big_x, z, n, &pool).unwrap_err();
assert!(matches!(err, ZTransformError::NotInvertible(_)));
}
#[test]
fn inverse_complex_pole_unit_circle() {
let (pool, n, z) = setup();
let z2 = pool.pow(z, pool.integer(2_i32));
let denom = pool.add(vec![z2, neg(z, &pool), pool.integer(1_i32)]);
let big_x = pool.mul(vec![z, pool.pow(denom, pool.integer(-1_i32))]);
let x_n = inverse_z_transform(big_x, z, n, &pool).unwrap();
assert!(
!pool.display(x_n).to_string().contains('I'),
"complex-pole inverse must be real (no I): {}",
pool.display(x_n),
);
let round = z_transform(x_n, n, z, &pool).unwrap();
assert_numeric_eq(round, big_x, z, &[2.0, 3.0, 5.0, 7.0], &pool);
}
#[test]
fn inverse_complex_pole_pure_imaginary() {
let (pool, n, z) = setup();
let z2 = pool.pow(z, pool.integer(2_i32));
let denom = pool.add(vec![z2, pool.integer(1_i32)]);
let big_x = pool.mul(vec![z, pool.pow(denom, pool.integer(-1_i32))]);
let x_n = inverse_z_transform(big_x, z, n, &pool).unwrap();
assert!(
!pool.display(x_n).to_string().contains('I'),
"complex-pole inverse must be real (no I): {}",
pool.display(x_n),
);
for (k, want) in [(0.0, 0.0), (1.0, 1.0), (2.0, 0.0), (3.0, -1.0), (4.0, 0.0)] {
let got = eval_at(x_n, n, k, &pool).expect("evaluable");
assert!(
(got - want).abs() < 1e-9,
"x[{k}] = {got}, want {want} for sin(πn/2)",
);
}
let round = z_transform(x_n, n, z, &pool).unwrap();
assert_numeric_eq(round, big_x, z, &[2.0, 3.0, 5.0], &pool);
}
#[test]
fn inverse_complex_pole_damped() {
let (pool, n, z) = setup();
let z2 = pool.pow(z, pool.integer(2_i32));
let denom = pool.add(vec![z2, neg(z, &pool), pool.rational(1_i32, 2_i32)]);
let big_x = pool.mul(vec![z, pool.pow(denom, pool.integer(-1_i32))]);
let x_n = inverse_z_transform(big_x, z, n, &pool).unwrap();
assert!(
!pool.display(x_n).to_string().contains('I'),
"complex-pole inverse must be real (no I): {}",
pool.display(x_n),
);
let round = z_transform(x_n, n, z, &pool).unwrap();
assert_numeric_eq(round, big_x, z, &[2.0, 3.0, 5.0, 8.0], &pool);
}
#[test]
fn fibonacci_via_z_transform_matches_rsolve() {
use crate::sum::rsolve;
use std::collections::BTreeMap;
let (pool, n, z) = setup();
let a0 = pool.integer(0_i32);
let a1 = pool.integer(1_i32);
let big_a = pool.symbol("A", Domain::Real);
let lhs = z_shift_advance(big_a, z, 2, &[a0, a1], &pool);
let rhs = pool.add(vec![z_shift_advance(big_a, z, 1, &[a0], &pool), big_a]);
let z2 = pool.pow(z, pool.integer(2_i32));
let denom = pool.add(vec![z2, neg(z, &pool), pool.integer(-1_i32)]);
let big_a_expr = pool.mul(vec![z, pool.pow(denom, pool.integer(-1_i32))]);
let mut map = std::collections::HashMap::new();
map.insert(big_a, big_a_expr);
let lhs_sub = simp(crate::kernel::subs(lhs, &map, &pool), &pool);
let rhs_sub = simp(crate::kernel::subs(rhs, &map, &pool), &pool);
assert_numeric_eq(lhs_sub, rhs_sub, z, &[3.0, 5.0, 7.0], &pool);
let inv_err = inverse_z_transform(big_a_expr, z, n, &pool).unwrap_err();
assert!(matches!(inv_err, ZTransformError::NotInvertible(_)));
let f = |args: Vec<ExprId>| pool.func("f", args);
let eq = simp(
pool.add(vec![
f(vec![n]),
pool.mul(vec![
f(vec![pool.add(vec![n, pool.integer(-1_i32)])]),
pool.integer(-1_i32),
]),
pool.mul(vec![
f(vec![pool.add(vec![n, pool.integer(-2_i32)])]),
pool.integer(-1_i32),
]),
]),
&pool,
);
let mut init = BTreeMap::new();
init.insert(0, pool.integer(0));
init.insert(1, pool.integer(1));
let rsolve_sol = rsolve(&pool, eq, n, "f", Some(&init)).expect("rsolve");
let expected = [0.0, 1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0];
for (ni, &exp) in expected.iter().enumerate() {
let mut env = HashMap::new();
env.insert(n, ni as f64);
let v = eval_interp(rsolve_sol, &env, &pool).expect("eval rsolve");
assert!((v - exp).abs() < 1e-4, "n={ni}: rsolve={v} expected={exp}");
}
}