1use std::collections::HashSet;
4
5use archive_trait::{
6 Archive as ArchiveTrait, Member, MemberMetadata, MemberPayload as MemberPayloadTrait,
7 SpecialKind,
8};
9use tar_framing::{
10 ArchiveFormat, FrameError, PaxKeyword, PaxKind, PaxRecord, UstarKind,
11 logical::{MemberExtensions, MemberFrame, MemberPayload as FramingMemberPayload, TarReader},
12};
13use thiserror::Error;
14use tokio::io::AsyncRead;
15
16pub use tar_framing::{
17 DEFAULT_MAX_GLOBAL_PAX_EXTENSIONS_SIZE, DEFAULT_MAX_GNU_EXTENSION_SIZE,
18 DEFAULT_MAX_PAX_EXTENSION_SIZE,
19};
20
21pub struct TarArchive<R> {
26 reader: TarReader<R>,
27 policy: DecodePolicy,
28 fused: bool,
29}
30
31impl<R> TarArchive<R> {
32 pub fn new(reader: R) -> Self {
34 Self::new_with_policy(reader, DecodePolicy::default())
35 }
36
37 pub fn new_with_policy(reader: R, policy: DecodePolicy) -> Self {
39 let mut reader = TarReader::new(reader);
40 reader.set_max_pax_extension_size(policy.pax_policy.max_extension_size);
41 reader.set_max_global_pax_extensions_size(policy.pax_policy.max_global_extensions_size);
42 reader.set_allow_all_nul_numeric_fields(policy.allow_all_nul_numeric_fields);
43 reader.set_max_gnu_extension_size(policy.max_gnu_extension_size);
44 Self {
45 reader,
46 policy,
47 fused: false,
48 }
49 }
50}
51
52#[derive(Clone, Copy, Debug)]
56pub struct DecodePolicy {
57 allow_gnu: bool,
58 allow_all_nul_numeric_fields: bool,
59 max_gnu_extension_size: u64,
60 pax_policy: PaxDecodePolicy,
61}
62
63#[derive(Clone, Copy, Debug, Eq, PartialEq)]
67pub struct PaxDecodePolicy {
68 max_extension_size: u64,
69 max_global_extensions_size: u64,
70 allow_global_pax_extensions: bool,
71 allow_unknown_pax_vendor_records: bool,
72 allow_duplicate_pax_records: bool,
73 allow_global_pax_member_metadata: bool,
74}
75
76impl Default for PaxDecodePolicy {
77 fn default() -> Self {
78 Self {
79 max_extension_size: DEFAULT_MAX_PAX_EXTENSION_SIZE,
80 max_global_extensions_size: DEFAULT_MAX_GLOBAL_PAX_EXTENSIONS_SIZE,
81 allow_global_pax_extensions: true,
82 allow_unknown_pax_vendor_records: false,
83 allow_duplicate_pax_records: false,
84 allow_global_pax_member_metadata: false,
85 }
86 }
87}
88
89impl Default for DecodePolicy {
90 fn default() -> Self {
91 Self {
92 allow_gnu: true,
93 allow_all_nul_numeric_fields: true,
94 max_gnu_extension_size: DEFAULT_MAX_GNU_EXTENSION_SIZE,
95 pax_policy: PaxDecodePolicy::default(),
96 }
97 }
98}
99
100impl DecodePolicy {
101 pub fn allow_gnu(mut self, allow: bool) -> Self {
108 self.allow_gnu = allow;
109 self
110 }
111
112 pub fn allow_all_nul_numeric_fields(mut self, allow: bool) -> Self {
120 self.allow_all_nul_numeric_fields = allow;
121 self
122 }
123
124 pub fn max_gnu_extension_size(mut self, max_gnu_extension_size: u64) -> Self {
132 self.max_gnu_extension_size = max_gnu_extension_size;
133 self
134 }
135
136 pub fn pax_policy(mut self, policy: PaxDecodePolicy) -> Self {
138 self.pax_policy = policy;
139 self
140 }
141
142 fn check_format(&self, position: u64, format: ArchiveFormat) -> Result<(), DecodeError> {
143 if format == ArchiveFormat::Gnu && !self.allow_gnu {
144 return Err(DecodeError::policy_violation(
145 position,
146 DecodePolicyViolation::GnuArchive,
147 ));
148 }
149 Ok(())
150 }
151
152 fn check_global_pax(&self, position: u64, records: &[PaxRecord]) -> Result<(), DecodeError> {
153 self.pax_policy.check_global_pax_extension(position)?;
154 self.pax_policy
155 .check_pax_records(position, PaxKind::Global, records)
156 }
157
158 fn check_member<R>(&self, frame: &MemberFrame<'_, R>) -> Result<(), DecodeError> {
159 if let MemberExtensions::Pax(state) = &frame.extensions {
160 for extension in state
161 .extensions()
162 .filter(|extension| extension.kind == PaxKind::Global)
163 {
164 self.check_global_pax(extension.position, extension.records())?;
165 }
166 }
167 let format_position = match &frame.extensions {
168 MemberExtensions::Pax(_) => frame.header.position,
169 MemberExtensions::Gnu {
170 long_name,
171 long_link,
172 } => long_name
173 .iter()
174 .chain(long_link.iter())
175 .map(|header| header.position)
176 .min()
177 .unwrap_or(frame.header.position),
178 };
179 self.check_format(format_position, frame.header.format)?;
180 if let MemberExtensions::Pax(state) = &frame.extensions {
181 for extension in state
182 .extensions()
183 .filter(|extension| extension.kind == PaxKind::Local)
184 {
185 self.pax_policy.check_pax_records(
186 extension.position,
187 PaxKind::Local,
188 extension.records(),
189 )?;
190 }
191 }
192 Ok(())
193 }
194}
195
196impl PaxDecodePolicy {
197 pub fn max_extension_size(mut self, max_extension_size: u64) -> Self {
208 self.max_extension_size = max_extension_size;
209 self
210 }
211
212 pub fn max_global_extensions_size(mut self, max_global_extensions_size: u64) -> Self {
222 self.max_global_extensions_size = max_global_extensions_size;
223 self
224 }
225
226 pub fn allow_global_pax_extensions(mut self, allow: bool) -> Self {
235 self.allow_global_pax_extensions = allow;
236 self
237 }
238
239 pub fn allow_unknown_pax_vendor_records(mut self, allow: bool) -> Self {
254 self.allow_unknown_pax_vendor_records = allow;
255 self
256 }
257
258 pub fn allow_duplicate_pax_records(mut self, allow: bool) -> Self {
265 self.allow_duplicate_pax_records = allow;
266 self
267 }
268
269 pub fn allow_global_pax_member_metadata(mut self, allow: bool) -> Self {
277 self.allow_global_pax_member_metadata = allow;
278 self
279 }
280
281 fn check_global_pax_extension(&self, position: u64) -> Result<(), DecodeError> {
282 if !self.allow_global_pax_extensions {
283 return Err(DecodeError::policy_violation(
284 position,
285 DecodePolicyViolation::GlobalPaxExtension,
286 ));
287 }
288 Ok(())
289 }
290
291 fn check_pax_records(
292 &self,
293 position: u64,
294 kind: PaxKind,
295 records: &[PaxRecord],
296 ) -> Result<(), DecodeError> {
297 if !self.allow_unknown_pax_vendor_records {
298 for record in records {
299 if let PaxRecord::Vendor { vendor, name, .. } = record {
300 return Err(DecodeError::policy_violation(
301 position,
302 DecodePolicyViolation::PaxVendorExtension {
303 vendor: vendor.to_string(),
304 name: name.to_string(),
305 },
306 ));
307 }
308 }
309 }
310
311 if kind == PaxKind::Global && !self.allow_global_pax_member_metadata {
312 for record in records {
313 let keyword = match record.keyword() {
314 PaxKeyword::Path => Some("path"),
315 PaxKeyword::LinkPath => Some("linkpath"),
316 PaxKeyword::Size => Some("size"),
317 _ => None,
318 };
319 if let Some(keyword) = keyword {
320 return Err(DecodeError::policy_violation(
321 position,
322 DecodePolicyViolation::GlobalPaxMemberMetadata { keyword },
323 ));
324 }
325 }
326 }
327
328 if !self.allow_duplicate_pax_records {
329 let mut keywords = HashSet::new();
330 for record in records {
331 let keyword = record.keyword();
332 if !keywords.insert(keyword.clone()) {
333 return Err(DecodeError::policy_violation(
334 position,
335 DecodePolicyViolation::DuplicatePaxRecord {
336 keyword: keyword.to_string(),
337 },
338 ));
339 }
340 }
341 }
342
343 Ok(())
344 }
345}
346
347#[derive(Clone, Debug, Eq, PartialEq, Error)]
349pub enum DecodePolicyViolation {
350 #[error("GNU archives are not allowed")]
352 GnuArchive,
353 #[error("global pax extended headers are not allowed")]
355 GlobalPaxExtension,
356 #[error("pax vendor extension {vendor}.{name} is not allowed")]
358 PaxVendorExtension {
359 vendor: String,
361 name: String,
363 },
364 #[error("pax extended header contains duplicate record {keyword}")]
366 DuplicatePaxRecord {
367 keyword: String,
369 },
370 #[error("global pax extended header contains restricted member metadata {keyword}")]
372 GlobalPaxMemberMetadata {
373 keyword: &'static str,
375 },
376}
377
378#[derive(Debug, Error)]
380pub enum DecodeError {
381 #[error(transparent)]
383 Framing(#[from] FrameError),
384 #[error("at byte {position}: {field} is not valid UTF-8")]
386 InvalidUtf8 {
387 position: u64,
389 field: &'static str,
391 },
392 #[error("at byte {position}: decode policy rejected input: {violation}")]
394 PolicyViolation {
395 position: u64,
397 violation: DecodePolicyViolation,
399 },
400}
401
402impl DecodeError {
403 fn policy_violation(position: u64, violation: DecodePolicyViolation) -> Self {
404 Self::PolicyViolation {
405 position,
406 violation,
407 }
408 }
409}
410
411pub struct TarMemberPayload<'a, R> {
413 payload: FramingMemberPayload<'a, R>,
414}
415
416impl<R: AsyncRead + Unpin> MemberPayloadTrait for TarMemberPayload<'_, R> {
417 type Error = DecodeError;
418
419 async fn next_chunk(
420 &mut self,
421 buffer: &mut Vec<u8>,
422 target_len: usize,
423 ) -> Result<bool, Self::Error> {
424 self.payload
425 .next_chunk(buffer, target_len)
426 .await
427 .map_err(Into::into)
428 }
429
430 async fn skip(self) -> Result<(), Self::Error> {
431 self.payload.skip().await.map_err(Into::into)
432 }
433}
434
435impl<R: AsyncRead + Unpin> ArchiveTrait for TarArchive<R> {
436 type Error = DecodeError;
437 type Payload<'a>
438 = TarMemberPayload<'a, R>
439 where
440 Self: 'a;
441
442 async fn next_member<'a>(
443 &'a mut self,
444 ) -> Result<Option<Member<Self::Payload<'a>>>, Self::Error> {
445 if self.fused {
446 return Ok(None);
447 }
448 let frame = match self.reader.next_frame().await {
449 Ok(Some(frame)) => frame,
450 Ok(None) => {
451 self.fused = true;
452 return Ok(None);
453 }
454 Err(error) => {
455 self.fused = true;
456 return Err(error.into());
457 }
458 };
459 if let Err(error) = self.policy.check_member(&frame) {
460 self.fused = true;
461 return Err(error);
462 }
463 match project_member(frame) {
464 Ok(member) => Ok(Some(member)),
465 Err(error) => {
466 self.fused = true;
467 Err(error)
468 }
469 }
470 }
471}
472
473fn project_member<'a, R>(
474 frame: MemberFrame<'a, R>,
475) -> Result<Member<TarMemberPayload<'a, R>>, DecodeError> {
476 let position = frame.header.position;
477 let kind = frame.header.kind;
478 let size = frame.header.effective_size;
479 let executable = frame.header.mode.unwrap_or_default() & 0o111 != 0;
480 let path = std::str::from_utf8(frame.effective_path()?.as_ref())
481 .map(str::to_owned)
482 .map_err(|_| DecodeError::InvalidUtf8 {
483 position,
484 field: "path",
485 })?;
486 let target = if matches!(kind, UstarKind::HardLink | UstarKind::SymbolicLink) {
487 std::str::from_utf8(frame.effective_link_path()?.as_ref())
488 .map(str::to_owned)
489 .map_err(|_| DecodeError::InvalidUtf8 {
490 position,
491 field: "linkpath",
492 })?
493 } else {
494 String::new()
495 };
496 let metadata = MemberMetadata { path, position };
497
498 Ok(match kind {
499 UstarKind::Regular | UstarKind::Contiguous => Member::File {
500 metadata,
501 size,
502 executable,
503 payload: TarMemberPayload {
504 payload: frame.payload,
505 },
506 },
507 UstarKind::Directory => Member::Directory { metadata },
508 UstarKind::SymbolicLink => Member::SymbolicLink { metadata, target },
509 UstarKind::HardLink => Member::HardLink {
510 metadata,
511 target,
512 size,
513 payload: TarMemberPayload {
514 payload: frame.payload,
515 },
516 },
517 UstarKind::CharacterDevice => Member::Special {
518 metadata,
519 kind: SpecialKind::CharacterDevice,
520 },
521 UstarKind::BlockDevice => Member::Special {
522 metadata,
523 kind: SpecialKind::BlockDevice,
524 },
525 UstarKind::Fifo => Member::Special {
526 metadata,
527 kind: SpecialKind::Fifo,
528 },
529 })
530}