#![allow(dead_code)]
use core::f32::consts::PI;
use cfg_if::cfg_if;
use crate::celt::entcode::ec_ilog;
use crate::celt::float_cast;
use crate::celt::types::OpusInt32;
#[cfg(not(miri))]
use libm::sqrt;
use libm::{cosf, exp, fmaf, log};
#[cfg(test)]
mod fast_atan2_trace {
extern crate std;
use std::env;
use std::sync::OnceLock;
#[derive(Clone, Copy)]
struct TraceConfig {
enabled: bool,
has_target: bool,
target_x: f32,
target_y: f32,
eps: f32,
}
fn env_truthy(key: &str) -> bool {
env::var(key).map_or(false, |value| !value.is_empty() && value != "0")
}
fn config() -> &'static TraceConfig {
static TRACE_CONFIG: OnceLock<TraceConfig> = OnceLock::new();
TRACE_CONFIG.get_or_init(|| {
let enabled = env_truthy("ANALYSIS_TRACE_FAST_ATAN2");
if !enabled {
return TraceConfig {
enabled: false,
has_target: false,
target_x: 0.0,
target_y: 0.0,
eps: 0.0,
};
}
let target_x = env::var("ANALYSIS_TRACE_FAST_ATAN2_X")
.ok()
.and_then(|value| value.parse::<f32>().ok());
let target_y = env::var("ANALYSIS_TRACE_FAST_ATAN2_Y")
.ok()
.and_then(|value| value.parse::<f32>().ok());
let has_target = target_x.is_some() && target_y.is_some();
let eps = env::var("ANALYSIS_TRACE_FAST_ATAN2_EPS")
.ok()
.and_then(|value| value.parse::<f32>().ok())
.unwrap_or(if has_target { 1e-9 } else { 0.0 });
TraceConfig {
enabled,
has_target,
target_x: target_x.unwrap_or(0.0),
target_y: target_y.unwrap_or(0.0),
eps,
}
})
}
pub(crate) fn maybe_dump(
y: f32,
x: f32,
x2: f32,
y2: f32,
branch_x2_lt_y2: bool,
cb: f32,
cc: f32,
t1: f32,
t2: f32,
den: f32,
xy: f32,
num_term: f32,
num: f32,
) {
let cfg = config();
if !cfg.enabled {
return;
}
if cfg.has_target {
let dx = (x - cfg.target_x).abs();
let dy = (y - cfg.target_y).abs();
if dx > cfg.eps || dy > cfg.eps {
return;
}
}
crate::test_trace::trace_println!("analysis_fast_atan2.x2={:.9e}", x2 as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.x2_bits=0x{:08x}", x2.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.y2={:.9e}", y2 as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.y2_bits=0x{:08x}", y2.to_bits());
if branch_x2_lt_y2 {
crate::test_trace::trace_println!("analysis_fast_atan2.branch=x2<y2");
} else {
crate::test_trace::trace_println!("analysis_fast_atan2.branch=x2>=y2");
}
crate::test_trace::trace_println!("analysis_fast_atan2.cb={:.9e}", cb as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.cb_bits=0x{:08x}", cb.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.cc={:.9e}", cc as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.cc_bits=0x{:08x}", cc.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.t1={:.9e}", t1 as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.t1_bits=0x{:08x}", t1.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.t2={:.9e}", t2 as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.t2_bits=0x{:08x}", t2.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.den={:.9e}", den as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.den_bits=0x{:08x}", den.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.xy={:.9e}", xy as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.xy_bits=0x{:08x}", xy.to_bits());
crate::test_trace::trace_println!("analysis_fast_atan2.num_term={:.9e}", num_term as f64);
crate::test_trace::trace_println!(
"analysis_fast_atan2.num_term_bits=0x{:08x}",
num_term.to_bits()
);
crate::test_trace::trace_println!("analysis_fast_atan2.num={:.9e}", num as f64);
crate::test_trace::trace_println!("analysis_fast_atan2.num_bits=0x{:08x}", num.to_bits());
}
}
pub(crate) fn isqrt32(mut value: u32) -> u32 {
if value == 0 {
return 0;
}
let mut root = 0u32;
let mut bit_shift = ((32 - value.leading_zeros()) as i32 - 1) >> 1;
let mut bit = 1u32 << (bit_shift as u32);
while bit_shift >= 0 {
let trial = ((root << 1) + bit) << (bit_shift as u32);
if trial <= value {
root += bit;
value -= trial;
}
bit >>= 1;
bit_shift -= 1;
}
root
}
#[must_use]
pub(crate) fn celt_ilog2(value: OpusInt32) -> OpusInt32 {
assert!(value > 0, "celt_ilog2 expects a strictly positive input");
ec_ilog(value as u32) - 1
}
#[must_use]
pub(crate) fn celt_zlog2(value: OpusInt32) -> OpusInt32 {
if value <= 0 { 0 } else { celt_ilog2(value) }
}
#[allow(clippy::many_single_char_names)]
pub(crate) fn fast_atan2f(y: f32, x: f32) -> f32 {
const CA: f32 = 0.431_579_74_f32;
const CB: f32 = f32::from_bits(0x3f2d_b121);
const CC: f32 = 0.085_955_42_f32;
const CE: f32 = PI / 2.0;
#[inline]
fn mul_add_c_order(a: f32, b: f32, c: f32) -> f32 {
mul_add_f32(a, b, c)
}
let x2 = x * x;
let y2 = y * y;
if x2 + y2 < 1e-18 {
return 0.0;
}
if x2 < y2 {
let t1 = mul_add_c_order(CB, x2, y2);
let t2 = mul_add_c_order(CC, x2, y2);
let den = t1 * t2;
let xy = x * y;
let num_term = mul_add_c_order(CA, x2, y2);
let num = -xy * num_term;
let result = num / den + if y < 0.0 { -CE } else { CE };
#[cfg(test)]
fast_atan2_trace::maybe_dump(y, x, x2, y2, true, CB, CC, t1, t2, den, xy, num_term, num);
result
} else {
let t1 = mul_add_c_order(CB, y2, x2);
let t2 = mul_add_c_order(CC, y2, x2);
let den = t1 * t2;
let xy = x * y;
let num_term = mul_add_c_order(CA, y2, x2);
let num = xy * num_term;
let result =
num / den + if y < 0.0 { -CE } else { CE } - if x * y < 0.0 { -CE } else { CE };
#[cfg(test)]
fast_atan2_trace::maybe_dump(y, x, x2, y2, false, CB, CC, t1, t2, den, xy, num_term, num);
result
}
}
#[inline]
pub(crate) fn mul_add_f32(a: f32, b: f32, c: f32) -> f32 {
fmaf(a, b, c)
}
#[cfg(test)]
mod fast_atan2f_regression {
use super::fast_atan2f;
#[test]
fn fast_atan2f_bin28_parity() {
let x = f32::from_bits(0x3b2f_ef98);
let y = f32::from_bits(0x3b6e_b557);
let atan = fast_atan2f(y, x);
assert_eq!(atan.to_bits(), 0x3f6f_86d4);
}
}
#[inline]
pub(crate) fn celt_log2(x: f32) -> f32 {
(core::f64::consts::LOG2_E * log(x as f64)) as f32
}
#[inline]
pub(crate) fn celt_exp2(x: f32) -> f32 {
exp(core::f64::consts::LN_2 * (x as f64)) as f32
}
#[inline]
pub(crate) fn celt_div(a: f32, b: f32) -> f32 {
a / b
}
#[inline]
pub(crate) fn celt_cos_norm(x: f32) -> f32 {
cosf(0.5 * PI * x)
}
#[inline]
pub(crate) fn celt_sqrt(x: f32) -> f32 {
#[cfg(miri)]
{
return sqrtf_fallback(x);
}
#[cfg(not(miri))]
{
sqrt(x as f64) as f32
}
}
#[cfg(miri)]
#[inline]
fn sqrtf_fallback(x: f32) -> f32 {
if !(x > 0.0) {
if x == 0.0 {
return 0.0;
}
if x.is_sign_negative() {
return f32::NAN;
}
return x;
}
let mut guess = f32::from_bits((x.to_bits() >> 1) + 0x1fc0_0000);
for _ in 0..4 {
guess = 0.5 * (guess + x / guess);
}
guess
}
#[inline]
pub(crate) fn celt_rsqrt(x: f32) -> f32 {
1.0 / celt_sqrt(x)
}
#[inline]
pub(crate) fn celt_rsqrt_norm(x: f32) -> f32 {
celt_rsqrt(x)
}
#[inline]
pub(crate) fn celt_rcp(x: f32) -> f32 {
1.0 / x
}
#[inline]
pub(crate) fn frac_div32(a: f32, b: f32) -> f32 {
a / b
}
#[inline]
pub(crate) fn frac_div32_q29(a: f32, b: f32) -> f32 {
frac_div32(a, b)
}
pub(crate) fn celt_maxabs16(samples: &[f32]) -> f32 {
let mut max_abs = 0.0f32;
for &value in samples {
let abs = value.abs();
if abs > max_abs {
max_abs = abs;
}
}
max_abs
}
#[inline]
pub(crate) fn celt_maxabs32(samples: &[f32]) -> f32 {
celt_maxabs16(samples)
}
#[inline]
fn opus_limit2_checkwithin1_scalar(samples: &mut [f32]) -> bool {
if samples.is_empty() {
return true;
}
for sample in samples {
*sample = sample.clamp(-2.0, 2.0);
}
false
}
#[inline]
fn celt_float2int16_scalar(input: &[f32], output: &mut [i16]) {
for (dst, &sample) in output.iter_mut().zip(input.iter()) {
*dst = float_cast::float2int16(sample);
}
}
cfg_if! {
if #[cfg(all(target_arch = "aarch64", not(feature = "force-scalar")))] {
mod aarch64_neon;
pub(crate) fn opus_limit2_checkwithin1(samples: &mut [f32]) -> bool {
aarch64_neon::opus_limit2_checkwithin1(samples)
}
pub(crate) fn celt_float2int16(input: &[f32], output: &mut [i16]) {
assert_eq!(
input.len(),
output.len(),
"input and output slices must have the same length"
);
aarch64_neon::celt_float2int16(input, output);
}
} else if #[cfg(all(target_arch = "arm", target_feature = "neon", not(feature = "force-scalar")))] {
mod arm_neon;
pub(crate) fn opus_limit2_checkwithin1(samples: &mut [f32]) -> bool {
arm_neon::opus_limit2_checkwithin1(samples)
}
pub(crate) fn celt_float2int16(input: &[f32], output: &mut [i16]) {
assert_eq!(
input.len(),
output.len(),
"input and output slices must have the same length"
);
arm_neon::celt_float2int16(input, output);
}
} else {
pub(crate) fn opus_limit2_checkwithin1(samples: &mut [f32]) -> bool {
opus_limit2_checkwithin1_scalar(samples)
}
pub(crate) fn celt_float2int16(input: &[f32], output: &mut [i16]) {
assert_eq!(
input.len(),
output.len(),
"input and output slices must have the same length"
);
celt_float2int16_scalar(input, output);
}
}
}
#[cfg(test)]
mod tests {
use core::f32::consts::PI;
use alloc::vec;
use libm::cosf;
use super::isqrt32;
use crate::celt::float_cast::CELT_SIG_SCALE;
use super::{
celt_cos_norm, celt_div, celt_exp2, celt_float2int16, celt_ilog2, celt_log2, celt_maxabs16,
celt_maxabs32, celt_rcp, celt_rsqrt, celt_rsqrt_norm, celt_sqrt, celt_zlog2, fast_atan2f,
frac_div32, frac_div32_q29, opus_limit2_checkwithin1,
};
use crate::celt::entcode::ec_ilog;
#[test]
fn fast_atan2f_matches_std() {
let samples = [
(0.0_f32, 0.0_f32),
(0.0, 1.0),
(1.0, 0.0),
(-1.0, 1.0),
(0.5, -0.75),
(3.0, 4.0),
(-2.0, -5.0),
];
for &(y, x) in &samples {
let approx = fast_atan2f(y, x);
let exact = y.atan2(x);
let diff = (approx - exact).abs();
assert!(diff <= 5e-3, "diff {} for y={}, x={}", diff, y, x);
}
}
#[test]
fn log2_matches_std() {
let values = [0.125_f32, 0.5, 1.0, 2.0, 10.0, 42.5];
for &value in &values {
let diff = (celt_log2(value) - value.log2()).abs();
assert!(diff <= 1e-6, "diff {} for value {}", diff, value);
}
}
#[test]
fn log2_matches_reference_harness() {
let error_threshold = 2.2e-6_f32;
let mut max_error = 0.0_f32;
let mut x = 0.001_f32;
while x < 1_677_700.0 {
let expected = x.ln() * core::f32::consts::LOG2_E;
let actual = celt_log2(x);
let error = (expected - actual).abs();
if error > max_error {
max_error = error;
}
assert!(
error <= error_threshold,
"celt_log2 failed: x = {}, error = {} (threshold = {})",
x,
error,
error_threshold
);
x += x / 8.0;
}
}
#[test]
fn exp2_matches_std() {
let values = [-5.0_f32, -1.0, 0.0, 0.25, 1.5, 4.0];
for &value in &values {
let diff = (celt_exp2(value) - value.exp2()).abs();
let eps = if cfg!(miri) { 1e-5 } else { 1e-6 };
assert!(diff <= eps, "diff {} for value {}", diff, value);
}
}
#[test]
fn exp2_matches_reference_harness() {
let error_threshold = 2.5e-6_f32;
let mut max_error = 0.0_f32;
let mut x = -11.0_f32;
while x < 24.0 {
let actual = celt_exp2(x);
let expected = actual.ln() * core::f32::consts::LOG2_E;
let error = (x - expected).abs();
if error > max_error {
max_error = error;
}
assert!(
error <= error_threshold,
"celt_exp2 failed: x = {}, error = {} (threshold = {})",
x,
error,
error_threshold
);
x += 0.0007;
}
}
#[test]
fn exp2_log2_roundtrip() {
let error_threshold = 2.0e-6_f32;
let mut max_error = 0.0_f32;
let mut x = -11.0_f32;
while x < 24.0 {
let roundtrip = celt_log2(celt_exp2(x));
let error = (x - roundtrip).abs();
if error > max_error {
max_error = error;
}
assert!(
error <= error_threshold,
"celt_exp2/celt_log2 roundtrip failed: x = {}, error = {} (threshold = {})",
x,
error,
error_threshold
);
x += 0.0007;
}
}
#[test]
fn div_matches_std() {
let samples = [(1.0_f32, 2.0_f32), (5.5, 1.1), (100.0, -25.0), (-3.75, 0.5)];
for &(a, b) in &samples {
assert!((celt_div(a, b) - a / b).abs() <= f32::EPSILON * 2.0);
}
}
#[test]
fn sqrt_matches_libm() {
let values = [0.0_f32, 0.25, 1.0, 2.5, 16.0];
for &value in &values {
let diff = (celt_sqrt(value) - value.sqrt()).abs();
assert!(diff <= 1e-6, "diff {} for value {}", diff, value);
}
}
#[test]
fn sqrt_matches_reference_harness() {
let mut i: i64 = 1;
while i <= 1_000_000_000 {
let fi = i as f32;
let val = celt_sqrt(fi);
let expected = fi.sqrt();
let ratio = val / expected;
let tolerance_ratio = 0.0005;
let tolerance_abs = 2.0;
assert!(
(ratio - 1.0).abs() <= tolerance_ratio || (val - expected).abs() <= tolerance_abs,
"sqrt failed: sqrt({}) = {} (ratio = {}, expected = {})",
i,
val,
ratio,
expected
);
i += (i >> 10).max(1);
}
}
#[test]
fn rsqrt_matches_inverse_sqrt() {
let values = [0.25_f32, 1.0, 4.0, 16.0];
for &value in &values {
let expected = 1.0 / value.sqrt();
assert!((celt_rsqrt(value) - expected).abs() <= 1e-6);
assert!((celt_rsqrt_norm(value) - expected).abs() <= 1e-6);
}
}
#[test]
fn rcp_matches_inverse() {
let values = [1.0_f32, -0.5, 2.0, -4.0];
for &value in &values {
let diff = (celt_rcp(value) - 1.0 / value).abs();
assert!(diff <= 1e-6, "diff {} for value {}", diff, value);
}
}
#[test]
fn rcp_matches_reference_harness() {
for i in 1..=327_670 {
let val = celt_rcp(i as f32);
let prod = val * (i as f32);
assert!(
(prod - 1.0).abs() <= 0.00025,
"div failed: 1/{} = {} (product = {})",
i,
val,
prod
);
}
}
#[test]
fn frac_div_helpers_match_division() {
let samples = [(1.0_f32, 4.0_f32), (5.0, -2.0), (-6.0, 3.0)];
for &(a, b) in &samples {
assert!((frac_div32(a, b) - a / b).abs() <= f32::EPSILON * 2.0);
assert!((frac_div32_q29(a, b) - a / b).abs() <= f32::EPSILON * 2.0);
}
}
#[test]
fn maxabs16_matches_manual_scan() {
let samples = [0.0f32, -1.5, 3.25, -0.875, 2.0];
assert!((celt_maxabs16(&samples) - 3.25).abs() <= f32::EPSILON);
assert_eq!(celt_maxabs16(&[]), 0.0);
}
#[test]
fn maxabs32_aliases_16() {
let samples = [-4.0f32, 1.0, 2.5];
assert_eq!(celt_maxabs32(&samples), celt_maxabs16(&samples));
}
#[test]
fn cos_norm_matches_reference() {
let inputs = [0.0_f32, 0.25, 0.5, 0.75, 1.0];
for &input in &inputs {
let expected = cosf(0.5 * PI * input);
assert!((celt_cos_norm(input) - expected).abs() <= 1e-6);
}
}
#[test]
fn limit2_clamps_and_returns_hint() {
let mut samples = [-3.5_f32, -2.0, -0.5, 0.75, 1.5, 3.75];
let hint = opus_limit2_checkwithin1(&mut samples);
assert!(!hint);
assert_eq!(samples, [-2.0, -2.0, -0.5, 0.75, 1.5, 2.0]);
let mut empty: [f32; 0] = [];
assert!(opus_limit2_checkwithin1(&mut empty));
}
#[test]
fn float2int16_matches_reference_scaling() {
let input = [-2.0_f32, -1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0, 2.0];
let mut output = [0_i16; 9];
celt_float2int16(&input, &mut output);
assert_eq!(
output,
[
-32_768, -32_768, -16_384, -8_192, 0, 8_192, 16_384, 32_767, 32_767
]
);
}
#[test]
fn float2int16_uses_saturating_round_to_nearest() {
let input = [
-1.001_f32, -1.000_03, -0.999_9, -0.500_3, -0.499_7, 0.499_7, 0.500_3, 0.999_9,
1.000_03, 1.001,
];
let mut output = [0_i16; 10];
celt_float2int16(&input, &mut output);
assert_eq!(output[0], -32_768);
assert_eq!(output[1], -32_768);
assert_eq!(output[8], 32_767);
assert_eq!(output[9], 32_767);
for (&input_sample, &output_sample) in input.iter().zip(&output) {
let output_sample = i32::from(output_sample);
assert!((-32_768..=32_767).contains(&output_sample));
let scaled = (input_sample * CELT_SIG_SCALE).clamp(-32_768.0, 32_767.0);
let diff = scaled - output_sample as f32;
assert!(
diff.abs() <= 0.500_1,
"diff {} for input {} (scaled {})",
diff,
input_sample,
scaled
);
}
}
#[test]
#[should_panic(expected = "input and output slices must have the same length")]
fn float2int16_panics_on_length_mismatch() {
let input = [0.0_f32, 1.0];
let mut output = [0_i16; 1];
celt_float2int16(&input, &mut output);
}
#[test]
fn celt_ilog2_matches_ec_ilog_minus_one() {
let values = [
1, 2, 3, 4, 7, 8, 9, 15, 16, 17, 31, 32, 33, 1_024, 65_535, 65_536, 1_048_576,
];
for &value in &values {
assert_eq!(celt_ilog2(value), ec_ilog(value as u32) - 1);
}
}
#[test]
fn celt_zlog2_handles_non_positive_inputs() {
assert_eq!(celt_zlog2(0), 0);
assert_eq!(celt_zlog2(-123), 0);
assert_eq!(celt_zlog2(1), celt_ilog2(1));
assert_eq!(celt_zlog2(2_048), celt_ilog2(2_048));
}
#[test]
fn isqrt32_matches_f64_reference() {
let mut values = vec![
1u32,
2,
3,
4,
7,
9,
15,
16,
24,
36,
64,
65,
255,
256,
257,
1_000,
65_535,
65_536,
1_048_575,
u32::MAX,
];
for shift in 0..31 {
let base = 1u32 << shift;
values.push(base.saturating_sub(1));
values.push(base);
values.push(base.saturating_add(1));
}
values.sort_unstable();
values.dedup();
for value in values {
let expected = (f64::from(value).sqrt().floor()) as u32;
assert_eq!(isqrt32(value), expected, "value {}", value);
}
assert_eq!(isqrt32(0), 0);
}
}