Skip to main content

haagenti_stream/
writer.rs

1//! Write adapters for streaming compression.
2
3use std::io::{self, Write};
4use std::mem::ManuallyDrop;
5
6use haagenti_core::Compressor;
7
8use crate::{StreamBuffer, DEFAULT_BUFFER_SIZE};
9
10/// A writer that compresses data before writing to the inner writer.
11///
12/// Data is buffered and compressed when the buffer is full or when
13/// `flush()` or `finish()` is called.
14pub struct CompressWriter<W: Write, C: Compressor> {
15    inner: ManuallyDrop<W>,
16    compressor: C,
17    buffer: StreamBuffer,
18    finished: bool,
19}
20
21impl<W: Write, C: Compressor> CompressWriter<W, C> {
22    /// Create a new compressing writer with default buffer size.
23    pub fn new(inner: W, compressor: C) -> Self {
24        Self::with_buffer_size(inner, compressor, DEFAULT_BUFFER_SIZE)
25    }
26
27    /// Create a new compressing writer with specified buffer size.
28    pub fn with_buffer_size(inner: W, compressor: C, buffer_size: usize) -> Self {
29        Self {
30            inner: ManuallyDrop::new(inner),
31            compressor,
32            buffer: StreamBuffer::with_capacity(buffer_size),
33            finished: false,
34        }
35    }
36
37    /// Get a reference to the inner writer.
38    pub fn get_ref(&self) -> &W {
39        &self.inner
40    }
41
42    /// Get a mutable reference to the inner writer.
43    pub fn get_mut(&mut self) -> &mut W {
44        &mut self.inner
45    }
46
47    /// Get a reference to the compressor.
48    pub fn compressor(&self) -> &C {
49        &self.compressor
50    }
51
52    /// Finish compression and flush all remaining data.
53    ///
54    /// This must be called before dropping to ensure all data is written.
55    pub fn finish(mut self) -> io::Result<W> {
56        self.do_finish()?;
57        // Safety: we're consuming self, so inner won't be dropped twice
58        let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
59        std::mem::forget(self); // Prevent Drop from running
60        Ok(inner)
61    }
62
63    /// Internal finish implementation.
64    fn do_finish(&mut self) -> io::Result<()> {
65        if self.finished {
66            return Ok(());
67        }
68
69        // Compress any remaining buffered data
70        if self.buffer.available() > 0 {
71            self.flush_buffer()?;
72        }
73
74        self.finished = true;
75        Ok(())
76    }
77
78    /// Flush the internal buffer by compressing and writing.
79    fn flush_buffer(&mut self) -> io::Result<()> {
80        if self.buffer.is_empty() {
81            return Ok(());
82        }
83
84        let data = self.buffer.readable();
85        let compressed = self
86            .compressor
87            .compress(data)
88            .map_err(|e| io::Error::other(e.to_string()))?;
89
90        self.inner.write_all(&compressed)?;
91        self.buffer.clear();
92
93        Ok(())
94    }
95}
96
97impl<W: Write, C: Compressor> Write for CompressWriter<W, C> {
98    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
99        if self.finished {
100            return Err(io::Error::other("writer already finished"));
101        }
102
103        // Write to buffer
104        let mut written = 0;
105        while written < buf.len() {
106            let n = self.buffer.write(&buf[written..]);
107            written += n;
108
109            // If buffer is full, flush it
110            if self.buffer.is_full() {
111                self.flush_buffer()?;
112            }
113        }
114
115        Ok(written)
116    }
117
118    fn flush(&mut self) -> io::Result<()> {
119        self.flush_buffer()?;
120        self.inner.flush()
121    }
122}
123
124impl<W: Write, C: Compressor> Drop for CompressWriter<W, C> {
125    fn drop(&mut self) {
126        // Best effort finish on drop
127        let _ = self.do_finish();
128        // Safety: we're in drop, so this is the only time inner is dropped
129        unsafe { ManuallyDrop::drop(&mut self.inner) };
130    }
131}
132
133/// A generic write adapter for transforming data.
134///
135/// This is a simpler interface that doesn't buffer - it transforms
136/// each write immediately.
137pub struct WriteAdapter<W: Write, F> {
138    inner: W,
139    transform: F,
140}
141
142impl<W: Write, F> WriteAdapter<W, F>
143where
144    F: FnMut(&[u8]) -> io::Result<Vec<u8>>,
145{
146    /// Create a new write adapter.
147    pub fn new(inner: W, transform: F) -> Self {
148        Self { inner, transform }
149    }
150
151    /// Get a reference to the inner writer.
152    pub fn get_ref(&self) -> &W {
153        &self.inner
154    }
155
156    /// Get a mutable reference to the inner writer.
157    pub fn get_mut(&mut self) -> &mut W {
158        &mut self.inner
159    }
160
161    /// Consume the adapter and return the inner writer.
162    pub fn into_inner(self) -> W {
163        self.inner
164    }
165}
166
167impl<W: Write, F> Write for WriteAdapter<W, F>
168where
169    F: FnMut(&[u8]) -> io::Result<Vec<u8>>,
170{
171    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
172        let transformed = (self.transform)(buf)?;
173        self.inner.write_all(&transformed)?;
174        Ok(buf.len())
175    }
176
177    fn flush(&mut self) -> io::Result<()> {
178        self.inner.flush()
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    // Simple mock compressor for testing
187    struct MockCompressor;
188
189    impl Compressor for MockCompressor {
190        fn algorithm(&self) -> haagenti_core::Algorithm {
191            haagenti_core::Algorithm::Lz4
192        }
193
194        fn level(&self) -> haagenti_core::CompressionLevel {
195            haagenti_core::CompressionLevel::Default
196        }
197
198        fn compress(&self, input: &[u8]) -> haagenti_core::Result<Vec<u8>> {
199            // Simple "compression": prefix with length
200            let mut result = Vec::with_capacity(4 + input.len());
201            result.extend_from_slice(&(input.len() as u32).to_le_bytes());
202            result.extend_from_slice(input);
203            Ok(result)
204        }
205
206        fn compress_to(&self, input: &[u8], output: &mut [u8]) -> haagenti_core::Result<usize> {
207            let compressed = self.compress(input)?;
208            if compressed.len() > output.len() {
209                return Err(haagenti_core::Error::buffer_too_small(
210                    compressed.len(),
211                    output.len(),
212                ));
213            }
214            output[..compressed.len()].copy_from_slice(&compressed);
215            Ok(compressed.len())
216        }
217
218        fn max_compressed_size(&self, input_len: usize) -> usize {
219            input_len + 4
220        }
221    }
222
223    #[test]
224    fn test_compress_writer() {
225        let mut output = Vec::new();
226        {
227            let mut writer = CompressWriter::with_buffer_size(&mut output, MockCompressor, 16);
228            writer.write_all(b"Hello").unwrap();
229            writer.finish().unwrap();
230        }
231
232        // Verify output contains length prefix + data
233        assert_eq!(output.len(), 4 + 5);
234        let len = u32::from_le_bytes(output[..4].try_into().unwrap());
235        assert_eq!(len, 5);
236        assert_eq!(&output[4..], b"Hello");
237    }
238
239    #[test]
240    fn test_compress_writer_multiple_flushes() {
241        let mut output = Vec::new();
242        {
243            let mut writer = CompressWriter::with_buffer_size(&mut output, MockCompressor, 8);
244
245            // This should cause multiple buffer flushes
246            writer.write_all(b"Hello, World! This is a test.").unwrap();
247            writer.finish().unwrap();
248        }
249
250        // Output should contain multiple compressed blocks
251        assert!(output.len() > 4);
252    }
253
254    #[test]
255    fn test_write_adapter() {
256        let mut output = Vec::new();
257        {
258            let mut adapter = WriteAdapter::new(&mut output, |data: &[u8]| {
259                // Transform: uppercase
260                Ok(data.to_ascii_uppercase())
261            });
262            adapter.write_all(b"hello").unwrap();
263        }
264
265        assert_eq!(output, b"HELLO");
266    }
267}