1const EPSILON: f64 = 1e-5;
20
21pub struct Adaa1 {
30 f: fn(f64) -> f64,
32 ad1: fn(f64) -> f64,
34 x_prev: f64,
36}
37
38impl Adaa1 {
39 pub fn new(f: fn(f64) -> f64, ad1: fn(f64) -> f64) -> Self {
41 Self {
42 f,
43 ad1,
44 x_prev: 0.0,
45 }
46 }
47
48 #[inline]
50 pub fn process(&mut self, x: f32) -> f32 {
51 let x = x as f64;
52 let x_prev = self.x_prev;
53 self.x_prev = x;
54
55 let diff = x - x_prev;
56 if diff.abs() < EPSILON {
57 (self.f)((x + x_prev) * 0.5) as f32
59 } else {
60 (((self.ad1)(x) - (self.ad1)(x_prev)) / diff) as f32
62 }
63 }
64
65 #[inline]
67 pub fn process_block(&mut self, buffer: &mut [f32]) {
68 for sample in buffer.iter_mut() {
69 *sample = self.process(*sample);
70 }
71 }
72
73 pub fn reset(&mut self) {
75 self.x_prev = 0.0;
76 }
77}
78
79pub struct Adaa2 {
88 f: fn(f64) -> f64,
90 ad1: fn(f64) -> f64,
92 ad2: fn(f64) -> f64,
94 x_prev1: f64,
96 x_prev2: f64,
97 d1_prev: f64,
99}
100
101impl Adaa2 {
102 pub fn new(f: fn(f64) -> f64, ad1: fn(f64) -> f64, ad2: fn(f64) -> f64) -> Self {
103 Self {
104 f,
105 ad1,
106 ad2,
107 x_prev1: 0.0,
108 x_prev2: 0.0,
109 d1_prev: 0.0,
110 }
111 }
112
113 #[inline]
114 fn compute_d1(&self, x: f64, x_prev: f64) -> f64 {
115 let diff = x - x_prev;
116 if diff.abs() < EPSILON {
117 (self.ad1)((x + x_prev) * 0.5)
118 } else {
119 ((self.ad2)(x) - (self.ad2)(x_prev)) / diff
120 }
121 }
122
123 #[inline]
125 pub fn process(&mut self, x: f32) -> f32 {
126 let x = x as f64;
127 let x_prev1 = self.x_prev1;
128
129 let d1 = self.compute_d1(x, x_prev1);
130
131 let diff = (x - self.x_prev2) * 0.5;
132 let result = if diff.abs() < EPSILON {
133 (self.f)((x + x_prev1 + self.x_prev2) / 3.0)
134 } else {
135 (d1 - self.d1_prev) / diff
136 };
137
138 self.x_prev2 = self.x_prev1;
139 self.x_prev1 = x;
140 self.d1_prev = d1;
141
142 result as f32
143 }
144
145 pub fn process_block(&mut self, buffer: &mut [f32]) {
146 for sample in buffer.iter_mut() {
147 *sample = self.process(*sample);
148 }
149 }
150
151 pub fn reset(&mut self) {
152 self.x_prev1 = 0.0;
153 self.x_prev2 = 0.0;
154 self.d1_prev = 0.0;
155 }
156}
157
158fn tanh_f(x: f64) -> f64 {
164 x.tanh()
165}
166fn tanh_ad1(x: f64) -> f64 {
167 let abs_x = x.abs();
170 abs_x + (-2.0 * abs_x).exp().ln_1p() - std::f64::consts::LN_2
171}
172fn tanh_ad2(x: f64) -> f64 {
173 let z = (-2.0 * x).exp();
177 let li2 = dilog_neg(z);
178 0.5 * (x * x + li2) - std::f64::consts::LN_2 * x
179}
180
181fn dilog_neg(z: f64) -> f64 {
183 if z < 1e-15 {
184 return 0.0;
185 }
186 if (1.0 - z).abs() < 1e-12 {
187 return -std::f64::consts::PI * std::f64::consts::PI / 12.0;
188 }
189 if z <= 1.0 {
190 let mut result = 0.0;
193 let mut z_pow = 1.0;
194 for k in 1..=200 {
195 z_pow *= z;
196 let term = z_pow / (k * k) as f64;
197 if k % 2 == 1 {
198 result -= term;
199 } else {
200 result += term;
201 }
202 if term.abs() < 1e-15 {
203 break;
204 }
205 }
206 result
207 } else {
208 let ln_z = z.ln();
210 -dilog_neg(1.0 / z) - std::f64::consts::PI * std::f64::consts::PI / 6.0 - 0.5 * ln_z * ln_z
211 }
212}
213
214fn softclip_f(x: f64) -> f64 {
216 x / (1.0 + x.abs())
217}
218fn softclip_ad1(x: f64) -> f64 {
219 let abs_x = x.abs();
225 abs_x - (1.0 + abs_x).ln()
226}
227fn softclip_ad2(x: f64) -> f64 {
228 let abs_x = x.abs();
232 let one_plus = 1.0 + abs_x;
233 let magnitude = 0.5 * abs_x * abs_x - one_plus * one_plus.ln() + abs_x;
234 if x >= 0.0 { magnitude } else { -magnitude }
235}
236
237fn hardclip_f(x: f64) -> f64 {
239 x.clamp(-1.0, 1.0)
240}
241fn hardclip_ad1(x: f64) -> f64 {
242 if (-1.0..=1.0).contains(&x) {
248 x * x * 0.5
249 } else if x > 1.0 {
250 x - 0.5
251 } else {
252 -x - 0.5
253 }
254}
255fn hardclip_ad2(x: f64) -> f64 {
256 if (-1.0..=1.0).contains(&x) {
261 x * x * x / 6.0
262 } else if x > 1.0 {
263 x * x * 0.5 - 0.5 * x + 1.0 / 6.0
264 } else {
265 let abs_x = x.abs();
267 -(abs_x * abs_x * 0.5 - 0.5 * abs_x + 1.0 / 6.0)
268 }
269}
270
271pub fn adaa1_tanh() -> Adaa1 {
273 Adaa1::new(tanh_f, tanh_ad1)
274}
275
276pub fn adaa1_softclip() -> Adaa1 {
278 Adaa1::new(softclip_f, softclip_ad1)
279}
280
281pub fn adaa1_hardclip() -> Adaa1 {
283 Adaa1::new(hardclip_f, hardclip_ad1)
284}
285
286pub fn adaa2_tanh() -> Adaa2 {
288 Adaa2::new(tanh_f, tanh_ad1, tanh_ad2)
289}
290
291pub fn adaa2_softclip() -> Adaa2 {
293 Adaa2::new(softclip_f, softclip_ad1, softclip_ad2)
294}
295
296pub fn adaa2_hardclip() -> Adaa2 {
298 Adaa2::new(hardclip_f, hardclip_ad1, hardclip_ad2)
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_tanh_ad1_identity() {
307 for &x in &[0.0_f64, 0.5, 1.0, 2.0, -1.0, -3.0] {
309 let expected = x.cosh().ln();
310 let actual = tanh_ad1(x);
311 assert!(
312 (actual - expected).abs() < 1e-10,
313 "tanh_ad1({x}): expected {expected}, got {actual}"
314 );
315 }
316 }
317
318 #[test]
319 fn test_softclip_ad1_identity() {
320 let h = 1e-7;
327 for &x in &[0.1, 0.5, 1.0, 2.0, -0.5, -2.0] {
328 let numerical_derivative = (softclip_ad1(x + h) - softclip_ad1(x - h)) / (2.0 * h);
329 let actual = softclip_f(x);
330 assert!(
331 (numerical_derivative - actual).abs() < 1e-4,
332 "softclip AD1 derivative mismatch at x={x}: d/dx AD1={numerical_derivative}, f(x)={actual}"
333 );
334 }
335 }
336
337 #[test]
338 fn test_adaa1_tanh_basic() {
339 let mut adaa = adaa1_tanh();
340 let mut outputs = Vec::new();
342 for i in 0..100 {
343 let x = (i as f32 - 50.0) / 25.0; outputs.push(adaa.process(x));
345 }
346 for &y in &outputs {
348 assert!(y.abs() <= 1.01, "Output out of bounds: {y}");
349 }
350 }
351
352 #[test]
353 fn test_adaa1_reduces_aliasing() {
354 let sr = 48000.0;
356 let freq = 15000.0; let drive = 5.0; let n = 4096;
359
360 let naive_output: Vec<f32> = (0..n)
362 .map(|i| {
363 let t = i as f64 / sr;
364 (drive * (2.0 * std::f64::consts::PI * freq * t).sin()).tanh() as f32
365 })
366 .collect();
367
368 let mut adaa = adaa1_tanh();
370 let adaa_output: Vec<f32> = (0..n)
371 .map(|i| {
372 let t = i as f64 / sr;
373 let x = (drive * (2.0 * std::f64::consts::PI * freq * t).sin()) as f32;
374 adaa.process(x)
375 })
376 .collect();
377
378 let naive_energy: f32 = naive_output.iter().map(|x| x * x).sum();
382 let adaa_energy: f32 = adaa_output.iter().map(|x| x * x).sum();
383
384 assert!(adaa_energy > 0.0, "ADAA produced silence");
387 assert!(naive_energy > 0.0, "Naive produced silence");
388 }
389
390 #[test]
391 fn test_adaa1_reset() {
392 let mut adaa = adaa1_tanh();
393 adaa.process(1.0);
394 adaa.reset();
395 let out = adaa.process(0.0);
397 assert!(out.abs() < 0.01);
398 }
399
400 #[test]
401 fn test_adaa2_tanh_bounded() {
402 let mut adaa = adaa2_tanh();
403 for i in 0..200 {
404 let x = (i as f32 - 100.0) / 30.0;
405 let y = adaa.process(x);
406 assert!(y.abs() < 2.0, "ADAA2 output unbounded: {y} at x={x}");
407 }
408 }
409
410 #[test]
411 fn test_hardclip_adaa1() {
412 let mut adaa = adaa1_hardclip();
413 let mut outputs = Vec::new();
415 for i in 0..100 {
416 let x = (i as f32 - 50.0) / 25.0;
417 outputs.push(adaa.process(x));
418 }
419 for (i, &y) in outputs.iter().enumerate().skip(5) {
421 assert!(
422 y.abs() <= 1.5,
423 "Hard clip ADAA output too large at i={i}: {y}"
424 );
425 }
426 }
427
428 #[test]
429 fn test_consecutive_identical_samples() {
430 let mut adaa = adaa1_tanh();
431 for _ in 0..100 {
433 let y = adaa.process(0.5);
434 assert!(y.is_finite(), "Non-finite output: {y}");
435 }
436 }
437
438 #[test]
439 fn test_adaa2_fallback_uses_three_sample_centroid() {
440 fn f(x: f64) -> f64 {
441 x
442 }
443 fn ad1(x: f64) -> f64 {
444 0.5 * x * x
445 }
446 fn ad2(x: f64) -> f64 {
447 x * x * x / 6.0
448 }
449
450 let mut adaa = Adaa2::new(f, ad1, ad2);
451 adaa.x_prev1 = 2.0;
452 adaa.x_prev2 = 0.0;
453
454 let y = adaa.process(0.0);
455 assert!(
456 (y - (2.0_f32 / 3.0)).abs() < 1e-6,
457 "ADAA2 fallback must evaluate the three-sample centroid, got {y}"
458 );
459 }
460
461 #[test]
462 fn test_dilog_neg_basic() {
463 assert!((dilog_neg(0.0)).abs() < 1e-15);
465 let expected = -std::f64::consts::PI * std::f64::consts::PI / 12.0;
467 let actual = dilog_neg(1.0);
468 assert!(
469 (actual - expected).abs() < 1e-4,
470 "Li_2(-1): expected {expected}, got {actual}"
471 );
472
473 let near_one = dilog_neg(1.0 - 1e-13);
474 assert!(
475 (near_one - expected).abs() < 1e-12,
476 "Li_2 near -1 should use the stable reference value"
477 );
478 }
479}