libp2p_gossipsub/
subscription_filter.rs

1// Copyright 2020 Sigma Prime Pty Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::types::GossipsubSubscription;
22use crate::TopicHash;
23use log::info;
24use std::collections::{BTreeSet, HashMap, HashSet};
25
26pub trait TopicSubscriptionFilter {
27    /// Returns true iff the topic is of interest and we can subscribe to it.
28    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool;
29
30    /// Filters a list of incoming subscriptions and returns a filtered set
31    /// By default this deduplicates the subscriptions and calls
32    /// [`Self::filter_incoming_subscription_set`] on the filtered set.
33    fn filter_incoming_subscriptions<'a>(
34        &mut self,
35        subscriptions: &'a [GossipsubSubscription],
36        currently_subscribed_topics: &BTreeSet<TopicHash>,
37    ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
38        let mut filtered_subscriptions: HashMap<TopicHash, &GossipsubSubscription> = HashMap::new();
39        for subscription in subscriptions {
40            use std::collections::hash_map::Entry::*;
41            match filtered_subscriptions.entry(subscription.topic_hash.clone()) {
42                Occupied(entry) => {
43                    if entry.get().action != subscription.action {
44                        entry.remove();
45                    }
46                }
47                Vacant(entry) => {
48                    entry.insert(subscription);
49                }
50            }
51        }
52        self.filter_incoming_subscription_set(
53            filtered_subscriptions.into_iter().map(|(_, v)| v).collect(),
54            currently_subscribed_topics,
55        )
56    }
57
58    /// Filters a set of deduplicated subscriptions
59    /// By default this filters the elements based on [`Self::allow_incoming_subscription`].
60    fn filter_incoming_subscription_set<'a>(
61        &mut self,
62        mut subscriptions: HashSet<&'a GossipsubSubscription>,
63        _currently_subscribed_topics: &BTreeSet<TopicHash>,
64    ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
65        subscriptions.retain(|s| {
66            if self.allow_incoming_subscription(s) {
67                true
68            } else {
69                info!("Filtered incoming subscription {:?}", s);
70                false
71            }
72        });
73        Ok(subscriptions)
74    }
75
76    /// Returns true iff we allow an incoming subscription.
77    /// This is used by the default implementation of filter_incoming_subscription_set to decide
78    /// whether to filter out a subscription or not.
79    /// By default this uses can_subscribe to decide the same for incoming subscriptions as for
80    /// outgoing ones.
81    fn allow_incoming_subscription(&mut self, subscription: &GossipsubSubscription) -> bool {
82        self.can_subscribe(&subscription.topic_hash)
83    }
84}
85
86//some useful implementers
87
88/// Allows all subscriptions
89#[derive(Default, Clone)]
90pub struct AllowAllSubscriptionFilter {}
91
92impl TopicSubscriptionFilter for AllowAllSubscriptionFilter {
93    fn can_subscribe(&mut self, _: &TopicHash) -> bool {
94        true
95    }
96}
97
98/// Allows only whitelisted subscriptions
99#[derive(Default, Clone)]
100pub struct WhitelistSubscriptionFilter(pub HashSet<TopicHash>);
101
102impl TopicSubscriptionFilter for WhitelistSubscriptionFilter {
103    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
104        self.0.contains(topic_hash)
105    }
106}
107
108/// Adds a max count to a given subscription filter
109pub struct MaxCountSubscriptionFilter<T: TopicSubscriptionFilter> {
110    pub filter: T,
111    pub max_subscribed_topics: usize,
112    pub max_subscriptions_per_request: usize,
113}
114
115impl<T: TopicSubscriptionFilter> TopicSubscriptionFilter for MaxCountSubscriptionFilter<T> {
116    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
117        self.filter.can_subscribe(topic_hash)
118    }
119
120    fn filter_incoming_subscriptions<'a>(
121        &mut self,
122        subscriptions: &'a [GossipsubSubscription],
123        currently_subscribed_topics: &BTreeSet<TopicHash>,
124    ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
125        if subscriptions.len() > self.max_subscriptions_per_request {
126            return Err("too many subscriptions per request".into());
127        }
128        let result = self
129            .filter
130            .filter_incoming_subscriptions(subscriptions, currently_subscribed_topics)?;
131
132        use crate::types::GossipsubSubscriptionAction::*;
133
134        let mut unsubscribed = 0;
135        let mut new_subscribed = 0;
136        for s in &result {
137            let currently_contained = currently_subscribed_topics.contains(&s.topic_hash);
138            match s.action {
139                Unsubscribe => {
140                    if currently_contained {
141                        unsubscribed += 1;
142                    }
143                }
144                Subscribe => {
145                    if !currently_contained {
146                        new_subscribed += 1;
147                    }
148                }
149            }
150        }
151
152        if new_subscribed + currently_subscribed_topics.len()
153            > self.max_subscribed_topics + unsubscribed
154        {
155            return Err("too many subscribed topics".into());
156        }
157
158        Ok(result)
159    }
160}
161
162/// Combines two subscription filters
163pub struct CombinedSubscriptionFilters<T: TopicSubscriptionFilter, S: TopicSubscriptionFilter> {
164    pub filter1: T,
165    pub filter2: S,
166}
167
168impl<T, S> TopicSubscriptionFilter for CombinedSubscriptionFilters<T, S>
169where
170    T: TopicSubscriptionFilter,
171    S: TopicSubscriptionFilter,
172{
173    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
174        self.filter1.can_subscribe(topic_hash) && self.filter2.can_subscribe(topic_hash)
175    }
176
177    fn filter_incoming_subscription_set<'a>(
178        &mut self,
179        subscriptions: HashSet<&'a GossipsubSubscription>,
180        currently_subscribed_topics: &BTreeSet<TopicHash>,
181    ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
182        let intermediate = self
183            .filter1
184            .filter_incoming_subscription_set(subscriptions, currently_subscribed_topics)?;
185        self.filter2
186            .filter_incoming_subscription_set(intermediate, currently_subscribed_topics)
187    }
188}
189
190pub struct CallbackSubscriptionFilter<T>(pub T)
191where
192    T: FnMut(&TopicHash) -> bool;
193
194impl<T> TopicSubscriptionFilter for CallbackSubscriptionFilter<T>
195where
196    T: FnMut(&TopicHash) -> bool,
197{
198    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
199        (self.0)(topic_hash)
200    }
201}
202
203#[cfg(feature = "regex-filter")]
204pub mod regex {
205    use super::TopicSubscriptionFilter;
206    use crate::TopicHash;
207    use regex::Regex;
208
209    ///A subscription filter that filters topics based on a regular expression.
210    pub struct RegexSubscriptionFilter(pub Regex);
211
212    impl TopicSubscriptionFilter for RegexSubscriptionFilter {
213        fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
214            self.0.is_match(topic_hash.as_str())
215        }
216    }
217
218    #[cfg(test)]
219    mod test {
220        use super::*;
221        use crate::types::GossipsubSubscription;
222        use crate::types::GossipsubSubscriptionAction::*;
223
224        #[test]
225        fn test_regex_subscription_filter() {
226            let t1 = TopicHash::from_raw("tt");
227            let t2 = TopicHash::from_raw("et3t3te");
228            let t3 = TopicHash::from_raw("abcdefghijklmnopqrsuvwxyz");
229
230            let mut filter = RegexSubscriptionFilter(Regex::new("t.*t").unwrap());
231
232            let old = Default::default();
233            let subscriptions = vec![
234                GossipsubSubscription {
235                    action: Subscribe,
236                    topic_hash: t1.clone(),
237                },
238                GossipsubSubscription {
239                    action: Subscribe,
240                    topic_hash: t2.clone(),
241                },
242                GossipsubSubscription {
243                    action: Subscribe,
244                    topic_hash: t3.clone(),
245                },
246            ];
247
248            let result = filter
249                .filter_incoming_subscriptions(&subscriptions, &old)
250                .unwrap();
251            assert_eq!(result, subscriptions[..2].iter().collect());
252        }
253    }
254}
255
256#[cfg(test)]
257mod test {
258    use super::*;
259    use crate::types::GossipsubSubscriptionAction::*;
260    use std::iter::FromIterator;
261
262    #[test]
263    fn test_filter_incoming_allow_all_with_duplicates() {
264        let mut filter = AllowAllSubscriptionFilter {};
265
266        let t1 = TopicHash::from_raw("t1");
267        let t2 = TopicHash::from_raw("t2");
268
269        let old = BTreeSet::from_iter(vec![t1.clone()].into_iter());
270        let subscriptions = vec![
271            GossipsubSubscription {
272                action: Unsubscribe,
273                topic_hash: t1.clone(),
274            },
275            GossipsubSubscription {
276                action: Unsubscribe,
277                topic_hash: t2.clone(),
278            },
279            GossipsubSubscription {
280                action: Subscribe,
281                topic_hash: t2.clone(),
282            },
283            GossipsubSubscription {
284                action: Subscribe,
285                topic_hash: t1.clone(),
286            },
287            GossipsubSubscription {
288                action: Unsubscribe,
289                topic_hash: t1.clone(),
290            },
291        ];
292
293        let result = filter
294            .filter_incoming_subscriptions(&subscriptions, &old)
295            .unwrap();
296        assert_eq!(result, vec![&subscriptions[4]].into_iter().collect());
297    }
298
299    #[test]
300    fn test_filter_incoming_whitelist() {
301        let t1 = TopicHash::from_raw("t1");
302        let t2 = TopicHash::from_raw("t2");
303
304        let mut filter = WhitelistSubscriptionFilter(HashSet::from_iter(vec![t1.clone()]));
305
306        let old = Default::default();
307        let subscriptions = vec![
308            GossipsubSubscription {
309                action: Subscribe,
310                topic_hash: t1.clone(),
311            },
312            GossipsubSubscription {
313                action: Subscribe,
314                topic_hash: t2.clone(),
315            },
316        ];
317
318        let result = filter
319            .filter_incoming_subscriptions(&subscriptions, &old)
320            .unwrap();
321        assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
322    }
323
324    #[test]
325    fn test_filter_incoming_too_many_subscriptions_per_request() {
326        let t1 = TopicHash::from_raw("t1");
327
328        let mut filter = MaxCountSubscriptionFilter {
329            filter: AllowAllSubscriptionFilter {},
330            max_subscribed_topics: 100,
331            max_subscriptions_per_request: 2,
332        };
333
334        let old = Default::default();
335
336        let subscriptions = vec![
337            GossipsubSubscription {
338                action: Subscribe,
339                topic_hash: t1.clone(),
340            },
341            GossipsubSubscription {
342                action: Unsubscribe,
343                topic_hash: t1.clone(),
344            },
345            GossipsubSubscription {
346                action: Subscribe,
347                topic_hash: t1.clone(),
348            },
349        ];
350
351        let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
352        assert_eq!(result, Err("too many subscriptions per request".into()));
353    }
354
355    #[test]
356    fn test_filter_incoming_too_many_subscriptions() {
357        let t: Vec<_> = (0..4)
358            .map(|i| TopicHash::from_raw(format!("t{}", i)))
359            .collect();
360
361        let mut filter = MaxCountSubscriptionFilter {
362            filter: AllowAllSubscriptionFilter {},
363            max_subscribed_topics: 3,
364            max_subscriptions_per_request: 2,
365        };
366
367        let old = t[0..2].iter().cloned().collect();
368
369        let subscriptions = vec![
370            GossipsubSubscription {
371                action: Subscribe,
372                topic_hash: t[2].clone(),
373            },
374            GossipsubSubscription {
375                action: Subscribe,
376                topic_hash: t[3].clone(),
377            },
378        ];
379
380        let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
381        assert_eq!(result, Err("too many subscribed topics".into()));
382    }
383
384    #[test]
385    fn test_filter_incoming_max_subscribed_valid() {
386        let t: Vec<_> = (0..5)
387            .map(|i| TopicHash::from_raw(format!("t{}", i)))
388            .collect();
389
390        let mut filter = MaxCountSubscriptionFilter {
391            filter: WhitelistSubscriptionFilter(t.iter().take(4).cloned().collect()),
392            max_subscribed_topics: 2,
393            max_subscriptions_per_request: 5,
394        };
395
396        let old = t[0..2].iter().cloned().collect();
397
398        let subscriptions = vec![
399            GossipsubSubscription {
400                action: Subscribe,
401                topic_hash: t[4].clone(),
402            },
403            GossipsubSubscription {
404                action: Subscribe,
405                topic_hash: t[2].clone(),
406            },
407            GossipsubSubscription {
408                action: Subscribe,
409                topic_hash: t[3].clone(),
410            },
411            GossipsubSubscription {
412                action: Unsubscribe,
413                topic_hash: t[0].clone(),
414            },
415            GossipsubSubscription {
416                action: Unsubscribe,
417                topic_hash: t[1].clone(),
418            },
419        ];
420
421        let result = filter
422            .filter_incoming_subscriptions(&subscriptions, &old)
423            .unwrap();
424        assert_eq!(result, subscriptions[1..].iter().collect());
425    }
426
427    #[test]
428    fn test_callback_filter() {
429        let t1 = TopicHash::from_raw("t1");
430        let t2 = TopicHash::from_raw("t2");
431
432        let mut filter = CallbackSubscriptionFilter(|h| h.as_str() == "t1");
433
434        let old = Default::default();
435        let subscriptions = vec![
436            GossipsubSubscription {
437                action: Subscribe,
438                topic_hash: t1.clone(),
439            },
440            GossipsubSubscription {
441                action: Subscribe,
442                topic_hash: t2.clone(),
443            },
444        ];
445
446        let result = filter
447            .filter_incoming_subscriptions(&subscriptions, &old)
448            .unwrap();
449        assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
450    }
451}