use core::cmp::Ordering;
use crate::silk::decode_indices::ConditionalCoding;
use crate::silk::encoder::control::EncoderControl;
use crate::silk::encoder::state::EncoderChannelState;
use crate::silk::gain_quant::silk_gains_quant;
use crate::silk::log2lin::log2lin;
use crate::silk::sigm_q15::sigm_q15;
use crate::silk::tables_other::SILK_QUANTIZATION_OFFSETS_Q10;
use crate::silk::{FrameQuantizationOffsetType, FrameSignalType, MAX_NB_SUBFR};
const LTP_SIGMOID_OFFSET_Q7: i32 = 12 << 7;
const INV_MAX_SQR_BASE_Q7: i32 = 8894; const INV_MAX_SQR_EXP_Q16: i32 = 21_627; const ONE_Q7: i32 = 1 << 7;
const LAMBDA_OFFSET_Q10: i32 = 1_229; const LAMBDA_DELAYED_DECISIONS_Q10: i32 = -50; const LAMBDA_SPEECH_ACT_Q18: i32 = -52_428; const LAMBDA_INPUT_QUALITY_Q12: i32 = -409; const LAMBDA_CODING_QUALITY_Q12: i32 = -818; const LAMBDA_QUANT_OFFSET_Q16: i32 = 52_429;
pub fn process_gains(
encoder: &mut EncoderChannelState,
control: &mut EncoderControl,
cond_coding: ConditionalCoding,
) {
let nb_subfr = encoder.common.nb_subfr;
assert!(
nb_subfr == MAX_NB_SUBFR || nb_subfr == MAX_NB_SUBFR / 2,
"encoder supports 2 or 4 subframes"
);
if matches!(encoder.common.indices.signal_type, FrameSignalType::Voiced) {
let diff_q7 = control.lt_pred_cod_gain_q7 - LTP_SIGMOID_OFFSET_Q7;
let scaled_q5 = rshift_round(diff_q7, 4);
let reduction_q16 = -sigm_q15(scaled_q5);
for gain in control.gains_q16.iter_mut().take(nb_subfr) {
*gain = smlawb(*gain, *gain, reduction_q16);
}
}
let subfr_length = encoder.common.subfr_length as i32;
let log_arg_q7 = smulwb(
INV_MAX_SQR_BASE_Q7 - encoder.common.snr_db_q7,
INV_MAX_SQR_EXP_Q16,
);
let inv_max_sqr_val_q16 = if subfr_length > 0 {
div32_16(log2lin(log_arg_q7), subfr_length)
} else {
0
};
for k in 0..nb_subfr {
let mut res_nrg_part = smulww(control.res_nrg[k], inv_max_sqr_val_q16);
let res_q = control.res_nrg_q[k];
if res_q > 0 {
res_nrg_part = rshift_round(res_nrg_part, res_q);
} else if res_q < 0 {
let shift = -res_q;
let limit = i32::MAX >> shift;
if res_nrg_part >= limit {
res_nrg_part = i32::MAX;
} else {
res_nrg_part = res_nrg_part.wrapping_shl(shift as u32);
}
}
let gain = control.gains_q16[k];
let gain_squared = add_sat32(res_nrg_part, smmul(gain, gain));
if gain_squared < i32::from(i16::MAX) {
let precise = smla_ww(res_nrg_part.wrapping_shl(16), gain, gain);
debug_assert!(precise > 0, "gain clamp expects positive precision");
let root = sqrt_approx(precise);
let clamped = root.min(i32::MAX >> 8);
control.gains_q16[k] = lshift_sat32(clamped, 8);
} else {
let root = sqrt_approx(gain_squared);
let clamped = root.min(i32::MAX >> 16);
control.gains_q16[k] = lshift_sat32(clamped, 16);
}
}
control.gains_unq_q16[..nb_subfr].copy_from_slice(&control.gains_q16[..nb_subfr]);
control.last_gain_index_prev = encoder.shape_state.last_gain_index as i8;
let conditional = matches!(cond_coding, ConditionalCoding::Conditional);
let mut last_gain_index = control.last_gain_index_prev;
silk_gains_quant(
&mut encoder.common.indices.gains_indices[..nb_subfr],
&mut control.gains_q16[..nb_subfr],
&mut last_gain_index,
conditional,
);
encoder.shape_state.last_gain_index = i32::from(last_gain_index);
if matches!(encoder.common.indices.signal_type, FrameSignalType::Voiced) {
let combined = control.lt_pred_cod_gain_q7 + (encoder.common.input_tilt_q15 >> 8);
encoder.common.indices.quant_offset_type = if combined > ONE_Q7 {
FrameQuantizationOffsetType::Low
} else {
FrameQuantizationOffsetType::High
};
}
let signal_row = (i32::from(encoder.common.indices.signal_type) >> 1) as usize;
let quant_col = match encoder.common.indices.quant_offset_type {
FrameQuantizationOffsetType::Low => 0,
FrameQuantizationOffsetType::High => 1,
};
let quant_offset_q10 = i32::from(SILK_QUANTIZATION_OFFSETS_Q10[signal_row][quant_col]);
control.lambda_q10 = LAMBDA_OFFSET_Q10
+ smulbb(
LAMBDA_DELAYED_DECISIONS_Q10,
encoder.common.n_states_delayed_decision,
)
+ smulwb(LAMBDA_SPEECH_ACT_Q18, encoder.common.speech_activity_q8)
+ smulwb(LAMBDA_INPUT_QUALITY_Q12, control.input_quality_q14)
+ smulwb(LAMBDA_CODING_QUALITY_Q12, control.coding_quality_q14)
+ smulwb(LAMBDA_QUANT_OFFSET_Q16, quant_offset_q10);
debug_assert!(
control.lambda_q10 > 0 && control.lambda_q10 < (2 << 10),
"lambda must stay within the fixed-point range"
);
}
fn rshift_round(value: i32, shift: i32) -> i32 {
debug_assert!(shift > 0);
if shift == 1 {
(value >> 1) + (value & 1)
} else {
((value >> (shift - 1)) + 1) >> 1
}
}
fn smulwb(a: i32, b: i32) -> i32 {
((i64::from(a) * i64::from(b)) >> 16) as i32
}
fn smulww(a: i32, b: i32) -> i32 {
((i64::from(a) * i64::from(b)) >> 16) as i32
}
fn smla_ww(a: i32, b: i32, c: i32) -> i32 {
a.wrapping_add(smulww(b, c))
}
fn smmul(a: i32, b: i32) -> i32 {
((i64::from(a) * i64::from(b)) >> 32) as i32
}
fn div32_16(a: i32, b: i32) -> i32 {
if b == 0 { 0 } else { a / b }
}
fn add_sat32(a: i32, b: i32) -> i32 {
let sum = i64::from(a) + i64::from(b);
if sum > i64::from(i32::MAX) {
i32::MAX
} else if sum < i64::from(i32::MIN) {
i32::MIN
} else {
sum as i32
}
}
fn lshift_sat32(value: i32, shift: i32) -> i32 {
if shift <= 0 {
return value;
}
if shift >= 31 {
return match value.cmp(&0) {
Ordering::Greater => i32::MAX,
Ordering::Less => i32::MIN,
Ordering::Equal => 0,
};
}
let max_val = i32::MAX >> shift;
let min_val = i32::MIN >> shift;
if value > max_val {
i32::MAX
} else if value < min_val {
i32::MIN
} else {
value << shift
}
}
fn smulbb(a: i32, b: i32) -> i32 {
let lhs = i32::from(a as i16);
let rhs = i32::from(b as i16);
lhs * rhs
}
fn smlawb(a: i32, b: i32, c: i32) -> i32 {
a.wrapping_add(smulwb(b, c))
}
fn sqrt_approx(x: i32) -> i32 {
if x <= 0 {
return 0;
}
let leading = x.leading_zeros() as i32;
let frac = ((x as u32).rotate_right(((24 - leading) & 31) as u32) & 0x7f) as i32;
let mut y = if leading & 1 != 0 { 32_768 } else { 46_214 };
y >>= leading >> 1;
smlawb(y, y, smulbb(213, frac))
}
#[cfg(test)]
mod tests {
use super::process_gains;
use crate::silk::decode_indices::{ConditionalCoding, SideInfoIndices};
use crate::silk::encoder::control::EncoderControl;
use crate::silk::encoder::state::{EncoderChannelState, EncoderStateCommon};
use crate::silk::{FrameQuantizationOffsetType, FrameSignalType, MAX_NB_SUBFR};
fn encoder_state(signal_type: FrameSignalType) -> EncoderChannelState {
let mut state = EncoderChannelState::default();
state.common.nb_subfr = MAX_NB_SUBFR;
state.common.subfr_length = 40;
state.common.snr_db_q7 = 2_560;
state.common.indices = SideInfoIndices {
signal_type,
quant_offset_type: FrameQuantizationOffsetType::Low,
..SideInfoIndices::default()
};
state
}
#[test]
fn voiced_frames_reduce_gains_when_ltp_gain_is_high() {
let mut encoder = encoder_state(FrameSignalType::Voiced);
encoder.common.input_tilt_q15 = 0;
let mut control = EncoderControl::default();
control.gains_q16 = [1 << 16; MAX_NB_SUBFR];
control.lt_pred_cod_gain_q7 = 3_200;
process_gains(&mut encoder, &mut control, ConditionalCoding::Independent);
for gain in control.gains_unq_q16.iter() {
assert!(
*gain < 1 << 16,
"voiced reduction should act on unquantised gains"
);
}
assert_eq!(
encoder.common.indices.quant_offset_type,
FrameQuantizationOffsetType::Low
);
}
#[test]
fn unvoiced_frames_keep_high_quant_offset_when_ltp_gain_is_low() {
let mut encoder = encoder_state(FrameSignalType::Voiced);
encoder.common.input_tilt_q15 = 1 << 10;
let mut control = EncoderControl::default();
control.gains_q16 = [90_000; MAX_NB_SUBFR];
control.lt_pred_cod_gain_q7 = 50;
process_gains(&mut encoder, &mut control, ConditionalCoding::Independent);
assert_eq!(
encoder.common.indices.quant_offset_type,
FrameQuantizationOffsetType::High
);
}
#[test]
fn lambda_tracks_quality_and_speech_activity() {
let mut encoder = EncoderChannelState::default();
encoder.common = EncoderStateCommon {
snr_db_q7: 0,
subfr_length: 40,
nb_subfr: MAX_NB_SUBFR,
n_states_delayed_decision: 2,
speech_activity_q8: 128,
input_tilt_q15: 0,
indices: SideInfoIndices::default(),
..EncoderStateCommon::default()
};
let mut control = EncoderControl::default();
control.gains_q16 = [65_536; MAX_NB_SUBFR];
control.res_nrg = [10_000; MAX_NB_SUBFR];
control.res_nrg_q = [0; MAX_NB_SUBFR];
control.input_quality_q14 = 1 << 13;
control.coding_quality_q14 = 1 << 13;
process_gains(&mut encoder, &mut control, ConditionalCoding::Conditional);
assert!(control.lambda_q10 < 1 << 11);
assert!(encoder.shape_state.last_gain_index != 0);
for gain in control.gains_q16.iter() {
assert!(*gain >= 65_536);
}
}
}