1use crate::types::GossipsubSubscription;
22use crate::TopicHash;
23use log::info;
24use std::collections::{BTreeSet, HashMap, HashSet};
25
26pub trait TopicSubscriptionFilter {
27 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool;
29
30 fn filter_incoming_subscriptions<'a>(
34 &mut self,
35 subscriptions: &'a [GossipsubSubscription],
36 currently_subscribed_topics: &BTreeSet<TopicHash>,
37 ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
38 let mut filtered_subscriptions: HashMap<TopicHash, &GossipsubSubscription> = HashMap::new();
39 for subscription in subscriptions {
40 use std::collections::hash_map::Entry::*;
41 match filtered_subscriptions.entry(subscription.topic_hash.clone()) {
42 Occupied(entry) => {
43 if entry.get().action != subscription.action {
44 entry.remove();
45 }
46 }
47 Vacant(entry) => {
48 entry.insert(subscription);
49 }
50 }
51 }
52 self.filter_incoming_subscription_set(
53 filtered_subscriptions.into_iter().map(|(_, v)| v).collect(),
54 currently_subscribed_topics,
55 )
56 }
57
58 fn filter_incoming_subscription_set<'a>(
61 &mut self,
62 mut subscriptions: HashSet<&'a GossipsubSubscription>,
63 _currently_subscribed_topics: &BTreeSet<TopicHash>,
64 ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
65 subscriptions.retain(|s| {
66 if self.allow_incoming_subscription(s) {
67 true
68 } else {
69 info!("Filtered incoming subscription {:?}", s);
70 false
71 }
72 });
73 Ok(subscriptions)
74 }
75
76 fn allow_incoming_subscription(&mut self, subscription: &GossipsubSubscription) -> bool {
82 self.can_subscribe(&subscription.topic_hash)
83 }
84}
85
86#[derive(Default, Clone)]
90pub struct AllowAllSubscriptionFilter {}
91
92impl TopicSubscriptionFilter for AllowAllSubscriptionFilter {
93 fn can_subscribe(&mut self, _: &TopicHash) -> bool {
94 true
95 }
96}
97
98#[derive(Default, Clone)]
100pub struct WhitelistSubscriptionFilter(pub HashSet<TopicHash>);
101
102impl TopicSubscriptionFilter for WhitelistSubscriptionFilter {
103 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
104 self.0.contains(topic_hash)
105 }
106}
107
108pub struct MaxCountSubscriptionFilter<T: TopicSubscriptionFilter> {
110 pub filter: T,
111 pub max_subscribed_topics: usize,
112 pub max_subscriptions_per_request: usize,
113}
114
115impl<T: TopicSubscriptionFilter> TopicSubscriptionFilter for MaxCountSubscriptionFilter<T> {
116 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
117 self.filter.can_subscribe(topic_hash)
118 }
119
120 fn filter_incoming_subscriptions<'a>(
121 &mut self,
122 subscriptions: &'a [GossipsubSubscription],
123 currently_subscribed_topics: &BTreeSet<TopicHash>,
124 ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
125 if subscriptions.len() > self.max_subscriptions_per_request {
126 return Err("too many subscriptions per request".into());
127 }
128 let result = self
129 .filter
130 .filter_incoming_subscriptions(subscriptions, currently_subscribed_topics)?;
131
132 use crate::types::GossipsubSubscriptionAction::*;
133
134 let mut unsubscribed = 0;
135 let mut new_subscribed = 0;
136 for s in &result {
137 let currently_contained = currently_subscribed_topics.contains(&s.topic_hash);
138 match s.action {
139 Unsubscribe => {
140 if currently_contained {
141 unsubscribed += 1;
142 }
143 }
144 Subscribe => {
145 if !currently_contained {
146 new_subscribed += 1;
147 }
148 }
149 }
150 }
151
152 if new_subscribed + currently_subscribed_topics.len()
153 > self.max_subscribed_topics + unsubscribed
154 {
155 return Err("too many subscribed topics".into());
156 }
157
158 Ok(result)
159 }
160}
161
162pub struct CombinedSubscriptionFilters<T: TopicSubscriptionFilter, S: TopicSubscriptionFilter> {
164 pub filter1: T,
165 pub filter2: S,
166}
167
168impl<T, S> TopicSubscriptionFilter for CombinedSubscriptionFilters<T, S>
169where
170 T: TopicSubscriptionFilter,
171 S: TopicSubscriptionFilter,
172{
173 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
174 self.filter1.can_subscribe(topic_hash) && self.filter2.can_subscribe(topic_hash)
175 }
176
177 fn filter_incoming_subscription_set<'a>(
178 &mut self,
179 subscriptions: HashSet<&'a GossipsubSubscription>,
180 currently_subscribed_topics: &BTreeSet<TopicHash>,
181 ) -> Result<HashSet<&'a GossipsubSubscription>, String> {
182 let intermediate = self
183 .filter1
184 .filter_incoming_subscription_set(subscriptions, currently_subscribed_topics)?;
185 self.filter2
186 .filter_incoming_subscription_set(intermediate, currently_subscribed_topics)
187 }
188}
189
190pub struct CallbackSubscriptionFilter<T>(pub T)
191where
192 T: FnMut(&TopicHash) -> bool;
193
194impl<T> TopicSubscriptionFilter for CallbackSubscriptionFilter<T>
195where
196 T: FnMut(&TopicHash) -> bool,
197{
198 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
199 (self.0)(topic_hash)
200 }
201}
202
203#[cfg(feature = "regex-filter")]
204pub mod regex {
205 use super::TopicSubscriptionFilter;
206 use crate::TopicHash;
207 use regex::Regex;
208
209 pub struct RegexSubscriptionFilter(pub Regex);
211
212 impl TopicSubscriptionFilter for RegexSubscriptionFilter {
213 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
214 self.0.is_match(topic_hash.as_str())
215 }
216 }
217
218 #[cfg(test)]
219 mod test {
220 use super::*;
221 use crate::types::GossipsubSubscription;
222 use crate::types::GossipsubSubscriptionAction::*;
223
224 #[test]
225 fn test_regex_subscription_filter() {
226 let t1 = TopicHash::from_raw("tt");
227 let t2 = TopicHash::from_raw("et3t3te");
228 let t3 = TopicHash::from_raw("abcdefghijklmnopqrsuvwxyz");
229
230 let mut filter = RegexSubscriptionFilter(Regex::new("t.*t").unwrap());
231
232 let old = Default::default();
233 let subscriptions = vec![
234 GossipsubSubscription {
235 action: Subscribe,
236 topic_hash: t1.clone(),
237 },
238 GossipsubSubscription {
239 action: Subscribe,
240 topic_hash: t2.clone(),
241 },
242 GossipsubSubscription {
243 action: Subscribe,
244 topic_hash: t3.clone(),
245 },
246 ];
247
248 let result = filter
249 .filter_incoming_subscriptions(&subscriptions, &old)
250 .unwrap();
251 assert_eq!(result, subscriptions[..2].iter().collect());
252 }
253 }
254}
255
256#[cfg(test)]
257mod test {
258 use super::*;
259 use crate::types::GossipsubSubscriptionAction::*;
260 use std::iter::FromIterator;
261
262 #[test]
263 fn test_filter_incoming_allow_all_with_duplicates() {
264 let mut filter = AllowAllSubscriptionFilter {};
265
266 let t1 = TopicHash::from_raw("t1");
267 let t2 = TopicHash::from_raw("t2");
268
269 let old = BTreeSet::from_iter(vec![t1.clone()].into_iter());
270 let subscriptions = vec![
271 GossipsubSubscription {
272 action: Unsubscribe,
273 topic_hash: t1.clone(),
274 },
275 GossipsubSubscription {
276 action: Unsubscribe,
277 topic_hash: t2.clone(),
278 },
279 GossipsubSubscription {
280 action: Subscribe,
281 topic_hash: t2.clone(),
282 },
283 GossipsubSubscription {
284 action: Subscribe,
285 topic_hash: t1.clone(),
286 },
287 GossipsubSubscription {
288 action: Unsubscribe,
289 topic_hash: t1.clone(),
290 },
291 ];
292
293 let result = filter
294 .filter_incoming_subscriptions(&subscriptions, &old)
295 .unwrap();
296 assert_eq!(result, vec![&subscriptions[4]].into_iter().collect());
297 }
298
299 #[test]
300 fn test_filter_incoming_whitelist() {
301 let t1 = TopicHash::from_raw("t1");
302 let t2 = TopicHash::from_raw("t2");
303
304 let mut filter = WhitelistSubscriptionFilter(HashSet::from_iter(vec![t1.clone()]));
305
306 let old = Default::default();
307 let subscriptions = vec![
308 GossipsubSubscription {
309 action: Subscribe,
310 topic_hash: t1.clone(),
311 },
312 GossipsubSubscription {
313 action: Subscribe,
314 topic_hash: t2.clone(),
315 },
316 ];
317
318 let result = filter
319 .filter_incoming_subscriptions(&subscriptions, &old)
320 .unwrap();
321 assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
322 }
323
324 #[test]
325 fn test_filter_incoming_too_many_subscriptions_per_request() {
326 let t1 = TopicHash::from_raw("t1");
327
328 let mut filter = MaxCountSubscriptionFilter {
329 filter: AllowAllSubscriptionFilter {},
330 max_subscribed_topics: 100,
331 max_subscriptions_per_request: 2,
332 };
333
334 let old = Default::default();
335
336 let subscriptions = vec![
337 GossipsubSubscription {
338 action: Subscribe,
339 topic_hash: t1.clone(),
340 },
341 GossipsubSubscription {
342 action: Unsubscribe,
343 topic_hash: t1.clone(),
344 },
345 GossipsubSubscription {
346 action: Subscribe,
347 topic_hash: t1.clone(),
348 },
349 ];
350
351 let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
352 assert_eq!(result, Err("too many subscriptions per request".into()));
353 }
354
355 #[test]
356 fn test_filter_incoming_too_many_subscriptions() {
357 let t: Vec<_> = (0..4)
358 .map(|i| TopicHash::from_raw(format!("t{}", i)))
359 .collect();
360
361 let mut filter = MaxCountSubscriptionFilter {
362 filter: AllowAllSubscriptionFilter {},
363 max_subscribed_topics: 3,
364 max_subscriptions_per_request: 2,
365 };
366
367 let old = t[0..2].iter().cloned().collect();
368
369 let subscriptions = vec![
370 GossipsubSubscription {
371 action: Subscribe,
372 topic_hash: t[2].clone(),
373 },
374 GossipsubSubscription {
375 action: Subscribe,
376 topic_hash: t[3].clone(),
377 },
378 ];
379
380 let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
381 assert_eq!(result, Err("too many subscribed topics".into()));
382 }
383
384 #[test]
385 fn test_filter_incoming_max_subscribed_valid() {
386 let t: Vec<_> = (0..5)
387 .map(|i| TopicHash::from_raw(format!("t{}", i)))
388 .collect();
389
390 let mut filter = MaxCountSubscriptionFilter {
391 filter: WhitelistSubscriptionFilter(t.iter().take(4).cloned().collect()),
392 max_subscribed_topics: 2,
393 max_subscriptions_per_request: 5,
394 };
395
396 let old = t[0..2].iter().cloned().collect();
397
398 let subscriptions = vec![
399 GossipsubSubscription {
400 action: Subscribe,
401 topic_hash: t[4].clone(),
402 },
403 GossipsubSubscription {
404 action: Subscribe,
405 topic_hash: t[2].clone(),
406 },
407 GossipsubSubscription {
408 action: Subscribe,
409 topic_hash: t[3].clone(),
410 },
411 GossipsubSubscription {
412 action: Unsubscribe,
413 topic_hash: t[0].clone(),
414 },
415 GossipsubSubscription {
416 action: Unsubscribe,
417 topic_hash: t[1].clone(),
418 },
419 ];
420
421 let result = filter
422 .filter_incoming_subscriptions(&subscriptions, &old)
423 .unwrap();
424 assert_eq!(result, subscriptions[1..].iter().collect());
425 }
426
427 #[test]
428 fn test_callback_filter() {
429 let t1 = TopicHash::from_raw("t1");
430 let t2 = TopicHash::from_raw("t2");
431
432 let mut filter = CallbackSubscriptionFilter(|h| h.as_str() == "t1");
433
434 let old = Default::default();
435 let subscriptions = vec![
436 GossipsubSubscription {
437 action: Subscribe,
438 topic_hash: t1.clone(),
439 },
440 GossipsubSubscription {
441 action: Subscribe,
442 topic_hash: t2.clone(),
443 },
444 ];
445
446 let result = filter
447 .filter_incoming_subscriptions(&subscriptions, &old)
448 .unwrap();
449 assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
450 }
451}