Skip to main content

opus_rs/silk/
dec_api.rs

1use crate::range_coder::RangeCoder;
2use crate::silk::decode_frame::{FLAG_DECODE_NORMAL, FLAG_PACKET_LOST, silk_decode_frame};
3use crate::silk::decode_indices::silk_decode_indices;
4use crate::silk::decode_pulses::silk_decode_pulses;
5use crate::silk::decoder_structs::SilkDecoderState;
6use crate::silk::define::*;
7use crate::silk::init_decoder::{silk_decoder_set_fs, silk_init_decoder};
8use crate::silk::tables::{SILK_LBRR_FLAGS_2_ICDF, SILK_LBRR_FLAGS_3_ICDF};
9
10pub struct SilkDecoder {
11    pub channel_state: [SilkDecoderState; 2],
12
13    pub n_channels_api: i32,
14
15    pub n_channels_internal: i32,
16
17    pub prev_decode_only_middle: i32,
18}
19
20impl Default for SilkDecoder {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl SilkDecoder {
27    pub fn new() -> Self {
28        let mut dec = Self {
29            channel_state: [SilkDecoderState::default(), SilkDecoderState::default()],
30            n_channels_api: 1,
31            n_channels_internal: 1,
32            prev_decode_only_middle: 0,
33        };
34        silk_init_decoder(&mut dec.channel_state[0]);
35        silk_init_decoder(&mut dec.channel_state[1]);
36        dec
37    }
38
39    pub fn init(&mut self, sample_rate_hz: i32, channels: i32) -> i32 {
40        let fs_khz = sample_rate_hz / 1000;
41        let ret = silk_decoder_set_fs(&mut self.channel_state[0], fs_khz, sample_rate_hz);
42        if ret < 0 {
43            return ret;
44        }
45        if channels == 2 {
46            let ret = silk_decoder_set_fs(&mut self.channel_state[1], fs_khz, sample_rate_hz);
47            if ret < 0 {
48                return ret;
49            }
50        }
51
52        self.channel_state[0].n_frames_per_packet = 1;
53        self.n_channels_api = channels;
54        self.n_channels_internal = channels;
55        ret
56    }
57
58    pub fn decode(
59        &mut self,
60        range_dec: &mut RangeCoder,
61        output: &mut [i16],
62        lost_flag: i32,
63        new_packet: bool,
64        payload_size_ms: i32,
65        internal_sample_rate: i32,
66    ) -> i32 {
67        if new_packet {
68            self.channel_state[0].n_frames_decoded = 0;
69            self.channel_state[1].n_frames_decoded = 0;
70        }
71
72        if self.channel_state[0].n_frames_decoded == 0 {
73            match payload_size_ms {
74                0 | 10 => {
75                    self.channel_state[0].n_frames_per_packet = 1;
76                    self.channel_state[0].nb_subfr = 2;
77                }
78                20 => {
79                    self.channel_state[0].n_frames_per_packet = 1;
80                    self.channel_state[0].nb_subfr = MAX_NB_SUBFR as i32;
81                }
82                40 => {
83                    self.channel_state[0].n_frames_per_packet = 2;
84                    self.channel_state[0].nb_subfr = MAX_NB_SUBFR as i32;
85                }
86                60 => {
87                    self.channel_state[0].n_frames_per_packet = 3;
88                    self.channel_state[0].nb_subfr = MAX_NB_SUBFR as i32;
89                }
90                _ => return -1,
91            }
92
93            let fs_khz_dec = (internal_sample_rate >> 10) + 1;
94            if fs_khz_dec != 8 && fs_khz_dec != 12 && fs_khz_dec != 16 {
95                return -1;
96            }
97            let api_sample_rate = self.channel_state[0].fs_api_hz;
98            let ret = silk_decoder_set_fs(&mut self.channel_state[0], fs_khz_dec, api_sample_rate);
99            if ret < 0 {
100                return ret;
101            }
102
103            if payload_size_ms == 10 {
104                self.channel_state[0].nb_subfr = 2;
105                self.channel_state[0].frame_length = self.channel_state[0].subfr_length * 2;
106            }
107        }
108
109        if lost_flag != FLAG_PACKET_LOST && self.channel_state[0].n_frames_decoded == 0 {
110            let n_frames_per_packet = self.channel_state[0].n_frames_per_packet.max(1);
111
112            for i in 0..n_frames_per_packet as usize {
113                let vad = range_dec.decode_bit_logp(1);
114                self.channel_state[0].vad_flags[i] = if vad { 1 } else { 0 };
115            }
116
117            let lbrr = range_dec.decode_bit_logp(1);
118            self.channel_state[0].lbrr_flag = if lbrr { 1 } else { 0 };
119
120            self.channel_state[0].lbrr_flags.fill(0);
121            if self.channel_state[0].lbrr_flag != 0 {
122                if n_frames_per_packet == 1 {
123                    self.channel_state[0].lbrr_flags[0] = 1;
124                } else {
125                    let lbrr_icdf = match n_frames_per_packet {
126                        2 => &SILK_LBRR_FLAGS_2_ICDF[..],
127                        3 => &SILK_LBRR_FLAGS_3_ICDF[..],
128                        _ => &SILK_LBRR_FLAGS_2_ICDF[..],
129                    };
130                    let lbrr_symbol = range_dec.decode_icdf(lbrr_icdf, 8) + 1;
131                    for i in 0..n_frames_per_packet as usize {
132                        self.channel_state[0].lbrr_flags[i] = (lbrr_symbol >> i) & 1;
133                    }
134                }
135            }
136
137            if lost_flag == FLAG_DECODE_NORMAL {
138                for i in 0..n_frames_per_packet as usize {
139                    if self.channel_state[0].lbrr_flags[i] != 0 {
140                        let cond_coding = if i > 0 && self.channel_state[0].lbrr_flags[i - 1] != 0 {
141                            CODE_CONDITIONALLY
142                        } else {
143                            CODE_INDEPENDENTLY
144                        };
145
146                        silk_decode_indices(
147                            &mut self.channel_state[0],
148                            range_dec,
149                            i as i32,
150                            1,
151                            cond_coding,
152                        );
153
154                        let mut pulses = [0i16; MAX_FRAME_LENGTH];
155                        silk_decode_pulses(
156                            range_dec,
157                            &mut pulses,
158                            self.channel_state[0].indices.signal_type as i32,
159                            self.channel_state[0].indices.quant_offset_type as i32,
160                            self.channel_state[0].frame_length,
161                        );
162                    }
163                }
164            }
165        }
166
167        let mut n_samples_out: i32 = 0;
168        let frame_index = self.channel_state[0].n_frames_decoded;
169        let cond_coding = if frame_index == 0 {
170            CODE_INDEPENDENTLY
171        } else {
172            CODE_CONDITIONALLY
173        };
174
175        let channel = &mut self.channel_state[0];
176        let ret = silk_decode_frame(
177            channel,
178            range_dec,
179            output,
180            &mut n_samples_out,
181            lost_flag,
182            cond_coding,
183        );
184
185        channel.n_frames_decoded += 1;
186
187        if ret < 0 { ret } else { n_samples_out }
188    }
189
190    pub fn decode_bytes(&mut self, data: &[u8], output: &mut [i16], new_packet: bool) -> i32 {
191        let mut range_dec = RangeCoder::new_decoder(data);
192        let internal_rate = self.channel_state[0].fs_khz * 1000;
193        let payload_ms = if self.channel_state[0].nb_subfr == 2 {
194            10
195        } else {
196            20
197        };
198        self.decode(
199            &mut range_dec,
200            output,
201            FLAG_DECODE_NORMAL,
202            new_packet,
203            payload_ms,
204            internal_rate,
205        )
206    }
207
208    pub fn reset(&mut self) {
209        silk_init_decoder(&mut self.channel_state[0]);
210        silk_init_decoder(&mut self.channel_state[1]);
211        self.prev_decode_only_middle = 0;
212    }
213
214    pub fn frame_length(&self) -> i32 {
215        self.channel_state[0].frame_length
216    }
217
218    pub fn sample_rate(&self) -> i32 {
219        self.channel_state[0].fs_khz * 1000
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_decoder_creation() {
229        let dec = SilkDecoder::new();
230        assert_eq!(dec.n_channels_api, 1);
231        assert_eq!(dec.n_channels_internal, 1);
232    }
233
234    #[test]
235    fn test_decoder_init() {
236        let mut dec = SilkDecoder::new();
237
238        let ret = dec.init(16000, 1);
239        assert_eq!(ret, 0);
240        assert_eq!(dec.sample_rate(), 16000);
241    }
242
243    #[test]
244    fn test_decoder_16khz() {
245        let mut dec = SilkDecoder::new();
246        let ret = dec.init(16000, 1);
247        assert_eq!(ret, 0);
248        assert_eq!(dec.frame_length(), 320);
249    }
250}