1use {
2 firestorm::{profile_fn, profile_method},
3 std::{
4 convert::{Infallible, TryInto},
5 error::Error,
6 fmt,
7 hint::unreachable_unchecked,
8 },
9};
10
11struct Case {
12 count: usize,
13 bits: [u32; 28],
14}
15
16const CASES: [Case; 16] = [
17 Case {
18 count: 28,
19 bits: [
20 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
21 ],
22 },
23 Case {
24 count: 21,
25 bits: [
26 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
27 ],
28 },
29 Case {
30 count: 21,
31 bits: [
32 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
33 ],
34 },
35 Case {
36 count: 21,
37 bits: [
38 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0,
39 ],
40 },
41 Case {
42 count: 14,
43 bits: [
44 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
45 ],
46 },
47 Case {
48 count: 9,
49 bits: [
50 4, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
51 ],
52 },
53 Case {
54 count: 8,
55 bits: [
56 3, 4, 4, 4, 4, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
57 ],
58 },
59 Case {
60 count: 7,
61 bits: [
62 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
63 ],
64 },
65 Case {
66 count: 6,
67 bits: [
68 5, 5, 5, 5, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
69 ],
70 },
71 Case {
72 count: 6,
73 bits: [
74 4, 4, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
75 ],
76 },
77 Case {
78 count: 5,
79 bits: [
80 6, 6, 6, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81 ],
82 },
83 Case {
84 count: 5,
85 bits: [
86 5, 5, 6, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87 ],
88 },
89 Case {
90 count: 4,
91 bits: [
92 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
93 ],
94 },
95 Case {
96 count: 3,
97 bits: [
98 10, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
99 ],
100 },
101 Case {
102 count: 2,
103 bits: [
104 14, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
105 ],
106 },
107 Case {
108 count: 1,
109 bits: [
110 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
111 ],
112 },
113];
114
115#[derive(Debug, Eq, PartialEq, Copy, Clone)]
116pub struct ValueOutOfRange(());
117
118impl Error for ValueOutOfRange {}
119
120impl fmt::Display for ValueOutOfRange {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 write!(
123 f,
124 "Value out of range for simple16. Maximum value is 268435455"
125 )
126 }
127}
128
129fn pack<T: Simple16>(values: &[T]) -> (u32, usize) {
130 unsafe {
131 let mut i = 0;
132 'try_again: loop {
133 let mut value = i << 28;
134 let Case { mut count, bits } = CASES.get_unchecked(i as usize);
135 count = count.min(values.len());
136 let mut packed = 0;
137 for j in 0..count {
138 let v = values.get_unchecked(j).as_();
139 let bits_j = *bits.get_unchecked(j);
140 if v >= 1 << bits_j {
141 i += 1;
142 continue 'try_again;
143 }
144 value |= v << packed;
145 packed += bits_j;
146 }
147 return (value, count);
148 }
149 }
150}
151
152fn consume<T: Simple16>(values: &[T]) -> usize {
153 unsafe {
154 let mut i = 0;
155 'try_again: loop {
156 let Case { mut count, bits } = CASES.get_unchecked(i as usize);
157 count = count.min(values.len());
158
159 for j in 0..count {
160 let values_j = values.get_unchecked(j).as_();
161 if values_j >= (1u32 << bits.get_unchecked(j)) {
162 i += 1;
163 continue 'try_again;
164 }
165 }
166 return count;
167 }
168 }
169}
170
171impl From<Infallible> for ValueOutOfRange {
172 #[inline(always)]
173 fn from(_: Infallible) -> Self {
174 unsafe { unreachable_unchecked() }
175 }
176}
177
178pub const MAX: u32 = 268435455;
179
180pub unsafe trait Simple16: Sized + Copy {
182 fn check(data: &[Self]) -> Result<(), ValueOutOfRange>;
183 fn as_(self) -> u32;
184}
185
186unsafe impl Simple16 for u32 {
187 fn check(data: &[Self]) -> Result<(), ValueOutOfRange> {
188 profile_method!(check);
189 for &value in data {
190 if value > MAX {
191 return Err(ValueOutOfRange(()));
192 }
193 }
194 Ok(())
195 }
196 #[inline(always)]
197 fn as_(self) -> u32 {
198 self
199 }
200}
201unsafe impl Simple16 for u64 {
202 fn check(data: &[Self]) -> Result<(), ValueOutOfRange> {
203 profile_method!(check);
204 for &value in data {
205 if value > MAX as u64 {
206 return Err(ValueOutOfRange(()));
207 }
208 }
209 Ok(())
210 }
211 #[inline(always)]
212 fn as_(self) -> u32 {
213 self as u32
214 }
215}
216
217unsafe impl Simple16 for u16 {
218 #[inline(always)]
219 fn check(_data: &[Self]) -> Result<(), ValueOutOfRange> {
220 Ok(())
221 }
222 #[inline(always)]
223 fn as_(self) -> u32 {
224 self as u32
225 }
226}
227
228unsafe impl Simple16 for u8 {
229 #[inline(always)]
230 fn check(_data: &[Self]) -> Result<(), ValueOutOfRange> {
231 Ok(())
232 }
233 #[inline(always)]
234 fn as_(self) -> u32 {
235 self as u32
236 }
237}
238
239pub fn calculate_size<T: Simple16>(data: &[T]) -> Result<usize, ValueOutOfRange> {
241 profile_fn!(calculate_size);
242 T::check(data)?;
243 let size = unsafe { calculate_size_unchecked(data) };
244 Ok(size)
245}
246
247pub unsafe fn calculate_size_unchecked<T: Simple16>(mut data: &[T]) -> usize {
248 let mut size = 0;
249 while data.len() > 0 {
250 let advanced = consume(data);
251 data = &data[advanced..];
252 size += 4;
253 }
254
255 size
256}
257
258pub unsafe fn compress_unchecked<T: Simple16>(mut values: &[T], out: &mut Vec<u8>) {
259 while values.len() > 0 {
260 let (next, advanced) = pack(values);
261 values = &values[advanced..];
262 out.extend_from_slice(&next.to_le_bytes());
263 }
264}
265
266pub fn compress<T: Simple16>(values: &[T], out: &mut Vec<u8>) -> Result<(), ValueOutOfRange> {
268 profile_fn!(compress);
269 T::check(values)?;
270 unsafe { compress_unchecked(values, out) }
271
272 Ok(())
273}
274
275pub fn decompress(bytes: &[u8], out: &mut Vec<u32>) -> Result<(), ()> {
277 profile_fn!(decompress);
278 if bytes.len() % 4 != 0 {
279 return Err(());
280 }
281 let mut offset = 0;
282 while offset < bytes.len() {
283 let start = offset;
284 offset += 4;
285 let slice = &bytes[start..offset];
286 let next = u32::from_le_bytes(slice.try_into().unwrap());
287 let num_idx = (next >> 28) as usize;
288 let Case { count, bits } = unsafe { CASES.get_unchecked(num_idx) };
289 let count = *count;
290 let mut j = 0;
291 let mut unpacked = 0;
292 while j < count {
293 let bits_j = unsafe { bits.get_unchecked(j) };
294 let value = (next >> unpacked) & (0xffffffff >> (32 - bits_j));
295 out.push(value);
296 unpacked += bits_j;
297 j += 1;
298 }
299 }
300
301 Ok(())
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use std::convert::TryFrom;
308
309 fn round_trip<T: Simple16 + TryFrom<u32> + std::fmt::Debug + Eq>(data: &[T]) {
310 let mut bytes = Vec::new();
311 compress(&data, &mut bytes).unwrap_or_else(|_| todo!());
312 assert_eq!(
313 calculate_size(&data).unwrap_or_else(|_| todo!()),
314 bytes.len()
315 );
316 let mut out = Vec::new();
317 decompress(&bytes, &mut out).unwrap_or_else(|_| todo!());
318 let out: Vec<_> = out
319 .into_iter()
320 .map(|o| o.try_into().unwrap_or_else(|_| panic!("round trip failed")))
321 .collect();
322
323 assert_eq!(data, &out[..data.len()]);
324 }
325
326 #[test]
327 fn t1() {
328 let i = &[1u32, 5, 18, 99, 2023, 289981, 223389999];
329 round_trip(i);
330 let i = &[1u16, 5, 18, 99, 2023, u16::MAX];
331 round_trip(i);
332 let i = &[1u8, 5, 18, 99, u8::MAX];
333 round_trip(i);
334 }
335
336 #[test]
337 fn t2() {
338 let i = &[1u32];
339 round_trip(i);
340 }
341
342 #[test]
343 fn t3() {
344 let i = &[0u32, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0];
345 round_trip(i);
346 }
347
348 #[test]
349 fn too_large_is_err() {
350 assert!(compress(&[u32::MAX], &mut Vec::new()).is_err());
351 assert!(compress(&[MAX + 1], &mut Vec::new()).is_err());
352 }
353
354 #[test]
355 #[ignore = "Takes a while"]
356 fn check_all() {
357 let mut v = Vec::new();
358 for i in 0..MAX {
359 let data = &[i];
360 if compress(&data[..], &mut v).is_err() {
361 panic!("{}", i);
362 }
363 v.clear();
364 }
365 }
366}