gnostr_relay/
subscriber.rs

1use std::{
2    collections::HashMap,
3    rc::{Rc, Weak},
4};
5
6use crate::{message::*, setting::SettingWrapper};
7use actix::prelude::*;
8use nostr_db::{EventIndex, Filter};
9
10#[derive(Debug, PartialEq, Eq, Hash, Clone)]
11struct Key {
12    session_id: usize,
13    sub_id: String,
14    index: usize,
15}
16
17impl Key {
18    fn new(session_id: usize, sub_id: String, index: usize) -> Self {
19        Self {
20            session_id,
21            sub_id,
22            index,
23        }
24    }
25}
26
27fn concat_tag<K, I>(key: K, val: I) -> Vec<u8>
28where
29    K: AsRef<[u8]>,
30    I: AsRef<[u8]>,
31{
32    [key.as_ref(), val.as_ref()].concat()
33}
34
35// index for fast filter
36#[derive(Debug, Default)]
37pub struct SubscriberIndex {
38    /// map session_id -> subscription_id -> filters
39    subscriptions: HashMap<usize, HashMap<String, Vec<Rc<Filter>>>>,
40    ids: HashMap<[u8; 32], HashMap<Key, Weak<Filter>>>,
41    authors: HashMap<[u8; 32], HashMap<Key, Weak<Filter>>>,
42    tags: HashMap<Vec<u8>, HashMap<Key, Weak<Filter>>>,
43    kinds: HashMap<u16, HashMap<Key, Weak<Filter>>>,
44    others: HashMap<Key, Weak<Filter>>,
45}
46
47impl SubscriberIndex {
48    fn install_index(&mut self, session_id: usize, sub_id: String, filters: &[Rc<Filter>]) {
49        for (index, filter) in filters.iter().enumerate() {
50            if !filter.ids.is_empty() {
51                for key in filter.ids.iter() {
52                    self.ids.entry(*key).or_default().insert(
53                        Key::new(session_id, sub_id.clone(), index),
54                        Rc::downgrade(filter),
55                    );
56                }
57            } else if !filter.authors.is_empty() {
58                for key in filter.authors.iter() {
59                    self.authors.entry(*key).or_default().insert(
60                        Key::new(session_id, sub_id.clone(), index),
61                        Rc::downgrade(filter),
62                    );
63                }
64            } else if !filter.tags.is_empty() {
65                for (tag, values) in filter.tags.iter() {
66                    for val in values.iter() {
67                        self.tags.entry(concat_tag(tag, val)).or_default().insert(
68                            Key::new(session_id, sub_id.clone(), index),
69                            Rc::downgrade(filter),
70                        );
71                    }
72                }
73            } else if !filter.kinds.is_empty() {
74                for key in filter.kinds.iter() {
75                    self.kinds.entry(*key).or_default().insert(
76                        Key::new(session_id, sub_id.clone(), index),
77                        Rc::downgrade(filter),
78                    );
79                }
80            } else {
81                self.others.insert(
82                    Key::new(session_id, sub_id.clone(), index),
83                    Rc::downgrade(filter),
84                );
85            }
86        }
87    }
88
89    fn uninstall_index(&mut self, session_id: usize, limit_sub_id: Option<&String>) {
90        if let Some(subs) = self.subscriptions.get(&session_id) {
91            for (sub_id, filters) in subs {
92                if let Some(limit_sub_id) = limit_sub_id {
93                    if limit_sub_id != sub_id {
94                        continue;
95                    }
96                }
97                for (index, filter) in filters.iter().enumerate() {
98                    if !filter.ids.is_empty() {
99                        for key in filter.ids.iter() {
100                            if let Some(map) = self.ids.get_mut(key) {
101                                map.remove(&Key::new(session_id, sub_id.clone(), index));
102                                if map.is_empty() {
103                                    self.ids.remove(key);
104                                }
105                            }
106                        }
107                    } else if !filter.authors.is_empty() {
108                        for key in filter.authors.iter() {
109                            if let Some(map) = self.authors.get_mut(key) {
110                                map.remove(&Key::new(session_id, sub_id.clone(), index));
111                                if map.is_empty() {
112                                    self.authors.remove(key);
113                                }
114                            }
115                        }
116                    } else if !filter.tags.is_empty() {
117                        for (tag, values) in filter.tags.iter() {
118                            for val in values.iter() {
119                                let key = concat_tag(tag, val);
120                                if let Some(map) = self.tags.get_mut(&key) {
121                                    map.remove(&Key::new(session_id, sub_id.clone(), index));
122                                    if map.is_empty() {
123                                        self.tags.remove(&key);
124                                    }
125                                }
126                            }
127                        }
128                    } else if !filter.kinds.is_empty() {
129                        for key in filter.kinds.iter() {
130                            if let Some(map) = self.kinds.get_mut(key) {
131                                map.remove(&Key::new(session_id, sub_id.clone(), index));
132                                if map.is_empty() {
133                                    self.kinds.remove(key);
134                                }
135                            }
136                        }
137                    } else {
138                        self.others
139                            .remove(&Key::new(session_id, sub_id.clone(), index));
140                    }
141                }
142            }
143        }
144    }
145
146    pub fn add(
147        &mut self,
148        session_id: usize,
149        sub_id: String,
150        filters: Vec<Filter>,
151        limit: usize,
152    ) -> Subscribed {
153        // according to NIP-01, <subscription_id> is an arbitrary, non-empty string of max length 64 chars
154        if sub_id.is_empty() || sub_id.len() > 64 {
155            return Subscribed::InvalidIdLength;
156        }
157
158        if let Some(subs) = self.subscriptions.get(&session_id) {
159            if subs.len() >= limit {
160                return Subscribed::Overlimit;
161            }
162        }
163
164        let filters = filters.into_iter().map(Rc::new).collect::<Vec<_>>();
165
166        // remove old
167        self.uninstall_index(session_id, Some(&sub_id));
168        self.install_index(session_id, sub_id.clone(), &filters);
169
170        let map = self.subscriptions.entry(session_id).or_default();
171
172        // NIP01: overwrite the previous subscription
173        map.insert(sub_id, filters);
174        Subscribed::Ok
175    }
176
177    pub fn remove(&mut self, session_id: usize, sub_id: Option<&String>) {
178        self.uninstall_index(session_id, sub_id);
179        if let Some(sub_id) = sub_id {
180            if let Some(map) = self.subscriptions.get_mut(&session_id) {
181                map.remove(sub_id);
182                if map.is_empty() {
183                    self.subscriptions.remove(&session_id);
184                }
185            }
186        } else {
187            self.subscriptions.remove(&session_id);
188        }
189    }
190
191    pub fn lookup(&self, event: &EventIndex, mut f: impl FnMut(&usize, &String)) {
192        let mut dup = HashMap::new();
193
194        fn check(
195            session_id: usize,
196            sub_id: &String,
197            filter: &Weak<Filter>,
198            event: &EventIndex,
199            dup: &mut HashMap<(usize, String), bool>,
200            mut f: impl FnMut(&usize, &String),
201        ) {
202            if let Some(filter) = filter.upgrade() {
203                if filter.r#match(event) {
204                    let key = (session_id, sub_id.clone());
205                    if dup.get(&key).is_none() {
206                        f(&session_id, sub_id);
207                        dup.insert(key, true);
208                    }
209                }
210            }
211        }
212
213        fn scan<T: std::cmp::Eq + std::hash::Hash>(
214            map: &HashMap<T, HashMap<Key, Weak<Filter>>>,
215            key: &T,
216            event: &EventIndex,
217            dup: &mut HashMap<(usize, String), bool>,
218            mut f: impl FnMut(&usize, &String),
219        ) {
220            if let Some(map) = map.get(key) {
221                for (k, filter) in map {
222                    check(k.session_id, &k.sub_id, filter, event, dup, &mut f);
223                }
224            }
225        }
226
227        scan(&self.ids, event.id(), event, &mut dup, &mut f);
228        scan(&self.authors, event.pubkey(), event, &mut dup, &mut f);
229        scan(&self.kinds, &event.kind(), event, &mut dup, &mut f);
230        for (key, val) in event.tags() {
231            scan(&self.tags, &concat_tag(key, val), event, &mut dup, &mut f);
232        }
233
234        for (k, filter) in &self.others {
235            check(k.session_id, &k.sub_id, filter, event, &mut dup, &mut f);
236        }
237    }
238
239    pub fn lookup1(&self, event: &EventIndex, mut f: impl FnMut(&usize, &String)) {
240        for (session_id, subs) in &self.subscriptions {
241            for (sub_id, filters) in subs {
242                for filter in filters {
243                    if filter.r#match(event) {
244                        f(session_id, sub_id);
245                        break;
246                    }
247                }
248            }
249        }
250    }
251}
252
253pub struct Subscriber {
254    pub addr: Recipient<SubscribeResult>,
255    /// map session_id -> subscription_id -> filters
256    pub subscriptions: HashMap<usize, HashMap<String, Vec<Filter>>>,
257    pub index: SubscriberIndex,
258    pub setting: SettingWrapper,
259}
260
261impl Subscriber {
262    pub fn new(addr: Recipient<SubscribeResult>, setting: SettingWrapper) -> Self {
263        Self {
264            addr,
265            subscriptions: HashMap::new(),
266            setting,
267            index: SubscriberIndex::default(),
268        }
269    }
270}
271
272impl Actor for Subscriber {
273    type Context = Context<Self>;
274    fn started(&mut self, ctx: &mut Self::Context) {
275        ctx.set_mailbox_capacity(10000);
276    }
277}
278
279impl Handler<Subscribe> for Subscriber {
280    type Result = Subscribed;
281    fn handle(&mut self, msg: Subscribe, _: &mut Self::Context) -> Subscribed {
282        self.index.add(
283            msg.id,
284            msg.subscription.id,
285            msg.subscription.filters,
286            self.setting.read().limitation.max_subscriptions,
287        )
288    }
289}
290
291impl Handler<Unsubscribe> for Subscriber {
292    type Result = ();
293    fn handle(&mut self, msg: Unsubscribe, _: &mut Self::Context) {
294        self.index.remove(msg.id, msg.sub_id.as_ref());
295    }
296}
297
298impl Handler<Dispatch> for Subscriber {
299    type Result = ();
300    fn handle(&mut self, msg: Dispatch, _: &mut Self::Context) {
301        let event = &msg.event;
302        let index = event.index();
303        let event_str = event.to_string();
304        self.index.lookup(index, |session_id, sub_id| {
305            self.addr.do_send(SubscribeResult {
306                id: *session_id,
307                msg: OutgoingMessage::event(sub_id, &event_str),
308                sub_id: sub_id.clone(),
309            });
310        });
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use crate::Setting;
317
318    use super::*;
319    use actix_rt::time::sleep;
320    use anyhow::Result;
321    use nostr_db::{Event, Filter};
322    use parking_lot::RwLock;
323    use std::sync::Arc;
324    use std::{str::FromStr, time::Duration};
325
326    #[derive(Default)]
327    struct Receiver(Arc<RwLock<Vec<SubscribeResult>>>);
328    impl Actor for Receiver {
329        type Context = Context<Self>;
330    }
331
332    impl Handler<SubscribeResult> for Receiver {
333        type Result = ();
334        fn handle(&mut self, msg: SubscribeResult, _ctx: &mut Self::Context) {
335            self.0.write().push(msg);
336        }
337    }
338
339    #[actix_rt::test]
340    async fn subscribe() -> Result<()> {
341        let note = r#"
342        {
343            "content": "Good morning everyone 😃",
344            "created_at": 1680690006,
345            "id": "332747c0fab8a1a92def4b0937e177be6df4382ce6dd7724f86dc4710b7d4d7d",
346            "kind": 1,
347            "pubkey": "7abf57d516b1ff7308ca3bd5650ea6a4674d469c7c5057b1d005fb13d218bfef",
348            "sig": "ef4ff4f69ac387239eb1401fb07d7a44a5d5d57127e0dc3466a0403cf7d5486b668608ebfcbe9ff1f8d3b5d710545999fe08ee767284ec0b474e4cf92537678f",
349            "tags": [["t", "nostr"]]
350          }
351        "#;
352        let event = Event::from_str(note)?;
353
354        let receiver = Receiver::default();
355        let messages = receiver.0.clone();
356        let receiver = receiver.start();
357        let addr = receiver.recipient();
358
359        let subscriber = Subscriber::new(addr.clone(), Setting::default().into()).start();
360
361        subscriber
362            .send(Dispatch {
363                id: 0,
364                event: event.clone(),
365            })
366            .await?;
367
368        sleep(Duration::from_millis(100)).await;
369        {
370            let r = messages.read();
371            assert_eq!(r.len(), 0);
372            drop(r);
373        }
374
375        let res = subscriber
376            .send(Subscribe {
377                id: 0,
378                subscription: Subscription {
379                    id: 0.to_string(),
380                    filters: vec![Filter {
381                        ..Default::default()
382                    }],
383                },
384            })
385            .await?;
386        assert_eq!(res, Subscribed::Ok);
387
388        // overwrite
389        let res = subscriber
390            .send(Subscribe {
391                id: 0,
392                subscription: Subscription {
393                    id: 0.to_string(),
394                    filters: vec![Filter {
395                        ..Default::default()
396                    }],
397                },
398            })
399            .await?;
400        assert_eq!(res, Subscribed::Ok);
401
402        let res = subscriber
403            .send(Subscribe {
404                id: 0,
405                subscription: Subscription {
406                    id: 1.to_string(),
407                    filters: vec![Filter {
408                        kinds: vec![1000].into(),
409                        ..Default::default()
410                    }],
411                },
412            })
413            .await?;
414        assert_eq!(res, Subscribed::Ok);
415
416        let res = subscriber
417            .send(Subscribe {
418                id: 0,
419                subscription: Subscription {
420                    id: "".to_string(),
421                    filters: vec![Filter {
422                        kinds: vec![1000].into(),
423                        ..Default::default()
424                    }],
425                },
426            })
427            .await?;
428        assert_eq!(res, Subscribed::InvalidIdLength);
429
430        let res = subscriber
431            .send(Subscribe {
432                id: 0,
433                subscription: Subscription {
434                    id: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefA"
435                        .to_string(),
436                    filters: vec![Filter {
437                        kinds: vec![1000].into(),
438                        ..Default::default()
439                    }],
440                },
441            })
442            .await?;
443        assert_eq!(res, Subscribed::InvalidIdLength);
444
445        subscriber
446            .send(Dispatch {
447                id: 0,
448                event: event.clone(),
449            })
450            .await?;
451
452        sleep(Duration::from_millis(100)).await;
453        let r = messages.read();
454        assert_eq!(r.len(), 1);
455        drop(r);
456
457        Ok(())
458    }
459
460    fn lookup(index: &SubscriberIndex, event: &str) -> Result<Vec<(usize, String)>> {
461        let event = Event::from_str(event)?;
462        let mut result = vec![];
463        let mut result1 = vec![];
464        index.lookup(event.index(), |session_id, sub_id| {
465            result.push((*session_id, sub_id.clone()));
466        });
467        index.lookup1(event.index(), |session_id, sub_id| {
468            result1.push((*session_id, sub_id.clone()));
469        });
470        result.sort();
471        result1.sort();
472        assert_eq!(result, result1);
473        Ok(result)
474    }
475
476    // fn gen_id(p: u8, index: u8) -> [u8; 32] {
477    //     let mut id = [0; 32];
478    //     id[29] = 1;
479    //     id[30] = p;
480    //     id[31] = index;
481    //     id
482    // }
483
484    #[test]
485    fn index() -> Result<()> {
486        let mut index = SubscriberIndex::default();
487        // all
488        index.add(
489            1,
490            "all".to_owned(),
491            vec![Filter::from_str("{}")?, Filter::from_str("{}")?],
492            5,
493        );
494
495        index.add(
496            1,
497            "id".to_owned(),
498            vec![
499                Filter::from_str(
500                    r###"
501         {
502            "ids": ["0000000000000000000000000000000000000000000000000000000000000000", 
503                    "0000000000000000000000000000000000000000000000000000000000000001"]
504          }
505        "###,
506                )?,
507                Filter::from_str(
508                    r###"
509         {
510            "ids": ["0000000000000000000000000000000000000000000000000000000000000000"]
511          }
512        "###,
513                )?,
514            ],
515            5,
516        );
517        index.add(
518            2,
519            "author".to_owned(),
520            vec![
521                Filter::from_str(
522                    r###"
523         {
524            "authors": ["0000000000000000000000000000000000000000000000000000000000000000", 
525                    "0000000000000000000000000000000000000000000000000000000000000001"]
526          }
527        "###,
528                )?,
529                Filter::from_str(
530                    r###"
531         {
532            "authors": ["0000000000000000000000000000000000000000000000000000000000000000"]
533          }
534        "###,
535                )?,
536            ],
537            5,
538        );
539        index.add(
540            3,
541            "kind".to_owned(),
542            vec![
543                Filter::from_str(
544                    r###"
545         {
546            "kinds": [0, 1]
547          }
548        "###,
549                )?,
550                Filter::from_str(
551                    r###"
552         {
553            "kinds": [0]
554          }
555        "###,
556                )?,
557            ],
558            5,
559        );
560        index.add(
561            4,
562            "tag1".to_owned(),
563            vec![
564                Filter::from_str(
565                    r###"
566         {
567            "#p": ["0000000000000000000000000000000000000000000000000000000000000000", 
568                    "0000000000000000000000000000000000000000000000000000000000000001"]
569          }
570        "###,
571                )?,
572                Filter::from_str(
573                    r###"
574         {
575            "#p": ["0000000000000000000000000000000000000000000000000000000000000000"]
576          }
577        "###,
578                )?,
579            ],
580            5,
581        );
582        index.add(
583            4,
584            "tag2".to_owned(),
585            vec![Filter::from_str(
586                r###"
587         {
588            "#p": ["0000000000000000000000000000000000000000000000000000000000000000", 
589                    "0000000000000000000000000000000000000000000000000000000000000001"],
590                    "#t": ["test"]
591          }
592        "###,
593            )?],
594            5,
595        );
596        // override
597        let ok = index.add(
598            4,
599            "tag2".to_owned(),
600            vec![Filter::from_str(
601                r###"
602         {
603            "#p": ["0000000000000000000000000000000000000000000000000000000000000000", 
604                    "0000000000000000000000000000000000000000000000000000000000000001"],
605                    "#t": ["test"]
606          }
607        "###,
608            )?],
609            5,
610        );
611        assert_eq!(ok, Subscribed::Ok);
612        assert_eq!(index.others.len(), 2);
613        assert_eq!(index.ids.len(), 2);
614        assert_eq!(index.authors.len(), 2);
615        assert_eq!(index.kinds.len(), 2);
616        assert_eq!(index.tags.len(), 3);
617
618        let res = lookup(
619            &index,
620            r###"
621        {
622           "id": "0000000000000000000000000000000000000000000000000000000000000000",
623           "pubkey": "0000000000000000000000000000000000000000000000000000000000000001",
624           "kind": 1,
625           "tags": [],
626           "content": "",
627           "created_at": 0,
628           "sig": "633db60e2e7082c13a47a6b19d663d45b2a2ebdeaf0b4c35ef83be2738030c54fc7fd56d139652937cdca875ee61b51904a1d0d0588a6acd6168d7be2909d693"
629         }
630       "###,
631        )?;
632        assert_eq!(res.len(), 4);
633        let res = lookup(
634            &index,
635            r###"
636        {
637           "id": "0000000000000000000000000000000000000000000000000000000000000002",
638           "pubkey": "0000000000000000000000000000000000000000000000000000000000000001",
639           "kind": 1,
640           "tags": [],
641           "content": "",
642           "created_at": 0,
643           "sig": "633db60e2e7082c13a47a6b19d663d45b2a2ebdeaf0b4c35ef83be2738030c54fc7fd56d139652937cdca875ee61b51904a1d0d0588a6acd6168d7be2909d693"
644         }
645       "###,
646        )?;
647        assert_eq!(res.len(), 3);
648
649        let res = lookup(
650            &index,
651            r###"
652        {
653           "id": "0000000000000000000000000000000000000000000000000000000000000008",
654           "pubkey": "0000000000000000000000000000000000000000000000000000000000000008",
655           "kind": 10,
656           "tags": [["p", "0000000000000000000000000000000000000000000000000000000000000000"]],
657           "content": "",
658           "created_at": 0,
659           "sig": "633db60e2e7082c13a47a6b19d663d45b2a2ebdeaf0b4c35ef83be2738030c54fc7fd56d139652937cdca875ee61b51904a1d0d0588a6acd6168d7be2909d693"
660         }
661       "###,
662        )?;
663        assert_eq!(res.len(), 2);
664
665        let res = lookup(
666            &index,
667            r###"
668        {
669           "id": "0000000000000000000000000000000000000000000000000000000000000008",
670           "pubkey": "0000000000000000000000000000000000000000000000000000000000000008",
671           "kind": 10,
672           "tags": [["p", "0000000000000000000000000000000000000000000000000000000000000000"], ["t", "test"]],
673           "content": "",
674           "created_at": 0,
675           "sig": "633db60e2e7082c13a47a6b19d663d45b2a2ebdeaf0b4c35ef83be2738030c54fc7fd56d139652937cdca875ee61b51904a1d0d0588a6acd6168d7be2909d693"
676         }
677       "###,
678        )?;
679        assert_eq!(res.len(), 3);
680
681        index.remove(1, Some(&"all".to_owned()));
682        index.remove(1, Some(&"id".to_owned()));
683        index.remove(2, Some(&"author".to_owned()));
684        index.remove(3, Some(&"kind".to_owned()));
685        index.remove(4, Some(&"tag1".to_owned()));
686        index.remove(4, Some(&"tag2".to_owned()));
687
688        assert_eq!(index.subscriptions.len(), 0);
689        assert_eq!(index.others.len(), 0);
690        assert_eq!(index.ids.len(), 0);
691        assert_eq!(index.authors.len(), 0);
692        assert_eq!(index.kinds.len(), 0);
693        assert_eq!(index.tags.len(), 0);
694        Ok(())
695    }
696}