1use crate::{
2 discover::DiscoveredAdvisory,
3 model::{metadata::ProviderMetadata, store::distribution_base},
4 retrieve::{RetrievalContext, RetrievedAdvisory, RetrievedVisitor},
5 source::{HttpSourceError, Source},
6 validation::{ValidatedAdvisory, ValidatedVisitor, ValidationContext, ValidationError},
7};
8use anyhow::Context;
9use sequoia_openpgp::{Cert, armor::Kind, serialize::SerializeInto};
10use std::{
11 any::Any,
12 collections::HashSet,
13 fmt::Debug,
14 io::{ErrorKind, Write},
15 path::{Path, PathBuf},
16 rc::Rc,
17};
18use tokio::fs;
19use walker_common::{
20 fetcher,
21 retrieve::RetrievalError,
22 store::{Document, ErrorData, StoreError, store_document, store_errors},
23 utils::openpgp::PublicKey,
24};
25
26pub const DIR_METADATA: &str = "metadata";
27
28#[non_exhaustive]
30pub struct StoreVisitor {
31 pub base: PathBuf,
33
34 pub no_timestamps: bool,
36
37 pub no_xattrs: bool,
39
40 pub allowed_client_errors: HashSet<reqwest::StatusCode>,
42}
43
44impl StoreVisitor {
45 pub fn new(base: impl Into<PathBuf>) -> Self {
46 Self {
47 base: base.into(),
48 no_timestamps: false,
49 no_xattrs: false,
50 allowed_client_errors: Default::default(),
51 }
52 }
53
54 pub fn no_timestamps(mut self, no_timestamps: bool) -> Self {
55 self.no_timestamps = no_timestamps;
56 self
57 }
58
59 pub fn no_xattrs(mut self, no_xattrs: bool) -> Self {
60 self.no_xattrs = no_xattrs;
61 self
62 }
63
64 pub fn allow_client_errors(
65 mut self,
66 allowed_client_errors: HashSet<reqwest::StatusCode>,
67 ) -> Self {
68 self.allowed_client_errors = allowed_client_errors;
69 self
70 }
71
72 pub fn allow_client_errors_iter(
75 self,
76 allowed_client_errors: impl IntoIterator<Item = reqwest::StatusCode>,
77 ) -> Self {
78 self.allow_client_errors(allowed_client_errors.into_iter().collect())
79 }
80}
81
82#[derive(Debug, thiserror::Error)]
83#[allow(clippy::large_enum_variant)]
84pub enum StoreRetrievedError<S: Source> {
85 #[error(transparent)]
86 Store(#[from] StoreError),
87 #[error(transparent)]
88 Retrieval(#[from] RetrievalError<DiscoveredAdvisory, S>),
89}
90
91#[allow(clippy::large_enum_variant)]
92#[derive(Debug, thiserror::Error)]
93pub enum StoreValidatedError<S: Source> {
94 #[error(transparent)]
95 Store(#[from] StoreError),
96 #[error(transparent)]
97 Validation(#[from] ValidationError<S>),
98}
99
100impl<S: Source + Debug> RetrievedVisitor<S> for StoreVisitor
101where
102 S::Error: 'static,
103{
104 type Error = StoreRetrievedError<S>;
105 type Context = Rc<ProviderMetadata>;
106
107 async fn visit_context(
108 &self,
109 context: &RetrievalContext<'_>,
110 ) -> Result<Self::Context, Self::Error> {
111 self.store_provider_metadata(context.metadata).await?;
112 self.prepare_distributions(context.metadata).await?;
113 self.store_keys(context.keys).await?;
114
115 Ok(Rc::new(context.metadata.clone()))
116 }
117
118 async fn visit_advisory(
121 &self,
122 _context: &Self::Context,
123 result: Result<RetrievedAdvisory, RetrievalError<DiscoveredAdvisory, S>>,
124 ) -> Result<(), Self::Error> {
125 match result {
126 Ok(advisory) => {
127 self.store_advisory(&advisory).await?;
128 Ok(())
129 }
130 Err(err) => {
131 match Self::get_client_error_status_code(&err) {
132 Some(status) if self.allowed_client_errors.contains(&status) => {
133 self.store_error(status, err.discovered()).await?;
134 }
135 _ => return Err(StoreRetrievedError::Retrieval(err)),
136 }
137 Ok(())
138 }
139 }
140 }
141}
142
143impl<S: Source> ValidatedVisitor<S> for StoreVisitor {
144 type Error = StoreValidatedError<S>;
145 type Context = ();
146
147 async fn visit_context(
148 &self,
149 context: &ValidationContext<'_>,
150 ) -> Result<Self::Context, Self::Error> {
151 self.store_provider_metadata(context.metadata).await?;
152 self.prepare_distributions(context.metadata).await?;
153 self.store_keys(context.retrieval.keys).await?;
154 Ok(())
155 }
156
157 async fn visit_advisory(
158 &self,
159 _context: &Self::Context,
160 result: Result<ValidatedAdvisory, ValidationError<S>>,
161 ) -> Result<(), Self::Error> {
162 self.store_advisory(&result?.retrieved).await?;
163 Ok(())
164 }
165}
166
167impl StoreVisitor {
168 async fn prepare_distributions(&self, metadata: &ProviderMetadata) -> Result<(), StoreError> {
169 for dist in &metadata.distributions {
170 if let Some(directory_url) = &dist.directory_url {
171 let base = distribution_base(&self.base, directory_url.as_str());
172 log::debug!("Creating base distribution directory: {}", base.display());
173
174 fs::create_dir_all(&base)
175 .await
176 .with_context(|| {
177 format!(
178 "Unable to create distribution directory: {}",
179 base.display()
180 )
181 })
182 .map_err(StoreError::Io)?;
183 }
184 if let Some(rolie) = &dist.rolie {
185 for feed in &rolie.feeds {
186 let base = distribution_base(&self.base, feed.url.as_str());
187 fs::create_dir_all(&base)
188 .await
189 .with_context(|| {
190 format!(
191 "Unable to create distribution directory: {}",
192 base.display()
193 )
194 })
195 .map_err(StoreError::Io)?;
196 }
197 }
198 }
199
200 Ok(())
201 }
202
203 async fn store_provider_metadata(&self, metadata: &ProviderMetadata) -> Result<(), StoreError> {
204 let metadir = self.base.join(DIR_METADATA);
205
206 fs::create_dir(&metadir)
207 .await
208 .or_else(|err| match err.kind() {
209 ErrorKind::AlreadyExists => Ok(()),
210 _ => Err(err),
211 })
212 .with_context(|| format!("Failed to create metadata directory: {}", metadir.display()))
213 .map_err(StoreError::Io)?;
214
215 let file = metadir.join("provider-metadata.json");
216 let mut out = std::fs::File::create(&file)
217 .with_context(|| {
218 format!(
219 "Unable to open provider metadata file for writing: {}",
220 file.display()
221 )
222 })
223 .map_err(StoreError::Io)?;
224 serde_json::to_writer_pretty(&mut out, metadata)
225 .context("Failed serializing provider metadata")
226 .map_err(StoreError::Io)?;
227 Ok(())
228 }
229
230 async fn store_keys(&self, keys: &[PublicKey]) -> Result<(), StoreError> {
231 let metadata = self.base.join(DIR_METADATA).join("keys");
232 std::fs::create_dir(&metadata)
233 .or_else(|err| match err.kind() {
235 ErrorKind::AlreadyExists => Ok(()),
236 _ => Err(err),
237 })
238 .with_context(|| {
239 format!(
240 "Failed to create metadata directory: {}",
241 metadata.display()
242 )
243 })
244 .map_err(StoreError::Io)?;
245
246 for cert in keys.iter().flat_map(|k| &k.certs) {
247 log::info!("Storing key: {}", cert.fingerprint());
248 self.store_cert(cert, &metadata).await?;
249 }
250
251 Ok(())
252 }
253
254 async fn store_cert(&self, cert: &Cert, path: &Path) -> Result<(), StoreError> {
255 let name = path.join(format!("{}.txt", cert.fingerprint().to_hex()));
256
257 let data = Self::serialize_key(cert).map_err(StoreError::SerializeKey)?;
258
259 fs::write(&name, data)
260 .await
261 .with_context(|| format!("Failed to store key: {}", name.display()))
262 .map_err(StoreError::Io)?;
263 Ok(())
264 }
265
266 fn serialize_key(cert: &Cert) -> Result<Vec<u8>, anyhow::Error> {
267 let mut writer = sequoia_openpgp::armor::Writer::new(Vec::new(), Kind::PublicKey)?;
268 writer.write_all(&cert.to_vec()?)?;
269 Ok(writer.finalize()?)
270 }
271
272 async fn store_advisory(&self, advisory: &RetrievedAdvisory) -> Result<(), StoreError> {
273 log::info!(
274 "Storing: {} (modified: {:?})",
275 advisory.url,
276 advisory.metadata.last_modification
277 );
278
279 let relative_url_result = advisory.context.url().make_relative(&advisory.url);
280 let name = match &relative_url_result {
281 Some(name) => name,
282 None => return Err(StoreError::Filename(advisory.url.to_string())),
283 };
284
285 let distribution_base = distribution_base(&self.base, advisory.context.url().as_str());
287
288 let file = distribution_base.join(name);
290
291 store_document(
292 &file,
293 Document {
294 data: &advisory.data,
295 changed: advisory.modified,
296 metadata: &advisory.metadata,
297 sha256: &advisory.sha256,
298 sha512: &advisory.sha512,
299 signature: &advisory.signature,
300 no_timestamps: self.no_timestamps,
301 no_xattrs: self.no_xattrs,
302 },
303 )
304 .await?;
305
306 Ok(())
307 }
308
309 fn get_client_error_status_code<S: Source + Debug>(
310 err: &RetrievalError<DiscoveredAdvisory, S>,
311 ) -> Option<reqwest::StatusCode>
312 where
313 S::Error: 'static,
314 {
315 let source_error = match err {
317 RetrievalError::Source { err, .. } => err,
318 };
319
320 if let Some(http_error) = (source_error as &dyn Any).downcast_ref::<HttpSourceError>()
321 && let HttpSourceError::Fetcher(fetcher::Error::ClientError(status)) = http_error
322 {
323 return Some(*status);
324 }
325
326 None
327 }
328
329 async fn store_error(
330 &self,
331 status_code: reqwest::StatusCode,
332 discovered: &DiscoveredAdvisory,
333 ) -> Result<(), StoreError> {
334 log::warn!("Storing retrieval error for: {}", discovered.url);
335
336 let relative_url_result = discovered.context.url().make_relative(&discovered.url);
337 let name = match &relative_url_result {
338 Some(name) => name,
339 None => return Err(StoreError::Filename(discovered.url.to_string())),
340 };
341
342 let distribution_base = distribution_base(&self.base, discovered.context.url().as_str());
343 let file = distribution_base.join(name);
344
345 store_errors(
346 &file,
347 ErrorData {
348 status_code: status_code.as_u16(),
349 },
350 )
351 .await?;
352
353 Ok(())
354 }
355}