csaf_walker/
walker.rs

1//! The actual walker
2
3use crate::{
4    discover::{DiscoveredAdvisory, DiscoveredContext, DiscoveredVisitor, DistributionContext},
5    model::metadata::Distribution,
6    source::Source,
7};
8use futures::{Stream, StreamExt, TryFutureExt, TryStream, TryStreamExt, stream};
9use std::{fmt::Debug, sync::Arc};
10use tokio::sync::Mutex;
11use url::ParseError;
12use walker_common::progress::{Progress, ProgressBar};
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error<VE, SE>
16where
17    VE: std::fmt::Display + Debug,
18    SE: std::fmt::Display + Debug,
19{
20    #[error("Source error: {0}")]
21    Source(SE),
22    #[error("URL error: {0}")]
23    Url(#[from] ParseError),
24    #[error("Visitor error: {0}")]
25    Visitor(VE),
26}
27
28pub type DistributionFilter = Box<dyn Fn(&DistributionContext) -> bool>;
29
30pub struct Walker<S: Source, P: Progress> {
31    source: S,
32    progress: P,
33    distribution_filter: Option<DistributionFilter>,
34}
35
36impl<S: Source> Walker<S, ()> {
37    pub fn new(source: S) -> Self {
38        Self {
39            source,
40            progress: (),
41            distribution_filter: None,
42        }
43    }
44}
45
46impl<S: Source, P: Progress> Walker<S, P> {
47    pub fn with_progress<U: Progress>(self, progress: U) -> Walker<S, U> {
48        Walker {
49            progress,
50            source: self.source,
51            distribution_filter: self.distribution_filter,
52        }
53    }
54
55    /// Set a filter for distributions.
56    ///
57    /// Each distribution from the metadata file will be passed to this function, if it returns `false`, the distribution
58    /// will not even be fetched.
59    pub fn with_distribution_filter<F>(mut self, distribution_filter: F) -> Self
60    where
61        F: Fn(&DistributionContext) -> bool + 'static,
62    {
63        self.distribution_filter = Some(Box::new(distribution_filter));
64        self
65    }
66
67    fn collect_distributions(&self, distributions: Vec<Distribution>) -> Vec<DistributionContext> {
68        distributions
69            .into_iter()
70            .flat_map(|distribution| {
71                distribution
72                    .rolie
73                    .into_iter()
74                    .flat_map(|rolie| rolie.feeds)
75                    .map(|feed| DistributionContext::Feed(feed.url))
76                    .chain(
77                        distribution
78                            .directory_url
79                            .map(DistributionContext::Directory),
80                    )
81            })
82            .filter(|distribution| {
83                if let Some(filter) = &self.distribution_filter {
84                    filter(distribution)
85                } else {
86                    true
87                }
88            })
89            .collect()
90    }
91
92    pub async fn walk<V>(self, visitor: V) -> Result<(), Error<V::Error, S::Error>>
93    where
94        V: DiscoveredVisitor,
95    {
96        let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
97
98        let context = visitor
99            .visit_context(&DiscoveredContext {
100                metadata: &metadata,
101            })
102            .await
103            .map_err(Error::Visitor)?;
104
105        let distributions = self.collect_distributions(metadata.distributions);
106        log::info!("processing {} distribution URLs", distributions.len());
107
108        for distribution in distributions {
109            log::info!("Walking directory URL: {distribution:?}");
110            let index = self
111                .source
112                .load_index(distribution)
113                .await
114                .map_err(Error::Source)?;
115
116            let mut progress = self.progress.start(index.len());
117
118            for advisory in index {
119                log::debug!("  Discovered advisory: {advisory:?}");
120                progress
121                    .set_message(
122                        advisory
123                            .url
124                            .path()
125                            .rsplit_once('/')
126                            .map(|(_, s)| s)
127                            .unwrap_or(advisory.url.as_str())
128                            .to_string(),
129                    )
130                    .await;
131                visitor
132                    .visit_advisory(&context, advisory)
133                    .await
134                    .map_err(Error::Visitor)?;
135                progress.tick().await;
136            }
137
138            progress.finish().await;
139        }
140
141        Ok(())
142    }
143
144    pub async fn walk_parallel<V>(
145        self,
146        limit: usize,
147        visitor: V,
148    ) -> Result<(), Error<V::Error, S::Error>>
149    where
150        V: DiscoveredVisitor,
151    {
152        let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
153        let context = visitor
154            .visit_context(&DiscoveredContext {
155                metadata: &metadata,
156            })
157            .await
158            .map_err(Error::Visitor)?;
159
160        let context = Arc::new(context);
161        let visitor = Arc::new(visitor);
162
163        let distributions = self.collect_distributions(metadata.distributions);
164        log::info!("processing {} distribution URLs", distributions.len());
165
166        let advisories: Vec<_> = collect_advisories::<V, S>(&self.source, distributions)
167            .try_collect()
168            .await?;
169
170        let size = advisories.len();
171        log::info!("Discovered {size} advisories");
172
173        let progress = Arc::new(Mutex::new(self.progress.start(size)));
174
175        stream::iter(advisories)
176            .map(Ok)
177            .try_for_each_concurrent(limit, async |advisory| {
178                log::debug!("Discovered advisory: {}", advisory.url);
179
180                let result = visitor
181                    .visit_advisory(&context, advisory.clone())
182                    .map_err(Error::Visitor)
183                    .await;
184
185                progress.lock().await.tick().await;
186
187                result
188            })
189            .await?;
190
191        if let Ok(progress) = Arc::try_unwrap(progress) {
192            let progress = progress.into_inner();
193            progress.finish().await;
194        }
195
196        Ok(())
197    }
198}
199
200#[allow(clippy::needless_lifetimes)] // false positive
201fn collect_sources<'s, V: DiscoveredVisitor, S: Source>(
202    source: &'s S,
203    discover_contexts: Vec<DistributionContext>,
204) -> impl TryStream<Ok = impl Stream<Item = DiscoveredAdvisory>, Error = Error<V::Error, S::Error>> + 's
205{
206    stream::iter(discover_contexts).then(async |discover_context| {
207        log::debug!("Walking: {}", discover_context.url());
208        Ok(stream::iter(
209            source
210                .load_index(discover_context.clone())
211                .await
212                .map_err(Error::Source)?,
213        ))
214    })
215}
216
217fn collect_advisories<'s, V: DiscoveredVisitor + 's, S: Source>(
218    source: &'s S,
219    discover_contexts: Vec<DistributionContext>,
220) -> impl TryStream<Ok = DiscoveredAdvisory, Error = Error<V::Error, S::Error>> + 's {
221    collect_sources::<V, S>(source, discover_contexts)
222        .map_ok(|s| s.map(Ok))
223        .try_flatten()
224}