1use crate::{
4 discover::{DiscoveredAdvisory, DiscoveredContext, DiscoveredVisitor, DistributionContext},
5 model::metadata::Distribution,
6 source::Source,
7};
8use futures::{Stream, StreamExt, TryFutureExt, TryStream, TryStreamExt, stream};
9use std::{fmt::Debug, sync::Arc};
10use tokio::sync::Mutex;
11use url::ParseError;
12use walker_common::progress::{Progress, ProgressBar};
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error<VE, SE>
16where
17 VE: std::fmt::Display + Debug,
18 SE: std::fmt::Display + Debug,
19{
20 #[error("Source error: {0}")]
21 Source(SE),
22 #[error("URL error: {0}")]
23 Url(#[from] ParseError),
24 #[error("Visitor error: {0}")]
25 Visitor(VE),
26}
27
28pub type DistributionFilter = Box<dyn Fn(&DistributionContext) -> bool>;
29
30pub struct Walker<S: Source, P: Progress> {
31 source: S,
32 progress: P,
33 distribution_filter: Option<DistributionFilter>,
34}
35
36impl<S: Source> Walker<S, ()> {
37 pub fn new(source: S) -> Self {
38 Self {
39 source,
40 progress: (),
41 distribution_filter: None,
42 }
43 }
44}
45
46impl<S: Source, P: Progress> Walker<S, P> {
47 pub fn with_progress<U: Progress>(self, progress: U) -> Walker<S, U> {
48 Walker {
49 progress,
50 source: self.source,
51 distribution_filter: self.distribution_filter,
52 }
53 }
54
55 pub fn with_distribution_filter<F>(mut self, distribution_filter: F) -> Self
60 where
61 F: Fn(&DistributionContext) -> bool + 'static,
62 {
63 self.distribution_filter = Some(Box::new(distribution_filter));
64 self
65 }
66
67 fn collect_distributions(&self, distributions: Vec<Distribution>) -> Vec<DistributionContext> {
68 distributions
69 .into_iter()
70 .flat_map(|distribution| {
71 distribution
72 .rolie
73 .into_iter()
74 .flat_map(|rolie| rolie.feeds)
75 .map(|feed| DistributionContext::Feed(feed.url))
76 .chain(
77 distribution
78 .directory_url
79 .map(DistributionContext::Directory),
80 )
81 })
82 .filter(|distribution| {
83 if let Some(filter) = &self.distribution_filter {
84 filter(distribution)
85 } else {
86 true
87 }
88 })
89 .collect()
90 }
91
92 pub async fn walk<V>(self, visitor: V) -> Result<(), Error<V::Error, S::Error>>
93 where
94 V: DiscoveredVisitor,
95 {
96 let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
97
98 let context = visitor
99 .visit_context(&DiscoveredContext {
100 metadata: &metadata,
101 })
102 .await
103 .map_err(Error::Visitor)?;
104
105 let distributions = self.collect_distributions(metadata.distributions);
106 log::info!("processing {} distribution URLs", distributions.len());
107
108 for distribution in distributions {
109 log::info!("Walking directory URL: {distribution:?}");
110 let index = self
111 .source
112 .load_index(distribution)
113 .await
114 .map_err(Error::Source)?;
115
116 let mut progress = self.progress.start(index.len());
117
118 for advisory in index {
119 log::debug!(" Discovered advisory: {advisory:?}");
120 progress
121 .set_message(
122 advisory
123 .url
124 .path()
125 .rsplit_once('/')
126 .map(|(_, s)| s)
127 .unwrap_or(advisory.url.as_str())
128 .to_string(),
129 )
130 .await;
131 visitor
132 .visit_advisory(&context, advisory)
133 .await
134 .map_err(Error::Visitor)?;
135 progress.tick().await;
136 }
137
138 progress.finish().await;
139 }
140
141 Ok(())
142 }
143
144 pub async fn walk_parallel<V>(
145 self,
146 limit: usize,
147 visitor: V,
148 ) -> Result<(), Error<V::Error, S::Error>>
149 where
150 V: DiscoveredVisitor,
151 {
152 let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
153 let context = visitor
154 .visit_context(&DiscoveredContext {
155 metadata: &metadata,
156 })
157 .await
158 .map_err(Error::Visitor)?;
159
160 let context = Arc::new(context);
161 let visitor = Arc::new(visitor);
162
163 let distributions = self.collect_distributions(metadata.distributions);
164 log::info!("processing {} distribution URLs", distributions.len());
165
166 let advisories: Vec<_> = collect_advisories::<V, S>(&self.source, distributions)
167 .try_collect()
168 .await?;
169
170 let size = advisories.len();
171 log::info!("Discovered {size} advisories");
172
173 let progress = Arc::new(Mutex::new(self.progress.start(size)));
174
175 stream::iter(advisories)
176 .map(Ok)
177 .try_for_each_concurrent(limit, async |advisory| {
178 log::debug!("Discovered advisory: {}", advisory.url);
179
180 let result = visitor
181 .visit_advisory(&context, advisory.clone())
182 .map_err(Error::Visitor)
183 .await;
184
185 progress.lock().await.tick().await;
186
187 result
188 })
189 .await?;
190
191 if let Ok(progress) = Arc::try_unwrap(progress) {
192 let progress = progress.into_inner();
193 progress.finish().await;
194 }
195
196 Ok(())
197 }
198}
199
200#[allow(clippy::needless_lifetimes)] fn collect_sources<'s, V: DiscoveredVisitor, S: Source>(
202 source: &'s S,
203 discover_contexts: Vec<DistributionContext>,
204) -> impl TryStream<Ok = impl Stream<Item = DiscoveredAdvisory>, Error = Error<V::Error, S::Error>> + 's
205{
206 stream::iter(discover_contexts).then(async |discover_context| {
207 log::debug!("Walking: {}", discover_context.url());
208 Ok(stream::iter(
209 source
210 .load_index(discover_context.clone())
211 .await
212 .map_err(Error::Source)?,
213 ))
214 })
215}
216
217fn collect_advisories<'s, V: DiscoveredVisitor + 's, S: Source>(
218 source: &'s S,
219 discover_contexts: Vec<DistributionContext>,
220) -> impl TryStream<Ok = DiscoveredAdvisory, Error = Error<V::Error, S::Error>> + 's {
221 collect_sources::<V, S>(source, discover_contexts)
222 .map_ok(|s| s.map(Ok))
223 .try_flatten()
224}