use std::collections::HashSet;
use archive_trait::{
Archive as ArchiveTrait, Member, MemberMetadata, MemberPayload as MemberPayloadTrait,
SpecialKind,
};
use tar_framing::{
ArchiveFormat, FrameError, PaxKeyword, PaxKind, PaxRecord, UstarKind,
logical::{MemberExtensions, MemberFrame, MemberPayload as FramingMemberPayload, TarReader},
};
use thiserror::Error;
use tokio::io::AsyncRead;
pub use tar_framing::{
DEFAULT_MAX_GLOBAL_PAX_EXTENSIONS_SIZE, DEFAULT_MAX_GNU_EXTENSION_SIZE,
DEFAULT_MAX_PAX_EXTENSION_SIZE,
};
pub struct TarArchive<R> {
reader: TarReader<R>,
policy: DecodePolicy,
fused: bool,
}
impl<R> TarArchive<R> {
pub fn new(reader: R) -> Self {
Self::new_with_policy(reader, DecodePolicy::default())
}
pub fn new_with_policy(reader: R, policy: DecodePolicy) -> Self {
let mut reader = TarReader::new(reader);
reader.set_max_pax_extension_size(policy.pax_policy.max_extension_size);
reader.set_max_global_pax_extensions_size(policy.pax_policy.max_global_extensions_size);
reader.set_allow_all_nul_numeric_fields(policy.allow_all_nul_numeric_fields);
reader.set_max_gnu_extension_size(policy.max_gnu_extension_size);
Self {
reader,
policy,
fused: false,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct DecodePolicy {
allow_gnu: bool,
allow_all_nul_numeric_fields: bool,
max_gnu_extension_size: u64,
pax_policy: PaxDecodePolicy,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct PaxDecodePolicy {
max_extension_size: u64,
max_global_extensions_size: u64,
allow_global_pax_extensions: bool,
allow_unknown_pax_vendor_records: bool,
allow_duplicate_pax_records: bool,
allow_global_pax_member_metadata: bool,
}
impl Default for PaxDecodePolicy {
fn default() -> Self {
Self {
max_extension_size: DEFAULT_MAX_PAX_EXTENSION_SIZE,
max_global_extensions_size: DEFAULT_MAX_GLOBAL_PAX_EXTENSIONS_SIZE,
allow_global_pax_extensions: true,
allow_unknown_pax_vendor_records: false,
allow_duplicate_pax_records: false,
allow_global_pax_member_metadata: false,
}
}
}
impl Default for DecodePolicy {
fn default() -> Self {
Self {
allow_gnu: true,
allow_all_nul_numeric_fields: true,
max_gnu_extension_size: DEFAULT_MAX_GNU_EXTENSION_SIZE,
pax_policy: PaxDecodePolicy::default(),
}
}
}
impl DecodePolicy {
pub fn allow_gnu(mut self, allow: bool) -> Self {
self.allow_gnu = allow;
self
}
pub fn allow_all_nul_numeric_fields(mut self, allow: bool) -> Self {
self.allow_all_nul_numeric_fields = allow;
self
}
pub fn max_gnu_extension_size(mut self, max_gnu_extension_size: u64) -> Self {
self.max_gnu_extension_size = max_gnu_extension_size;
self
}
pub fn pax_policy(mut self, policy: PaxDecodePolicy) -> Self {
self.pax_policy = policy;
self
}
fn check_format(&self, position: u64, format: ArchiveFormat) -> Result<(), DecodeError> {
if format == ArchiveFormat::Gnu && !self.allow_gnu {
return Err(DecodeError::policy_violation(
position,
DecodePolicyViolation::GnuArchive,
));
}
Ok(())
}
fn check_global_pax(&self, position: u64, records: &[PaxRecord]) -> Result<(), DecodeError> {
self.pax_policy.check_global_pax_extension(position)?;
self.pax_policy
.check_pax_records(position, PaxKind::Global, records)
}
fn check_member<R>(&self, frame: &MemberFrame<'_, R>) -> Result<(), DecodeError> {
if let MemberExtensions::Pax(state) = &frame.extensions {
for extension in state
.extensions()
.filter(|extension| extension.kind == PaxKind::Global)
{
self.check_global_pax(extension.position, extension.records())?;
}
}
let format_position = match &frame.extensions {
MemberExtensions::Pax(_) => frame.header.position,
MemberExtensions::Gnu {
long_name,
long_link,
} => long_name
.iter()
.chain(long_link.iter())
.map(|header| header.position)
.min()
.unwrap_or(frame.header.position),
};
self.check_format(format_position, frame.header.format)?;
if let MemberExtensions::Pax(state) = &frame.extensions {
for extension in state
.extensions()
.filter(|extension| extension.kind == PaxKind::Local)
{
self.pax_policy.check_pax_records(
extension.position,
PaxKind::Local,
extension.records(),
)?;
}
}
Ok(())
}
}
impl PaxDecodePolicy {
pub fn max_extension_size(mut self, max_extension_size: u64) -> Self {
self.max_extension_size = max_extension_size;
self
}
pub fn max_global_extensions_size(mut self, max_global_extensions_size: u64) -> Self {
self.max_global_extensions_size = max_global_extensions_size;
self
}
pub fn allow_global_pax_extensions(mut self, allow: bool) -> Self {
self.allow_global_pax_extensions = allow;
self
}
pub fn allow_unknown_pax_vendor_records(mut self, allow: bool) -> Self {
self.allow_unknown_pax_vendor_records = allow;
self
}
pub fn allow_duplicate_pax_records(mut self, allow: bool) -> Self {
self.allow_duplicate_pax_records = allow;
self
}
pub fn allow_global_pax_member_metadata(mut self, allow: bool) -> Self {
self.allow_global_pax_member_metadata = allow;
self
}
fn check_global_pax_extension(&self, position: u64) -> Result<(), DecodeError> {
if !self.allow_global_pax_extensions {
return Err(DecodeError::policy_violation(
position,
DecodePolicyViolation::GlobalPaxExtension,
));
}
Ok(())
}
fn check_pax_records(
&self,
position: u64,
kind: PaxKind,
records: &[PaxRecord],
) -> Result<(), DecodeError> {
if !self.allow_unknown_pax_vendor_records {
for record in records {
if let PaxRecord::Vendor { vendor, name, .. } = record {
return Err(DecodeError::policy_violation(
position,
DecodePolicyViolation::PaxVendorExtension {
vendor: vendor.to_string(),
name: name.to_string(),
},
));
}
}
}
if kind == PaxKind::Global && !self.allow_global_pax_member_metadata {
for record in records {
let keyword = match record.keyword() {
PaxKeyword::Path => Some("path"),
PaxKeyword::LinkPath => Some("linkpath"),
PaxKeyword::Size => Some("size"),
_ => None,
};
if let Some(keyword) = keyword {
return Err(DecodeError::policy_violation(
position,
DecodePolicyViolation::GlobalPaxMemberMetadata { keyword },
));
}
}
}
if !self.allow_duplicate_pax_records {
let mut keywords = HashSet::new();
for record in records {
let keyword = record.keyword();
if !keywords.insert(keyword.clone()) {
return Err(DecodeError::policy_violation(
position,
DecodePolicyViolation::DuplicatePaxRecord {
keyword: keyword.to_string(),
},
));
}
}
}
Ok(())
}
}
#[derive(Clone, Debug, Eq, PartialEq, Error)]
pub enum DecodePolicyViolation {
#[error("GNU archives are not allowed")]
GnuArchive,
#[error("global pax extended headers are not allowed")]
GlobalPaxExtension,
#[error("pax vendor extension {vendor}.{name} is not allowed")]
PaxVendorExtension {
vendor: String,
name: String,
},
#[error("pax extended header contains duplicate record {keyword}")]
DuplicatePaxRecord {
keyword: String,
},
#[error("global pax extended header contains restricted member metadata {keyword}")]
GlobalPaxMemberMetadata {
keyword: &'static str,
},
}
#[derive(Debug, Error)]
pub enum DecodeError {
#[error(transparent)]
Framing(#[from] FrameError),
#[error("at byte {position}: {field} is not valid UTF-8")]
InvalidUtf8 {
position: u64,
field: &'static str,
},
#[error("at byte {position}: decode policy rejected input: {violation}")]
PolicyViolation {
position: u64,
violation: DecodePolicyViolation,
},
}
impl DecodeError {
fn policy_violation(position: u64, violation: DecodePolicyViolation) -> Self {
Self::PolicyViolation {
position,
violation,
}
}
}
pub struct TarMemberPayload<'a, R> {
payload: FramingMemberPayload<'a, R>,
}
impl<R: AsyncRead + Unpin> MemberPayloadTrait for TarMemberPayload<'_, R> {
type Error = DecodeError;
async fn next_chunk(
&mut self,
buffer: &mut Vec<u8>,
target_len: usize,
) -> Result<bool, Self::Error> {
self.payload
.next_chunk(buffer, target_len)
.await
.map_err(Into::into)
}
async fn skip(self) -> Result<(), Self::Error> {
self.payload.skip().await.map_err(Into::into)
}
}
impl<R: AsyncRead + Unpin> ArchiveTrait for TarArchive<R> {
type Error = DecodeError;
type Payload<'a>
= TarMemberPayload<'a, R>
where
Self: 'a;
async fn next_member<'a>(
&'a mut self,
) -> Result<Option<Member<Self::Payload<'a>>>, Self::Error> {
if self.fused {
return Ok(None);
}
let frame = match self.reader.next_frame().await {
Ok(Some(frame)) => frame,
Ok(None) => {
self.fused = true;
return Ok(None);
}
Err(error) => {
self.fused = true;
return Err(error.into());
}
};
if let Err(error) = self.policy.check_member(&frame) {
self.fused = true;
return Err(error);
}
match project_member(frame) {
Ok(member) => Ok(Some(member)),
Err(error) => {
self.fused = true;
Err(error)
}
}
}
}
fn project_member<'a, R>(
frame: MemberFrame<'a, R>,
) -> Result<Member<TarMemberPayload<'a, R>>, DecodeError> {
let position = frame.header.position;
let kind = frame.header.kind;
let size = frame.header.effective_size;
let executable = frame.header.mode.unwrap_or_default() & 0o111 != 0;
let path = std::str::from_utf8(frame.effective_path()?.as_ref())
.map(str::to_owned)
.map_err(|_| DecodeError::InvalidUtf8 {
position,
field: "path",
})?;
let target = if matches!(kind, UstarKind::HardLink | UstarKind::SymbolicLink) {
std::str::from_utf8(frame.effective_link_path()?.as_ref())
.map(str::to_owned)
.map_err(|_| DecodeError::InvalidUtf8 {
position,
field: "linkpath",
})?
} else {
String::new()
};
let metadata = MemberMetadata { path, position };
Ok(match kind {
UstarKind::Regular | UstarKind::Contiguous => Member::File {
metadata,
size,
executable,
payload: TarMemberPayload {
payload: frame.payload,
},
},
UstarKind::Directory => Member::Directory { metadata },
UstarKind::SymbolicLink => Member::SymbolicLink { metadata, target },
UstarKind::HardLink => Member::HardLink {
metadata,
target,
size,
payload: TarMemberPayload {
payload: frame.payload,
},
},
UstarKind::CharacterDevice => Member::Special {
metadata,
kind: SpecialKind::CharacterDevice,
},
UstarKind::BlockDevice => Member::Special {
metadata,
kind: SpecialKind::BlockDevice,
},
UstarKind::Fifo => Member::Special {
metadata,
kind: SpecialKind::Fifo,
},
})
}