dg_xch_pos/finite_state_entropy/
decompress.rs1use crate::constants::FSE_MAX_SYMBOL_VALUE;
2use crate::finite_state_entropy::bitstream::{highbit_32, BitDstream, BitDstreamStatus};
3use crate::finite_state_entropy::{
4 fse_dtable_size_u32, fse_tablestep, FSE_MAX_TABLELOG, FSE_TABLELOG_ABSOLUTE_MAX,
5};
6use std::io::{Cursor, Error, ErrorKind, Read};
7use std::sync::Arc;
8
9#[derive(Default, Clone)]
10pub struct DTableH {
11 pub table_log: u16,
12 pub fast_mode: u16,
13}
14
15#[derive(Default, Clone)]
16pub struct DTableEntry {
17 pub new_state: u16,
18 pub symbol: u8,
19 pub nb_bits: u8,
20}
21
22#[derive(Default, Clone)]
23pub struct DTable {
24 pub header: DTableH,
25 pub table: Vec<DTableEntry>,
26}
27#[must_use]
28pub fn parse_d_table(bytes: &[u8]) -> DTable {
29 let mut cursor = Cursor::new(bytes);
30 let mut u16_buf: [u8; 2] = [0; 2];
31 let mut u8_buf: [u8; 1] = [0; 1];
32 cursor.read_exact(&mut u16_buf).unwrap();
33 let table_log = u16::from_le_bytes(u16_buf);
34 cursor.read_exact(&mut u16_buf).unwrap();
35 let fast_mode = u16::from_le_bytes(u16_buf);
36 let mut table = vec![];
37 let max_size = fse_dtable_size_u32(u32::from(table_log));
38 for _ in 0..max_size {
39 let new_state = match cursor.read_exact(&mut u16_buf) {
40 Ok(()) => u16::from_le_bytes(u16_buf),
41 Err(_) => 0,
42 };
43 let symbol = match cursor.read_exact(&mut u8_buf) {
44 Ok(()) => u8::from_le_bytes(u8_buf),
45 Err(_) => 0,
46 };
47 let nb_bits = match cursor.read_exact(&mut u8_buf) {
48 Ok(()) => u8::from_le_bytes(u8_buf),
49 Err(_) => 0,
50 };
51 table.push(DTableEntry {
52 new_state,
53 symbol,
54 nb_bits,
55 });
56 }
57 DTable {
58 header: DTableH {
59 table_log,
60 fast_mode,
61 },
62 table,
63 }
64}
65
66pub struct DState {
67 pub state: usize,
68 pub table: Arc<DTable>,
69}
70impl DState {
71 pub fn new(bit_d: &mut BitDstream, dt: Arc<DTable>) -> Self {
72 let state = bit_d.read_bits(u32::from(dt.header.table_log));
73 bit_d.reload();
74 DState { state, table: dt }
75 }
76}
77
78fn create_dtable(table_log: u32) -> DTable {
79 let size = if table_log > FSE_TABLELOG_ABSOLUTE_MAX as u32 {
80 fse_dtable_size_u32(FSE_TABLELOG_ABSOLUTE_MAX as u32)
81 } else {
82 fse_dtable_size_u32(table_log)
83 } as usize;
84 DTable {
85 header: DTableH {
86 table_log: 0,
87 fast_mode: 0,
88 },
89 table: vec![DTableEntry::default(); size],
90 }
91}
92
93#[allow(clippy::cast_possible_truncation)]
94#[allow(clippy::cast_sign_loss)]
95pub fn build_dtable(
96 normalized_counter: &[i16],
97 max_symbol_value: u32,
98 table_log: u32,
99) -> Result<DTable, Error> {
100 let mut dt = create_dtable(table_log);
101 let mut symbol_next = vec![0u16; (FSE_MAX_SYMBOL_VALUE + 1) as usize];
102 let max_sv1 = max_symbol_value + 1;
103 let table_size = 1 << table_log;
104
105 if max_symbol_value > FSE_MAX_SYMBOL_VALUE {
107 return Err(Error::new(
108 ErrorKind::InvalidInput,
109 "max_symbol_value too large",
110 ));
111 }
112 if table_log > FSE_MAX_TABLELOG {
113 return Err(Error::new(ErrorKind::InvalidInput, "table_log too large"));
114 }
115
116 dt.header.table_log = table_log as u16;
118 dt.header.fast_mode = 1;
119 let large_limit = (1 << (table_log - 1)) as i16;
120 let mut high_threshold = table_size - 1;
121 for (index, (normalize, symbol_next)) in normalized_counter
122 .iter()
123 .zip(symbol_next.iter_mut())
124 .enumerate()
125 .take(max_sv1 as usize)
126 {
127 if *normalize == -1 {
128 dt.table[high_threshold as usize].symbol = index as u8;
129 high_threshold -= 1;
130 *symbol_next = 1;
131 } else {
132 if *normalize >= large_limit {
133 dt.header.fast_mode = 0;
134 }
135 *symbol_next = *normalize as u16;
136 }
137 }
138 let table_mask = table_size - 1;
140 let step = fse_tablestep(table_size);
141 let mut position: u32 = 0;
142 for s in 0..max_sv1 {
143 for _ in 0..normalized_counter[s as usize] {
144 dt.table[position as usize].symbol = s as u8;
145 position = (position + step) & table_mask;
146 while position > high_threshold {
147 position = (position + step) & table_mask;
149 }
150 }
151 }
152 if position != 0 {
153 return Err(Error::new(
155 ErrorKind::InvalidInput,
156 "normalized_counter is incorrect",
157 ));
158 }
159 for table in &mut dt.table[0..(table_size as usize)] {
161 let next_state = symbol_next[table.symbol as usize];
162 symbol_next[table.symbol as usize] += 1;
163 table.nb_bits = (table_log - highbit_32(u32::from(next_state))) as u8;
164 table.new_state = (u32::from(next_state << table.nb_bits) - table_size) as u16;
165 }
166 Ok(dt)
167}
168
169pub fn decompress_using_dtable(
170 mut dst: impl AsMut<[u8]>,
171 dst_size: usize,
172 src: impl AsRef<[u8]>,
173 src_size: usize,
174 dt: Arc<DTable>,
175) -> Result<usize, Error> {
176 let fast = dt.header.fast_mode > 0;
177 fse_decompress_using_dtable_generic(dst.as_mut(), dst_size, src.as_ref(), src_size, dt, fast)
178}
179
180trait SymbolFn {
181 fn decode_symbol(&self, state: &mut DState, bit_d: &mut BitDstream) -> u8;
182}
183
184pub fn fse_decompress_using_dtable_generic(
185 dst: &mut [u8],
186 dst_size: usize,
187 src: &[u8],
188 src_size: usize,
189 dt: Arc<DTable>,
190 fast: bool,
191) -> Result<usize, Error> {
192 let mut bit_d = match BitDstream::new(src, src_size) {
193 Ok(b) => b,
194 Err(e) => {
195 return Err(e);
196 }
197 };
198 let mut index = 0;
200 let limit = dst_size - 3;
201 let mut state1 = DState::new(&mut bit_d, dt.clone());
202 let mut state2 = DState::new(&mut bit_d, dt);
203 let symbol_fn: Box<dyn SymbolFn> = if fast {
204 Box::new(FastDecodeSymbol {})
205 } else {
206 Box::new(DecodeSymbol {})
207 };
208 while bit_d.reload().eq(BitDstreamStatus::Unfinished) & (index < limit) {
210 dst[index] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
211 if FSE_MAX_TABLELOG * 2 + 7 > usize::BITS {
212 bit_d.reload();
213 }
214 dst[index + 1] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
215 if FSE_MAX_TABLELOG * 4 + 7 > usize::BITS && bit_d.reload().gt(BitDstreamStatus::Unfinished)
216 {
217 index += 2;
218 break;
219 }
220 dst[index + 2] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
221 if FSE_MAX_TABLELOG * 2 + 7 > usize::BITS {
222 bit_d.reload();
223 }
224 dst[index + 3] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
225 index += 4;
226 }
227 loop {
228 if index > dst_size - 2 {
229 return Err(Error::new(ErrorKind::InvalidInput, "dst_size too small"));
230 }
231 dst[index] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
232 index += 1;
233 if bit_d.reload().eq(BitDstreamStatus::Overflow) {
234 dst[index] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
235 break;
236 }
237 if index > dst_size - 2 {
238 return Err(Error::new(ErrorKind::InvalidInput, "dst_size too small"));
239 }
240 dst[index] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
241 index += 1;
242 if bit_d.reload().eq(BitDstreamStatus::Overflow) {
243 dst[index] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
244 break;
245 }
246 }
247 Ok(index)
248}
249
250pub struct DecodeSymbol {}
251impl SymbolFn for DecodeSymbol {
252 fn decode_symbol(&self, state: &mut DState, bit_d: &mut BitDstream) -> u8 {
253 let entry = &state.table.table[state.state];
254 let low_bits: usize = bit_d.read_bits(u32::from(entry.nb_bits));
255 state.state = entry.new_state as usize + low_bits;
256 entry.symbol
257 }
258}
259
260pub struct FastDecodeSymbol {}
262impl SymbolFn for FastDecodeSymbol {
263 fn decode_symbol(&self, state: &mut DState, bit_d: &mut BitDstream) -> u8 {
264 let entry = &state.table.table[state.state];
265 let low_bits: usize = bit_d.read_bits_fast(u32::from(entry.nb_bits));
266 state.state = entry.new_state as usize + low_bits;
267 entry.symbol
268 }
269}