puzzlefs_lib/compression/
zstd_seekable_wrapper.rs1use std::cmp::min;
2use std::io;
3use std::io::{Read, Seek, Write};
4
5use zstd_seekable::{CStream, Seekable, SeekableCStream};
6
7use crate::compression::{Compression, Compressor, Decompressor};
8
9const FRAME_SIZE: usize = 4096;
18const COMPRESSION_LEVEL: usize = 3;
19
20fn err_to_io<E: 'static + std::error::Error + Send + Sync>(e: E) -> io::Error {
21 io::Error::new(io::ErrorKind::Other, e)
22}
23
24pub struct ZstdCompressor<W> {
25 f: W,
26 stream: SeekableCStream,
27 buf: Vec<u8>,
28}
29
30impl<W: Write> Compressor for ZstdCompressor<W> {
31 fn end(mut self: Box<Self>) -> io::Result<()> {
32 loop {
36 let size = self.stream.end_stream(&mut self.buf).map_err(err_to_io)?;
37 self.f.write_all(&self.buf[0..size])?;
38 if size == 0 {
39 break;
40 }
41 }
42 Ok(())
43 }
44}
45
46impl<W: Write> Write for ZstdCompressor<W> {
47 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
48 let (out_pos, in_pos) = self
50 .stream
51 .compress(&mut self.buf, buf)
52 .map_err(err_to_io)?;
53 self.f.write_all(&self.buf[0..out_pos])?;
54 Ok(in_pos)
55 }
56
57 fn flush(&mut self) -> io::Result<()> {
58 Ok(())
61 }
62}
63
64pub struct ZstdDecompressor<'a, R: Read + Seek> {
65 stream: Seekable<'a, R>,
66 offset: u64,
67 uncompressed_length: u64,
68}
69
70impl<'a, R: Seek + Read> Decompressor for ZstdDecompressor<'a, R> {
71 fn get_uncompressed_length(&mut self) -> io::Result<u64> {
72 Ok(self.uncompressed_length)
73 }
74}
75
76impl<'a, R: Seek + Read> Seek for ZstdDecompressor<'a, R> {
77 fn seek(&mut self, offset: io::SeekFrom) -> io::Result<u64> {
78 match offset {
79 io::SeekFrom::Start(s) => {
80 self.offset = s;
81 }
82 io::SeekFrom::End(e) => {
83 if e > 0 {
84 return Err(io::Error::new(io::ErrorKind::Other, "zstd seek past end"));
85 }
86 self.offset = self.uncompressed_length - u64::try_from(-e).map_err(err_to_io)?;
87 }
88 io::SeekFrom::Current(c) => {
89 if c > 0 {
90 self.offset += u64::try_from(c).map_err(err_to_io)?;
91 } else {
92 self.offset -= u64::try_from(-c).map_err(err_to_io)?;
93 }
94 }
95 }
96 Ok(self.offset)
97 }
98}
99
100impl<'a, R: Seek + Read> Read for ZstdDecompressor<'a, R> {
101 fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
102 let end = min(out.len(), (self.uncompressed_length - self.offset) as usize);
106 let size = self
107 .stream
108 .decompress(&mut out[0..end], self.offset)
109 .map_err(err_to_io)?;
110 self.offset += size as u64;
111 Ok(size)
112 }
113}
114
115pub struct Zstd {}
116
117impl Compression for Zstd {
118 fn compress<'a, W: Write + 'a>(dest: W) -> io::Result<Box<dyn Compressor + 'a>> {
119 let stream = SeekableCStream::new(COMPRESSION_LEVEL, FRAME_SIZE).map_err(err_to_io)?;
122 Ok(Box::new(ZstdCompressor {
123 f: dest,
124 stream,
125 buf: vec![0_u8; CStream::out_size()],
126 }))
127 }
128
129 fn decompress<'a, R: Read + Seek + 'a>(source: R) -> io::Result<Box<dyn Decompressor + 'a>> {
130 let stream = Seekable::init(Box::new(source)).map_err(err_to_io)?;
131
132 let uncompressed_length = (0..stream.get_num_frames())
136 .map(|i| stream.get_frame_decompressed_size(i) as u64)
137 .sum();
138 Ok(Box::new(ZstdDecompressor {
139 stream,
140 offset: 0,
141 uncompressed_length,
142 }))
143 }
144
145 fn append_extension(media_type: &str) -> String {
146 format!("{media_type}+zstd")
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::compression::tests::{compress_decompress, compression_is_seekable};
154
155 #[test]
156 fn test_ztsd_roundtrip() -> anyhow::Result<()> {
157 compress_decompress::<Zstd>()
158 }
159
160 #[test]
161 fn test_zstd_seekable() -> anyhow::Result<()> {
162 compression_is_seekable::<Zstd>()
163 }
164}