Skip to main content

opus_rs/silk/
dec_api.rs

1use crate::range_coder::RangeCoder;
2use crate::silk::decode_frame::{FLAG_DECODE_NORMAL, FLAG_PACKET_LOST, silk_decode_frame};
3use crate::silk::decode_indices::{
4    silk_decode_indices, silk_stereo_decode_mid_only, silk_stereo_decode_pred,
5};
6use crate::silk::decode_pulses::silk_decode_pulses;
7use crate::silk::decoder_structs::SilkDecoderState;
8use crate::silk::define::*;
9use crate::silk::init_decoder::{silk_decoder_set_fs, silk_init_decoder};
10use crate::silk::tables::{SILK_LBRR_FLAGS_2_ICDF, SILK_LBRR_FLAGS_3_ICDF};
11
12pub struct SilkDecoder {
13    pub channel_state: [SilkDecoderState; 2],
14
15    pub n_channels_api: i32,
16
17    pub n_channels_internal: i32,
18
19    pub prev_decode_only_middle: i32,
20}
21
22impl Default for SilkDecoder {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl SilkDecoder {
29    pub fn new() -> Self {
30        let mut dec = Self {
31            channel_state: [SilkDecoderState::default(), SilkDecoderState::default()],
32            n_channels_api: 1,
33            n_channels_internal: 1,
34            prev_decode_only_middle: 0,
35        };
36        silk_init_decoder(&mut dec.channel_state[0]);
37        silk_init_decoder(&mut dec.channel_state[1]);
38        dec
39    }
40
41    pub fn init(&mut self, sample_rate_hz: i32, channels: i32) -> i32 {
42        let fs_khz = sample_rate_hz / 1000;
43        let ret = silk_decoder_set_fs(&mut self.channel_state[0], fs_khz, sample_rate_hz);
44        if ret < 0 {
45            return ret;
46        }
47        if channels == 2 {
48            let ret = silk_decoder_set_fs(&mut self.channel_state[1], fs_khz, sample_rate_hz);
49            if ret < 0 {
50                return ret;
51            }
52        }
53
54        self.channel_state[0].n_frames_per_packet = 1;
55        self.n_channels_api = channels;
56        self.n_channels_internal = channels;
57        ret
58    }
59
60    pub fn decode(
61        &mut self,
62        range_dec: &mut RangeCoder,
63        output: &mut [i16],
64        lost_flag: i32,
65        new_packet: bool,
66        payload_size_ms: i32,
67        internal_sample_rate: i32,
68    ) -> i32 {
69        if new_packet {
70            self.channel_state[0].n_frames_decoded = 0;
71            self.channel_state[1].n_frames_decoded = 0;
72        }
73
74        if self.channel_state[0].n_frames_decoded == 0 {
75            match payload_size_ms {
76                0 | 10 => {
77                    self.channel_state[0].n_frames_per_packet = 1;
78                    self.channel_state[0].nb_subfr = 2;
79                }
80                20 => {
81                    self.channel_state[0].n_frames_per_packet = 1;
82                    self.channel_state[0].nb_subfr = MAX_NB_SUBFR as i32;
83                }
84                40 => {
85                    self.channel_state[0].n_frames_per_packet = 2;
86                    self.channel_state[0].nb_subfr = MAX_NB_SUBFR as i32;
87                }
88                60 => {
89                    self.channel_state[0].n_frames_per_packet = 3;
90                    self.channel_state[0].nb_subfr = MAX_NB_SUBFR as i32;
91                }
92                _ => return -1,
93            }
94
95            let fs_khz_dec = (internal_sample_rate >> 10) + 1;
96            if fs_khz_dec != 8 && fs_khz_dec != 12 && fs_khz_dec != 16 {
97                return -1;
98            }
99            let api_sample_rate = self.channel_state[0].fs_api_hz;
100            let ret = silk_decoder_set_fs(&mut self.channel_state[0], fs_khz_dec, api_sample_rate);
101            if ret < 0 {
102                return ret;
103            }
104
105            if payload_size_ms == 10 {
106                self.channel_state[0].nb_subfr = 2;
107                self.channel_state[0].frame_length = self.channel_state[0].subfr_length * 2;
108            }
109        }
110
111        if lost_flag != FLAG_PACKET_LOST && self.channel_state[0].n_frames_decoded == 0 {
112            let n_frames_per_packet = self.channel_state[0].n_frames_per_packet.max(1);
113            let n_channels = self.n_channels_internal as usize;
114
115            for n in 0..n_channels {
116                for i in 0..n_frames_per_packet as usize {
117                    let vad = range_dec.decode_bit_logp(1);
118                    self.channel_state[n].vad_flags[i] = if vad { 1 } else { 0 };
119                }
120                let lbrr = range_dec.decode_bit_logp(1);
121                self.channel_state[n].lbrr_flag = if lbrr { 1 } else { 0 };
122            }
123
124            for n in 0..n_channels {
125                self.channel_state[n].lbrr_flags.fill(0);
126                if self.channel_state[n].lbrr_flag != 0 {
127                    if n_frames_per_packet == 1 {
128                        self.channel_state[n].lbrr_flags[0] = 1;
129                    } else {
130                        let lbrr_icdf = match n_frames_per_packet {
131                            2 => &SILK_LBRR_FLAGS_2_ICDF[..],
132                            3 => &SILK_LBRR_FLAGS_3_ICDF[..],
133                            _ => &SILK_LBRR_FLAGS_2_ICDF[..],
134                        };
135                        let lbrr_symbol = range_dec.decode_icdf(lbrr_icdf, 8) + 1;
136                        for i in 0..n_frames_per_packet as usize {
137                            self.channel_state[n].lbrr_flags[i] = (lbrr_symbol >> i) & 1;
138                        }
139                    }
140                }
141            }
142
143            // Skip LBRR data for all frames/channels
144            if lost_flag == FLAG_DECODE_NORMAL {
145                for i in 0..n_frames_per_packet as usize {
146                    for n in 0..n_channels {
147                        if self.channel_state[n].lbrr_flags[i] != 0 {
148                            if n_channels == 2 && n == 0 {
149                                silk_stereo_decode_pred(range_dec);
150                                if self.channel_state[1].lbrr_flags[i] == 0 {
151                                    silk_stereo_decode_mid_only(range_dec);
152                                }
153                            }
154                            let cond_coding =
155                                if i > 0 && self.channel_state[n].lbrr_flags[i - 1] != 0 {
156                                    CODE_CONDITIONALLY
157                                } else {
158                                    CODE_INDEPENDENTLY
159                                };
160                            silk_decode_indices(
161                                &mut self.channel_state[n],
162                                range_dec,
163                                i as i32,
164                                1,
165                                cond_coding,
166                            );
167                            let mut pulses = [0i16; MAX_FRAME_LENGTH];
168                            silk_decode_pulses(
169                                range_dec,
170                                &mut pulses,
171                                self.channel_state[n].indices.signal_type as i32,
172                                self.channel_state[n].indices.quant_offset_type as i32,
173                                self.channel_state[n].frame_length,
174                            );
175                        }
176                    }
177                }
178            }
179        }
180
181        let frame_index = self.channel_state[0].n_frames_decoded as usize;
182        if self.n_channels_internal == 2 && lost_flag == FLAG_DECODE_NORMAL {
183            silk_stereo_decode_pred(range_dec);
184            if self.channel_state[1].vad_flags[frame_index] == 0 {
185                silk_stereo_decode_mid_only(range_dec);
186            }
187        }
188        let cond_coding = if frame_index == 0 {
189            CODE_INDEPENDENTLY
190        } else {
191            CODE_CONDITIONALLY
192        };
193
194        let mut n_samples_out: i32 = 0;
195        let channel = &mut self.channel_state[0];
196        let ret = silk_decode_frame(
197            channel,
198            range_dec,
199            output,
200            &mut n_samples_out,
201            lost_flag,
202            cond_coding,
203        );
204
205        channel.n_frames_decoded += 1;
206
207        if ret < 0 { ret } else { n_samples_out }
208    }
209
210    pub fn decode_bytes(&mut self, data: &[u8], output: &mut [i16], new_packet: bool) -> i32 {
211        let mut range_dec = RangeCoder::new_decoder(data);
212        let internal_rate = self.channel_state[0].fs_khz * 1000;
213        let payload_ms = if self.channel_state[0].nb_subfr == 2 {
214            10
215        } else {
216            20
217        };
218        self.decode(
219            &mut range_dec,
220            output,
221            FLAG_DECODE_NORMAL,
222            new_packet,
223            payload_ms,
224            internal_rate,
225        )
226    }
227
228    pub fn reset(&mut self) {
229        silk_init_decoder(&mut self.channel_state[0]);
230        silk_init_decoder(&mut self.channel_state[1]);
231        self.prev_decode_only_middle = 0;
232    }
233
234    pub fn frame_length(&self) -> i32 {
235        self.channel_state[0].frame_length
236    }
237
238    pub fn sample_rate(&self) -> i32 {
239        self.channel_state[0].fs_khz * 1000
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_decoder_creation() {
249        let dec = SilkDecoder::new();
250        assert_eq!(dec.n_channels_api, 1);
251        assert_eq!(dec.n_channels_internal, 1);
252    }
253
254    #[test]
255    fn test_decoder_init() {
256        let mut dec = SilkDecoder::new();
257
258        let ret = dec.init(16000, 1);
259        assert_eq!(ret, 0);
260        assert_eq!(dec.sample_rate(), 16000);
261    }
262
263    #[test]
264    fn test_decoder_16khz() {
265        let mut dec = SilkDecoder::new();
266        let ret = dec.init(16000, 1);
267        assert_eq!(ret, 0);
268        assert_eq!(dec.frame_length(), 320);
269    }
270}