1mod path;
7mod root;
8
9use std::path::Path;
10
11use self::{path::decode_member, root::ExtractionRoot};
12use super::*;
13
14#[derive(Clone, Copy, Debug)]
18pub struct ExtractPolicy {
19 pub(crate) link_policy: LinkPolicy,
20 pub(crate) allow_overwrites: bool,
21 pub(crate) name_validation: crate::name::NameValidation,
22}
23
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
30pub struct LinkPolicy {
31 pub(crate) symlink_policy: SymlinkPolicy,
32 pub(crate) allow_hard_links: bool,
33 pub(crate) allow_ambient_targets: bool,
34 pub(crate) allow_missing_targets: bool,
35}
36
37#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
39pub enum SymlinkPolicy {
40 #[default]
42 Preserve,
43 Skip,
45 Reject,
47}
48
49impl Default for ExtractPolicy {
50 fn default() -> Self {
51 Self {
52 link_policy: LinkPolicy::default(),
53 allow_overwrites: true,
54 name_validation: crate::name::NameValidation::Default,
55 }
56 }
57}
58
59impl Default for LinkPolicy {
60 fn default() -> Self {
61 Self {
62 symlink_policy: SymlinkPolicy::default(),
63 allow_hard_links: false,
64 allow_ambient_targets: false,
65 allow_missing_targets: true,
66 }
67 }
68}
69
70impl ExtractPolicy {
71 pub fn link_policy(mut self, policy: LinkPolicy) -> Self {
73 self.link_policy = policy;
74 self
75 }
76
77 pub fn allow_overwrites(mut self, allow: bool) -> Self {
83 self.allow_overwrites = allow;
84 self
85 }
86
87 pub fn name_validator(mut self, validator: Option<NameValidator>) -> Self {
92 self.name_validation = crate::name::NameValidation::from_validator(validator);
93 self
94 }
95
96 fn check_name<E>(
97 self,
98 position: u64,
99 context: &'static str,
100 value: &str,
101 ) -> Result<(), ExtractError<E>> {
102 if !self.name_validation.accepts(value) {
103 return Err(ExtractError::policy_violation(
104 position,
105 ExtractPolicyViolation::NameRejected {
106 context,
107 value: value.to_owned(),
108 },
109 ));
110 }
111 Ok(())
112 }
113}
114
115impl LinkPolicy {
116 pub fn symlink_policy(mut self, policy: SymlinkPolicy) -> Self {
122 self.symlink_policy = policy;
123 self
124 }
125
126 pub fn allow_hard_links(mut self, allow: bool) -> Self {
132 self.allow_hard_links = allow;
133 self
134 }
135
136 pub fn allow_ambient_targets(mut self, allow: bool) -> Self {
142 self.allow_ambient_targets = allow;
143 self
144 }
145
146 pub fn allow_missing_targets(mut self, allow: bool) -> Self {
151 self.allow_missing_targets = allow;
152 self
153 }
154}
155
156pub(crate) async fn extract<A: Archive>(
158 mut members: Members<A>,
159 destination: &Path,
160 policy: ExtractPolicy,
161) -> Result<(), ExtractError<A::Error>> {
162 let mut root = ExtractionRoot::<A::Error>::open(destination, policy.allow_overwrites).await?;
163 let mut chunk_buffer = Vec::new();
165 let mut buffered_payload = Vec::new();
167 let result: Result<(), ExtractError<A::Error>> = async {
168 while let Some(member) = members.next().await.map_err(ExtractError::Archive)? {
169 check_member_policy(&member, policy)?;
170 let decoded = decode_member(&member, policy)?;
171 match member {
172 Member::File {
173 size,
174 executable,
175 payload,
176 ..
177 } => {
178 root.extract_file(
179 &decoded.path,
180 size,
181 executable,
182 payload,
183 &mut chunk_buffer,
184 &mut buffered_payload,
185 )
186 .await?;
187 }
188 Member::Directory { .. } => root.extract_directory(&decoded.path).await?,
189 Member::SymbolicLink { .. } => {
190 if policy.link_policy.symlink_policy == SymlinkPolicy::Preserve {
191 root.reserve_symlink(&decoded).await?;
192 }
193 }
194 Member::HardLink { size, payload, .. } => {
195 root.extract_hard_link(&decoded, size, payload, &mut chunk_buffer)
196 .await?;
197 }
198 Member::Special { kind, .. } => {
199 return Err(ExtractError::UnsupportedMember {
200 position: decoded.position,
201 path: decoded.path.to_path_buf(),
202 kind,
203 });
204 }
205 }
206 }
207 Ok(())
208 }
209 .await;
210 root.flush_buffered_files().await?;
212 result?;
213 root.finalize_symlinks(policy.link_policy).await
214}
215
216fn check_member_policy<E, P>(
217 member: &Member<P>,
218 policy: ExtractPolicy,
219) -> Result<(), ExtractError<E>> {
220 let position = member.metadata().position;
221 match member {
222 Member::SymbolicLink { .. } => {
223 let violation = match policy.link_policy.symlink_policy {
224 SymlinkPolicy::Reject => Some(ExtractPolicyViolation::SymbolicLink),
225 #[cfg(not(unix))]
226 SymlinkPolicy::Preserve => {
227 Some(ExtractPolicyViolation::NativeSymlinkCreationUnsupported)
228 }
229 _ => None,
230 };
231 if let Some(violation) = violation {
232 return Err(ExtractError::policy_violation(position, violation));
233 }
234 }
235 Member::HardLink { .. } if !policy.link_policy.allow_hard_links => {
236 return Err(ExtractError::policy_violation(
237 position,
238 ExtractPolicyViolation::HardLink,
239 ));
240 }
241 _ => {}
242 }
243 Ok(())
244}