1use std::io::{self, Write};
2
3use super::{Flags, Model, RangeCoder};
4use crate::io::writer::num::{write_u8, write_uint7};
5
6pub fn encode(mut flags: Flags, src: &[u8]) -> io::Result<Vec<u8>> {
7 use crate::codecs::rans_nx16::encode::bit_pack;
8
9 let mut src = src.to_vec();
10 let mut dst = Vec::new();
11
12 write_u8(&mut dst, u8::from(flags))?;
13
14 if !flags.contains(Flags::NO_SIZE) {
15 let ulen =
16 u32::try_from(src.len()).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
17 write_uint7(&mut dst, ulen)?;
18 }
19
20 if flags.contains(Flags::STRIPE) {
21 let buf = encode_stripe(&src)?;
22 dst.extend(buf);
23 return Ok(dst);
24 }
25
26 if flags.contains(Flags::PACK) {
27 match bit_pack::build_context(&src) {
28 Ok(ctx) => {
29 src = bit_pack::encode(&src, &ctx);
30 bit_pack::write_context(&mut dst, &ctx, src.len())?;
31 }
32 Err(
33 bit_pack::context::BuildContextError::EmptyAlphabet
34 | bit_pack::context::BuildContextError::TooManySymbols(_),
35 ) => {
36 flags.remove(Flags::PACK);
37 dst[0] = u8::from(flags);
38 }
39 }
40 }
41
42 if flags.contains(Flags::CAT) {
43 dst.write_all(&src)?;
44 } else if flags.contains(Flags::EXT) {
45 encode_ext(&src, &mut dst)?;
46 } else if flags.contains(Flags::RLE) {
47 if flags.contains(Flags::ORDER) {
48 encode_rle_1(&src, &mut dst)?;
49 } else {
50 encode_rle_0(&src, &mut dst)?;
51 }
52 } else if flags.contains(Flags::ORDER) {
53 encode_order_1(&src, &mut dst)?;
54 } else {
55 encode_order_0(&src, &mut dst)?;
56 }
57
58 Ok(dst)
59}
60
61fn encode_stripe(src: &[u8]) -> io::Result<Vec<u8>> {
62 const N: usize = 4;
63
64 let mut ulens = Vec::with_capacity(N);
65 let mut t = Vec::with_capacity(N);
66
67 for j in 0..N {
68 let mut ulen = src.len() / N;
69
70 if src.len() % N > j {
71 ulen += 1;
72 }
73
74 let chunk = vec![0; ulen];
75
76 ulens.push(ulen);
77 t.push(chunk);
78 }
79
80 let mut x = 0;
81 let mut i = 0;
82
83 while i < src.len() {
84 for j in 0..N {
85 if x < ulens[j] {
86 t[j][x] = src[i + j];
87 }
88 }
89
90 x += 1;
91 i += N;
92 }
93
94 let mut chunks = vec![Vec::new(); N];
95
96 for (chunk, s) in chunks.iter_mut().zip(t.iter()) {
97 *chunk = encode(Flags::empty(), s)?;
98 }
99
100 let mut dst = Vec::new();
101
102 write_u8(&mut dst, N as u8)?;
103
104 for chunk in &chunks {
105 let clen = chunk.len() as u32;
106 write_uint7(&mut dst, clen)?;
107 }
108
109 for chunk in &chunks {
110 dst.write_all(chunk)?;
111 }
112
113 Ok(dst)
114}
115
116fn encode_ext(src: &[u8], dst: &mut Vec<u8>) -> io::Result<()> {
117 use bzip2::write::BzEncoder;
118
119 let mut encoder = BzEncoder::new(dst, Default::default());
120 encoder.write_all(src)?;
121 encoder.finish()?;
122
123 Ok(())
124}
125
126fn encode_rle_0(src: &[u8], dst: &mut Vec<u8>) -> io::Result<()> {
127 let max_sym = src.iter().max().copied().unwrap_or(0);
128 write_u8(dst, max_sym.overflowing_add(1).0)?;
129
130 let mut model_lit = Model::new(max_sym);
131 let mut model_run = vec![Model::new(3); 258];
132
133 let mut range_coder = RangeCoder::default();
134
135 let mut i = 0;
136
137 while i < src.len() {
138 let sym = src[i];
139 model_lit.encode(dst, &mut range_coder, sym)?;
140
141 let mut run = src[i + 1..].iter().position(|&s| s != sym).unwrap_or(0);
142 i += run + 1;
143
144 let mut rctx = usize::from(sym);
145
146 let mut part = run.min(3);
147 model_run[rctx].encode(dst, &mut range_coder, part as u8)?;
148 rctx = 256;
149 run -= part;
150
151 while part == 3 {
152 part = run.min(3);
153 model_run[rctx].encode(dst, &mut range_coder, part as u8)?;
154 rctx = 257;
155 run -= part;
156 }
157 }
158
159 range_coder.range_encode_end(dst)?;
160
161 Ok(())
162}
163
164fn encode_rle_1(src: &[u8], dst: &mut Vec<u8>) -> io::Result<()> {
165 let max_sym = src.iter().max().copied().unwrap_or(0);
166 write_u8(dst, max_sym.overflowing_add(1).0)?;
167
168 let model_lit_count = usize::from(max_sym) + 1;
169 let mut model_lit = vec![Model::new(max_sym); model_lit_count];
170 let mut model_run = vec![Model::new(3); 258];
171
172 let mut range_coder = RangeCoder::default();
173
174 let mut i = 0;
175 let mut last = 0;
176
177 while i < src.len() {
178 let sym = src[i];
179 model_lit[last].encode(dst, &mut range_coder, sym)?;
180
181 let mut run = src[i + 1..].iter().position(|&s| s != sym).unwrap_or(0);
182 i += run + 1;
183
184 let mut rctx = usize::from(sym);
185 last = usize::from(sym);
186
187 let mut part = run.min(3);
188 model_run[rctx].encode(dst, &mut range_coder, part as u8)?;
189 rctx = 256;
190 run -= part;
191
192 while part == 3 {
193 part = run.min(3);
194 model_run[rctx].encode(dst, &mut range_coder, part as u8)?;
195 rctx = 257;
196 run -= part;
197 }
198 }
199
200 range_coder.range_encode_end(dst)?;
201
202 Ok(())
203}
204
205fn encode_order_0(src: &[u8], dst: &mut Vec<u8>) -> io::Result<()> {
206 let max_sym = src.iter().max().copied().unwrap_or(0);
207 write_u8(dst, max_sym.overflowing_add(1).0)?;
208
209 let mut model = Model::new(max_sym);
210 let mut range_coder = RangeCoder::default();
211
212 for &sym in src {
213 model.encode(dst, &mut range_coder, sym)?;
214 }
215
216 range_coder.range_encode_end(dst)?;
217
218 Ok(())
219}
220
221fn encode_order_1(src: &[u8], dst: &mut Vec<u8>) -> io::Result<()> {
222 let max_sym = src.iter().max().copied().unwrap_or(0);
223 write_u8(dst, max_sym.overflowing_add(1).0)?;
224
225 let model_count = usize::from(max_sym) + 1;
226 let mut models = vec![Model::new(max_sym); model_count];
227
228 let mut range_coder = RangeCoder::default();
229
230 models[0].encode(dst, &mut range_coder, src[0])?;
231
232 for window in src.windows(2) {
233 let sym_0 = usize::from(window[0]);
234 let sym_1 = window[1];
235 models[sym_0].encode(dst, &mut range_coder, sym_1)?;
236 }
237
238 range_coder.range_encode_end(dst)?;
239
240 Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_encode_ext() -> io::Result<()> {
249 use crate::codecs::bzip2;
250
251 let actual = encode(Flags::EXT, b"noodles")?;
252
253 let mut expected = vec![0x04, 0x07];
254
255 let compression_level = ::bzip2::Compression::default();
256 let data = bzip2::encode(compression_level, b"noodles")?;
257 expected.extend(data);
258
259 assert_eq!(actual, expected);
260
261 Ok(())
262 }
263
264 #[test]
265 fn test_encode_stripe() -> io::Result<()> {
266 let actual = encode(Flags::STRIPE, b"noodles")?;
267
268 let expected = [
269 0x08, 0x07, 0x04, 0x09, 0x09, 0x09, 0x08, 0x00, 0x02, 0x6f, 0x00, 0xff, 0xa7, 0xab,
270 0x62, 0x00, 0x00, 0x02, 0x70, 0x00, 0xff, 0x84, 0x92, 0x1b, 0x00, 0x00, 0x02, 0x74,
271 0x00, 0xf7, 0x27, 0xdb, 0x24, 0x00, 0x00, 0x01, 0x65, 0x00, 0xfd, 0x77, 0x20, 0xb0,
272 ];
273
274 assert_eq!(actual, expected);
275
276 Ok(())
277 }
278
279 #[test]
280 fn test_encode_order_0() -> io::Result<()> {
281 let actual = encode(Flags::empty(), b"noodles")?;
282
283 let expected = [
284 0x00, 0x07, 0x74, 0x00, 0xf4, 0xe5, 0xb7, 0x4e, 0x50, 0x0f, 0x2e, 0x97, 0x00,
285 ];
286
287 assert_eq!(actual, expected);
288
289 Ok(())
290 }
291
292 #[test]
293 fn test_encode_order_1() -> io::Result<()> {
294 let actual = encode(Flags::ORDER, b"noodles")?;
295
296 let expected = [
297 0x01, 0x07, 0x74, 0x00, 0xf4, 0xe3, 0x83, 0x41, 0xe2, 0x9a, 0xef, 0x53, 0x50, 0x00,
298 ];
299
300 assert_eq!(actual, expected);
301
302 Ok(())
303 }
304
305 #[test]
306 fn test_encode_cat() -> io::Result<()> {
307 let actual = encode(Flags::CAT, b"noodles")?;
308 let expected = [0x20, 0x07, 0x6e, 0x6f, 0x6f, 0x64, 0x6c, 0x65, 0x73];
309 assert_eq!(actual, expected);
310 Ok(())
311 }
312
313 #[test]
314 fn test_encode_rle_with_order_0() -> io::Result<()> {
315 let actual = encode(Flags::RLE, b"noooooooodles")?;
316
317 let expected = [
318 0x40, 0x0d, 0x74, 0x00, 0xf3, 0x4b, 0x21, 0x10, 0xa8, 0xe3, 0x84, 0xfe, 0x6b, 0x22,
319 0x00,
320 ];
321
322 assert_eq!(actual, expected);
323
324 Ok(())
325 }
326
327 #[test]
328 fn test_encode_rle_with_order_1() -> io::Result<()> {
329 let actual = encode(Flags::ORDER | Flags::RLE, b"noooooooodles")?;
330
331 let expected = [
332 0x41, 0x0d, 0x74, 0x00, 0xf3, 0x4a, 0x89, 0x79, 0xc1, 0xe8, 0xc3, 0xc5, 0x62, 0x31,
333 0x00,
334 ];
335
336 assert_eq!(actual, expected);
337
338 Ok(())
339 }
340
341 #[test]
342 fn test_encode_pack() -> io::Result<()> {
343 let actual = encode(Flags::CAT | Flags::PACK, b"noodles")?;
344
345 let expected = [
346 0xa0, 0x07, 0x06, 0x64, 0x65, 0x6c, 0x6e, 0x6f, 0x73, 0x04, 0x43, 0x04, 0x12, 0x05,
347 ];
348
349 assert_eq!(actual, expected);
350
351 Ok(())
352 }
353}