zstd_framed/
writer.rs

1use crate::{
2    buffer::Buffer as _,
3    encoder::{ZstdFramedEncoder, ZstdFramedEncoderSeekTableConfig},
4    ZstdOutcome,
5};
6
7/// A writer that writes a compressed zstd stream to the underlying writer.
8///
9/// The underlying writer `W` must implement the following traits:
10///
11/// - [`std::io::Write`]
12///
13/// For async support, see [`crate::AsyncZstdWriter`].
14///
15/// ## Construction
16///
17/// Create a builder using [`ZstdWriter::builder`]. See [`ZstdWriterBuilder`]
18/// for builder options. Call [`ZstdWriterBuilder::build`] to build the
19/// [`ZstdWriter`] instance.
20///
21/// ```
22/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
23/// # let compressed_file = vec![];
24/// let mut writer = zstd_framed::ZstdWriter::builder(compressed_file)
25///     .with_compression_level(3) // Set custom compression level
26///     .with_seek_table(1024 * 1024) // Write zstd seekable format table
27///     .build()?;
28///
29/// // ...
30///
31/// writer.shutdown()?; // Optional, will shut down automatically on drop
32/// # Ok(())
33/// # }
34/// ```
35///
36/// ## Writing multiple frames
37///
38/// To allow for efficient seeking (e.g. when using [`ZstdReaderBuilder::with_seek_table`](crate::reader::ZstdReaderBuilder::with_seek_table)),
39/// you can write multiple zstd frames to the underlying writer. If the
40/// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table) option is
41/// given during construction, multiple frames will be created automatically
42/// to fit within the given `max_frame_size`.
43///
44/// Alternatively, you can use [`ZstdWriter::finish_frame()`] to explicitly
45/// split the underlying stream into multiple frames. [`.finish_frame()`](ZstdWriter::finish_frame)
46/// can be used even when not using the [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)
47/// option (but note the seek table will only be written when using
48/// [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)).
49///
50/// ## Clean shutdown
51///
52/// To ensure the writer shuts down cleanly (including flushing any in-memory
53/// buffers and writing the seek table if enabled with [`.with_seek_table()`](ZstdWriterBuilder::with_seek_table)),
54/// you can explicitly call the [`ZstdWriter::shutdown`] method. This
55/// method will also be called automatically on drop, but errors will
56/// be ignored.
57pub struct ZstdWriter<'dict, W>
58where
59    W: std::io::Write,
60{
61    writer: W,
62    encoder: ZstdFramedEncoder<'dict>,
63    buffer: crate::buffer::FixedBuffer<Vec<u8>>,
64}
65
66impl<W> ZstdWriter<'_, W>
67where
68    W: std::io::Write,
69{
70    /// Create a new zstd writer that writes a compressed zstd stream
71    /// to the underlying writer.
72    pub fn builder(writer: W) -> ZstdWriterBuilder<W> {
73        ZstdWriterBuilder::new(writer)
74    }
75
76    /// Explicitly finish the current zstd frame. If more data is written,
77    /// a new frame will be started.
78    ///
79    /// When using [`ZstdWriterBuilder::with_seek_table`], the just-finished
80    /// frame will be reflected in the resulting seek table.
81    pub fn finish_frame(&mut self) -> std::io::Result<()> {
82        self.encoder.finish_frame(&mut self.buffer)?;
83
84        Ok(())
85    }
86
87    /// Cleanly shut down the zstd stream. This will flush internal buffers,
88    /// finish writing any partially-written frames, and write the
89    /// seek table when using [`ZstdWriterBuilder::with_seek_table`].
90    ///
91    /// This method will be called automatically on drop, although
92    /// any errors will be ignored.
93    pub fn shutdown(&mut self) -> std::io::Result<()> {
94        loop {
95            // Flush any uncommitted data
96            self.flush_uncommitted()?;
97
98            // Shut down the encoder
99            let outcome = self.encoder.shutdown(&mut self.buffer)?;
100
101            match outcome {
102                ZstdOutcome::HasMore { .. } => {
103                    // Encoder still has more to write, so keep looping
104                }
105                ZstdOutcome::Complete(_) => {
106                    // Encoder has nothing else to do, so we're done
107                    break;
108                }
109            }
110        }
111
112        // Flush any final data from the encoder
113        self.flush_uncommitted()?;
114
115        // Flush the underlying writer for good measure
116        self.writer.flush()?;
117
118        Ok(())
119    }
120
121    /// Write all uncommitted buffered data to the underlying writer. After
122    /// returning `Ok(_)`, `self.buffer` will be empty.
123    fn flush_uncommitted(&mut self) -> std::io::Result<()> {
124        loop {
125            // Get the uncommitted data to write
126            let uncommitted = self.buffer.uncommitted();
127            if uncommitted.is_empty() {
128                // If there's no uncommitted data, we're done
129                return Ok(());
130            }
131
132            // Write the data to the underlying writer, and record it
133            // as committed
134            let committed = self.writer.write(uncommitted)?;
135            self.buffer.commit(committed);
136
137            if committed == 0 {
138                // The underlying reader didn't accept any more of our data
139
140                return Err(std::io::Error::new(
141                    std::io::ErrorKind::WriteZero,
142                    "failed to write buffered data",
143                ));
144            }
145        }
146    }
147}
148
149impl<W> std::io::Write for ZstdWriter<'_, W>
150where
151    W: std::io::Write,
152{
153    fn write(&mut self, data: &[u8]) -> Result<usize, std::io::Error> {
154        loop {
155            // Write all buffered data
156            self.flush_uncommitted()?;
157
158            // Encode the newly-written data
159            let outcome = self.encoder.encode(data, &mut self.buffer)?;
160
161            match outcome {
162                ZstdOutcome::HasMore { .. } => {
163                    // The encoder has more to do before data can be encoded
164                }
165                ZstdOutcome::Complete(consumed) => {
166                    // We've now encoded some data to the buffer, so we're done
167                    return Ok(consumed);
168                }
169            }
170        }
171    }
172
173    fn flush(&mut self) -> std::io::Result<()> {
174        loop {
175            // Write all buffered data
176            self.flush_uncommitted()?;
177
178            // Flush any data from the encoder to the interal buffer
179            let outcome = self.encoder.flush(&mut self.buffer)?;
180
181            match outcome {
182                ZstdOutcome::HasMore { .. } => {
183                    // zstd still has more data to flush, so loop again
184                }
185                ZstdOutcome::Complete(_) => {
186                    // No more data from the encoder
187                    break;
188                }
189            }
190        }
191
192        // Write any newly buffered data from the encoder
193        self.flush_uncommitted()?;
194
195        // Flush the underlying writer
196        self.writer.flush()
197    }
198}
199
200impl<W> Drop for ZstdWriter<'_, W>
201where
202    W: std::io::Write,
203{
204    fn drop(&mut self) {
205        // Try to shut down the writer
206        let _ = self.shutdown();
207    }
208}
209
210/// A builder that builds a [`ZstdWriter`] from the provided writer.
211pub struct ZstdWriterBuilder<W> {
212    writer: W,
213    compression_level: i32,
214    seek_table_config: Option<ZstdFramedEncoderSeekTableConfig>,
215}
216
217impl<W> ZstdWriterBuilder<W> {
218    fn new(writer: W) -> Self {
219        Self {
220            writer,
221            compression_level: 0,
222            seek_table_config: None,
223        }
224    }
225
226    /// Set the zstd compression level.
227    pub fn with_compression_level(mut self, level: i32) -> Self {
228        self.compression_level = level;
229        self
230    }
231
232    /// Write the stream using the [zstd seekable format].
233    ///
234    /// Once the current zstd frame reaches a decompressed size of
235    /// `max_frame_size`, a new frame will automatically be started. When
236    /// the writer is [shut down](ZstdWriter::shutdown), a final frame
237    /// containing a seek table will be written to the end of the writer.
238    /// This seek table can be used to efficiently seek through the file, such
239    /// as by using [crate::table::read_seek_table] along with
240    /// [`ZstdReaderBuilder::with_seek_table`](crate::reader::ZstdReaderBuilder::with_seek_table).
241    ///
242    /// [zstd seekable format]: https://github.com/facebook/zstd/tree/51eb7daf39c8e8a7c338ba214a9d4e2a6a086826/contrib/seekable_format
243    pub fn with_seek_table(mut self, max_frame_size: u32) -> Self {
244        assert!(max_frame_size > 0, "max frame size must be greater than 0");
245
246        self.seek_table_config = Some(ZstdFramedEncoderSeekTableConfig { max_frame_size });
247        self
248    }
249
250    /// Build the writer.
251    pub fn build(self) -> std::io::Result<ZstdWriter<'static, W>>
252    where
253        W: std::io::Write,
254    {
255        let zstd_encoder = zstd::stream::raw::Encoder::new(self.compression_level)?;
256        let buffer = crate::buffer::FixedBuffer::new(vec![0; zstd::zstd_safe::CCtx::out_size()]);
257        let encoder = ZstdFramedEncoder::new(zstd_encoder, self.seek_table_config);
258
259        Ok(ZstdWriter {
260            writer: self.writer,
261            encoder,
262            buffer,
263        })
264    }
265}