Skip to main content

csaf_walker/visitors/
store.rs

1use crate::{
2    discover::DiscoveredAdvisory,
3    model::{metadata::ProviderMetadata, store::distribution_base},
4    retrieve::{RetrievalContext, RetrievedAdvisory, RetrievedVisitor},
5    source::{HttpSourceError, Source},
6    validation::{ValidatedAdvisory, ValidatedVisitor, ValidationContext, ValidationError},
7};
8use anyhow::Context;
9use sequoia_openpgp::{Cert, armor::Kind, serialize::SerializeInto};
10use std::{
11    any::Any,
12    collections::HashSet,
13    fmt::Debug,
14    io::{ErrorKind, Write},
15    path::{Path, PathBuf},
16    rc::Rc,
17};
18use tokio::fs;
19use walker_common::{
20    fetcher,
21    retrieve::RetrievalError,
22    store::{Document, ErrorData, StoreError, store_document, store_errors},
23    utils::openpgp::PublicKey,
24};
25
26pub const DIR_METADATA: &str = "metadata";
27
28/// Stores all data so that it can be used as a [`crate::source::Source`] later.
29#[non_exhaustive]
30pub struct StoreVisitor {
31    /// the output base
32    pub base: PathBuf,
33
34    /// whether to set the file modification timestamps
35    pub no_timestamps: bool,
36
37    /// whether to store additional metadata (like the etag) using extended attributes
38    pub no_xattrs: bool,
39
40    /// the clients errors which can be ignored
41    pub allowed_client_errors: HashSet<reqwest::StatusCode>,
42}
43
44impl StoreVisitor {
45    pub fn new(base: impl Into<PathBuf>) -> Self {
46        Self {
47            base: base.into(),
48            no_timestamps: false,
49            no_xattrs: false,
50            allowed_client_errors: Default::default(),
51        }
52    }
53
54    pub fn no_timestamps(mut self, no_timestamps: bool) -> Self {
55        self.no_timestamps = no_timestamps;
56        self
57    }
58
59    pub fn no_xattrs(mut self, no_xattrs: bool) -> Self {
60        self.no_xattrs = no_xattrs;
61        self
62    }
63
64    pub fn allow_client_errors(
65        mut self,
66        allowed_client_errors: HashSet<reqwest::StatusCode>,
67    ) -> Self {
68        self.allowed_client_errors = allowed_client_errors;
69        self
70    }
71
72    /// Similar to [`Self::allow_client_errors`], but accepting any iterable and removing duplicates
73    /// in the process.
74    pub fn allow_client_errors_iter(
75        self,
76        allowed_client_errors: impl IntoIterator<Item = reqwest::StatusCode>,
77    ) -> Self {
78        self.allow_client_errors(allowed_client_errors.into_iter().collect())
79    }
80}
81
82#[derive(Debug, thiserror::Error)]
83#[allow(clippy::large_enum_variant)]
84pub enum StoreRetrievedError<S: Source> {
85    #[error(transparent)]
86    Store(#[from] StoreError),
87    #[error(transparent)]
88    Retrieval(#[from] RetrievalError<DiscoveredAdvisory, S>),
89}
90
91#[allow(clippy::large_enum_variant)]
92#[derive(Debug, thiserror::Error)]
93pub enum StoreValidatedError<S: Source> {
94    #[error(transparent)]
95    Store(#[from] StoreError),
96    #[error(transparent)]
97    Validation(#[from] ValidationError<S>),
98}
99
100impl<S: Source + Debug> RetrievedVisitor<S> for StoreVisitor
101where
102    S::Error: 'static,
103{
104    type Error = StoreRetrievedError<S>;
105    type Context = Rc<ProviderMetadata>;
106
107    async fn visit_context(
108        &self,
109        context: &RetrievalContext<'_>,
110    ) -> Result<Self::Context, Self::Error> {
111        self.store_provider_metadata(context.metadata).await?;
112        self.prepare_distributions(context.metadata).await?;
113        self.store_keys(context.keys).await?;
114
115        Ok(Rc::new(context.metadata.clone()))
116    }
117
118    /// Stores a retrieved advisory or its retrieval error.
119    /// Fails if storing fails.
120    async fn visit_advisory(
121        &self,
122        _context: &Self::Context,
123        result: Result<RetrievedAdvisory, RetrievalError<DiscoveredAdvisory, S>>,
124    ) -> Result<(), Self::Error> {
125        match result {
126            Ok(advisory) => {
127                self.store_advisory(&advisory).await?;
128                Ok(())
129            }
130            Err(err) => {
131                match Self::get_client_error_status_code(&err) {
132                    Some(status) if self.allowed_client_errors.contains(&status) => {
133                        self.store_error(status, err.discovered()).await?;
134                    }
135                    _ => return Err(StoreRetrievedError::Retrieval(err)),
136                }
137                Ok(())
138            }
139        }
140    }
141}
142
143impl<S: Source> ValidatedVisitor<S> for StoreVisitor {
144    type Error = StoreValidatedError<S>;
145    type Context = ();
146
147    async fn visit_context(
148        &self,
149        context: &ValidationContext<'_>,
150    ) -> Result<Self::Context, Self::Error> {
151        self.store_provider_metadata(context.metadata).await?;
152        self.prepare_distributions(context.metadata).await?;
153        self.store_keys(context.retrieval.keys).await?;
154        Ok(())
155    }
156
157    async fn visit_advisory(
158        &self,
159        _context: &Self::Context,
160        result: Result<ValidatedAdvisory, ValidationError<S>>,
161    ) -> Result<(), Self::Error> {
162        self.store_advisory(&result?.retrieved).await?;
163        Ok(())
164    }
165}
166
167impl StoreVisitor {
168    async fn prepare_distributions(&self, metadata: &ProviderMetadata) -> Result<(), StoreError> {
169        for dist in &metadata.distributions {
170            if let Some(directory_url) = &dist.directory_url {
171                let base = distribution_base(&self.base, directory_url.as_str());
172                log::debug!("Creating base distribution directory: {}", base.display());
173
174                fs::create_dir_all(&base)
175                    .await
176                    .with_context(|| {
177                        format!(
178                            "Unable to create distribution directory: {}",
179                            base.display()
180                        )
181                    })
182                    .map_err(StoreError::Io)?;
183            }
184            if let Some(rolie) = &dist.rolie {
185                for feed in &rolie.feeds {
186                    let base = distribution_base(&self.base, feed.url.as_str());
187                    fs::create_dir_all(&base)
188                        .await
189                        .with_context(|| {
190                            format!(
191                                "Unable to create distribution directory: {}",
192                                base.display()
193                            )
194                        })
195                        .map_err(StoreError::Io)?;
196                }
197            }
198        }
199
200        Ok(())
201    }
202
203    async fn store_provider_metadata(&self, metadata: &ProviderMetadata) -> Result<(), StoreError> {
204        let metadir = self.base.join(DIR_METADATA);
205
206        fs::create_dir(&metadir)
207            .await
208            .or_else(|err| match err.kind() {
209                ErrorKind::AlreadyExists => Ok(()),
210                _ => Err(err),
211            })
212            .with_context(|| format!("Failed to create metadata directory: {}", metadir.display()))
213            .map_err(StoreError::Io)?;
214
215        let file = metadir.join("provider-metadata.json");
216        let mut out = std::fs::File::create(&file)
217            .with_context(|| {
218                format!(
219                    "Unable to open provider metadata file for writing: {}",
220                    file.display()
221                )
222            })
223            .map_err(StoreError::Io)?;
224        serde_json::to_writer_pretty(&mut out, metadata)
225            .context("Failed serializing provider metadata")
226            .map_err(StoreError::Io)?;
227        Ok(())
228    }
229
230    async fn store_keys(&self, keys: &[PublicKey]) -> Result<(), StoreError> {
231        let metadata = self.base.join(DIR_METADATA).join("keys");
232        std::fs::create_dir(&metadata)
233            // ignore if the directory already exists
234            .or_else(|err| match err.kind() {
235                ErrorKind::AlreadyExists => Ok(()),
236                _ => Err(err),
237            })
238            .with_context(|| {
239                format!(
240                    "Failed to create metadata directory: {}",
241                    metadata.display()
242                )
243            })
244            .map_err(StoreError::Io)?;
245
246        for cert in keys.iter().flat_map(|k| &k.certs) {
247            log::info!("Storing key: {}", cert.fingerprint());
248            self.store_cert(cert, &metadata).await?;
249        }
250
251        Ok(())
252    }
253
254    async fn store_cert(&self, cert: &Cert, path: &Path) -> Result<(), StoreError> {
255        let name = path.join(format!("{}.txt", cert.fingerprint().to_hex()));
256
257        let data = Self::serialize_key(cert).map_err(StoreError::SerializeKey)?;
258
259        fs::write(&name, data)
260            .await
261            .with_context(|| format!("Failed to store key: {}", name.display()))
262            .map_err(StoreError::Io)?;
263        Ok(())
264    }
265
266    fn serialize_key(cert: &Cert) -> Result<Vec<u8>, anyhow::Error> {
267        let mut writer = sequoia_openpgp::armor::Writer::new(Vec::new(), Kind::PublicKey)?;
268        writer.write_all(&cert.to_vec()?)?;
269        Ok(writer.finalize()?)
270    }
271
272    async fn store_advisory(&self, advisory: &RetrievedAdvisory) -> Result<(), StoreError> {
273        log::info!(
274            "Storing: {} (modified: {:?})",
275            advisory.url,
276            advisory.metadata.last_modification
277        );
278
279        let relative_url_result = advisory.context.url().make_relative(&advisory.url);
280        let name = match &relative_url_result {
281            Some(name) => name,
282            None => return Err(StoreError::Filename(advisory.url.to_string())),
283        };
284
285        // create a distribution base
286        let distribution_base = distribution_base(&self.base, advisory.context.url().as_str());
287
288        // put the file there
289        let file = distribution_base.join(name);
290
291        store_document(
292            &file,
293            Document {
294                data: &advisory.data,
295                changed: advisory.modified,
296                metadata: &advisory.metadata,
297                sha256: &advisory.sha256,
298                sha512: &advisory.sha512,
299                signature: &advisory.signature,
300                no_timestamps: self.no_timestamps,
301                no_xattrs: self.no_xattrs,
302            },
303        )
304        .await?;
305
306        Ok(())
307    }
308
309    fn get_client_error_status_code<S: Source + Debug>(
310        err: &RetrievalError<DiscoveredAdvisory, S>,
311    ) -> Option<reqwest::StatusCode>
312    where
313        S::Error: 'static,
314    {
315        // Get the underlying source error by pattern matching
316        let source_error = match err {
317            RetrievalError::Source { err, .. } => err,
318        };
319
320        if let Some(http_error) = (source_error as &dyn Any).downcast_ref::<HttpSourceError>()
321            && let HttpSourceError::Fetcher(fetcher::Error::ClientError(status)) = http_error
322        {
323            return Some(*status);
324        }
325
326        None
327    }
328
329    async fn store_error(
330        &self,
331        status_code: reqwest::StatusCode,
332        discovered: &DiscoveredAdvisory,
333    ) -> Result<(), StoreError> {
334        log::warn!("Storing retrieval error for: {}", discovered.url);
335
336        let relative_url_result = discovered.context.url().make_relative(&discovered.url);
337        let name = match &relative_url_result {
338            Some(name) => name,
339            None => return Err(StoreError::Filename(discovered.url.to_string())),
340        };
341
342        let distribution_base = distribution_base(&self.base, discovered.context.url().as_str());
343        let file = distribution_base.join(name);
344
345        store_errors(
346            &file,
347            ErrorData {
348                status_code: status_code.as_u16(),
349            },
350        )
351        .await?;
352
353        Ok(())
354    }
355}