csaf_walker/visitors/
filter.rs

1use crate::discover::{DiscoveredAdvisory, DiscoveredContext, DiscoveredVisitor};
2use std::collections::HashSet;
3
4/// A visitor, skipping advisories for existing files.
5pub struct FilteringVisitor<V: DiscoveredVisitor> {
6    pub visitor: V,
7
8    pub config: FilterConfig,
9}
10
11#[non_exhaustive]
12#[derive(Clone, Default, Debug, PartialEq, Eq)]
13pub struct FilterConfig {
14    /// A set of distributions to ignore
15    ///
16    /// **NOTE:** The distributions will still be discovered, as this is a post-discovery visitor. If you want to
17    /// even skip discovering the source, use [`crate::walker::Walker::with_distribution_filter`].
18    pub ignored_distributions: HashSet<String>,
19    pub ignored_prefixes: Vec<String>,
20    pub only_prefixes: Vec<String>,
21}
22
23impl FilterConfig {
24    pub fn new() -> Self {
25        Self::default()
26    }
27
28    pub fn ignored_distributions<I>(mut self, ignored_distributions: I) -> Self
29    where
30        I: IntoIterator<Item = String>,
31    {
32        self.ignored_distributions = HashSet::from_iter(ignored_distributions);
33        self
34    }
35
36    pub fn add_ignored_distribution(mut self, ignored_distribution: impl Into<String>) -> Self {
37        self.ignored_distributions
38            .insert(ignored_distribution.into());
39        self
40    }
41
42    pub fn extend_ignored_distributions<I>(mut self, ignored_distributions: I) -> Self
43    where
44        I: IntoIterator<Item = String>,
45    {
46        self.ignored_distributions.extend(ignored_distributions);
47        self
48    }
49
50    pub fn ignored_prefixes<I>(mut self, ignored_prefixes: I) -> Self
51    where
52        I: IntoIterator<Item = String>,
53    {
54        self.ignored_prefixes = Vec::from_iter(ignored_prefixes);
55        self
56    }
57
58    pub fn add_ignored_prefix(mut self, ignored_prefix: impl Into<String>) -> Self {
59        self.ignored_prefixes.push(ignored_prefix.into());
60        self
61    }
62
63    pub fn extend_ignored_prefixes<I>(mut self, ignored_prefixes: I) -> Self
64    where
65        I: IntoIterator<Item = String>,
66    {
67        self.ignored_prefixes.extend(ignored_prefixes);
68        self
69    }
70
71    pub fn only_prefixes<I>(mut self, only_prefixes: I) -> Self
72    where
73        I: IntoIterator<Item = String>,
74    {
75        self.only_prefixes = Vec::from_iter(only_prefixes);
76        self
77    }
78
79    pub fn add_only_prefix(mut self, only_prefix: impl Into<String>) -> Self {
80        self.only_prefixes.push(only_prefix.into());
81        self
82    }
83
84    pub fn extend_only_prefixes<I>(mut self, only_prefixes: I) -> Self
85    where
86        I: IntoIterator<Item = String>,
87    {
88        self.only_prefixes.extend(only_prefixes);
89        self
90    }
91}
92
93impl<V: DiscoveredVisitor> DiscoveredVisitor for FilteringVisitor<V> {
94    type Error = V::Error;
95    type Context = V::Context;
96
97    async fn visit_context(
98        &self,
99        discovered: &DiscoveredContext<'_>,
100    ) -> Result<Self::Context, Self::Error> {
101        self.visitor.visit_context(discovered).await
102    }
103
104    async fn visit_advisory(
105        &self,
106        context: &Self::Context,
107        advisory: DiscoveredAdvisory,
108    ) -> Result<(), Self::Error> {
109        // ignore distributions
110
111        if self
112            .config
113            .ignored_distributions
114            .contains(advisory.context.url().as_str())
115        {
116            return Ok(());
117        };
118
119        // eval name
120
121        let name = advisory
122            .url
123            .path_segments()
124            .and_then(|mut seg| seg.next_back())
125            .unwrap_or(advisory.url.path());
126
127        // "ignore" prefix
128
129        for n in &self.config.ignored_prefixes {
130            if name.starts_with(n.as_str()) {
131                return Ok(());
132            }
133        }
134
135        // "only" prefix
136
137        if !self.config.only_prefixes.is_empty()
138            && !self
139                .config
140                .only_prefixes
141                .iter()
142                .any(|n| name.starts_with(n.as_str()))
143        {
144            return Ok(());
145        }
146
147        // ok to proceed
148
149        self.visitor.visit_advisory(context, advisory).await
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use super::*;
156    use crate::discover::DistributionContext;
157    use std::sync::Arc;
158    use std::time::SystemTime;
159    use tokio::sync::Mutex;
160    use url::Url;
161
162    #[derive(Default)]
163    struct MockVisitor {
164        pub items: Arc<Mutex<Vec<DiscoveredAdvisory>>>,
165    }
166
167    impl DiscoveredVisitor for MockVisitor {
168        type Error = anyhow::Error;
169        type Context = ();
170
171        async fn visit_context(
172            &self,
173            _context: &DiscoveredContext<'_>,
174        ) -> Result<Self::Context, Self::Error> {
175            Ok(())
176        }
177
178        async fn visit_advisory(
179            &self,
180            _context: &Self::Context,
181            advisory: DiscoveredAdvisory,
182        ) -> Result<(), Self::Error> {
183            self.items.lock().await.push(advisory);
184            Ok(())
185        }
186    }
187
188    async fn issue<V>(filter: &FilteringVisitor<V>, name: &str) -> Result<(), anyhow::Error>
189    where
190        V: DiscoveredVisitor<Error = anyhow::Error, Context = ()>,
191    {
192        let context = Arc::new(DistributionContext::Directory(Url::parse(
193            "https://localhost",
194        )?));
195        let url = Url::parse(&format!("https://localhost/{name}"))?;
196        let modified = SystemTime::now();
197
198        filter
199            .visit_advisory(
200                &(),
201                DiscoveredAdvisory {
202                    context,
203                    url,
204                    digest: None,
205                    signature: None,
206                    modified,
207                },
208            )
209            .await?;
210
211        Ok(())
212    }
213
214    #[tokio::test]
215    async fn ignored() -> anyhow::Result<()> {
216        let mock = MockVisitor::default();
217        let filter = FilteringVisitor {
218            config: FilterConfig::new()
219                .add_ignored_prefix("foo-")
220                .add_ignored_prefix("bar-"),
221            visitor: mock,
222        };
223
224        issue(&filter, "foo-1").await?;
225        issue(&filter, "foo-2").await?;
226        issue(&filter, "bar-1").await?;
227        issue(&filter, "bar-2").await?;
228        issue(&filter, "baz-1").await?;
229        issue(&filter, "baz-2").await?;
230
231        let items = filter.visitor.items.lock().await.clone();
232        assert_eq!(items.len(), 2);
233
234        Ok(())
235    }
236
237    #[tokio::test]
238    async fn only() -> anyhow::Result<()> {
239        let mock = MockVisitor::default();
240        let filter = FilteringVisitor {
241            config: FilterConfig::new()
242                .add_only_prefix("foo-")
243                .add_only_prefix("bar-"),
244            visitor: mock,
245        };
246
247        issue(&filter, "foo-1").await?;
248        issue(&filter, "foo-2").await?;
249        issue(&filter, "bar-1").await?;
250        issue(&filter, "bar-2").await?;
251        issue(&filter, "baz-1").await?;
252        issue(&filter, "baz-2").await?;
253
254        let items = filter.visitor.items.lock().await.clone();
255        assert_eq!(items.len(), 4);
256
257        Ok(())
258    }
259}