1mod order_0;
2mod order_1;
3mod rle;
4mod stripe;
5
6use std::{
7 borrow::Cow,
8 io::{self, Write},
9 num::NonZero,
10};
11
12use super::Flags;
13use crate::io::writer::num::{write_u8, write_uint7};
14
15pub fn encode(mut flags: Flags, src: &[u8]) -> io::Result<Vec<u8>> {
16 use crate::codecs::rans_nx16::encode::bit_pack;
17
18 let mut src = Cow::from(src);
19 let mut dst = Vec::new();
20
21 write_flags(&mut dst, flags)?;
22
23 if flags.has_uncompressed_size() {
24 write_uncompressed_size(&mut dst, src.len())?;
25 }
26
27 if flags.is_striped() {
28 let buf = stripe::encode(&src)?;
29 dst.extend(buf);
30 return Ok(dst);
31 }
32
33 if flags.is_bit_packed() {
34 match bit_pack::build_context(&src) {
35 Ok(ctx) => {
36 src = Cow::from(bit_pack::encode(&src, &ctx));
37 bit_pack::write_context(&mut dst, &ctx, src.len())?;
38 }
39 Err(
40 bit_pack::context::BuildContextError::EmptyAlphabet
41 | bit_pack::context::BuildContextError::TooManySymbols(_),
42 ) => {
43 flags.remove(Flags::PACK);
44 dst[0] = u8::from(flags);
45 }
46 }
47 }
48
49 if flags.is_uncompressed() {
50 dst.write_all(&src)?;
51 } else if flags.uses_external_codec() {
52 encode_ext(&src, &mut dst)?;
53 } else if flags.is_rle() {
54 rle::encode(&src, flags, &mut dst)?;
55 } else if flags.order() == 0 {
56 order_0::encode(&src, &mut dst)?;
57 } else {
58 order_1::encode(&src, &mut dst)?;
59 }
60
61 Ok(dst)
62}
63
64fn write_flags(dst: &mut Vec<u8>, flags: Flags) -> io::Result<()> {
65 write_u8(dst, u8::from(flags))
66}
67
68fn write_uncompressed_size(dst: &mut Vec<u8>, uncompressed_size: usize) -> io::Result<()> {
69 let n = u32::try_from(uncompressed_size)
70 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
71
72 write_uint7(dst, n)
73}
74
75fn count_symbols(src: &[u8]) -> NonZero<usize> {
76 assert!(!src.is_empty());
77
78 let max_symbol = src.iter().max().copied().unwrap();
80
81 let n = usize::from(max_symbol) + 1;
82
83 NonZero::new(n).unwrap()
85}
86
87fn write_symbol_count(dst: &mut Vec<u8>, symbol_count: NonZero<usize>) -> io::Result<()> {
88 let n = u8::try_from(usize::from(symbol_count))
89 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
90
91 write_u8(dst, n)
92}
93
94fn encode_ext(src: &[u8], dst: &mut Vec<u8>) -> io::Result<()> {
95 use bzip2::write::BzEncoder;
96
97 let mut encoder = BzEncoder::new(dst, Default::default());
98 encoder.write_all(src)?;
99 encoder.finish()?;
100
101 Ok(())
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_encode_ext() -> io::Result<()> {
110 use crate::codecs::bzip2;
111
112 let actual = encode(Flags::EXT, b"noodles")?;
113
114 let mut expected = vec![0x04, 0x07];
115
116 let compression_level = ::bzip2::Compression::default();
117 let data = bzip2::encode(compression_level, b"noodles")?;
118 expected.extend(data);
119
120 assert_eq!(actual, expected);
121
122 Ok(())
123 }
124
125 #[test]
126 fn test_encode_stripe() -> io::Result<()> {
127 let actual = encode(Flags::STRIPE, b"noodles")?;
128
129 let expected = [
130 0x08, 0x07, 0x04, 0x08, 0x08, 0x08, 0x07, 0x10, 0x6f, 0x00, 0xff, 0xa7, 0xab, 0x62, 0x00, 0x10, 0x70, 0x00, 0xff, 0x84, 0x92, 0x1b, 0x00, 0x10, 0x74, 0x00, 0xf7, 0x27, 0xdb, 0x24, 0x00, 0x10, 0x65, 0x00, 0xfd, 0x77, 0x20, 0xb0, ];
139
140 assert_eq!(actual, expected);
141
142 Ok(())
143 }
144
145 #[test]
146 fn test_encode_order_0() -> io::Result<()> {
147 let actual = encode(Flags::empty(), b"noodles")?;
148
149 let expected = [
150 0x00, 0x07, 0x74, 0x00, 0xf4, 0xe5, 0xb7, 0x4e, 0x50, 0x0f, 0x2e, 0x97, 0x00,
151 ];
152
153 assert_eq!(actual, expected);
154
155 Ok(())
156 }
157
158 #[test]
159 fn test_encode_order_1() -> io::Result<()> {
160 let actual = encode(Flags::ORDER, b"noodles")?;
161
162 let expected = [
163 0x01, 0x07, 0x74, 0x00, 0xf4, 0xe3, 0x83, 0x41, 0xe2, 0x9a, 0xef, 0x53, 0x50, 0x00,
164 ];
165
166 assert_eq!(actual, expected);
167
168 Ok(())
169 }
170
171 #[test]
172 fn test_encode_cat() -> io::Result<()> {
173 let actual = encode(Flags::CAT, b"noodles")?;
174 let expected = [0x20, 0x07, 0x6e, 0x6f, 0x6f, 0x64, 0x6c, 0x65, 0x73];
175 assert_eq!(actual, expected);
176 Ok(())
177 }
178
179 #[test]
180 fn test_encode_rle_with_order_0() -> io::Result<()> {
181 let actual = encode(Flags::RLE, b"noooooooodles")?;
182
183 let expected = [
184 0x40, 0x0d, 0x74, 0x00, 0xf3, 0x4b, 0x21, 0x10, 0xa8, 0xe3, 0x84, 0xfe, 0x6b, 0x22,
185 0x00,
186 ];
187
188 assert_eq!(actual, expected);
189
190 Ok(())
191 }
192
193 #[test]
194 fn test_encode_rle_with_order_1() -> io::Result<()> {
195 let actual = encode(Flags::ORDER | Flags::RLE, b"noooooooodles")?;
196
197 let expected = [
198 0x41, 0x0d, 0x74, 0x00, 0xf3, 0x4a, 0x89, 0x79, 0xc1, 0xe8, 0xc3, 0xc5, 0x62, 0x31,
199 0x00,
200 ];
201
202 assert_eq!(actual, expected);
203
204 Ok(())
205 }
206
207 #[test]
208 fn test_encode_pack() -> io::Result<()> {
209 let actual = encode(Flags::CAT | Flags::PACK, b"noodles")?;
210
211 let expected = [
212 0xa0, 0x07, 0x06, 0x64, 0x65, 0x6c, 0x6e, 0x6f, 0x73, 0x04, 0x43, 0x04, 0x12, 0x05,
213 ];
214
215 assert_eq!(actual, expected);
216
217 Ok(())
218 }
219
220 #[test]
221 fn test_count_symbols() {
222 fn t(src: &[u8], expected: NonZero<usize>) {
223 assert_eq!(count_symbols(src), expected);
224 }
225
226 t(&[0x00], NonZero::<usize>::MIN);
227 t(&[0xff], const { NonZero::new(256).unwrap() });
228 t(b"range_coder", const { NonZero::new(115).unwrap() });
229 }
230}