atlas_runtime/snapshot_utils/
archive_format.rs1use {
2 std::{fmt, str::FromStr},
3 strum::Display,
4};
5
6pub const SUPPORTED_ARCHIVE_COMPRESSION: &[&str] = &["zstd", "lz4"];
9pub const DEFAULT_ARCHIVE_COMPRESSION: &str = "zstd";
10
11pub const TAR_ZSTD_EXTENSION: &str = "tar.zst";
12pub const TAR_LZ4_EXTENSION: &str = "tar.lz4";
13
14#[derive(Copy, Clone, Debug, Eq, PartialEq, Display)]
16pub enum ArchiveFormat {
17 TarZstd { config: ZstdConfig },
18 TarLz4,
19}
20
21impl ArchiveFormat {
22 pub fn extension(&self) -> &str {
24 match self {
25 ArchiveFormat::TarZstd { .. } => TAR_ZSTD_EXTENSION,
26 ArchiveFormat::TarLz4 => TAR_LZ4_EXTENSION,
27 }
28 }
29
30 pub fn from_cli_arg(archive_format_str: &str) -> Option<ArchiveFormat> {
31 match archive_format_str {
32 "zstd" => Some(ArchiveFormat::TarZstd {
33 config: ZstdConfig::default(),
34 }),
35 "lz4" => Some(ArchiveFormat::TarLz4),
36 _ => None,
37 }
38 }
39}
40
41impl TryFrom<&str> for ArchiveFormat {
44 type Error = ParseError;
45
46 fn try_from(extension: &str) -> Result<Self, Self::Error> {
47 match extension {
48 TAR_ZSTD_EXTENSION => Ok(ArchiveFormat::TarZstd {
49 config: ZstdConfig::default(),
50 }),
51 TAR_LZ4_EXTENSION => Ok(ArchiveFormat::TarLz4),
52 _ => Err(ParseError::InvalidExtension(extension.to_string())),
53 }
54 }
55}
56
57impl FromStr for ArchiveFormat {
58 type Err = ParseError;
59
60 fn from_str(extension: &str) -> Result<Self, Self::Err> {
61 Self::try_from(extension)
62 }
63}
64
65pub enum ArchiveFormatDecompressor<R> {
66 Zstd(zstd::stream::read::Decoder<'static, R>),
67 Lz4(lz4::Decoder<R>),
68}
69
70impl<R: std::io::BufRead> ArchiveFormatDecompressor<R> {
71 pub fn new(format: ArchiveFormat, input: R) -> std::io::Result<Self> {
72 Ok(match format {
73 ArchiveFormat::TarZstd { .. } => {
74 Self::Zstd(zstd::stream::read::Decoder::with_buffer(input)?)
75 }
76 ArchiveFormat::TarLz4 => {
77 Self::Lz4(lz4::Decoder::new(input).map_err(std::io::Error::other)?)
78 }
79 })
80 }
81}
82
83impl<R: std::io::BufRead> std::io::Read for ArchiveFormatDecompressor<R> {
84 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
85 match self {
86 Self::Zstd(decoder) => decoder.read(buf),
87 Self::Lz4(decoder) => decoder.read(buf),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Eq, PartialEq)]
93pub enum ParseError {
94 InvalidExtension(String),
95}
96
97impl fmt::Display for ParseError {
98 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 match self {
100 ParseError::InvalidExtension(extension) => {
101 write!(f, "Invalid archive extension: {extension}")
102 }
103 }
104 }
105}
106
107#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
109pub struct ZstdConfig {
110 pub compression_level: i32,
112}
113
114#[cfg(test)]
115mod tests {
116 use {super::*, std::iter::zip};
117 const INVALID_EXTENSION: &str = "zip";
118
119 #[test]
120 fn test_extension() {
121 assert_eq!(
122 ArchiveFormat::TarZstd {
123 config: ZstdConfig::default(),
124 }
125 .extension(),
126 TAR_ZSTD_EXTENSION
127 );
128 assert_eq!(ArchiveFormat::TarLz4.extension(), TAR_LZ4_EXTENSION);
129 }
130
131 #[test]
132 fn test_try_from() {
133 assert_eq!(
134 ArchiveFormat::try_from(TAR_ZSTD_EXTENSION),
135 Ok(ArchiveFormat::TarZstd {
136 config: ZstdConfig::default(),
137 })
138 );
139 assert_eq!(
140 ArchiveFormat::try_from(TAR_LZ4_EXTENSION),
141 Ok(ArchiveFormat::TarLz4)
142 );
143 assert_eq!(
144 ArchiveFormat::try_from(INVALID_EXTENSION),
145 Err(ParseError::InvalidExtension(INVALID_EXTENSION.to_string()))
146 );
147 }
148
149 #[test]
150 fn test_from_str() {
151 assert_eq!(
152 ArchiveFormat::from_str(TAR_ZSTD_EXTENSION),
153 Ok(ArchiveFormat::TarZstd {
154 config: ZstdConfig::default(),
155 })
156 );
157 assert_eq!(
158 ArchiveFormat::from_str(TAR_LZ4_EXTENSION),
159 Ok(ArchiveFormat::TarLz4)
160 );
161 assert_eq!(
162 ArchiveFormat::from_str(INVALID_EXTENSION),
163 Err(ParseError::InvalidExtension(INVALID_EXTENSION.to_string()))
164 );
165 }
166
167 #[test]
168 fn test_from_cli_arg() {
169 let golden = [
170 Some(ArchiveFormat::TarZstd {
171 config: ZstdConfig::default(),
172 }),
173 Some(ArchiveFormat::TarLz4),
174 ];
175
176 for (arg, expected) in zip(SUPPORTED_ARCHIVE_COMPRESSION.iter(), golden.into_iter()) {
177 assert_eq!(ArchiveFormat::from_cli_arg(arg), expected);
178 }
179
180 assert_eq!(ArchiveFormat::from_cli_arg("bad"), None);
181 }
182}