use crate::silk::lin2log::lin2log;
pub const LTP_ORDER: usize = 5;
const SILK_FIX_CONST_1_001_Q15: i32 = 32_801;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VqWMatEcResult {
pub index: i8,
pub residual_energy_q15: i32,
pub rate_dist_q8: i32,
pub gain_q7: i32,
}
impl Default for VqWMatEcResult {
fn default() -> Self {
Self {
index: 0,
residual_energy_q15: i32::MAX,
rate_dist_q8: i32::MAX,
gain_q7: 0,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn vq_wmat_ec(
xx_q17: &[i32; LTP_ORDER * LTP_ORDER],
x_x_q17: &[i32; LTP_ORDER],
cb_q7: &[[i8; LTP_ORDER]],
cb_gain_q7: &[u8],
cl_q5: &[u8],
subfr_len: i32,
max_gain_q7: i32,
) -> VqWMatEcResult {
let l = cb_gain_q7.len();
assert_eq!(cb_q7.len(), l, "codebook rows must match gain table");
assert_eq!(cl_q5.len(), l, "code lengths must match the codebook size");
let mut neg_xx_q24 = [0i32; LTP_ORDER];
for (dst, &value) in neg_xx_q24.iter_mut().zip(x_x_q17.iter()) {
*dst = value.wrapping_shl(7).wrapping_neg();
}
let mut best = VqWMatEcResult::default();
for (row_index, row) in cb_q7.iter().enumerate() {
let gain_tmp_q7 = i32::from(cb_gain_q7[row_index]);
let penalty = (gain_tmp_q7 - max_gain_q7).max(0).wrapping_shl(11);
let mut sum1_q15 = SILK_FIX_CONST_1_001_Q15;
let mut sum2_q24 = mla(neg_xx_q24[0], xx_q17[1], i32::from(row[1]));
sum2_q24 = mla(sum2_q24, xx_q17[2], i32::from(row[2]));
sum2_q24 = mla(sum2_q24, xx_q17[3], i32::from(row[3]));
sum2_q24 = mla(sum2_q24, xx_q17[4], i32::from(row[4]));
sum2_q24 = sum2_q24.wrapping_shl(1);
sum2_q24 = mla(sum2_q24, xx_q17[0], i32::from(row[0]));
sum1_q15 = smlawb(sum1_q15, sum2_q24, i32::from(row[0]));
sum2_q24 = mla(neg_xx_q24[1], xx_q17[7], i32::from(row[2]));
sum2_q24 = mla(sum2_q24, xx_q17[8], i32::from(row[3]));
sum2_q24 = mla(sum2_q24, xx_q17[9], i32::from(row[4]));
sum2_q24 = sum2_q24.wrapping_shl(1);
sum2_q24 = mla(sum2_q24, xx_q17[6], i32::from(row[1]));
sum1_q15 = smlawb(sum1_q15, sum2_q24, i32::from(row[1]));
sum2_q24 = mla(neg_xx_q24[2], xx_q17[13], i32::from(row[3]));
sum2_q24 = mla(sum2_q24, xx_q17[14], i32::from(row[4]));
sum2_q24 = sum2_q24.wrapping_shl(1);
sum2_q24 = mla(sum2_q24, xx_q17[12], i32::from(row[2]));
sum1_q15 = smlawb(sum1_q15, sum2_q24, i32::from(row[2]));
sum2_q24 = mla(neg_xx_q24[3], xx_q17[19], i32::from(row[4]));
sum2_q24 = sum2_q24.wrapping_shl(1);
sum2_q24 = mla(sum2_q24, xx_q17[18], i32::from(row[3]));
sum1_q15 = smlawb(sum1_q15, sum2_q24, i32::from(row[3]));
sum2_q24 = neg_xx_q24[4].wrapping_shl(1);
sum2_q24 = mla(sum2_q24, xx_q17[24], i32::from(row[4]));
sum1_q15 = smlawb(sum1_q15, sum2_q24, i32::from(row[4]));
if sum1_q15 >= 0 {
let sum_with_penalty = sum1_q15.wrapping_add(penalty);
let bits_res_q8 = smulbb(subfr_len, lin2log(sum_with_penalty) - (15 << 7));
let bits_tot_q8 = add_lshift32(bits_res_q8, i32::from(cl_q5[row_index]), 2);
if bits_tot_q8 <= best.rate_dist_q8 {
best.index = row_index as i8;
best.residual_energy_q15 = sum_with_penalty;
best.rate_dist_q8 = bits_tot_q8;
best.gain_q7 = gain_tmp_q7;
}
}
}
best
}
fn mla(a: i32, b: i32, c: i32) -> i32 {
a.wrapping_add(b.wrapping_mul(c))
}
fn smlawb(a: i32, b: i32, c: i32) -> i32 {
let product = (i64::from(b) * i64::from(i32::from(c as i16))) >> 16;
a.wrapping_add(product as i32)
}
fn smulbb(a: i32, b: i32) -> i32 {
i32::from((a as i16).wrapping_mul(b as i16))
}
fn add_lshift32(a: i32, b: i32, shift: u32) -> i32 {
a.wrapping_add(b.wrapping_shl(shift))
}
#[cfg(test)]
mod tests {
use super::{LTP_ORDER, VqWMatEcResult, vq_wmat_ec};
use alloc::vec::Vec;
#[test]
fn selects_single_codebook_entry() {
let xx_q17: [i32; 25] = [
20_000, 1_000, 2_000, 3_000, 4_000, 1_000, 19_000, 1_100, 1_200, 1_300, 2_000, 1_100,
18_000, 1_400, 1_500, 3_000, 1_200, 1_400, 17_000, 1_600, 4_000, 1_300, 1_500, 1_600,
16_000,
];
let x_x_q17: [i32; 5] = [300, 400, 500, 600, 700];
let cb_q7 = [[-3, -2, -1, 0, 1]];
let cb_gain_q7 = [12u8];
let cl_q5 = [5u8];
let result = vq_wmat_ec(&xx_q17, &x_x_q17, &cb_q7, &cb_gain_q7, &cl_q5, 60, 20);
assert_eq!(
result,
VqWMatEcResult {
index: 0,
residual_energy_q15: 32_810,
rate_dist_q8: 20,
gain_q7: 12,
},
);
}
#[test]
fn prefers_lower_rate_distortion() {
let xx_q17: [i32; 25] = [
21_000, 900, 1_100, 1_200, 1_400, 900, 21_000, 1_000, 1_100, 1_200, 1_100, 1_000,
21_000, 1_050, 1_100, 1_200, 1_100, 1_050, 21_000, 1_000, 1_400, 1_200, 1_100, 1_000,
21_000,
];
let x_x_q17: [i32; 5] = [250, 260, 270, 280, 290];
let cb_q7 = [
[3, 2, 1, 0, -1], [-2, -2, -2, -2, -2], ];
let cb_gain_q7 = [18u8, 10u8];
let cl_q5 = [10u8, 6u8];
let result = vq_wmat_ec(&xx_q17, &x_x_q17, &cb_q7, &cb_gain_q7, &cl_q5, 80, 25);
assert_eq!(
result,
VqWMatEcResult {
index: 1,
residual_energy_q15: 32_816,
rate_dist_q8: 24,
gain_q7: 10,
},
);
}
#[test]
fn returns_default_when_all_candidates_negative() {
let xx_q17: [i32; 25] = [
-1_000_000, 0, 0, 0, 0, 0, -1_000_000, 0, 0, 0, 0, 0, -1_000_000, 0, 0, 0, 0, 0,
-1_000_000, 0, 0, 0, 0, 0, -1_000_000,
];
let x_x_q17: [i32; 5] = [0, 0, 0, 0, 0];
let cb_q7 = [[127, 127, 127, 127, 127]];
let cb_gain_q7 = [127u8];
let cl_q5 = [0u8];
let result = vq_wmat_ec(&xx_q17, &x_x_q17, &cb_q7, &cb_gain_q7, &cl_q5, 40, 200);
assert_eq!(result, VqWMatEcResult::default());
}
#[test]
fn codebook_layout_matches_rows() {
let mut cb = Vec::new();
for row in 0..3 {
let mut taps = [0i8; LTP_ORDER];
for col in 0..LTP_ORDER {
taps[col] = (row * 10 + col) as i8;
}
cb.push(taps);
}
let xx_q17: [i32; 25] = [
25_000, 0, 0, 0, 0, 0, 25_000, 0, 0, 0, 0, 0, 25_000, 0, 0, 0, 0, 0, 25_000, 0, 0, 0,
0, 0, 25_000,
];
let x_x_q17: [i32; 5] = [0, 0, 0, 0, 0];
let cb_gain_q7 = [0u8, 0u8, 0u8];
let cl_q5 = [0u8, 0u8, 0u8];
let result = vq_wmat_ec(&xx_q17, &x_x_q17, &cb, &cb_gain_q7, &cl_q5, 20, 100);
assert_eq!(result.index, 0);
}
}