pub const Q15_SHIFT: i32 = 15;
pub const Q15_ONE: i32 = 32767;
pub const Q31_SHIFT: i32 = 31;
pub const Q31_ONE: i32 = 2147483647;
pub const NORM_SHIFT: i32 = 24;
#[inline(always)]
pub fn sat16(x: i32) -> i16 {
if x > 32767 {
32767
} else if x < -32768 {
-32768
} else {
x as i16
}
}
#[inline(always)]
pub fn sat32(x: i64) -> i32 {
if x > i32::MAX as i64 {
i32::MAX
} else if x < i32::MIN as i64 {
i32::MIN
} else {
x as i32
}
}
#[inline(always)]
pub fn mul_q15(a: i16, b: i16) -> i16 {
let prod = (a as i32 * b as i32 + 16384) >> 15;
prod as i16
}
#[inline(always)]
pub fn mul16_16(a: i16, b: i16) -> i32 {
a as i32 * b as i32
}
#[inline(always)]
pub fn mul_q14(a: i16, b: i16) -> i16 {
((a as i32 * b as i32 + 8192) >> 14) as i16
}
#[inline(always)]
pub fn mul16_32_q16(a: i16, b: i32) -> i32 {
((a as i64 * b as i64) >> 16) as i32
}
#[inline(always)]
pub fn shl32(a: i32, shift: i32) -> i32 {
if shift <= 0 {
a >> (-shift)
} else {
let result = (a as i64) << shift;
sat32(result)
}
}
#[inline(always)]
pub fn shr32(a: i32, shift: i32) -> i32 {
if shift <= 0 {
a << (-shift)
} else {
a >> shift
}
}
#[inline(always)]
pub fn pshr32(a: i32, shift: i32) -> i32 {
if shift <= 0 {
a << (-shift)
} else {
(a + (1 << (shift - 1))) >> shift
}
}
#[inline(always)]
pub fn abs16(x: i16) -> i16 {
if x < 0 { -x } else { x }
}
#[inline(always)]
pub fn abs32(x: i32) -> i32 {
if x < 0 { -x } else { x }
}
#[inline(always)]
pub fn float_to_q15(x: f32) -> i16 {
let scaled = x * 32767.0;
sat16(scaled as i32)
}
#[inline(always)]
pub fn q15_to_float(x: i16) -> f32 {
x as f32 / 32767.0
}
#[inline(always)]
pub fn recip_q15(x: i16) -> i16 {
if x <= 0 {
return Q15_ONE as i16;
}
let x_f = x as f32 / 32767.0;
let recip_f = 1.0 / x_f;
float_to_q15(recip_f)
}
#[inline(always)]
pub fn mac_q15(acc: i32, a: i16, b: i16) -> i32 {
acc + mul16_16(a, b)
}
#[inline(always)]
pub fn sqr_q15(a: i16) -> i32 {
let a_i32 = a as i32;
a_i32 * a_i32
}
pub fn sum_sqr_q15(x: &[i16]) -> i32 {
let mut sum: i64 = 0;
for &v in x {
let v_i32 = v as i32;
sum += (v_i32 * v_i32) as i64;
}
(sum >> 15) as i32
}
pub fn dot_product_q15(a: &[i16], b: &[i16]) -> i32 {
assert_eq!(a.len(), b.len());
let mut sum: i64 = 0;
for i in 0..a.len() {
sum += (a[i] as i32 * b[i] as i32) as i64;
}
(sum >> 15) as i32
}
pub fn maxabs16(x: &[i16]) -> i16 {
let mut max_val: i16 = 0;
let mut min_val: i16 = 0;
for &v in x {
if v > max_val {
max_val = v;
}
if v < min_val {
min_val = v;
}
}
if max_val > -min_val {
max_val
} else {
-min_val
}
}
pub fn normalize_to_q15(x: &[f32], out: &mut [i16], max_val_out: &mut f32) {
assert_eq!(x.len(), out.len());
let mut max_val = 0.0f32;
for &v in x {
let abs_v = v.abs();
if abs_v > max_val {
max_val = abs_v;
}
}
*max_val_out = max_val;
if max_val < 1e-15 {
out.fill(0);
return;
}
let scale = 32767.0 / max_val;
for i in 0..x.len() {
let scaled = x[i] * scale;
out[i] = sat16(scaled as i32);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sat16() {
assert_eq!(sat16(1000), 1000);
assert_eq!(sat16(40000), 32767);
assert_eq!(sat16(-40000), -32768);
}
#[test]
fn test_mul_q15() {
let a = float_to_q15(0.5);
let b = float_to_q15(0.5);
let c = mul_q15(a, b);
let c_f = q15_to_float(c);
assert!((c_f - 0.25).abs() < 0.001);
}
#[test]
fn test_normalize() {
let x = [0.0, 0.5, 1.0, -0.5, -1.0];
let mut out = [0i16; 5];
let mut max_val = 0.0;
normalize_to_q15(&x, &mut out, &mut max_val);
let max_out = maxabs16(&out);
assert!(max_out > 32000);
for i in 0..x.len() {
let back = q15_to_float(out[i]) * max_val;
assert!((back - x[i]).abs() < 0.02, "Mismatch at {}: {} vs {}", i, back, x[i]);
}
}
}