puzzlefs_lib/compression/
zstd_seekable_wrapper.rs

1use std::cmp::min;
2use std::io;
3use std::io::{Read, Seek, Write};
4
5use zstd_seekable::{CStream, Seekable, SeekableCStream};
6
7use crate::compression::{Compression, Compressor, Decompressor};
8
9// We compress files in 4KB frames; it's not clear what the ideal size for this is, but each frame
10// is compressed independently so the bigger they are the more compression savings we get. However,
11// the bigger they are the more decompression we have to do to get to the data in the middle of a
12// frame if someone e.g. mmap()s something in the middle of a frame.
13//
14// Another consideration is the average chunk size from FastCDC: if we make this the same as the
15// chunk size, there's no real point in using seekable compression at all, at least for files. It's
16// also possible that we want different frame sizes for metadata blobs and file content.
17const FRAME_SIZE: usize = 4096;
18const COMPRESSION_LEVEL: usize = 3;
19
20fn err_to_io<E: 'static + std::error::Error + Send + Sync>(e: E) -> io::Error {
21    io::Error::new(io::ErrorKind::Other, e)
22}
23
24pub struct ZstdCompressor<W> {
25    f: W,
26    stream: SeekableCStream,
27    buf: Vec<u8>,
28}
29
30impl<W: Write> Compressor for ZstdCompressor<W> {
31    fn end(mut self: Box<Self>) -> io::Result<()> {
32        // end_stream has to be called multiple times until 0 is returned, see
33        // https://docs.rs/zstd-seekable/0.1.23/src/zstd_seekable/lib.rs.html#224-237 and
34        // https://fossies.org/linux/zstd/contrib/seekable_format/zstd_seekable.h
35        loop {
36            let size = self.stream.end_stream(&mut self.buf).map_err(err_to_io)?;
37            self.f.write_all(&self.buf[0..size])?;
38            if size == 0 {
39                break;
40            }
41        }
42        Ok(())
43    }
44}
45
46impl<W: Write> Write for ZstdCompressor<W> {
47    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
48        // TODO: we could try to consume all the input, but for now we just consume a single block
49        let (out_pos, in_pos) = self
50            .stream
51            .compress(&mut self.buf, buf)
52            .map_err(err_to_io)?;
53        self.f.write_all(&self.buf[0..out_pos])?;
54        Ok(in_pos)
55    }
56
57    fn flush(&mut self) -> io::Result<()> {
58        // we could self.stream.flush(), but that adversely affects compression ratio... let's
59        // cheat for now.
60        Ok(())
61    }
62}
63
64pub struct ZstdDecompressor<'a, R: Read + Seek> {
65    stream: Seekable<'a, R>,
66    offset: u64,
67    uncompressed_length: u64,
68}
69
70impl<'a, R: Seek + Read> Decompressor for ZstdDecompressor<'a, R> {
71    fn get_uncompressed_length(&mut self) -> io::Result<u64> {
72        Ok(self.uncompressed_length)
73    }
74}
75
76impl<'a, R: Seek + Read> Seek for ZstdDecompressor<'a, R> {
77    fn seek(&mut self, offset: io::SeekFrom) -> io::Result<u64> {
78        match offset {
79            io::SeekFrom::Start(s) => {
80                self.offset = s;
81            }
82            io::SeekFrom::End(e) => {
83                if e > 0 {
84                    return Err(io::Error::new(io::ErrorKind::Other, "zstd seek past end"));
85                }
86                self.offset = self.uncompressed_length - u64::try_from(-e).map_err(err_to_io)?;
87            }
88            io::SeekFrom::Current(c) => {
89                if c > 0 {
90                    self.offset += u64::try_from(c).map_err(err_to_io)?;
91                } else {
92                    self.offset -= u64::try_from(-c).map_err(err_to_io)?;
93                }
94            }
95        }
96        Ok(self.offset)
97    }
98}
99
100impl<'a, R: Seek + Read> Read for ZstdDecompressor<'a, R> {
101    fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
102        // decompress() gets angry (ZSTD("Corrupted block detected")) if you pass it a buffer
103        // longer than the uncompressable data, so let's be careful to truncate the buffer if it
104        // would make zstd angry. maybe soon they'll implement a real read() API :)
105        let end = min(out.len(), (self.uncompressed_length - self.offset) as usize);
106        let size = self
107            .stream
108            .decompress(&mut out[0..end], self.offset)
109            .map_err(err_to_io)?;
110        self.offset += size as u64;
111        Ok(size)
112    }
113}
114
115pub struct Zstd {}
116
117impl Compression for Zstd {
118    fn compress<'a, W: Write + 'a>(dest: W) -> io::Result<Box<dyn Compressor + 'a>> {
119        // a "pretty high" compression level, since decompression should be nearly the same no
120        // matter what compression level. Maybe we should turn this to 22 or whatever the max is...
121        let stream = SeekableCStream::new(COMPRESSION_LEVEL, FRAME_SIZE).map_err(err_to_io)?;
122        Ok(Box::new(ZstdCompressor {
123            f: dest,
124            stream,
125            buf: vec![0_u8; CStream::out_size()],
126        }))
127    }
128
129    fn decompress<'a, R: Read + Seek + 'a>(source: R) -> io::Result<Box<dyn Decompressor + 'a>> {
130        let stream = Seekable::init(Box::new(source)).map_err(err_to_io)?;
131
132        // zstd-seekable doesn't like it when we pass a buffer past the end of the uncompressed
133        // stream, so let's figure out the size of the uncompressed file so we can implement
134        // ::read() in a reasonable way. This also lets us implement SeekFrom::End.
135        let uncompressed_length = (0..stream.get_num_frames())
136            .map(|i| stream.get_frame_decompressed_size(i) as u64)
137            .sum();
138        Ok(Box::new(ZstdDecompressor {
139            stream,
140            offset: 0,
141            uncompressed_length,
142        }))
143    }
144
145    fn append_extension(media_type: &str) -> String {
146        format!("{media_type}+zstd")
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::compression::tests::{compress_decompress, compression_is_seekable};
154
155    #[test]
156    fn test_ztsd_roundtrip() -> anyhow::Result<()> {
157        compress_decompress::<Zstd>()
158    }
159
160    #[test]
161    fn test_zstd_seekable() -> anyhow::Result<()> {
162        compression_is_seekable::<Zstd>()
163    }
164}