compression_codecs/lz4/
encoder.rs

1use crate::{lz4::params::EncoderParams, Encode};
2use compression_core::{unshared::Unshared, util::PartialBuffer};
3use lz4::liblz4::{
4    check_error, LZ4FCompressionContext, LZ4FPreferences, LZ4F_compressBegin, LZ4F_compressBound,
5    LZ4F_compressEnd, LZ4F_compressUpdate, LZ4F_createCompressionContext, LZ4F_flush,
6    LZ4F_freeCompressionContext, LZ4F_VERSION,
7};
8use std::io::{self, Result};
9
10// https://github.com/lz4/lz4/blob/9d53d8bb6c4120345a0966e5d8b16d7def1f32c5/lib/lz4frame.h#L281
11const LZ4F_HEADER_SIZE_MAX: usize = 19;
12
13#[derive(Debug)]
14struct EncoderContext {
15    ctx: LZ4FCompressionContext,
16}
17
18#[derive(Clone, Copy, Debug)]
19enum State {
20    Header,
21    Encoding,
22    Footer,
23    Done,
24}
25
26enum Lz4Fn<'a, T>
27where
28    T: AsRef<[u8]>,
29{
30    Begin,
31    Update { input: &'a mut PartialBuffer<T> },
32    Flush,
33    End,
34}
35
36#[derive(Debug)]
37pub struct Lz4Encoder {
38    ctx: Unshared<EncoderContext>,
39    state: State,
40    preferences: LZ4FPreferences,
41    limit: usize,
42    maybe_buffer: Option<PartialBuffer<Vec<u8>>>,
43    /// Minimum dst buffer size for a block
44    block_buffer_size: usize,
45    /// Minimum dst buffer size for flush/end
46    flush_buffer_size: usize,
47}
48
49// minimum size of destination buffer for compressing `src_size` bytes
50fn min_dst_size(src_size: usize, preferences: &LZ4FPreferences) -> usize {
51    unsafe { LZ4F_compressBound(src_size, preferences) }
52}
53
54impl EncoderContext {
55    fn new() -> Result<Self> {
56        let mut context = LZ4FCompressionContext(core::ptr::null_mut());
57        check_error(unsafe { LZ4F_createCompressionContext(&mut context, LZ4F_VERSION) })?;
58        Ok(Self { ctx: context })
59    }
60}
61
62impl Drop for EncoderContext {
63    fn drop(&mut self) {
64        unsafe { LZ4F_freeCompressionContext(self.ctx) };
65    }
66}
67
68impl Lz4Encoder {
69    pub fn new(params: EncoderParams) -> Self {
70        let preferences = LZ4FPreferences::from(params);
71        let block_size = preferences.frame_info.block_size_id.get_size();
72
73        let block_buffer_size = min_dst_size(block_size, &preferences);
74        let flush_buffer_size = min_dst_size(0, &preferences);
75
76        Self {
77            ctx: Unshared::new(EncoderContext::new().unwrap()),
78            state: State::Header,
79            preferences,
80            limit: block_size,
81            maybe_buffer: None,
82            block_buffer_size,
83            flush_buffer_size,
84        }
85    }
86
87    pub fn buffer_size(&self) -> usize {
88        self.block_buffer_size
89    }
90
91    fn drain_buffer(
92        &mut self,
93        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
94    ) -> (usize, usize) {
95        match self.maybe_buffer.as_mut() {
96            Some(buffer) => {
97                let drained_bytes = output.copy_unwritten_from(buffer);
98                (drained_bytes, buffer.unwritten().len())
99            }
100            None => (0, 0),
101        }
102    }
103
104    fn write<'a, T>(
105        &'a mut self,
106        lz4_fn: Lz4Fn<'a, T>,
107        output: &'a mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
108    ) -> Result<usize>
109    where
110        T: AsRef<[u8]>,
111    {
112        let (drained_before, undrained) = self.drain_buffer(output);
113        if undrained > 0 {
114            return Ok(drained_before);
115        }
116
117        let mut src_size = 0;
118
119        let min_dst_size = match &lz4_fn {
120            Lz4Fn::Begin => LZ4F_HEADER_SIZE_MAX,
121            Lz4Fn::Update { input } => {
122                src_size = input.unwritten().len().min(self.limit);
123                min_dst_size(src_size, &self.preferences)
124            }
125            Lz4Fn::Flush | Lz4Fn::End => self.flush_buffer_size,
126        };
127
128        let output_len = output.unwritten().len();
129
130        let (dst_buffer, dst_size, maybe_internal_buffer) = if min_dst_size > output_len {
131            let buffer_size = self.block_buffer_size;
132            let buffer = self
133                .maybe_buffer
134                .get_or_insert_with(|| PartialBuffer::new(Vec::with_capacity(buffer_size)));
135            buffer.reset();
136            (
137                buffer.unwritten_mut().as_mut_ptr(),
138                buffer_size,
139                Some(buffer),
140            )
141        } else {
142            (output.unwritten_mut().as_mut_ptr(), output_len, None)
143        };
144
145        let len = match lz4_fn {
146            Lz4Fn::Begin => {
147                let len = check_error(unsafe {
148                    LZ4F_compressBegin(
149                        self.ctx.get_mut().ctx,
150                        dst_buffer,
151                        dst_size,
152                        &self.preferences,
153                    )
154                })?;
155                self.state = State::Encoding;
156                len
157            }
158            Lz4Fn::Update { input } => {
159                let len = check_error(unsafe {
160                    LZ4F_compressUpdate(
161                        self.ctx.get_mut().ctx,
162                        dst_buffer,
163                        dst_size,
164                        input.unwritten().as_ptr(),
165                        src_size,
166                        core::ptr::null(),
167                    )
168                })?;
169                input.advance(src_size);
170                len
171            }
172            Lz4Fn::Flush => check_error(unsafe {
173                LZ4F_flush(
174                    self.ctx.get_mut().ctx,
175                    dst_buffer,
176                    dst_size,
177                    core::ptr::null(),
178                )
179            })?,
180            Lz4Fn::End => {
181                let len = check_error(unsafe {
182                    LZ4F_compressEnd(
183                        self.ctx.get_mut().ctx,
184                        dst_buffer,
185                        dst_size,
186                        core::ptr::null(),
187                    )
188                })?;
189                self.state = State::Footer;
190                len
191            }
192        };
193
194        let drained_after = if let Some(internal_buffer) = maybe_internal_buffer {
195            unsafe {
196                internal_buffer.get_mut().set_len(len);
197            }
198            let (d, _) = self.drain_buffer(output);
199            d
200        } else {
201            output.advance(len);
202            len
203        };
204
205        Ok(drained_before + drained_after)
206    }
207}
208
209impl Encode for Lz4Encoder {
210    fn encode(
211        &mut self,
212        input: &mut PartialBuffer<impl AsRef<[u8]>>,
213        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
214    ) -> Result<()> {
215        loop {
216            match self.state {
217                State::Header => {
218                    self.write(Lz4Fn::Begin::<&[u8]>, output)?;
219                }
220
221                State::Encoding => {
222                    self.write(Lz4Fn::Update { input }, output)?;
223                }
224
225                State::Footer | State::Done => {
226                    return Err(io::Error::other("encode after complete"));
227                }
228            }
229
230            if input.unwritten().is_empty() || output.unwritten().is_empty() {
231                return Ok(());
232            }
233        }
234    }
235
236    fn flush(
237        &mut self,
238        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
239    ) -> Result<bool> {
240        loop {
241            let done = match self.state {
242                State::Header => {
243                    self.write(Lz4Fn::Begin::<&[u8]>, output)?;
244                    false
245                }
246
247                State::Encoding => {
248                    let len = self.write(Lz4Fn::Flush::<&[u8]>, output)?;
249                    len == 0
250                }
251
252                State::Footer => {
253                    let (_, undrained) = self.drain_buffer(output);
254                    if undrained == 0 {
255                        self.state = State::Done;
256                        true
257                    } else {
258                        false
259                    }
260                }
261
262                State::Done => true,
263            };
264
265            if done {
266                return Ok(true);
267            }
268
269            if output.unwritten().is_empty() {
270                return Ok(false);
271            }
272        }
273    }
274
275    fn finish(
276        &mut self,
277        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
278    ) -> Result<bool> {
279        loop {
280            match self.state {
281                State::Header => {
282                    self.write(Lz4Fn::Begin::<&[u8]>, output)?;
283                }
284
285                State::Encoding => {
286                    self.write(Lz4Fn::End::<&[u8]>, output)?;
287                }
288
289                State::Footer => {
290                    let (_, undrained) = self.drain_buffer(output);
291                    if undrained == 0 {
292                        self.state = State::Done;
293                    }
294                }
295
296                State::Done => {}
297            }
298
299            if let State::Done = self.state {
300                return Ok(true);
301            }
302
303            if output.unwritten().is_empty() {
304                return Ok(false);
305            }
306        }
307    }
308}