use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FourierError {
NoRule(String),
SameVariable,
}
impl std::fmt::Display for FourierError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FourierError::NoRule(m) => {
write!(f, "fourier_transform: no rule for {m} [E-TRANSFORM-011]")
}
FourierError::SameVariable => write!(
f,
"fourier_transform: space and frequency variables must differ [E-TRANSFORM-012]"
),
}
}
}
impl std::error::Error for FourierError {}
fn is_free_of(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
crate::integrate::risch::poly_rde::is_free_of_var(expr, var, pool)
}
fn simp(expr: ExprId, pool: &ExprPool) -> ExprId {
crate::simplify::simplify(expr, pool).value
}
fn normalize(expr: ExprId, pool: &ExprPool) -> ExprId {
match pool.get(expr) {
ExprData::Add(args) => {
let mapped: Vec<ExprId> = args.iter().map(|&a| normalize(a, pool)).collect();
pool.add(mapped)
}
ExprData::Mul(args) => {
let one = pool.integer(1_i32);
let mapped: Vec<ExprId> = args
.iter()
.map(|&a| normalize(a, pool))
.filter(|&a| a != one)
.collect();
match mapped.len() {
0 => one,
1 => mapped[0],
_ => pool.mul(mapped),
}
}
ExprData::Func { name, args } => {
let mapped: Vec<ExprId> = args.iter().map(|&a| normalize(a, pool)).collect();
if name == "exp" && mapped.len() == 1 && mapped[0] == pool.integer(0_i32) {
return pool.integer(1_i32);
}
if (name == "diracdelta" || name == "abs") && mapped.len() == 1 {
if let Some(stripped) = strip_neg(mapped[0], pool) {
return pool.func(name, vec![normalize(stripped, pool)]);
}
}
pool.func(name, mapped)
}
ExprData::Pow { base, exp } => {
let base = normalize(base, pool);
if base == pool.integer(1_i32) {
return pool.integer(1_i32);
}
if let ExprData::Integer(e) = pool.get(exp) {
if let Some(ev) = e.0.to_i64() {
if ev % 2 == 0 {
if let Some(stripped) = strip_neg(base, pool) {
return pool.pow(normalize(stripped, pool), exp);
}
}
}
}
pool.pow(base, exp)
}
_ => expr,
}
}
fn strip_neg(expr: ExprId, pool: &ExprPool) -> Option<ExprId> {
if let ExprData::Mul(args) = pool.get(expr) {
let neg_one = pool.integer(-1_i32);
if let Some(pos) = args.iter().position(|&a| a == neg_one) {
let rest: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != pos)
.map(|(_, &a)| a)
.collect();
return Some(match rest.len() {
0 => pool.integer(1_i32),
1 => rest[0],
_ => pool.mul(rest),
});
}
}
None
}
fn neg(expr: ExprId, pool: &ExprPool) -> ExprId {
pool.mul(vec![pool.integer(-1_i32), expr])
}
fn recip(expr: ExprId, pool: &ExprPool) -> ExprId {
pool.pow(expr, pool.integer(-1_i32))
}
fn pi(pool: &ExprPool) -> ExprId {
pool.symbol("pi", Domain::Real)
}
fn imag(pool: &ExprPool) -> ExprId {
pool.imaginary_unit()
}
fn two_pi_i(pool: &ExprPool) -> ExprId {
pool.mul(vec![pool.integer(2_i32), pi(pool), imag(pool)])
}
fn subs_one(expr: ExprId, from: ExprId, to: ExprId, pool: &ExprPool) -> ExprId {
let mut map = std::collections::HashMap::new();
map.insert(from, to);
crate::kernel::subs(expr, &map, pool)
}
fn as_affine(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
if expr == var {
return Some((pool.integer(1_i32), pool.integer(0_i32)));
}
if is_free_of(expr, var, pool) {
return Some((pool.integer(0_i32), expr));
}
match pool.get(expr) {
ExprData::Mul(_) => {
let a = affine_coeff(expr, var, pool)?;
Some((a, pool.integer(0_i32)))
}
ExprData::Add(args) => {
let mut a_acc: Vec<ExprId> = Vec::new();
let mut b_acc: Vec<ExprId> = Vec::new();
for arg in args {
if is_free_of(arg, var, pool) {
b_acc.push(arg);
} else {
a_acc.push(affine_coeff(arg, var, pool)?);
}
}
let a = match a_acc.len() {
0 => pool.integer(0_i32),
1 => a_acc[0],
_ => pool.add(a_acc),
};
let b = match b_acc.len() {
0 => pool.integer(0_i32),
1 => b_acc[0],
_ => pool.add(b_acc),
};
Some((a, b))
}
_ => None,
}
}
fn affine_coeff(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<ExprId> {
if expr == var {
return Some(pool.integer(1_i32));
}
if let ExprData::Mul(args) = pool.get(expr) {
let pos = args.iter().position(|&a| a == var)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != pos)
.map(|(_, &a)| a)
.collect();
if others.iter().all(|&o| is_free_of(o, var, pool)) {
return Some(match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
});
}
}
None
}
fn remove_index(factors: &[ExprId], idx: usize, pool: &ExprPool) -> ExprId {
let rest: Vec<ExprId> = factors
.iter()
.enumerate()
.filter(|&(i, _)| i != idx)
.map(|(_, &f)| f)
.collect();
match rest.len() {
0 => pool.integer(1_i32),
1 => rest[0],
_ => pool.mul(rest),
}
}
fn phase_minus(a: ExprId, xi: ExprId, pool: &ExprPool) -> ExprId {
let arg = neg(pool.mul(vec![two_pi_i(pool), a, xi]), pool);
pool.func("exp", vec![simp(arg, pool)])
}
const MAX_DEPTH: usize = 32;
pub fn fourier_transform(
f: ExprId,
x: ExprId,
xi: ExprId,
pool: &ExprPool,
) -> Result<ExprId, FourierError> {
if x == xi {
return Err(FourierError::SameVariable);
}
let out = fourier_inner(f, x, xi, pool, 0)?;
Ok(normalize(simp(out, pool), pool))
}
pub fn inverse_fourier_transform(
g: ExprId,
xi: ExprId,
x: ExprId,
pool: &ExprPool,
) -> Result<ExprId, FourierError> {
if xi == x {
return Err(FourierError::SameVariable);
}
let forward = fourier_transform(g, xi, x, pool)?;
let neg_x = neg(x, pool);
Ok(normalize(
simp(subs_one(forward, x, neg_x, pool), pool),
pool,
))
}
fn fourier_inner(
f: ExprId,
x: ExprId,
xi: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, FourierError> {
if depth > MAX_DEPTH {
return Err(FourierError::NoRule("recursion depth exceeded".into()));
}
if is_free_of(f, x, pool) {
let delta = pool.func("diracdelta", vec![xi]);
return Ok(pool.mul(vec![f, delta]));
}
if let Some(res) = try_lorentzian(f, x, xi, pool) {
return Ok(res);
}
match pool.get(f) {
ExprData::Add(args) => {
let mut terms = Vec::with_capacity(args.len());
for a in args {
terms.push(fourier_inner(a, x, xi, pool, depth + 1)?);
}
Ok(pool.add(terms))
}
ExprData::Mul(args) => fourier_mul(&args, x, xi, pool, depth),
ExprData::Func { name, args } if args.len() == 1 => {
fourier_func(&name, args[0], f, x, xi, pool, depth)
}
_ => Err(FourierError::NoRule(pool.display(f).to_string())),
}
}
fn fourier_mul(
args: &[ExprId],
x: ExprId,
xi: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, FourierError> {
let (consts, rest): (Vec<ExprId>, Vec<ExprId>) =
args.iter().partition(|&&a| is_free_of(a, x, pool));
let scalar = match consts.len() {
0 => None,
1 => Some(consts[0]),
_ => Some(pool.mul(consts.clone())),
};
let body = match rest.len() {
0 => {
let c = scalar.unwrap_or_else(|| pool.integer(1_i32));
return fourier_inner(c, x, xi, pool, depth + 1);
}
1 => rest[0],
_ => pool.mul(rest.clone()),
};
let transformed = fourier_product_body(body, &rest, x, xi, pool, depth)?;
Ok(match scalar {
Some(c) => pool.mul(vec![c, transformed]),
None => transformed,
})
}
fn fourier_product_body(
body: ExprId,
factors: &[ExprId],
x: ExprId,
xi: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, FourierError> {
if let Some(res) = try_one_sided_exponential(factors, x, xi, pool)? {
return Ok(res);
}
for (i, &fac) in factors.iter().enumerate() {
if let Some(a) = match_modulation(fac, x, pool) {
let rest = remove_index(factors, i, pool);
let g_transform = fourier_inner(rest, x, xi, pool, depth + 1)?;
let xi_minus_a = simp(pool.add(vec![xi, neg(a, pool)]), pool);
return Ok(subs_one(g_transform, xi, xi_minus_a, pool));
}
}
if !matches!(pool.get(body), ExprData::Mul(_)) {
return fourier_inner(body, x, xi, pool, depth + 1);
}
Err(FourierError::NoRule(pool.display(body).to_string()))
}
fn match_modulation(fac: ExprId, x: ExprId, pool: &ExprPool) -> Option<ExprId> {
let ExprData::Func { name, args } = pool.get(fac) else {
return None;
};
if name != "exp" || args.len() != 1 {
return None;
}
let (coeff, off) = as_affine(args[0], x, pool)?;
if off != pool.integer(0_i32) || coeff == pool.integer(0_i32) {
return None;
}
let mut factors: Vec<ExprId> = match pool.get(coeff) {
ExprData::Mul(a) => a,
_ => vec![coeff],
};
let pi_sym = pi(pool);
let i_sym = imag(pool);
let two = pool.integer(2_i32);
for needle in [i_sym, pi_sym, two] {
let pos = factors.iter().position(|&f| f == needle)?;
factors.remove(pos);
}
let a = match factors.len() {
0 => pool.integer(1_i32),
1 => factors[0],
_ => pool.mul(factors),
};
if is_free_of(a, i_sym, pool) {
Some(simp(a, pool))
} else {
None
}
}
fn try_one_sided_exponential(
factors: &[ExprId],
x: ExprId,
xi: ExprId,
pool: &ExprPool,
) -> Result<Option<ExprId>, FourierError> {
let heaviside_idx = factors.iter().position(|&fac| {
if let ExprData::Func { name, args } = pool.get(fac) {
name == "heaviside" && args.len() == 1 && args[0] == x
} else {
false
}
});
let hi = match heaviside_idx {
Some(i) => i,
None => return Ok(None),
};
let rest = remove_index(factors, hi, pool);
if rest == pool.integer(1_i32) {
return Ok(None);
}
if let ExprData::Func { name, args } = pool.get(rest) {
if name == "exp" && args.len() == 1 {
let (coeff, off) = as_affine(args[0], x, pool)
.ok_or_else(|| FourierError::NoRule("one-sided exp: non-affine".into()))?;
if off == pool.integer(0_i32) && coeff != pool.integer(0_i32) {
let a = simp(neg(coeff, pool), pool);
let denom = pool.add(vec![a, pool.mul(vec![two_pi_i(pool), xi])]);
return Ok(Some(recip(denom, pool)));
}
}
}
Err(FourierError::NoRule(
"θ(x)·g(x): g is not a recognised one-sided exponential".into(),
))
}
fn try_lorentzian(f: ExprId, x: ExprId, xi: ExprId, pool: &ExprPool) -> Option<ExprId> {
let factors: Vec<ExprId> = match pool.get(f) {
ExprData::Mul(a) => a,
_ => vec![f],
};
let mut denom: Option<ExprId> = None;
let mut numer_parts: Vec<ExprId> = Vec::new();
for &fac in &factors {
if let ExprData::Pow { base, exp } = pool.get(fac) {
if exp == pool.integer(-1_i32) && !is_free_of(base, x, pool) {
if denom.is_some() {
return None; }
denom = Some(base);
continue;
}
}
numer_parts.push(fac);
}
let denom = denom?;
let numer = match numer_parts.len() {
0 => pool.integer(1_i32),
1 => numer_parts[0],
_ => pool.mul(numer_parts),
};
if !is_free_of(numer, x, pool) {
return None;
}
let (c0, c2) = quadratic_in_x(denom, x, pool)?;
let a = simp(pool.mul(vec![numer, pool.rational(1_i32, 2_i32)]), pool);
let a2 = pool.pow(a, pool.integer(2_i32));
if simp(pool.add(vec![c0, neg(a2, pool)]), pool) != pool.integer(0_i32) {
return None;
}
let four_pi2 = pool.mul(vec![
pool.integer(4_i32),
pool.pow(pi(pool), pool.integer(2_i32)),
]);
if simp(pool.add(vec![c2, neg(four_pi2, pool)]), pool) != pool.integer(0_i32) {
return None;
}
let abs_xi = pool.func("abs", vec![xi]);
let arg = neg(pool.mul(vec![a, abs_xi]), pool);
Some(pool.func("exp", vec![simp(arg, pool)]))
}
fn quadratic_in_x(expr: ExprId, x: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
let x2 = pool.pow(x, pool.integer(2_i32));
let terms: Vec<ExprId> = match pool.get(expr) {
ExprData::Add(a) => a,
_ => vec![expr],
};
let mut c0_parts: Vec<ExprId> = Vec::new();
let mut c2: Option<ExprId> = None;
for term in terms {
if is_free_of(term, x, pool) {
c0_parts.push(term);
continue;
}
if term == x2 {
if c2.is_some() {
return None;
}
c2 = Some(pool.integer(1_i32));
continue;
}
if let ExprData::Mul(margs) = pool.get(term) {
let pos = margs.iter().position(|&m| m == x2)?;
let others: Vec<ExprId> = margs
.iter()
.enumerate()
.filter(|&(i, _)| i != pos)
.map(|(_, &m)| m)
.collect();
if others.iter().all(|&o| is_free_of(o, x, pool)) {
let coeff = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
if c2.is_some() {
return None;
}
c2 = Some(coeff);
continue;
}
}
return None; }
let c0 = match c0_parts.len() {
0 => pool.integer(0_i32),
1 => c0_parts[0],
_ => pool.add(c0_parts),
};
Some((c0, c2?))
}
#[allow(clippy::too_many_arguments)]
fn fourier_func(
name: &str,
arg: ExprId,
f: ExprId,
x: ExprId,
xi: ExprId,
pool: &ExprPool,
depth: usize,
) -> Result<ExprId, FourierError> {
if name == "diracdelta" {
let (coeff, b) = as_affine(arg, x, pool).ok_or_else(|| {
FourierError::NoRule(format!(
"diracdelta of non-affine argument: {}",
pool.display(arg)
))
})?;
if coeff != pool.integer(1_i32) {
return Err(FourierError::NoRule(
"diracdelta(c·x − a): coefficient of x must be 1".into(),
));
}
let a = simp(neg(b, pool), pool); return Ok(phase_minus(a, xi, pool));
}
if name == "exp" {
return fourier_exp(arg, x, xi, pool);
}
let _ = (f, depth);
Err(FourierError::NoRule(format!("{name}(...)")))
}
fn fourier_exp(
arg: ExprId,
x: ExprId,
xi: ExprId,
pool: &ExprPool,
) -> Result<ExprId, FourierError> {
if let Some((a, b, d)) = match_gaussian_quadratic(arg, x, pool) {
let pi_e = pi(pool);
let half = pool.rational(1_i32, 2_i32);
let prefactor = pool.pow(pool.mul(vec![pi_e, recip(a, pool)]), half);
let pi2 = pool.pow(pi_e, pool.integer(2_i32));
let xi2 = pool.pow(xi, pool.integer(2_i32));
let exponent = neg(pool.mul(vec![pi2, xi2, recip(a, pool)]), pool);
let gauss = pool.func("exp", vec![simp(exponent, pool)]);
let mut out_factors = vec![prefactor, gauss];
if b != pool.integer(0_i32) {
out_factors.push(phase_minus(b, xi, pool));
}
if d != pool.integer(0_i32) {
out_factors.push(pool.func("exp", vec![d]));
}
return Ok(simp(pool.mul(out_factors), pool));
}
if let Some(a) = match_abs_neg(arg, x, pool) {
let two_a = pool.mul(vec![pool.integer(2_i32), a]);
let a2 = pool.pow(a, pool.integer(2_i32));
let pi2 = pool.pow(pi(pool), pool.integer(2_i32));
let xi2 = pool.pow(xi, pool.integer(2_i32));
let four_pi2_xi2 = pool.mul(vec![pool.integer(4_i32), pi2, xi2]);
let denom = pool.add(vec![a2, four_pi2_xi2]);
return Ok(pool.mul(vec![two_a, recip(denom, pool)]));
}
if let Some(a) = match_modulation(pool.func("exp", vec![arg]), x, pool) {
let shifted = simp(pool.add(vec![xi, neg(a, pool)]), pool);
return Ok(pool.func("diracdelta", vec![shifted]));
}
Err(FourierError::NoRule(format!(
"exp({}): not a recognised Gaussian / two-sided-exponential / modulation form",
pool.display(arg)
)))
}
fn match_gaussian_quadratic(
arg: ExprId,
x: ExprId,
pool: &ExprPool,
) -> Option<(ExprId, ExprId, ExprId)> {
let (a_coeff, b_coeff, c_coeff) = quadratic_abc(arg, x, pool)?;
let a = simp(neg(a_coeff, pool), pool);
if let Some(r) = literal_rational(a, pool) {
if r <= 0 {
return None;
}
}
let two_a = pool.mul(vec![pool.integer(2_i32), a]);
let b = simp(pool.mul(vec![b_coeff, recip(two_a, pool)]), pool);
let four_a = pool.mul(vec![pool.integer(4_i32), a]);
let b2 = pool.pow(b_coeff, pool.integer(2_i32));
let d = simp(
pool.add(vec![c_coeff, pool.mul(vec![b2, recip(four_a, pool)])]),
pool,
);
Some((a, b, d))
}
fn quadratic_abc(expr: ExprId, x: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId, ExprId)> {
let x2 = pool.pow(x, pool.integer(2_i32));
let terms: Vec<ExprId> = match pool.get(expr) {
ExprData::Add(a) => a,
_ => vec![expr],
};
let mut a_parts: Vec<ExprId> = Vec::new();
let mut b_parts: Vec<ExprId> = Vec::new();
let mut c_parts: Vec<ExprId> = Vec::new();
for term in terms {
if is_free_of(term, x, pool) {
c_parts.push(term);
continue;
}
if let Some(coeff) = coeff_of(term, x2, x, pool) {
a_parts.push(coeff);
continue;
}
if let Some(coeff) = coeff_of(term, x, x, pool) {
b_parts.push(coeff);
continue;
}
if let Some((coeff, p, q)) = squared_affine(term, x, pool) {
let p2 = pool.pow(p, pool.integer(2_i32));
a_parts.push(pool.mul(vec![coeff, p2]));
b_parts.push(pool.mul(vec![coeff, pool.integer(2_i32), p, q]));
let q2 = pool.pow(q, pool.integer(2_i32));
c_parts.push(pool.mul(vec![coeff, q2]));
continue;
}
return None; }
let a = simp(sum_or(&a_parts, 0, pool), pool);
if a == pool.integer(0_i32) {
return None; }
Some((a, sum_or(&b_parts, 0, pool), sum_or(&c_parts, 0, pool)))
}
fn coeff_of(term: ExprId, power: ExprId, var: ExprId, pool: &ExprPool) -> Option<ExprId> {
if term == power {
return Some(pool.integer(1_i32));
}
if let ExprData::Mul(args) = pool.get(term) {
let pos = args.iter().position(|&m| m == power)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != pos)
.map(|(_, &m)| m)
.collect();
if others.iter().all(|&o| is_free_of(o, var, pool)) {
return Some(sum_or(&others, 1, pool));
}
}
None
}
fn sum_or(parts: &[ExprId], identity: i32, pool: &ExprPool) -> ExprId {
match parts.len() {
0 => pool.integer(identity),
1 => parts[0],
_ if identity == 0 => pool.add(parts.to_vec()),
_ => pool.mul(parts.to_vec()),
}
}
fn squared_affine(term: ExprId, x: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId, ExprId)> {
let factors: Vec<ExprId> = match pool.get(term) {
ExprData::Mul(a) => a,
_ => vec![term],
};
let mut sq_idx = None;
let mut pq = None;
for (i, &fac) in factors.iter().enumerate() {
if let ExprData::Pow { base, exp } = pool.get(fac) {
if exp == pool.integer(2_i32) && !is_free_of(base, x, pool) {
let (p, q) = as_affine(base, x, pool)?;
if sq_idx.is_some() {
return None; }
sq_idx = Some(i);
pq = Some((p, q));
}
}
}
let idx = sq_idx?;
let (p, q) = pq?;
let coeff = remove_index(&factors, idx, pool);
if !is_free_of(coeff, x, pool) {
return None;
}
Some((coeff, p, q))
}
fn literal_rational(expr: ExprId, pool: &ExprPool) -> Option<rug::Rational> {
match pool.get(expr) {
ExprData::Integer(n) => Some(rug::Rational::from(n.0.clone())),
ExprData::Rational(r) => Some(r.0.clone()),
_ => None,
}
}
fn match_abs_neg(arg: ExprId, x: ExprId, pool: &ExprPool) -> Option<ExprId> {
let absx = abs_forms(x, pool);
if let ExprData::Mul(args) = pool.get(arg) {
let pos = args.iter().position(|&a| absx.contains(&a))?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != pos)
.map(|(_, &a)| a)
.collect();
if others.iter().all(|&o| is_free_of(o, x, pool)) {
let c = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
return Some(simp(neg(c, pool), pool));
}
}
None
}
fn abs_forms(x: ExprId, pool: &ExprPool) -> Vec<ExprId> {
let abs_fn = pool.func("abs", vec![x]);
let x2 = pool.pow(x, pool.integer(2_i32));
let sqrt_x2 = pool.pow(x2, pool.rational(1_i32, 2_i32));
let sqrt_fn = pool.func("sqrt", vec![x2]);
vec![abs_fn, sqrt_x2, sqrt_fn]
}
pub fn fourier_derivative_rule(
f_transform: ExprId,
xi: ExprId,
order: u32,
pool: &ExprPool,
) -> ExprId {
let factor = pool.pow(
pool.mul(vec![two_pi_i(pool), xi]),
pool.integer(order as i32),
);
normalize(simp(pool.mul(vec![factor, f_transform]), pool), pool)
}
#[cfg(test)]
mod tests;