atlas_runtime/snapshot_utils/
archive_format.rs

1use {
2    std::{fmt, str::FromStr},
3    strum::Display,
4};
5
6// SUPPORTED_ARCHIVE_COMPRESSION lists the compression types that can be
7// specified on the command line.
8pub const SUPPORTED_ARCHIVE_COMPRESSION: &[&str] = &["zstd", "lz4"];
9pub const DEFAULT_ARCHIVE_COMPRESSION: &str = "zstd";
10
11pub const TAR_ZSTD_EXTENSION: &str = "tar.zst";
12pub const TAR_LZ4_EXTENSION: &str = "tar.lz4";
13
14/// The different archive formats used for snapshots
15#[derive(Copy, Clone, Debug, Eq, PartialEq, Display)]
16pub enum ArchiveFormat {
17    TarZstd { config: ZstdConfig },
18    TarLz4,
19}
20
21impl ArchiveFormat {
22    /// Get the file extension for the ArchiveFormat
23    pub fn extension(&self) -> &str {
24        match self {
25            ArchiveFormat::TarZstd { .. } => TAR_ZSTD_EXTENSION,
26            ArchiveFormat::TarLz4 => TAR_LZ4_EXTENSION,
27        }
28    }
29
30    pub fn from_cli_arg(archive_format_str: &str) -> Option<ArchiveFormat> {
31        match archive_format_str {
32            "zstd" => Some(ArchiveFormat::TarZstd {
33                config: ZstdConfig::default(),
34            }),
35            "lz4" => Some(ArchiveFormat::TarLz4),
36            _ => None,
37        }
38    }
39}
40
41// Change this to `impl<S: AsRef<str>> TryFrom<S> for ArchiveFormat [...]`
42// once this Rust bug is fixed: https://github.com/rust-lang/rust/issues/50133
43impl TryFrom<&str> for ArchiveFormat {
44    type Error = ParseError;
45
46    fn try_from(extension: &str) -> Result<Self, Self::Error> {
47        match extension {
48            TAR_ZSTD_EXTENSION => Ok(ArchiveFormat::TarZstd {
49                config: ZstdConfig::default(),
50            }),
51            TAR_LZ4_EXTENSION => Ok(ArchiveFormat::TarLz4),
52            _ => Err(ParseError::InvalidExtension(extension.to_string())),
53        }
54    }
55}
56
57impl FromStr for ArchiveFormat {
58    type Err = ParseError;
59
60    fn from_str(extension: &str) -> Result<Self, Self::Err> {
61        Self::try_from(extension)
62    }
63}
64
65pub enum ArchiveFormatDecompressor<R> {
66    Zstd(zstd::stream::read::Decoder<'static, R>),
67    Lz4(lz4::Decoder<R>),
68}
69
70impl<R: std::io::BufRead> ArchiveFormatDecompressor<R> {
71    pub fn new(format: ArchiveFormat, input: R) -> std::io::Result<Self> {
72        Ok(match format {
73            ArchiveFormat::TarZstd { .. } => {
74                Self::Zstd(zstd::stream::read::Decoder::with_buffer(input)?)
75            }
76            ArchiveFormat::TarLz4 => {
77                Self::Lz4(lz4::Decoder::new(input).map_err(std::io::Error::other)?)
78            }
79        })
80    }
81}
82
83impl<R: std::io::BufRead> std::io::Read for ArchiveFormatDecompressor<R> {
84    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
85        match self {
86            Self::Zstd(decoder) => decoder.read(buf),
87            Self::Lz4(decoder) => decoder.read(buf),
88        }
89    }
90}
91
92#[derive(Debug, Clone, Eq, PartialEq)]
93pub enum ParseError {
94    InvalidExtension(String),
95}
96
97impl fmt::Display for ParseError {
98    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        match self {
100            ParseError::InvalidExtension(extension) => {
101                write!(f, "Invalid archive extension: {extension}")
102            }
103        }
104    }
105}
106
107/// Configuration when using zstd as the snapshot archive format
108#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
109pub struct ZstdConfig {
110    /// The compression level to use when archiving with zstd
111    pub compression_level: i32,
112}
113
114#[cfg(test)]
115mod tests {
116    use {super::*, std::iter::zip};
117    const INVALID_EXTENSION: &str = "zip";
118
119    #[test]
120    fn test_extension() {
121        assert_eq!(
122            ArchiveFormat::TarZstd {
123                config: ZstdConfig::default(),
124            }
125            .extension(),
126            TAR_ZSTD_EXTENSION
127        );
128        assert_eq!(ArchiveFormat::TarLz4.extension(), TAR_LZ4_EXTENSION);
129    }
130
131    #[test]
132    fn test_try_from() {
133        assert_eq!(
134            ArchiveFormat::try_from(TAR_ZSTD_EXTENSION),
135            Ok(ArchiveFormat::TarZstd {
136                config: ZstdConfig::default(),
137            })
138        );
139        assert_eq!(
140            ArchiveFormat::try_from(TAR_LZ4_EXTENSION),
141            Ok(ArchiveFormat::TarLz4)
142        );
143        assert_eq!(
144            ArchiveFormat::try_from(INVALID_EXTENSION),
145            Err(ParseError::InvalidExtension(INVALID_EXTENSION.to_string()))
146        );
147    }
148
149    #[test]
150    fn test_from_str() {
151        assert_eq!(
152            ArchiveFormat::from_str(TAR_ZSTD_EXTENSION),
153            Ok(ArchiveFormat::TarZstd {
154                config: ZstdConfig::default(),
155            })
156        );
157        assert_eq!(
158            ArchiveFormat::from_str(TAR_LZ4_EXTENSION),
159            Ok(ArchiveFormat::TarLz4)
160        );
161        assert_eq!(
162            ArchiveFormat::from_str(INVALID_EXTENSION),
163            Err(ParseError::InvalidExtension(INVALID_EXTENSION.to_string()))
164        );
165    }
166
167    #[test]
168    fn test_from_cli_arg() {
169        let golden = [
170            Some(ArchiveFormat::TarZstd {
171                config: ZstdConfig::default(),
172            }),
173            Some(ArchiveFormat::TarLz4),
174        ];
175
176        for (arg, expected) in zip(SUPPORTED_ARCHIVE_COMPRESSION.iter(), golden.into_iter()) {
177            assert_eq!(ArchiveFormat::from_cli_arg(arg), expected);
178        }
179
180        assert_eq!(ArchiveFormat::from_cli_arg("bad"), None);
181    }
182}