1use std::collections::HashSet;
4
5use archive_trait::{
6 Archive as ArchiveTrait, Member, MemberMetadata, MemberPayload as MemberPayloadTrait,
7 SpecialKind,
8};
9use tar_framing::{
10 ArchiveFormat, FrameError, PaxKeyword, PaxKind, PaxRecord, PaxValue, UstarKind,
11 logical::{MemberExtensions, MemberFrame, MemberPayload as FramingMemberPayload, TarReader},
12};
13use thiserror::Error;
14use tokio::io::AsyncRead;
15
16pub use tar_framing::{
17 DEFAULT_MAX_GLOBAL_PAX_EXTENSIONS_SIZE, DEFAULT_MAX_GNU_EXTENSION_SIZE,
18 DEFAULT_MAX_PAX_EXTENSION_SIZE,
19};
20
21pub struct TarArchive<R> {
26 reader: TarReader<R>,
27 policy: DecodePolicy,
28 fused: bool,
29}
30
31impl<R> TarArchive<R> {
32 pub fn new(reader: R) -> Self {
34 Self::new_with_policy(reader, DecodePolicy::default())
35 }
36
37 pub fn new_with_policy(reader: R, policy: DecodePolicy) -> Self {
39 let mut reader = TarReader::new(reader);
40 reader.set_max_pax_extension_size(policy.pax_policy.max_extension_size);
41 reader.set_max_global_pax_extensions_size(policy.pax_policy.max_global_extensions_size);
42 reader.set_allow_all_nul_numeric_fields(policy.allow_all_nul_numeric_fields);
43 reader.set_max_gnu_extension_size(policy.max_gnu_extension_size);
44 Self {
45 reader,
46 policy,
47 fused: false,
48 }
49 }
50}
51
52#[derive(Clone, Copy, Debug)]
56pub struct DecodePolicy {
57 allow_gnu: bool,
58 allow_all_nul_numeric_fields: bool,
59 max_gnu_extension_size: u64,
60 pax_policy: PaxDecodePolicy,
61}
62
63#[derive(Clone, Copy, Debug, Eq, PartialEq)]
67pub struct PaxDecodePolicy {
68 max_extension_size: u64,
69 max_global_extensions_size: u64,
70 allow_non_utf8_pax_vendor_values: bool,
71 allow_global_pax_extensions: bool,
72 allow_unknown_pax_vendor_records: bool,
73 allow_duplicate_pax_records: bool,
74 allow_global_pax_member_metadata: bool,
75}
76
77impl Default for PaxDecodePolicy {
78 fn default() -> Self {
79 Self {
80 max_extension_size: DEFAULT_MAX_PAX_EXTENSION_SIZE,
81 max_global_extensions_size: DEFAULT_MAX_GLOBAL_PAX_EXTENSIONS_SIZE,
82 allow_non_utf8_pax_vendor_values: true,
83 allow_global_pax_extensions: true,
84 allow_unknown_pax_vendor_records: false,
85 allow_duplicate_pax_records: false,
86 allow_global_pax_member_metadata: false,
87 }
88 }
89}
90
91impl Default for DecodePolicy {
92 fn default() -> Self {
93 Self {
94 allow_gnu: true,
95 allow_all_nul_numeric_fields: true,
96 max_gnu_extension_size: DEFAULT_MAX_GNU_EXTENSION_SIZE,
97 pax_policy: PaxDecodePolicy::default(),
98 }
99 }
100}
101
102impl DecodePolicy {
103 pub fn allow_gnu(mut self, allow: bool) -> Self {
110 self.allow_gnu = allow;
111 self
112 }
113
114 pub fn allow_all_nul_numeric_fields(mut self, allow: bool) -> Self {
122 self.allow_all_nul_numeric_fields = allow;
123 self
124 }
125
126 pub fn max_gnu_extension_size(mut self, max_gnu_extension_size: u64) -> Self {
134 self.max_gnu_extension_size = max_gnu_extension_size;
135 self
136 }
137
138 pub fn pax_policy(mut self, policy: PaxDecodePolicy) -> Self {
140 self.pax_policy = policy;
141 self
142 }
143
144 fn check_format(&self, position: u64, format: ArchiveFormat) -> Result<(), DecodeError> {
145 if format == ArchiveFormat::Gnu && !self.allow_gnu {
146 return Err(DecodeError::policy_violation(
147 position,
148 DecodePolicyViolation::GnuArchive,
149 ));
150 }
151 Ok(())
152 }
153
154 fn check_global_pax(&self, position: u64, records: &[PaxRecord]) -> Result<(), DecodeError> {
155 self.pax_policy.check_global_pax_extension(position)?;
156 self.pax_policy
157 .check_pax_records(position, PaxKind::Global, records)
158 }
159
160 fn check_member<R>(&self, frame: &MemberFrame<'_, R>) -> Result<(), DecodeError> {
161 if let MemberExtensions::Pax(state) = &frame.extensions {
162 for extension in state
163 .extensions()
164 .filter(|extension| extension.kind == PaxKind::Global)
165 {
166 self.check_global_pax(extension.position, extension.records())?;
167 }
168 }
169 let format_position = match &frame.extensions {
170 MemberExtensions::Pax(_) => frame.header.position,
171 MemberExtensions::Gnu {
172 long_name,
173 long_link,
174 } => long_name
175 .iter()
176 .chain(long_link.iter())
177 .map(|header| header.position)
178 .min()
179 .unwrap_or(frame.header.position),
180 };
181 self.check_format(format_position, frame.header.format)?;
182 if let MemberExtensions::Pax(state) = &frame.extensions {
183 for extension in state
184 .extensions()
185 .filter(|extension| extension.kind == PaxKind::Local)
186 {
187 self.pax_policy.check_pax_records(
188 extension.position,
189 PaxKind::Local,
190 extension.records(),
191 )?;
192 }
193 }
194 Ok(())
195 }
196}
197
198impl PaxDecodePolicy {
199 pub fn max_extension_size(mut self, max_extension_size: u64) -> Self {
210 self.max_extension_size = max_extension_size;
211 self
212 }
213
214 pub fn max_global_extensions_size(mut self, max_global_extensions_size: u64) -> Self {
224 self.max_global_extensions_size = max_global_extensions_size;
225 self
226 }
227
228 pub fn allow_non_utf8_pax_vendor_values(mut self, allow: bool) -> Self {
238 self.allow_non_utf8_pax_vendor_values = allow;
239 self
240 }
241
242 pub fn allow_global_pax_extensions(mut self, allow: bool) -> Self {
251 self.allow_global_pax_extensions = allow;
252 self
253 }
254
255 pub fn allow_unknown_pax_vendor_records(mut self, allow: bool) -> Self {
270 self.allow_unknown_pax_vendor_records = allow;
271 self
272 }
273
274 pub fn allow_duplicate_pax_records(mut self, allow: bool) -> Self {
281 self.allow_duplicate_pax_records = allow;
282 self
283 }
284
285 pub fn allow_global_pax_member_metadata(mut self, allow: bool) -> Self {
293 self.allow_global_pax_member_metadata = allow;
294 self
295 }
296
297 fn check_global_pax_extension(&self, position: u64) -> Result<(), DecodeError> {
298 if !self.allow_global_pax_extensions {
299 return Err(DecodeError::policy_violation(
300 position,
301 DecodePolicyViolation::GlobalPaxExtension,
302 ));
303 }
304 Ok(())
305 }
306
307 fn check_pax_records(
308 &self,
309 position: u64,
310 kind: PaxKind,
311 records: &[PaxRecord],
312 ) -> Result<(), DecodeError> {
313 for record in records {
314 if let PaxRecord::Vendor {
315 vendor,
316 name,
317 value,
318 } = record
319 {
320 if !self.allow_unknown_pax_vendor_records {
321 return Err(DecodeError::policy_violation(
322 position,
323 DecodePolicyViolation::PaxVendorExtension {
324 vendor: vendor.to_string(),
325 name: name.to_string(),
326 },
327 ));
328 }
329
330 if !self.allow_non_utf8_pax_vendor_values
331 && let PaxValue::Value(value) = value
332 && std::str::from_utf8(value).is_err()
333 {
334 return Err(DecodeError::policy_violation(
335 position,
336 DecodePolicyViolation::NonUtf8PaxVendorValue {
337 vendor: vendor.to_string(),
338 name: name.to_string(),
339 },
340 ));
341 }
342 }
343 }
344
345 if kind == PaxKind::Global && !self.allow_global_pax_member_metadata {
346 for record in records {
347 let keyword = match record.keyword() {
348 PaxKeyword::Path => Some("path"),
349 PaxKeyword::LinkPath => Some("linkpath"),
350 PaxKeyword::Size => Some("size"),
351 _ => None,
352 };
353 if let Some(keyword) = keyword {
354 return Err(DecodeError::policy_violation(
355 position,
356 DecodePolicyViolation::GlobalPaxMemberMetadata { keyword },
357 ));
358 }
359 }
360 }
361
362 if !self.allow_duplicate_pax_records {
363 let mut keywords = HashSet::new();
364 for record in records {
365 let keyword = record.keyword();
366 if !keywords.insert(keyword.clone()) {
367 return Err(DecodeError::policy_violation(
368 position,
369 DecodePolicyViolation::DuplicatePaxRecord {
370 keyword: keyword.to_string(),
371 },
372 ));
373 }
374 }
375 }
376
377 Ok(())
378 }
379}
380
381#[derive(Clone, Debug, Eq, PartialEq, Error)]
383pub enum DecodePolicyViolation {
384 #[error("GNU archives are not allowed")]
386 GnuArchive,
387 #[error("global pax extended headers are not allowed")]
389 GlobalPaxExtension,
390 #[error("pax vendor extension {vendor}.{name} is not allowed")]
392 PaxVendorExtension {
393 vendor: String,
395 name: String,
397 },
398 #[error("pax vendor extension {vendor}.{name} contains a non-UTF-8 value")]
400 NonUtf8PaxVendorValue {
401 vendor: String,
403 name: String,
405 },
406 #[error("pax extended header contains duplicate record {keyword}")]
408 DuplicatePaxRecord {
409 keyword: String,
411 },
412 #[error("global pax extended header contains restricted member metadata {keyword}")]
414 GlobalPaxMemberMetadata {
415 keyword: &'static str,
417 },
418}
419
420#[derive(Debug, Error)]
422pub enum DecodeError {
423 #[error(transparent)]
425 Framing(#[from] FrameError),
426 #[error("at byte {position}: {field} is not valid UTF-8")]
428 InvalidUtf8 {
429 position: u64,
431 field: &'static str,
433 },
434 #[error("at byte {position}: decode policy rejected input: {violation}")]
436 PolicyViolation {
437 position: u64,
439 violation: DecodePolicyViolation,
441 },
442}
443
444impl DecodeError {
445 fn policy_violation(position: u64, violation: DecodePolicyViolation) -> Self {
446 Self::PolicyViolation {
447 position,
448 violation,
449 }
450 }
451}
452
453pub struct TarMemberPayload<'a, R> {
455 payload: FramingMemberPayload<'a, R>,
456}
457
458impl<R: AsyncRead + Unpin> MemberPayloadTrait for TarMemberPayload<'_, R> {
459 type Error = DecodeError;
460
461 async fn next_chunk(
462 &mut self,
463 buffer: &mut Vec<u8>,
464 target_len: usize,
465 ) -> Result<bool, Self::Error> {
466 self.payload
467 .next_chunk(buffer, target_len)
468 .await
469 .map_err(Into::into)
470 }
471
472 async fn skip(self) -> Result<(), Self::Error> {
473 self.payload.skip().await.map_err(Into::into)
474 }
475}
476
477impl<R: AsyncRead + Unpin> ArchiveTrait for TarArchive<R> {
478 type Error = DecodeError;
479 type Payload<'a>
480 = TarMemberPayload<'a, R>
481 where
482 Self: 'a;
483
484 async fn next_member<'a>(
485 &'a mut self,
486 ) -> Result<Option<Member<Self::Payload<'a>>>, Self::Error> {
487 if self.fused {
488 return Ok(None);
489 }
490 let frame = match self.reader.next_frame().await {
491 Ok(Some(frame)) => frame,
492 Ok(None) => {
493 self.fused = true;
494 return Ok(None);
495 }
496 Err(error) => {
497 self.fused = true;
498 return Err(error.into());
499 }
500 };
501 if let Err(error) = self.policy.check_member(&frame) {
502 self.fused = true;
503 return Err(error);
504 }
505 match project_member(frame) {
506 Ok(member) => Ok(Some(member)),
507 Err(error) => {
508 self.fused = true;
509 Err(error)
510 }
511 }
512 }
513}
514
515fn project_member<'a, R>(
516 frame: MemberFrame<'a, R>,
517) -> Result<Member<TarMemberPayload<'a, R>>, DecodeError> {
518 let position = frame.header.position;
519 let kind = frame.header.kind;
520 let size = frame.header.effective_size;
521 let executable = frame.header.mode.unwrap_or_default() & 0o111 != 0;
522 let path = std::str::from_utf8(frame.effective_path()?.as_ref())
523 .map(str::to_owned)
524 .map_err(|_| DecodeError::InvalidUtf8 {
525 position,
526 field: "path",
527 })?;
528 let target = if matches!(kind, UstarKind::HardLink | UstarKind::SymbolicLink) {
529 std::str::from_utf8(frame.effective_link_path()?.as_ref())
530 .map(str::to_owned)
531 .map_err(|_| DecodeError::InvalidUtf8 {
532 position,
533 field: "linkpath",
534 })?
535 } else {
536 String::new()
537 };
538 let metadata = MemberMetadata { path, position };
539
540 Ok(match kind {
541 UstarKind::Regular | UstarKind::Contiguous => Member::File {
542 metadata,
543 size,
544 executable,
545 payload: TarMemberPayload {
546 payload: frame.payload,
547 },
548 },
549 UstarKind::Directory => Member::Directory { metadata },
550 UstarKind::SymbolicLink => Member::SymbolicLink { metadata, target },
551 UstarKind::HardLink => Member::HardLink {
552 metadata,
553 target,
554 size,
555 payload: TarMemberPayload {
556 payload: frame.payload,
557 },
558 },
559 UstarKind::CharacterDevice => Member::Special {
560 metadata,
561 kind: SpecialKind::CharacterDevice,
562 },
563 UstarKind::BlockDevice => Member::Special {
564 metadata,
565 kind: SpecialKind::BlockDevice,
566 },
567 UstarKind::Fifo => Member::Special {
568 metadata,
569 kind: SpecialKind::Fifo,
570 },
571 })
572}