dg_xch_pos/finite_state_entropy/
decompress.rs

1use crate::constants::FSE_MAX_SYMBOL_VALUE;
2use crate::finite_state_entropy::bitstream::{highbit_32, BitDstream, BitDstreamStatus};
3use crate::finite_state_entropy::{
4    fse_dtable_size_u32, fse_tablestep, FSE_MAX_TABLELOG, FSE_TABLELOG_ABSOLUTE_MAX,
5};
6use std::io::{Cursor, Error, ErrorKind, Read};
7use std::sync::Arc;
8
9#[derive(Default, Clone)]
10pub struct DTableH {
11    pub table_log: u16,
12    pub fast_mode: u16,
13}
14
15#[derive(Default, Clone)]
16pub struct DTableEntry {
17    pub new_state: u16,
18    pub symbol: u8,
19    pub nb_bits: u8,
20}
21
22#[derive(Default, Clone)]
23pub struct DTable {
24    pub header: DTableH,
25    pub table: Vec<DTableEntry>,
26}
27#[must_use]
28pub fn parse_d_table(bytes: &[u8]) -> DTable {
29    let mut cursor = Cursor::new(bytes);
30    let mut u16_buf: [u8; 2] = [0; 2];
31    let mut u8_buf: [u8; 1] = [0; 1];
32    cursor.read_exact(&mut u16_buf).unwrap();
33    let table_log = u16::from_le_bytes(u16_buf);
34    cursor.read_exact(&mut u16_buf).unwrap();
35    let fast_mode = u16::from_le_bytes(u16_buf);
36    let mut table = vec![];
37    let max_size = fse_dtable_size_u32(u32::from(table_log));
38    for _ in 0..max_size {
39        let new_state = match cursor.read_exact(&mut u16_buf) {
40            Ok(()) => u16::from_le_bytes(u16_buf),
41            Err(_) => 0,
42        };
43        let symbol = match cursor.read_exact(&mut u8_buf) {
44            Ok(()) => u8::from_le_bytes(u8_buf),
45            Err(_) => 0,
46        };
47        let nb_bits = match cursor.read_exact(&mut u8_buf) {
48            Ok(()) => u8::from_le_bytes(u8_buf),
49            Err(_) => 0,
50        };
51        table.push(DTableEntry {
52            new_state,
53            symbol,
54            nb_bits,
55        });
56    }
57    DTable {
58        header: DTableH {
59            table_log,
60            fast_mode,
61        },
62        table,
63    }
64}
65
66pub struct DState {
67    pub state: usize,
68    pub table: Arc<DTable>,
69}
70impl DState {
71    pub fn new(bit_d: &mut BitDstream, dt: Arc<DTable>) -> Self {
72        let state = bit_d.read_bits(u32::from(dt.header.table_log));
73        bit_d.reload();
74        DState { state, table: dt }
75    }
76}
77
78fn create_dtable(table_log: u32) -> DTable {
79    let size = if table_log > FSE_TABLELOG_ABSOLUTE_MAX as u32 {
80        fse_dtable_size_u32(FSE_TABLELOG_ABSOLUTE_MAX as u32)
81    } else {
82        fse_dtable_size_u32(table_log)
83    } as usize;
84    DTable {
85        header: DTableH {
86            table_log: 0,
87            fast_mode: 0,
88        },
89        table: vec![DTableEntry::default(); size],
90    }
91}
92
93#[allow(clippy::cast_possible_truncation)]
94#[allow(clippy::cast_sign_loss)]
95pub fn build_dtable(
96    normalized_counter: &[i16],
97    max_symbol_value: u32,
98    table_log: u32,
99) -> Result<DTable, Error> {
100    let mut dt = create_dtable(table_log);
101    let mut symbol_next = vec![0u16; (FSE_MAX_SYMBOL_VALUE + 1) as usize];
102    let max_sv1 = max_symbol_value + 1;
103    let table_size = 1 << table_log;
104
105    /* Sanity Checks */
106    if max_symbol_value > FSE_MAX_SYMBOL_VALUE {
107        return Err(Error::new(
108            ErrorKind::InvalidInput,
109            "max_symbol_value too large",
110        ));
111    }
112    if table_log > FSE_MAX_TABLELOG {
113        return Err(Error::new(ErrorKind::InvalidInput, "table_log too large"));
114    }
115
116    /* Init, lay down lowprob symbols */
117    dt.header.table_log = table_log as u16;
118    dt.header.fast_mode = 1;
119    let large_limit = (1 << (table_log - 1)) as i16;
120    let mut high_threshold = table_size - 1;
121    for (index, (normalize, symbol_next)) in normalized_counter
122        .iter()
123        .zip(symbol_next.iter_mut())
124        .enumerate()
125        .take(max_sv1 as usize)
126    {
127        if *normalize == -1 {
128            dt.table[high_threshold as usize].symbol = index as u8;
129            high_threshold -= 1;
130            *symbol_next = 1;
131        } else {
132            if *normalize >= large_limit {
133                dt.header.fast_mode = 0;
134            }
135            *symbol_next = *normalize as u16;
136        }
137    }
138    /* Spread symbols */
139    let table_mask = table_size - 1;
140    let step = fse_tablestep(table_size);
141    let mut position: u32 = 0;
142    for s in 0..max_sv1 {
143        for _ in 0..normalized_counter[s as usize] {
144            dt.table[position as usize].symbol = s as u8;
145            position = (position + step) & table_mask;
146            while position > high_threshold {
147                /* lowprob area */
148                position = (position + step) & table_mask;
149            }
150        }
151    }
152    if position != 0 {
153        /* position must reach all cells once, otherwise normalizedCounter is incorrect */
154        return Err(Error::new(
155            ErrorKind::InvalidInput,
156            "normalized_counter is incorrect",
157        ));
158    }
159    /* Build Decoding table */
160    for table in &mut dt.table[0..(table_size as usize)] {
161        let next_state = symbol_next[table.symbol as usize];
162        symbol_next[table.symbol as usize] += 1;
163        table.nb_bits = (table_log - highbit_32(u32::from(next_state))) as u8;
164        table.new_state = (u32::from(next_state << table.nb_bits) - table_size) as u16;
165    }
166    Ok(dt)
167}
168
169pub fn decompress_using_dtable(
170    mut dst: impl AsMut<[u8]>,
171    dst_size: usize,
172    src: impl AsRef<[u8]>,
173    src_size: usize,
174    dt: Arc<DTable>,
175) -> Result<usize, Error> {
176    let fast = dt.header.fast_mode > 0;
177    fse_decompress_using_dtable_generic(dst.as_mut(), dst_size, src.as_ref(), src_size, dt, fast)
178}
179
180trait SymbolFn {
181    fn decode_symbol(&self, state: &mut DState, bit_d: &mut BitDstream) -> u8;
182}
183
184pub fn fse_decompress_using_dtable_generic(
185    dst: &mut [u8],
186    dst_size: usize,
187    src: &[u8],
188    src_size: usize,
189    dt: Arc<DTable>,
190    fast: bool,
191) -> Result<usize, Error> {
192    let mut bit_d = match BitDstream::new(src, src_size) {
193        Ok(b) => b,
194        Err(e) => {
195            return Err(e);
196        }
197    };
198    /* Init */
199    let mut index = 0;
200    let limit = dst_size - 3;
201    let mut state1 = DState::new(&mut bit_d, dt.clone());
202    let mut state2 = DState::new(&mut bit_d, dt);
203    let symbol_fn: Box<dyn SymbolFn> = if fast {
204        Box::new(FastDecodeSymbol {})
205    } else {
206        Box::new(DecodeSymbol {})
207    };
208    /* 4 symbols per loop */
209    while bit_d.reload().eq(BitDstreamStatus::Unfinished) & (index < limit) {
210        dst[index] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
211        if FSE_MAX_TABLELOG * 2 + 7 > usize::BITS {
212            bit_d.reload();
213        }
214        dst[index + 1] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
215        if FSE_MAX_TABLELOG * 4 + 7 > usize::BITS && bit_d.reload().gt(BitDstreamStatus::Unfinished)
216        {
217            index += 2;
218            break;
219        }
220        dst[index + 2] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
221        if FSE_MAX_TABLELOG * 2 + 7 > usize::BITS {
222            bit_d.reload();
223        }
224        dst[index + 3] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
225        index += 4;
226    }
227    loop {
228        if index > dst_size - 2 {
229            return Err(Error::new(ErrorKind::InvalidInput, "dst_size too small"));
230        }
231        dst[index] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
232        index += 1;
233        if bit_d.reload().eq(BitDstreamStatus::Overflow) {
234            dst[index] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
235            break;
236        }
237        if index > dst_size - 2 {
238            return Err(Error::new(ErrorKind::InvalidInput, "dst_size too small"));
239        }
240        dst[index] = symbol_fn.decode_symbol(&mut state2, &mut bit_d);
241        index += 1;
242        if bit_d.reload().eq(BitDstreamStatus::Overflow) {
243            dst[index] = symbol_fn.decode_symbol(&mut state1, &mut bit_d);
244            break;
245        }
246    }
247    Ok(index)
248}
249
250pub struct DecodeSymbol {}
251impl SymbolFn for DecodeSymbol {
252    fn decode_symbol(&self, state: &mut DState, bit_d: &mut BitDstream) -> u8 {
253        let entry = &state.table.table[state.state];
254        let low_bits: usize = bit_d.read_bits(u32::from(entry.nb_bits));
255        state.state = entry.new_state as usize + low_bits;
256        entry.symbol
257    }
258}
259
260// FSE_decodeSymbolFast():unsafe, only works if no symbol has a probability > 50%
261pub struct FastDecodeSymbol {}
262impl SymbolFn for FastDecodeSymbol {
263    fn decode_symbol(&self, state: &mut DState, bit_d: &mut BitDstream) -> u8 {
264        let entry = &state.table.table[state.state];
265        let low_bits: usize = bit_d.read_bits_fast(u32::from(entry.nb_bits));
266        state.state = entry.new_state as usize + low_bits;
267        entry.symbol
268    }
269}