const EPSILON: f64 = 1e-5;
pub struct Adaa1 {
f: fn(f64) -> f64,
ad1: fn(f64) -> f64,
x_prev: f64,
}
impl Adaa1 {
pub fn new(f: fn(f64) -> f64, ad1: fn(f64) -> f64) -> Self {
Self {
f,
ad1,
x_prev: 0.0,
}
}
#[inline]
pub fn process(&mut self, x: f32) -> f32 {
let x = x as f64;
let x_prev = self.x_prev;
self.x_prev = x;
let diff = x - x_prev;
if diff.abs() < EPSILON {
(self.f)((x + x_prev) * 0.5) as f32
} else {
(((self.ad1)(x) - (self.ad1)(x_prev)) / diff) as f32
}
}
#[inline]
pub fn process_block(&mut self, buffer: &mut [f32]) {
for sample in buffer.iter_mut() {
*sample = self.process(*sample);
}
}
pub fn reset(&mut self) {
self.x_prev = 0.0;
}
}
pub struct Adaa2 {
f: fn(f64) -> f64,
ad1: fn(f64) -> f64,
ad2: fn(f64) -> f64,
x_prev1: f64,
x_prev2: f64,
d1_prev: f64,
}
impl Adaa2 {
pub fn new(f: fn(f64) -> f64, ad1: fn(f64) -> f64, ad2: fn(f64) -> f64) -> Self {
Self {
f,
ad1,
ad2,
x_prev1: 0.0,
x_prev2: 0.0,
d1_prev: 0.0,
}
}
#[inline]
fn compute_d1(&self, x: f64, x_prev: f64) -> f64 {
let diff = x - x_prev;
if diff.abs() < EPSILON {
(self.ad1)((x + x_prev) * 0.5)
} else {
((self.ad2)(x) - (self.ad2)(x_prev)) / diff
}
}
#[inline]
pub fn process(&mut self, x: f32) -> f32 {
let x = x as f64;
let x_prev1 = self.x_prev1;
let d1 = self.compute_d1(x, x_prev1);
let diff = (x - self.x_prev2) * 0.5;
let result = if diff.abs() < EPSILON {
(self.f)((x + self.x_prev2) * 0.25 + x_prev1 * 0.5)
} else {
(d1 - self.d1_prev) / diff
};
self.x_prev2 = self.x_prev1;
self.x_prev1 = x;
self.d1_prev = d1;
result as f32
}
pub fn process_block(&mut self, buffer: &mut [f32]) {
for sample in buffer.iter_mut() {
*sample = self.process(*sample);
}
}
pub fn reset(&mut self) {
self.x_prev1 = 0.0;
self.x_prev2 = 0.0;
self.d1_prev = 0.0;
}
}
fn tanh_f(x: f64) -> f64 {
x.tanh()
}
fn tanh_ad1(x: f64) -> f64 {
let abs_x = x.abs();
abs_x + (-2.0 * abs_x).exp().ln_1p() - std::f64::consts::LN_2
}
fn tanh_ad2(x: f64) -> f64 {
let z = (-2.0 * x).exp();
let li2 = dilog_neg(z);
0.5 * (x * x + li2) - std::f64::consts::LN_2 * x
}
fn dilog_neg(z: f64) -> f64 {
if z < 1e-15 {
return 0.0;
}
if z <= 1.0 {
let mut result = 0.0;
let mut z_pow = 1.0;
for k in 1..=200 {
z_pow *= z;
let term = z_pow / (k * k) as f64;
if k % 2 == 1 {
result -= term;
} else {
result += term;
}
if term.abs() < 1e-15 {
break;
}
}
result
} else {
let ln_z = z.ln();
-dilog_neg(1.0 / z) - std::f64::consts::PI * std::f64::consts::PI / 6.0 - 0.5 * ln_z * ln_z
}
}
fn softclip_f(x: f64) -> f64 {
x / (1.0 + x.abs())
}
fn softclip_ad1(x: f64) -> f64 {
let abs_x = x.abs();
abs_x - (1.0 + abs_x).ln()
}
fn softclip_ad2(x: f64) -> f64 {
let abs_x = x.abs();
let one_plus = 1.0 + abs_x;
let magnitude = 0.5 * abs_x * abs_x - one_plus * one_plus.ln() + abs_x;
if x >= 0.0 { magnitude } else { -magnitude }
}
fn hardclip_f(x: f64) -> f64 {
x.clamp(-1.0, 1.0)
}
fn hardclip_ad1(x: f64) -> f64 {
if (-1.0..=1.0).contains(&x) {
x * x * 0.5
} else if x > 1.0 {
x - 0.5
} else {
-x - 0.5
}
}
fn hardclip_ad2(x: f64) -> f64 {
if (-1.0..=1.0).contains(&x) {
x * x * x / 6.0
} else if x > 1.0 {
x * x * 0.5 - 0.5 * x + 1.0 / 6.0
} else {
let abs_x = x.abs();
-(abs_x * abs_x * 0.5 - 0.5 * abs_x + 1.0 / 6.0)
}
}
pub fn adaa1_tanh() -> Adaa1 {
Adaa1::new(tanh_f, tanh_ad1)
}
pub fn adaa1_softclip() -> Adaa1 {
Adaa1::new(softclip_f, softclip_ad1)
}
pub fn adaa1_hardclip() -> Adaa1 {
Adaa1::new(hardclip_f, hardclip_ad1)
}
pub fn adaa2_tanh() -> Adaa2 {
Adaa2::new(tanh_f, tanh_ad1, tanh_ad2)
}
pub fn adaa2_softclip() -> Adaa2 {
Adaa2::new(softclip_f, softclip_ad1, softclip_ad2)
}
pub fn adaa2_hardclip() -> Adaa2 {
Adaa2::new(hardclip_f, hardclip_ad1, hardclip_ad2)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tanh_ad1_identity() {
for &x in &[0.0_f64, 0.5, 1.0, 2.0, -1.0, -3.0] {
let expected = x.cosh().ln();
let actual = tanh_ad1(x);
assert!(
(actual - expected).abs() < 1e-10,
"tanh_ad1({x}): expected {expected}, got {actual}"
);
}
}
#[test]
fn test_softclip_ad1_identity() {
let h = 1e-7;
for &x in &[0.1, 0.5, 1.0, 2.0, -0.5, -2.0] {
let numerical_derivative = (softclip_ad1(x + h) - softclip_ad1(x - h)) / (2.0 * h);
let actual = softclip_f(x);
assert!(
(numerical_derivative - actual).abs() < 1e-4,
"softclip AD1 derivative mismatch at x={x}: d/dx AD1={numerical_derivative}, f(x)={actual}"
);
}
}
#[test]
fn test_adaa1_tanh_basic() {
let mut adaa = adaa1_tanh();
let mut outputs = Vec::new();
for i in 0..100 {
let x = (i as f32 - 50.0) / 25.0; outputs.push(adaa.process(x));
}
for &y in &outputs {
assert!(y.abs() <= 1.01, "Output out of bounds: {y}");
}
}
#[test]
fn test_adaa1_reduces_aliasing() {
let sr = 48000.0;
let freq = 15000.0; let drive = 5.0; let n = 4096;
let naive_output: Vec<f32> = (0..n)
.map(|i| {
let t = i as f64 / sr;
(drive * (2.0 * std::f64::consts::PI * freq * t).sin()).tanh() as f32
})
.collect();
let mut adaa = adaa1_tanh();
let adaa_output: Vec<f32> = (0..n)
.map(|i| {
let t = i as f64 / sr;
let x = (drive * (2.0 * std::f64::consts::PI * freq * t).sin()) as f32;
adaa.process(x)
})
.collect();
let naive_energy: f32 = naive_output.iter().map(|x| x * x).sum();
let adaa_energy: f32 = adaa_output.iter().map(|x| x * x).sum();
assert!(adaa_energy > 0.0, "ADAA produced silence");
assert!(naive_energy > 0.0, "Naive produced silence");
}
#[test]
fn test_adaa1_reset() {
let mut adaa = adaa1_tanh();
adaa.process(1.0);
adaa.reset();
let out = adaa.process(0.0);
assert!(out.abs() < 0.01);
}
#[test]
fn test_adaa2_tanh_bounded() {
let mut adaa = adaa2_tanh();
for i in 0..200 {
let x = (i as f32 - 100.0) / 30.0;
let y = adaa.process(x);
assert!(y.abs() < 2.0, "ADAA2 output unbounded: {y} at x={x}");
}
}
#[test]
fn test_hardclip_adaa1() {
let mut adaa = adaa1_hardclip();
let mut outputs = Vec::new();
for i in 0..100 {
let x = (i as f32 - 50.0) / 25.0;
outputs.push(adaa.process(x));
}
for (i, &y) in outputs.iter().enumerate().skip(5) {
assert!(
y.abs() <= 1.5,
"Hard clip ADAA output too large at i={i}: {y}"
);
}
}
#[test]
fn test_consecutive_identical_samples() {
let mut adaa = adaa1_tanh();
for _ in 0..100 {
let y = adaa.process(0.5);
assert!(y.is_finite(), "Non-finite output: {y}");
}
}
#[test]
fn test_dilog_neg_basic() {
assert!((dilog_neg(0.0)).abs() < 1e-15);
let expected = -std::f64::consts::PI * std::f64::consts::PI / 12.0;
let actual = dilog_neg(1.0);
assert!(
(actual - expected).abs() < 1e-4,
"Li_2(-1): expected {expected}, got {actual}"
);
}
}