1use crate::{
4 discover::{DiscoveredAdvisory, DiscoveredContext, DiscoveredVisitor, DistributionContext},
5 model::metadata::Distribution,
6 source::Source,
7};
8use futures::{stream, Stream, StreamExt, TryFutureExt, TryStream, TryStreamExt};
9use std::{fmt::Debug, sync::Arc};
10use tokio::sync::Mutex;
11use url::ParseError;
12use walker_common::progress::{Progress, ProgressBar};
13
14#[derive(Debug, thiserror::Error)]
15pub enum Error<VE, SE>
16where
17 VE: std::fmt::Display + Debug,
18 SE: std::fmt::Display + Debug,
19{
20 #[error("Source error: {0}")]
21 Source(SE),
22 #[error("URL error: {0}")]
23 Url(#[from] ParseError),
24 #[error("Visitor error: {0}")]
25 Visitor(VE),
26}
27
28pub type DistributionFilter = Box<dyn Fn(&DistributionContext) -> bool>;
29
30pub struct Walker<S: Source, P: Progress> {
31 source: S,
32 progress: P,
33 distribution_filter: Option<DistributionFilter>,
34}
35
36impl<S: Source> Walker<S, ()> {
37 pub fn new(source: S) -> Self {
38 Self {
39 source,
40 progress: (),
41 distribution_filter: None,
42 }
43 }
44}
45
46impl<S: Source, P: Progress> Walker<S, P> {
47 pub fn with_progress<U: Progress>(self, progress: U) -> Walker<S, U> {
48 Walker {
49 progress,
50 source: self.source,
51 distribution_filter: self.distribution_filter,
52 }
53 }
54
55 pub fn with_distribution_filter<F>(mut self, distribution_filter: F) -> Self
60 where
61 F: Fn(&DistributionContext) -> bool + 'static,
62 {
63 self.distribution_filter = Some(Box::new(distribution_filter));
64 self
65 }
66
67 fn collect_distributions(&self, distributions: Vec<Distribution>) -> Vec<DistributionContext> {
68 distributions
69 .into_iter()
70 .flat_map(|distribution| {
71 distribution
72 .rolie
73 .into_iter()
74 .flat_map(|rolie| rolie.feeds)
75 .map(|feed| DistributionContext::Feed(feed.url))
76 .chain(
77 distribution
78 .directory_url
79 .map(DistributionContext::Directory),
80 )
81 })
82 .filter(|distribution| {
83 if let Some(filter) = &self.distribution_filter {
84 filter(distribution)
85 } else {
86 true
87 }
88 })
89 .collect()
90 }
91
92 pub async fn walk<V>(self, visitor: V) -> Result<(), Error<V::Error, S::Error>>
93 where
94 V: DiscoveredVisitor,
95 {
96 let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
97
98 let context = visitor
99 .visit_context(&DiscoveredContext {
100 metadata: &metadata,
101 })
102 .await
103 .map_err(Error::Visitor)?;
104
105 let distributions = self.collect_distributions(metadata.distributions);
106 log::info!("processing {} distribution URLs", distributions.len());
107
108 for distribution in distributions {
109 log::info!("Walking directory URL: {:?}", distribution);
110 let index = self
111 .source
112 .load_index(distribution)
113 .await
114 .map_err(Error::Source)?;
115
116 let mut progress = self.progress.start(index.len());
117
118 for advisory in index {
119 log::debug!(" Discovered advisory: {advisory:?}");
120 progress
121 .set_message(
122 advisory
123 .url
124 .path()
125 .rsplit_once('/')
126 .map(|(_, s)| s)
127 .unwrap_or(advisory.url.as_str())
128 .to_string(),
129 )
130 .await;
131 visitor
132 .visit_advisory(&context, advisory)
133 .await
134 .map_err(Error::Visitor)?;
135 progress.tick().await;
136 }
137
138 progress.finish().await;
139 }
140
141 Ok(())
142 }
143
144 pub async fn walk_parallel<V>(
145 self,
146 limit: usize,
147 visitor: V,
148 ) -> Result<(), Error<V::Error, S::Error>>
149 where
150 V: DiscoveredVisitor,
151 {
152 let metadata = self.source.load_metadata().await.map_err(Error::Source)?;
153 let context = visitor
154 .visit_context(&DiscoveredContext {
155 metadata: &metadata,
156 })
157 .await
158 .map_err(Error::Visitor)?;
159
160 let context = Arc::new(context);
161 let visitor = Arc::new(visitor);
162
163 let distributions = self.collect_distributions(metadata.distributions);
164 log::info!("processing {} distribution URLs", distributions.len());
165
166 let advisories: Vec<_> = collect_advisories::<V, S>(&self.source, distributions)
167 .try_collect()
168 .await?;
169
170 let size = advisories.len();
171 log::info!("Discovered {size} advisories");
172
173 let progress = Arc::new(Mutex::new(self.progress.start(size)));
174
175 stream::iter(advisories)
176 .map(Ok)
177 .try_for_each_concurrent(limit, |advisory| {
178 log::debug!("Discovered advisory: {}", advisory.url);
179 let context = context.clone();
180 let visitor = visitor.clone();
181 let progress = progress.clone();
182
183 async move {
184 let result = visitor
185 .visit_advisory(&context, advisory.clone())
186 .map_err(Error::Visitor)
187 .await;
188
189 progress.lock().await.tick().await;
190
191 result
192 }
193 })
194 .await?;
195
196 if let Ok(progress) = Arc::try_unwrap(progress) {
197 let progress = progress.into_inner();
198 progress.finish().await;
199 }
200
201 Ok(())
202 }
203}
204
205#[allow(clippy::needless_lifetimes)] fn collect_sources<'s, V: DiscoveredVisitor, S: Source>(
207 source: &'s S,
208 discover_contexts: Vec<DistributionContext>,
209) -> impl TryStream<Ok = impl Stream<Item = DiscoveredAdvisory>, Error = Error<V::Error, S::Error>> + 's
210{
211 stream::iter(discover_contexts).then(move |discover_context| async move {
212 log::debug!("Walking: {}", discover_context.url());
213 Ok(stream::iter(
214 source
215 .load_index(discover_context.clone())
216 .await
217 .map_err(Error::Source)?,
218 ))
219 })
220}
221
222fn collect_advisories<'s, V: DiscoveredVisitor + 's, S: Source>(
223 source: &'s S,
224 discover_contexts: Vec<DistributionContext>,
225) -> impl TryStream<Ok = DiscoveredAdvisory, Error = Error<V::Error, S::Error>> + 's {
226 collect_sources::<V, S>(source, discover_contexts)
227 .map_ok(|s| s.map(Ok))
228 .try_flatten()
229}