csaf_walker/visitors/
filter.rs

1use crate::discover::{DiscoveredAdvisory, DiscoveredContext, DiscoveredVisitor};
2use std::collections::HashSet;
3
4/// A visitor, skipping advisories for existing files.
5pub struct FilteringVisitor<V: DiscoveredVisitor> {
6    pub visitor: V,
7
8    pub config: FilterConfig,
9}
10
11#[non_exhaustive]
12#[derive(Clone, Default, Debug, PartialEq, Eq)]
13pub struct FilterConfig {
14    /// A set of distributions to ignore
15    ///
16    /// **NOTE:** The distributions will still be discovered, as this is a post-discovery visitor. If you want to
17    /// even skip discovering the source, use [`crate::walker::Walker::with_distribution_filter`].
18    pub ignored_distributions: HashSet<String>,
19    pub ignored_prefixes: Vec<String>,
20    pub only_prefixes: Vec<String>,
21}
22
23impl FilterConfig {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn ignored_distributions<I>(mut self, ignored_distributions: I) -> Self
29    where
30        I: IntoIterator<Item = String>,
31    {
32        self.ignored_distributions = HashSet::from_iter(ignored_distributions);
33        self
34    }
35
36    pub fn add_ignored_distribution(mut self, ignored_distribution: impl Into<String>) -> Self {
37        self.ignored_distributions
38            .insert(ignored_distribution.into());
39        self
40    }
41
42    pub fn extend_ignored_distributions<I>(mut self, ignored_distributions: I) -> Self
43    where
44        I: IntoIterator<Item = String>,
45    {
46        self.ignored_distributions.extend(ignored_distributions);
47        self
48    }
49
50    pub fn ignored_prefixes<I>(mut self, ignored_prefixes: I) -> Self
51    where
52        I: IntoIterator<Item = String>,
53    {
54        self.ignored_prefixes = Vec::from_iter(ignored_prefixes);
55        self
56    }
57
58    pub fn add_ignored_prefix(mut self, ignored_prefix: impl Into<String>) -> Self {
59        self.ignored_prefixes.push(ignored_prefix.into());
60        self
61    }
62
63    pub fn extend_ignored_prefixes<I>(mut self, ignored_prefixes: I) -> Self
64    where
65        I: IntoIterator<Item = String>,
66    {
67        self.ignored_prefixes.extend(ignored_prefixes);
68        self
69    }
70
71    pub fn only_prefixes<I>(mut self, only_prefixes: I) -> Self
72    where
73        I: IntoIterator<Item = String>,
74    {
75        self.only_prefixes = Vec::from_iter(only_prefixes);
76        self
77    }
78
79    pub fn add_only_prefix(mut self, only_prefix: impl Into<String>) -> Self {
80        self.only_prefixes.push(only_prefix.into());
81        self
82    }
83
84    pub fn extend_only_prefixes<I>(mut self, only_prefixes: I) -> Self
85    where
86        I: IntoIterator<Item = String>,
87    {
88        self.only_prefixes.extend(only_prefixes);
89        self
90    }
91}
92
93impl<V: DiscoveredVisitor> DiscoveredVisitor for FilteringVisitor<V> {
94    type Error = V::Error;
95    type Context = V::Context;
96
97    async fn visit_context(
98        &self,
99        discovered: &DiscoveredContext<'_>,
100    ) -> Result<Self::Context, Self::Error> {
101        self.visitor.visit_context(discovered).await
102    }
103
104    async fn visit_advisory(
105        &self,
106        context: &Self::Context,
107        advisory: DiscoveredAdvisory,
108    ) -> Result<(), Self::Error> {
109        // ignore distributions
110
111        if self
112            .config
113            .ignored_distributions
114            .contains(advisory.context.url().as_str())
115        {
116            return Ok(());
117        };
118
119        // eval name
120
121        let name = advisory
122            .url
123            .path_segments()
124            .and_then(|seg| seg.last())
125            .unwrap_or(advisory.url.path());
126
127        // "ignore" prefix
128
129        for n in &self.config.ignored_prefixes {
130            if name.starts_with(n.as_str()) {
131                return Ok(());
132            }
133        }
134
135        // "only" prefix
136
137        if !self.config.only_prefixes.is_empty()
138            && !self
139                .config
140                .only_prefixes
141                .iter()
142                .any(|n| name.starts_with(n.as_str()))
143        {
144            return Ok(());
145        }
146
147        // ok to proceed
148
149        self.visitor.visit_advisory(context, advisory).await
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use super::*;
156    use crate::discover::DistributionContext;
157    use std::sync::Arc;
158    use std::time::SystemTime;
159    use tokio::sync::Mutex;
160    use url::Url;
161
162    #[derive(Default)]
163    struct MockVisitor {
164        pub items: Arc<Mutex<Vec<DiscoveredAdvisory>>>,
165    }
166
167    impl DiscoveredVisitor for MockVisitor {
168        type Error = anyhow::Error;
169        type Context = ();
170
171        async fn visit_context(
172            &self,
173            _context: &DiscoveredContext<'_>,
174        ) -> Result<Self::Context, Self::Error> {
175            Ok(())
176        }
177
178        async fn visit_advisory(
179            &self,
180            _context: &Self::Context,
181            advisory: DiscoveredAdvisory,
182        ) -> Result<(), Self::Error> {
183            self.items.lock().await.push(advisory);
184            Ok(())
185        }
186    }
187
188    async fn issue<V>(filter: &FilteringVisitor<V>, name: &str) -> Result<(), anyhow::Error>
189    where
190        V: DiscoveredVisitor<Error = anyhow::Error, Context = ()>,
191    {
192        let context = Arc::new(DistributionContext::Directory(Url::parse(
193            "https://localhost",
194        )?));
195        let url = Url::parse(&format!("https://localhost/{name}"))?;
196        let modified = SystemTime::now();
197
198        filter
199            .visit_advisory(
200                &(),
201                DiscoveredAdvisory {
202                    context,
203                    url,
204                    modified,
205                },
206            )
207            .await?;
208
209        Ok(())
210    }
211
212    #[tokio::test]
213    async fn ignored() -> anyhow::Result<()> {
214        let mock = MockVisitor::default();
215        let filter = FilteringVisitor {
216            config: FilterConfig::new()
217                .add_ignored_prefix("foo-")
218                .add_ignored_prefix("bar-"),
219            visitor: mock,
220        };
221
222        issue(&filter, "foo-1").await?;
223        issue(&filter, "foo-2").await?;
224        issue(&filter, "bar-1").await?;
225        issue(&filter, "bar-2").await?;
226        issue(&filter, "baz-1").await?;
227        issue(&filter, "baz-2").await?;
228
229        let items = filter.visitor.items.lock().await.clone();
230        assert_eq!(items.len(), 2);
231
232        Ok(())
233    }
234
235    #[tokio::test]
236    async fn only() -> anyhow::Result<()> {
237        let mock = MockVisitor::default();
238        let filter = FilteringVisitor {
239            config: FilterConfig::new()
240                .add_only_prefix("foo-")
241                .add_only_prefix("bar-"),
242            visitor: mock,
243        };
244
245        issue(&filter, "foo-1").await?;
246        issue(&filter, "foo-2").await?;
247        issue(&filter, "bar-1").await?;
248        issue(&filter, "bar-2").await?;
249        issue(&filter, "baz-1").await?;
250        issue(&filter, "baz-2").await?;
251
252        let items = filter.visitor.items.lock().await.clone();
253        assert_eq!(items.len(), 4);
254
255        Ok(())
256    }
257}