use crate::range::RangeEncoder;
use crate::silk::SilkRangeDecoder;
use crate::silk::tables_pulses_per_block::SILK_SIGN_ICDF;
const SHELL_CODEC_FRAME_LENGTH: usize = 16;
const LOG2_SHELL_CODEC_FRAME_LENGTH: usize = 4;
#[inline]
fn silk_enc_map(value: i16) -> usize {
((i32::from(value) >> 15) + 1) as usize
}
#[inline]
fn silk_dec_map(symbol: usize) -> i16 {
debug_assert!(symbol <= 1);
((symbol as i16) << 1) - 1
}
#[inline]
fn sign_icdf_base(signal_type: i32, quant_offset_type: i32) -> usize {
let index = 7 * (quant_offset_type + (signal_type << 1));
debug_assert!(
(0..=(SILK_SIGN_ICDF.len() as i32 - 7)).contains(&index),
"invalid range-coder context for pulse signs"
);
index as usize
}
fn number_of_shell_blocks(frame_length: usize) -> usize {
(frame_length + (SHELL_CODEC_FRAME_LENGTH / 2)) >> LOG2_SHELL_CODEC_FRAME_LENGTH
}
pub fn silk_encode_signs(
encoder: &mut RangeEncoder,
pulses: &[i8],
frame_length: usize,
signal_type: i32,
quant_offset_type: i32,
sum_pulses: &[i32],
) {
assert!(
frame_length <= pulses.len(),
"pulse buffer shorter than frame length"
);
let num_blocks = number_of_shell_blocks(frame_length);
assert!(
sum_pulses.len() >= num_blocks,
"sum_pulses slice shorter than required shell blocks"
);
let mut icdf = [0u8; 2];
icdf[1] = 0;
let icdf_ptr = &SILK_SIGN_ICDF[sign_icdf_base(signal_type, quant_offset_type)..];
let mut pulse_index = 0usize;
for &total in sum_pulses.iter().take(num_blocks) {
if total > 0 {
let table_index = ((total & 0x1F) as usize).min(6);
icdf[0] = icdf_ptr[table_index];
let block_end = (pulse_index + SHELL_CODEC_FRAME_LENGTH).min(frame_length);
for &pulse in &pulses[pulse_index..block_end] {
if pulse != 0 {
let symbol = silk_enc_map(pulse.into());
encoder.encode_icdf(symbol, &icdf, 8);
}
}
}
pulse_index += SHELL_CODEC_FRAME_LENGTH;
}
}
pub fn silk_decode_signs(
decoder: &mut impl SilkRangeDecoder,
pulses: &mut [i16],
frame_length: usize,
signal_type: i32,
quant_offset_type: i32,
sum_pulses: &[i32],
) {
assert!(
frame_length <= pulses.len(),
"pulse buffer shorter than frame length"
);
let num_blocks = number_of_shell_blocks(frame_length);
assert!(
sum_pulses.len() >= num_blocks,
"sum_pulses slice shorter than required shell blocks"
);
let mut icdf = [0u8; 2];
icdf[1] = 0;
let icdf_ptr = &SILK_SIGN_ICDF[sign_icdf_base(signal_type, quant_offset_type)..];
let mut pulse_index = 0usize;
for &total in sum_pulses.iter().take(num_blocks) {
if total > 0 {
let table_index = ((total & 0x1F) as usize).min(6);
icdf[0] = icdf_ptr[table_index];
let block_end = (pulse_index + SHELL_CODEC_FRAME_LENGTH).min(frame_length);
for pulse in &mut pulses[pulse_index..block_end] {
if *pulse > 0 {
let symbol = decoder.decode_icdf(&icdf, 8);
*pulse *= silk_dec_map(symbol);
}
}
}
pulse_index += SHELL_CODEC_FRAME_LENGTH;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::celt::EcDec;
use alloc::{vec, vec::Vec};
fn sums_for_blocks(pulses: &[i8]) -> Vec<i32> {
let num_blocks = number_of_shell_blocks(pulses.len());
let mut sums = vec![0i32; num_blocks];
for (block_index, chunk) in pulses
.chunks(SHELL_CODEC_FRAME_LENGTH)
.take(num_blocks)
.enumerate()
{
let sum = chunk.iter().map(|value| i32::from(value.abs())).sum();
sums[block_index] = sum;
}
sums
}
#[test]
fn encode_decode_roundtrip() {
let frame_length = 32;
let pulses = [
3, -1, 0, 2, -2, 0, 1, -4, 0, 0, 2, -1, 0, 1, 0, -1, -2, 1, 0, -1, 3, 0, 0, -2, 0, 1, 0, 2, -1, 0, 1, 0,
];
let sums = sums_for_blocks(&pulses);
let signal_type = 2; let quant_offset_type = 0;
let mut encoder = RangeEncoder::new();
silk_encode_signs(
&mut encoder,
&pulses,
frame_length,
signal_type,
quant_offset_type,
&sums,
);
let mut encoded = encoder.finish();
let mut magnitudes: Vec<i16> = pulses.iter().map(|&value| i16::from(value.abs())).collect();
let mut decoder = EcDec::new(encoded.as_mut_slice());
silk_decode_signs(
&mut decoder,
&mut magnitudes,
frame_length,
signal_type,
quant_offset_type,
&sums,
);
let reconstructed: Vec<i8> = magnitudes.iter().map(|&value| value as i8).collect();
assert_eq!(reconstructed, pulses);
}
#[test]
fn zero_sum_blocks_emit_no_bits() {
let frame_length = SHELL_CODEC_FRAME_LENGTH;
let pulses = [0i8; SHELL_CODEC_FRAME_LENGTH];
let sums = vec![0];
let signal_type = 0;
let quant_offset_type = 1;
let mut encoder = RangeEncoder::new();
silk_encode_signs(
&mut encoder,
&pulses,
frame_length,
signal_type,
quant_offset_type,
&sums,
);
let mut encoded = encoder.finish();
assert!(encoded.is_empty());
let mut magnitudes = [0i16; SHELL_CODEC_FRAME_LENGTH];
let mut decoder = EcDec::new(encoded.as_mut_slice());
silk_decode_signs(
&mut decoder,
&mut magnitudes,
frame_length,
signal_type,
quant_offset_type,
&sums,
);
assert!(magnitudes.iter().all(|&value| value == 0));
}
}