use rug::Integer;
use crate::kernel::{ExprData, ExprId, ExprPool};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LaplaceError {
NoRule(String),
NotInvertible(String),
SameVariable,
}
impl std::fmt::Display for LaplaceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LaplaceError::NoRule(m) => {
write!(f, "laplace_transform: no rule for {m} [E-TRANSFORM-001]")
}
LaplaceError::NotInvertible(m) => write!(
f,
"inverse_laplace_transform: cannot invert {m} [E-TRANSFORM-002]"
),
LaplaceError::SameVariable => write!(
f,
"laplace_transform: time and frequency variables must differ [E-TRANSFORM-003]"
),
}
}
}
impl std::error::Error for LaplaceError {}
fn is_free_of(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
crate::integrate::risch::poly_rde::is_free_of_var(expr, var, pool)
}
fn factorial(n: u64, pool: &ExprPool) -> ExprId {
let mut acc = Integer::from(1);
for k in 2..=n {
acc *= Integer::from(k);
}
pool.integer(acc)
}
fn as_affine(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
if expr == var {
return Some((pool.integer(1_i32), pool.integer(0_i32)));
}
if is_free_of(expr, var, pool) {
return Some((pool.integer(0_i32), expr));
}
match pool.get(expr) {
ExprData::Mul(_) => {
let (a, b) = as_affine_term(expr, var, pool)?;
if b == pool.integer(0_i32) {
Some((a, pool.integer(0_i32)))
} else {
None
}
}
ExprData::Add(args) => {
let mut a_acc: Vec<ExprId> = Vec::new();
let mut b_acc: Vec<ExprId> = Vec::new();
for arg in args {
if is_free_of(arg, var, pool) {
b_acc.push(arg);
} else {
let (a, b) = as_affine_term(arg, var, pool)?;
if b != pool.integer(0_i32) {
return None;
}
a_acc.push(a);
}
}
let a = match a_acc.len() {
0 => pool.integer(0_i32),
1 => a_acc[0],
_ => pool.add(a_acc),
};
let b = match b_acc.len() {
0 => pool.integer(0_i32),
1 => b_acc[0],
_ => pool.add(b_acc),
};
Some((a, b))
}
_ => None,
}
}
fn as_affine_term(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
if expr == var {
return Some((pool.integer(1_i32), pool.integer(0_i32)));
}
if is_free_of(expr, var, pool) {
return Some((pool.integer(0_i32), expr));
}
if let ExprData::Mul(args) = pool.get(expr) {
let pos = args.iter().position(|&a| a == var)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != pos)
.map(|(_, &a)| a)
.collect();
if others.iter().all(|&o| is_free_of(o, var, pool)) {
let coeff = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
return Some((coeff, pool.integer(0_i32)));
}
}
None
}
fn simp(expr: ExprId, pool: &ExprPool) -> ExprId {
crate::simplify::simplify(expr, pool).value
}
fn neg(expr: ExprId, pool: &ExprPool) -> ExprId {
pool.mul(vec![pool.integer(-1_i32), expr])
}
fn recip(expr: ExprId, pool: &ExprPool) -> ExprId {
pool.pow(expr, pool.integer(-1_i32))
}
pub fn laplace_transform(
f: ExprId,
t: ExprId,
s: ExprId,
pool: &ExprPool,
) -> Result<ExprId, LaplaceError> {
if t == s {
return Err(LaplaceError::SameVariable);
}
let out = laplace_inner(f, t, s, pool, 0)?;
Ok(simp(out, pool))
}
const MAX_DEPTH: usize = 32;
fn laplace_inner(
f: ExprId,
t: ExprId,
s: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, LaplaceError> {
if depth > MAX_DEPTH {
return Err(LaplaceError::NoRule("recursion depth exceeded".into()));
}
if is_free_of(f, t, pool) {
return Ok(pool.mul(vec![f, recip(s, pool)]));
}
if f == t {
return Ok(recip(pool.pow(s, pool.integer(2_i32)), pool));
}
match pool.get(f) {
ExprData::Add(args) => {
let mut terms = Vec::with_capacity(args.len());
for a in args {
terms.push(laplace_inner(a, t, s, pool, depth + 1)?);
}
Ok(pool.add(terms))
}
ExprData::Mul(args) => laplace_mul(&args, t, s, pool, depth),
ExprData::Pow { base, exp } if base == t => {
if let Some(n) = nonneg_int_exp(exp, pool) {
let fact = factorial(n, pool);
let denom = pool.pow(s, pool.integer(Integer::from(n + 1)));
Ok(pool.mul(vec![fact, recip(denom, pool)]))
} else {
Err(LaplaceError::NoRule(format!(
"t^e with non-integer exponent: {}",
pool.display(f)
)))
}
}
ExprData::Func { name, args } if args.len() == 1 => {
laplace_func(&name, args[0], t, s, pool, depth)
}
_ => Err(LaplaceError::NoRule(pool.display(f).to_string())),
}
}
fn nonneg_int_exp(exp: ExprId, pool: &ExprPool) -> Option<u64> {
if let ExprData::Integer(n) = pool.get(exp) {
let n = n.0;
if n >= 0 {
return n.to_u64();
}
}
None
}
fn laplace_mul(
args: &[ExprId],
t: ExprId,
s: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, LaplaceError> {
let (consts, rest): (Vec<ExprId>, Vec<ExprId>) =
args.iter().partition(|&&a| is_free_of(a, t, pool));
let scalar = match consts.len() {
0 => None,
1 => Some(consts[0]),
_ => Some(pool.mul(consts.clone())),
};
let inner = match rest.len() {
0 => {
let c = scalar.unwrap_or_else(|| pool.integer(1_i32));
return Ok(pool.mul(vec![c, recip(s, pool)]));
}
1 => rest[0],
_ => pool.mul(rest.clone()),
};
let transformed = laplace_product_body(inner, t, s, pool, depth)?;
Ok(match scalar {
Some(c) => pool.mul(vec![c, transformed]),
None => transformed,
})
}
fn laplace_product_body(
body: ExprId,
t: ExprId,
s: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, LaplaceError> {
let factors: Vec<ExprId> = match pool.get(body) {
ExprData::Mul(a) => a,
_ => vec![body],
};
for (i, &fac) in factors.iter().enumerate() {
if let Some(a) = match_exp_linear(fac, t, pool) {
let rest = remove_index(&factors, i, pool);
let g_transform = laplace_inner(rest, t, s, pool, depth + 1)?;
let s_minus_a = simp(pool.add(vec![s, neg(a, pool)]), pool);
return Ok(subs_one(g_transform, s, s_minus_a, pool));
}
}
if let Some(res) = try_time_shift(&factors, t, s, pool, depth)? {
return Ok(res);
}
for (i, &fac) in factors.iter().enumerate() {
if let Some(n) = match_t_power(fac, t, pool) {
let rest = remove_index(&factors, i, pool);
let mut g_transform = laplace_inner(rest, t, s, pool, depth + 1)?;
for _ in 0..n {
g_transform = crate::diff::diff(g_transform, s, pool)
.map_err(|_| LaplaceError::NoRule("frequency-diff failed".into()))?
.value;
}
let sign = if n % 2 == 0 {
pool.integer(1_i32)
} else {
pool.integer(-1_i32)
};
return Ok(pool.mul(vec![sign, g_transform]));
}
}
if !matches!(pool.get(body), ExprData::Mul(_)) {
return laplace_inner(body, t, s, pool, depth + 1);
}
Err(LaplaceError::NoRule(pool.display(body).to_string()))
}
fn match_exp_linear(fac: ExprId, t: ExprId, pool: &ExprPool) -> Option<ExprId> {
if let ExprData::Func { name, args } = pool.get(fac) {
if name == "exp" && args.len() == 1 {
let (a, b) = as_affine(args[0], t, pool)?;
if b == pool.integer(0_i32) && a != pool.integer(0_i32) {
return Some(a);
}
}
}
None
}
fn match_t_power(fac: ExprId, t: ExprId, pool: &ExprPool) -> Option<u64> {
if fac == t {
return Some(1);
}
if let ExprData::Pow { base, exp } = pool.get(fac) {
if base == t {
return nonneg_int_exp(exp, pool).filter(|&n| n >= 1);
}
}
None
}
fn try_time_shift(
factors: &[ExprId],
t: ExprId,
s: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<Option<ExprId>, LaplaceError> {
let mut heaviside_idx = None;
let mut shift = None;
for (i, &fac) in factors.iter().enumerate() {
if let ExprData::Func { name, args } = pool.get(fac) {
if name == "heaviside" && args.len() == 1 {
if let Some((coeff, b)) = as_affine(args[0], t, pool) {
if coeff == pool.integer(1_i32) {
heaviside_idx = Some(i);
shift = Some(neg(b, pool)); break;
}
}
}
}
}
let (hi, a) = match (heaviside_idx, shift) {
(Some(hi), Some(a)) => (hi, simp(a, pool)),
_ => return Ok(None),
};
let rest = remove_index(factors, hi, pool);
let exp_neg_as = pool.func("exp", vec![simp(neg(pool.mul(vec![a, s]), pool), pool)]);
if rest == pool.integer(1_i32) {
return Ok(Some(pool.mul(vec![exp_neg_as, recip(s, pool)])));
}
let t_plus_a = simp(pool.add(vec![t, a]), pool);
let g_of_t = subs_one(rest, t, t_plus_a, pool);
let g_transform = laplace_inner(simp(g_of_t, pool), t, s, pool, depth + 1)?;
Ok(Some(pool.mul(vec![exp_neg_as, g_transform])))
}
fn laplace_func(
name: &str,
arg: ExprId,
t: ExprId,
s: ExprId,
pool: &ExprPool,
_depth: usize,
) -> Result<ExprId, LaplaceError> {
if name == "exp" {
let (a, b) = as_affine(arg, t, pool).ok_or_else(|| {
LaplaceError::NoRule(format!("exp of non-affine argument: {}", pool.display(arg)))
})?;
if b != pool.integer(0_i32) {
return Err(LaplaceError::NoRule(
"exp(a t + b): nonzero constant offset".into(),
));
}
let denom = pool.add(vec![s, neg(a, pool)]);
return Ok(recip(denom, pool));
}
let trig = matches!(name, "sin" | "cos" | "sinh" | "cosh");
if trig {
let (b, off) = as_affine(arg, t, pool).ok_or_else(|| {
LaplaceError::NoRule(format!(
"{name} of non-affine argument: {}",
pool.display(arg)
))
})?;
if off != pool.integer(0_i32) || b == pool.integer(0_i32) {
return Err(LaplaceError::NoRule(format!(
"{name}(b t): argument must be a nonzero multiple of t"
)));
}
let b2 = pool.pow(b, pool.integer(2_i32));
let s2 = pool.pow(s, pool.integer(2_i32));
return Ok(match name {
"sin" => {
let denom = pool.add(vec![s2, b2]);
pool.mul(vec![b, recip(denom, pool)])
}
"cos" => {
let denom = pool.add(vec![s2, b2]);
pool.mul(vec![s, recip(denom, pool)])
}
"sinh" => {
let denom = pool.add(vec![s2, neg(b2, pool)]);
pool.mul(vec![b, recip(denom, pool)])
}
"cosh" => {
let denom = pool.add(vec![s2, neg(b2, pool)]);
pool.mul(vec![s, recip(denom, pool)])
}
_ => unreachable!(),
});
}
if name == "heaviside" {
let (coeff, b) = as_affine(arg, t, pool).ok_or_else(|| {
LaplaceError::NoRule(format!(
"heaviside of non-affine argument: {}",
pool.display(arg)
))
})?;
if coeff != pool.integer(1_i32) {
return Err(LaplaceError::NoRule(
"heaviside(c·t − a): coefficient of t must be 1".into(),
));
}
let a = simp(neg(b, pool), pool); let exp_neg_as = pool.func("exp", vec![simp(neg(pool.mul(vec![a, s]), pool), pool)]);
return Ok(pool.mul(vec![exp_neg_as, recip(s, pool)]));
}
if name == "diracdelta" {
let (coeff, b) = as_affine(arg, t, pool).ok_or_else(|| {
LaplaceError::NoRule(format!(
"diracdelta of non-affine argument: {}",
pool.display(arg)
))
})?;
if coeff != pool.integer(1_i32) {
return Err(LaplaceError::NoRule(
"diracdelta(c·t − a): coefficient of t must be 1".into(),
));
}
let a = simp(neg(b, pool), pool);
return Ok(pool.func("exp", vec![simp(neg(pool.mul(vec![a, s]), pool), pool)]));
}
Err(LaplaceError::NoRule(format!("{name}(...)")))
}
fn subs_one(expr: ExprId, from: ExprId, to: ExprId, pool: &ExprPool) -> ExprId {
let mut map = std::collections::HashMap::new();
map.insert(from, to);
crate::kernel::subs(expr, &map, pool)
}
fn remove_index(factors: &[ExprId], idx: usize, pool: &ExprPool) -> ExprId {
let rest: Vec<ExprId> = factors
.iter()
.enumerate()
.filter(|&(i, _)| i != idx)
.map(|(_, &f)| f)
.collect();
match rest.len() {
0 => pool.integer(1_i32),
1 => rest[0],
_ => pool.mul(rest),
}
}
pub fn laplace_derivative_rule(
f_transform: ExprId,
s: ExprId,
order: u32,
initial_values: &[ExprId],
pool: &ExprPool,
) -> ExprId {
let s_n = pool.pow(s, pool.integer(order as i32));
let mut terms = vec![pool.mul(vec![s_n, f_transform])];
for k in 0..order {
let f_k0 = initial_values
.get(k as usize)
.copied()
.unwrap_or_else(|| pool.integer(0_i32));
if f_k0 == pool.integer(0_i32) {
continue;
}
let power = (order - 1 - k) as i32;
let s_pow = pool.pow(s, pool.integer(power));
terms.push(pool.mul(vec![pool.integer(-1_i32), s_pow, f_k0]));
}
simp(pool.add(terms), pool)
}
pub fn inverse_laplace_transform(
big_f: ExprId,
s: ExprId,
t: ExprId,
pool: &ExprPool,
) -> Result<ExprId, LaplaceError> {
if s == t {
return Err(LaplaceError::SameVariable);
}
if let Some((a, g)) = split_delay(big_f, s, pool) {
let g_inv = inverse_laplace_transform(g, s, t, pool)?;
let t_minus_a = simp(pool.add(vec![t, neg(a, pool)]), pool);
let shifted = subs_one(g_inv, t, t_minus_a, pool);
let heaviside = pool.func("heaviside", vec![t_minus_a]);
return Ok(simp(pool.mul(vec![heaviside, shifted]), pool));
}
let pf = crate::poly::apart(big_f, s, pool)
.map_err(|e| LaplaceError::NotInvertible(format!("apart failed: {e}")))?;
let terms: Vec<ExprId> = match pool.get(pf) {
ExprData::Add(args) => args,
_ => vec![pf],
};
let mut out = Vec::with_capacity(terms.len());
for term in terms {
out.push(invert_term(term, s, t, pool)?);
}
Ok(simp(pool.add(out), pool))
}
fn split_delay(big_f: ExprId, s: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
let factors: Vec<ExprId> = match pool.get(big_f) {
ExprData::Mul(a) => a,
_ => vec![big_f],
};
for (i, &fac) in factors.iter().enumerate() {
if let ExprData::Func { name, args } = pool.get(fac) {
if name == "exp" && args.len() == 1 {
if let Some((coeff, b)) = as_affine(args[0], s, pool) {
if b == pool.integer(0_i32) && coeff != pool.integer(0_i32) {
let a = simp(neg(coeff, pool), pool); let g = remove_index(&factors, i, pool);
return Some((a, g));
}
}
}
}
}
None
}
fn invert_term(
term: ExprId,
s: ExprId,
t: ExprId,
pool: &ExprPool,
) -> Result<ExprId, LaplaceError> {
let (numer, base, n) = split_rational_term(term, pool)
.ok_or_else(|| LaplaceError::NotInvertible(pool.display(term).to_string()))?;
if n == 0 {
return Err(LaplaceError::NotInvertible(format!(
"polynomial part {} (δ / derivatives of δ — improper rational)",
pool.display(term)
)));
}
match poly_degree(base, s, pool) {
Some(1) => invert_linear_pole(numer, base, n, s, t, pool),
Some(2) => invert_quadratic(numer, base, n, s, t, pool),
_ => Err(LaplaceError::NotInvertible(format!(
"denominator factor of degree > 2: {}",
pool.display(base)
))),
}
}
fn split_rational_term(term: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId, u64)> {
let factors: Vec<ExprId> = match pool.get(term) {
ExprData::Mul(a) => a,
_ => vec![term],
};
let mut numer_parts: Vec<ExprId> = Vec::new();
let mut base: Option<ExprId> = None;
let mut n: u64 = 0;
for &fac in &factors {
if let ExprData::Pow { base: b, exp } = pool.get(fac) {
if let ExprData::Integer(e) = pool.get(exp) {
let ev = e.0;
if ev < 0 {
if base.is_some() && base != Some(b) {
return None;
}
base = Some(b);
n = (-ev).to_u64()?;
continue;
}
}
}
numer_parts.push(fac);
}
let numer = match numer_parts.len() {
0 => pool.integer(1_i32),
1 => numer_parts[0],
_ => pool.mul(numer_parts),
};
match base {
Some(b) => Some((numer, b, n)),
None => Some((numer, pool.integer(1_i32), 0)),
}
}
fn poly_degree(base: ExprId, s: ExprId, pool: &ExprPool) -> Option<u64> {
if base == s {
return Some(1);
}
match pool.get(base) {
ExprData::Add(args) => {
let mut deg = 0u64;
for a in args {
deg = deg.max(monomial_degree(a, s, pool)?);
}
Some(deg)
}
ExprData::Pow { .. } | ExprData::Mul(_) => monomial_degree(base, s, pool),
_ if is_free_of(base, s, pool) => Some(0),
_ => None,
}
}
fn monomial_degree(term: ExprId, s: ExprId, pool: &ExprPool) -> Option<u64> {
if term == s {
return Some(1);
}
if is_free_of(term, s, pool) {
return Some(0);
}
match pool.get(term) {
ExprData::Pow { base, exp } if base == s => nonneg_int_exp(exp, pool),
ExprData::Mul(args) => {
let mut deg = 0u64;
for a in args {
deg += monomial_degree(a, s, pool)?;
}
Some(deg)
}
_ => None,
}
}
fn invert_linear_pole(
numer: ExprId,
base: ExprId,
n: u64,
s: ExprId,
t: ExprId,
pool: &ExprPool,
) -> Result<ExprId, LaplaceError> {
if !is_free_of(numer, s, pool) {
return Err(LaplaceError::NotInvertible(format!(
"linear-pole numerator depends on s: {}",
pool.display(numer)
)));
}
let (coeff, b) = as_affine(base, s, pool)
.ok_or_else(|| LaplaceError::NotInvertible(pool.display(base).to_string()))?;
if coeff != pool.integer(1_i32) {
return Err(LaplaceError::NotInvertible(
"non-monic linear denominator".into(),
));
}
let a = simp(neg(b, pool), pool);
let exp_at = pool.func("exp", vec![pool.mul(vec![a, t])]);
let mut parts = vec![numer, exp_at];
if n >= 2 {
let t_pow = pool.pow(t, pool.integer(Integer::from(n - 1)));
parts.push(t_pow);
let fact = factorial(n - 1, pool);
parts.push(recip(fact, pool));
}
Ok(pool.mul(parts))
}
fn invert_quadratic(
numer: ExprId,
base: ExprId,
n: u64,
s: ExprId,
t: ExprId,
pool: &ExprPool,
) -> Result<ExprId, LaplaceError> {
if n != 1 {
return Err(LaplaceError::NotInvertible(
"repeated irreducible quadratic pole (n ≥ 2) not in table".into(),
));
}
let (alpha, beta, gamma) = quadratic_coeffs(base, s, pool)
.ok_or_else(|| LaplaceError::NotInvertible(pool.display(base).to_string()))?;
if alpha != pool.integer(1_i32) {
return Err(LaplaceError::NotInvertible(
"non-monic quadratic denominator".into(),
));
}
let half = pool.rational(1_i32, 2_i32);
let p = simp(pool.mul(vec![neg(beta, pool), half]), pool);
let beta2 = pool.pow(beta, pool.integer(2_i32));
let quarter = pool.rational(1_i32, 4_i32);
let omega_sq = simp(
pool.add(vec![gamma, neg(pool.mul(vec![beta2, quarter]), pool)]),
pool,
);
let omega = simp(pool.pow(omega_sq, half), pool);
let (bb, cc) = as_affine(numer, s, pool)
.ok_or_else(|| LaplaceError::NotInvertible(pool.display(numer).to_string()))?;
let exp_pt = pool.func("exp", vec![pool.mul(vec![p, t])]);
let omega_t = pool.mul(vec![omega, t]);
let cos_term = pool.mul(vec![bb, pool.func("cos", vec![omega_t])]);
let bp = pool.mul(vec![bb, p]);
let sin_coeff = pool.mul(vec![pool.add(vec![cc, bp]), recip(omega, pool)]);
let sin_term = pool.mul(vec![sin_coeff, pool.func("sin", vec![omega_t])]);
Ok(pool.mul(vec![exp_pt, pool.add(vec![cos_term, sin_term])]))
}
fn quadratic_coeffs(base: ExprId, s: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId, ExprId)> {
let args: Vec<ExprId> = match pool.get(base) {
ExprData::Add(a) => a,
_ => vec![base],
};
let mut alpha = pool.integer(0_i32);
let mut beta = pool.integer(0_i32);
let mut gamma_parts: Vec<ExprId> = Vec::new();
for term in args {
match monomial_degree(term, s, pool)? {
2 => alpha = monomial_coeff(term, s, 2, pool)?,
1 => beta = monomial_coeff(term, s, 1, pool)?,
0 => gamma_parts.push(term),
_ => return None,
}
}
let gamma = match gamma_parts.len() {
0 => pool.integer(0_i32),
1 => gamma_parts[0],
_ => pool.add(gamma_parts),
};
Some((alpha, beta, gamma))
}
fn monomial_coeff(term: ExprId, s: ExprId, deg: u64, pool: &ExprPool) -> Option<ExprId> {
if deg == 0 {
return Some(term);
}
if deg == 1 && term == s {
return Some(pool.integer(1_i32));
}
if let ExprData::Pow { base, exp } = pool.get(term) {
if base == s && nonneg_int_exp(exp, pool) == Some(deg) {
return Some(pool.integer(1_i32));
}
}
if let ExprData::Mul(args) = pool.get(term) {
let mut coeff_parts: Vec<ExprId> = Vec::new();
let mut found = false;
for a in args {
if a == s && deg == 1 {
found = true;
continue;
}
if let ExprData::Pow { base, exp } = pool.get(a) {
if base == s && nonneg_int_exp(exp, pool) == Some(deg) {
found = true;
continue;
}
}
coeff_parts.push(a);
}
if found {
return Some(match coeff_parts.len() {
0 => pool.integer(1_i32),
1 => coeff_parts[0],
_ => pool.mul(coeff_parts),
});
}
}
None
}
#[cfg(test)]
mod tests;