haagenti_stream/
writer.rs1use std::io::{self, Write};
4use std::mem::ManuallyDrop;
5
6use haagenti_core::Compressor;
7
8use crate::{StreamBuffer, DEFAULT_BUFFER_SIZE};
9
10pub struct CompressWriter<W: Write, C: Compressor> {
15 inner: ManuallyDrop<W>,
16 compressor: C,
17 buffer: StreamBuffer,
18 finished: bool,
19}
20
21impl<W: Write, C: Compressor> CompressWriter<W, C> {
22 pub fn new(inner: W, compressor: C) -> Self {
24 Self::with_buffer_size(inner, compressor, DEFAULT_BUFFER_SIZE)
25 }
26
27 pub fn with_buffer_size(inner: W, compressor: C, buffer_size: usize) -> Self {
29 Self {
30 inner: ManuallyDrop::new(inner),
31 compressor,
32 buffer: StreamBuffer::with_capacity(buffer_size),
33 finished: false,
34 }
35 }
36
37 pub fn get_ref(&self) -> &W {
39 &self.inner
40 }
41
42 pub fn get_mut(&mut self) -> &mut W {
44 &mut self.inner
45 }
46
47 pub fn compressor(&self) -> &C {
49 &self.compressor
50 }
51
52 pub fn finish(mut self) -> io::Result<W> {
56 self.do_finish()?;
57 let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
59 std::mem::forget(self); Ok(inner)
61 }
62
63 fn do_finish(&mut self) -> io::Result<()> {
65 if self.finished {
66 return Ok(());
67 }
68
69 if self.buffer.available() > 0 {
71 self.flush_buffer()?;
72 }
73
74 self.finished = true;
75 Ok(())
76 }
77
78 fn flush_buffer(&mut self) -> io::Result<()> {
80 if self.buffer.is_empty() {
81 return Ok(());
82 }
83
84 let data = self.buffer.readable();
85 let compressed = self
86 .compressor
87 .compress(data)
88 .map_err(|e| io::Error::other(e.to_string()))?;
89
90 self.inner.write_all(&compressed)?;
91 self.buffer.clear();
92
93 Ok(())
94 }
95}
96
97impl<W: Write, C: Compressor> Write for CompressWriter<W, C> {
98 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
99 if self.finished {
100 return Err(io::Error::other("writer already finished"));
101 }
102
103 let mut written = 0;
105 while written < buf.len() {
106 let n = self.buffer.write(&buf[written..]);
107 written += n;
108
109 if self.buffer.is_full() {
111 self.flush_buffer()?;
112 }
113 }
114
115 Ok(written)
116 }
117
118 fn flush(&mut self) -> io::Result<()> {
119 self.flush_buffer()?;
120 self.inner.flush()
121 }
122}
123
124impl<W: Write, C: Compressor> Drop for CompressWriter<W, C> {
125 fn drop(&mut self) {
126 let _ = self.do_finish();
128 unsafe { ManuallyDrop::drop(&mut self.inner) };
130 }
131}
132
133pub struct WriteAdapter<W: Write, F> {
138 inner: W,
139 transform: F,
140}
141
142impl<W: Write, F> WriteAdapter<W, F>
143where
144 F: FnMut(&[u8]) -> io::Result<Vec<u8>>,
145{
146 pub fn new(inner: W, transform: F) -> Self {
148 Self { inner, transform }
149 }
150
151 pub fn get_ref(&self) -> &W {
153 &self.inner
154 }
155
156 pub fn get_mut(&mut self) -> &mut W {
158 &mut self.inner
159 }
160
161 pub fn into_inner(self) -> W {
163 self.inner
164 }
165}
166
167impl<W: Write, F> Write for WriteAdapter<W, F>
168where
169 F: FnMut(&[u8]) -> io::Result<Vec<u8>>,
170{
171 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
172 let transformed = (self.transform)(buf)?;
173 self.inner.write_all(&transformed)?;
174 Ok(buf.len())
175 }
176
177 fn flush(&mut self) -> io::Result<()> {
178 self.inner.flush()
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 struct MockCompressor;
188
189 impl Compressor for MockCompressor {
190 fn algorithm(&self) -> haagenti_core::Algorithm {
191 haagenti_core::Algorithm::Lz4
192 }
193
194 fn level(&self) -> haagenti_core::CompressionLevel {
195 haagenti_core::CompressionLevel::Default
196 }
197
198 fn compress(&self, input: &[u8]) -> haagenti_core::Result<Vec<u8>> {
199 let mut result = Vec::with_capacity(4 + input.len());
201 result.extend_from_slice(&(input.len() as u32).to_le_bytes());
202 result.extend_from_slice(input);
203 Ok(result)
204 }
205
206 fn compress_to(&self, input: &[u8], output: &mut [u8]) -> haagenti_core::Result<usize> {
207 let compressed = self.compress(input)?;
208 if compressed.len() > output.len() {
209 return Err(haagenti_core::Error::buffer_too_small(
210 compressed.len(),
211 output.len(),
212 ));
213 }
214 output[..compressed.len()].copy_from_slice(&compressed);
215 Ok(compressed.len())
216 }
217
218 fn max_compressed_size(&self, input_len: usize) -> usize {
219 input_len + 4
220 }
221 }
222
223 #[test]
224 fn test_compress_writer() {
225 let mut output = Vec::new();
226 {
227 let mut writer = CompressWriter::with_buffer_size(&mut output, MockCompressor, 16);
228 writer.write_all(b"Hello").unwrap();
229 writer.finish().unwrap();
230 }
231
232 assert_eq!(output.len(), 4 + 5);
234 let len = u32::from_le_bytes(output[..4].try_into().unwrap());
235 assert_eq!(len, 5);
236 assert_eq!(&output[4..], b"Hello");
237 }
238
239 #[test]
240 fn test_compress_writer_multiple_flushes() {
241 let mut output = Vec::new();
242 {
243 let mut writer = CompressWriter::with_buffer_size(&mut output, MockCompressor, 8);
244
245 writer.write_all(b"Hello, World! This is a test.").unwrap();
247 writer.finish().unwrap();
248 }
249
250 assert!(output.len() > 4);
252 }
253
254 #[test]
255 fn test_write_adapter() {
256 let mut output = Vec::new();
257 {
258 let mut adapter = WriteAdapter::new(&mut output, |data: &[u8]| {
259 Ok(data.to_ascii_uppercase())
261 });
262 adapter.write_all(b"hello").unwrap();
263 }
264
265 assert_eq!(output, b"HELLO");
266 }
267}