Skip to main content

hashtree_cli/
nostr_relay.rs

1use std::collections::HashMap;
2use std::collections::{HashSet, VecDeque};
3use std::path::PathBuf;
4use std::sync::{
5    atomic::{AtomicU64, Ordering},
6    Arc,
7};
8use std::time::{Duration, Instant};
9
10use tokio::sync::{mpsc, Mutex};
11
12use nostr::{ClientMessage as NostrClientMessage, JsonUtil, RelayMessage as NostrRelayMessage};
13use nostr::{Event, EventId, Filter as NostrFilter, SubscriptionId};
14
15use crate::socialgraph;
16
17#[derive(Debug, Clone)]
18pub struct NostrRelayConfig {
19    pub spambox_db_max_bytes: u64,
20    pub max_query_limit: usize,
21    pub max_subs_per_client: usize,
22    pub max_filters_per_sub: usize,
23    pub spambox_max_events_per_min: u32,
24    pub spambox_max_reqs_per_min: u32,
25}
26
27impl Default for NostrRelayConfig {
28    fn default() -> Self {
29        Self {
30            spambox_db_max_bytes: 1024 * 1024 * 1024,
31            max_query_limit: 200,
32            max_subs_per_client: 64,
33            max_filters_per_sub: 32,
34            spambox_max_events_per_min: 120,
35            spambox_max_reqs_per_min: 120,
36        }
37    }
38}
39
40mod imp {
41    use super::*;
42    use anyhow::Result;
43
44    use crate::socialgraph::{SocialGraphAccessControl, SocialGraphBackend};
45    use tracing::warn;
46
47    struct NostrStore {
48        store: Arc<dyn SocialGraphBackend>,
49    }
50
51    impl NostrStore {
52        fn new(store: Arc<dyn SocialGraphBackend>) -> Self {
53            Self { store }
54        }
55
56        fn ingest(&self, event: &Event) -> Result<()> {
57            crate::socialgraph::ingest_parsed_event(self.store.as_ref(), event)
58        }
59
60        fn query(&self, filter: &NostrFilter, limit: usize) -> Vec<Event> {
61            crate::socialgraph::query_events(self.store.as_ref(), filter, limit)
62        }
63    }
64
65    #[derive(Debug, Clone)]
66    struct ClientQuota {
67        last_reset: Instant,
68        spambox_events: u32,
69        reqs: u32,
70    }
71
72    impl ClientQuota {
73        fn new() -> Self {
74            Self {
75                last_reset: Instant::now(),
76                spambox_events: 0,
77                reqs: 0,
78            }
79        }
80
81        fn reset_if_needed(&mut self) {
82            if self.last_reset.elapsed() >= Duration::from_secs(60) {
83                self.last_reset = Instant::now();
84                self.spambox_events = 0;
85                self.reqs = 0;
86            }
87        }
88
89        fn allow_spambox_event(&mut self, limit: u32) -> bool {
90            self.reset_if_needed();
91            if self.spambox_events >= limit {
92                return false;
93            }
94            self.spambox_events += 1;
95            true
96        }
97
98        fn allow_req(&mut self, limit: u32) -> bool {
99            self.reset_if_needed();
100            if self.reqs >= limit {
101                return false;
102            }
103            self.reqs += 1;
104            true
105        }
106    }
107
108    struct ClientState {
109        sender: mpsc::UnboundedSender<String>,
110        pubkey: Option<String>,
111        quota: ClientQuota,
112    }
113
114    struct RecentEvents {
115        order: VecDeque<EventId>,
116        events: HashMap<EventId, Event>,
117        max_len: usize,
118    }
119
120    impl RecentEvents {
121        fn new(max_len: usize) -> Self {
122            Self {
123                order: VecDeque::new(),
124                events: HashMap::new(),
125                max_len: max_len.max(128),
126            }
127        }
128
129        fn insert(&mut self, event: Event) {
130            if self.events.contains_key(&event.id) {
131                return;
132            }
133            self.order.push_back(event.id);
134            self.events.insert(event.id, event);
135            while self.order.len() > self.max_len {
136                if let Some(oldest) = self.order.pop_front() {
137                    self.events.remove(&oldest);
138                }
139            }
140        }
141
142        fn matching(&self, filter: &NostrFilter) -> Vec<Event> {
143            self.events
144                .values()
145                .filter(|event| filter.match_event(event))
146                .cloned()
147                .collect()
148        }
149    }
150
151    enum SpamboxStore {
152        Persistent(NostrStore),
153        Memory(MemorySpambox),
154    }
155
156    struct MemorySpambox {
157        events: Mutex<VecDeque<Event>>,
158        max_len: usize,
159    }
160
161    impl MemorySpambox {
162        fn new(max_len: usize) -> Self {
163            Self {
164                events: Mutex::new(VecDeque::new()),
165                max_len: max_len.max(128),
166            }
167        }
168
169        async fn ingest(&self, event: &Event) -> bool {
170            let mut events = self.events.lock().await;
171            events.push_back(event.clone());
172            while events.len() > self.max_len {
173                events.pop_front();
174            }
175            true
176        }
177    }
178
179    impl SpamboxStore {
180        async fn ingest(&self, event: &Event) -> bool {
181            match self {
182                SpamboxStore::Persistent(store) => store.ingest(event).is_ok(),
183                SpamboxStore::Memory(store) => store.ingest(event).await,
184            }
185        }
186    }
187
188    pub struct NostrRelay {
189        config: NostrRelayConfig,
190        trusted: NostrStore,
191        spambox: Option<SpamboxStore>,
192        social_graph: Option<Arc<SocialGraphAccessControl>>,
193        clients: Mutex<HashMap<u64, ClientState>>,
194        subscriptions: Mutex<HashMap<u64, HashMap<SubscriptionId, Vec<NostrFilter>>>>,
195        recent_events: Mutex<RecentEvents>,
196        next_client_id: AtomicU64,
197    }
198
199    impl NostrRelay {
200        pub fn new(
201            trusted_store: Arc<dyn SocialGraphBackend>,
202            data_dir: PathBuf,
203            social_graph: Option<Arc<SocialGraphAccessControl>>,
204            config: NostrRelayConfig,
205        ) -> Result<Self> {
206            let spambox = if config.spambox_db_max_bytes == 0 {
207                Some(SpamboxStore::Memory(MemorySpambox::new(
208                    config.max_query_limit * 2,
209                )))
210            } else {
211                let spam_dir = data_dir.join("socialgraph_spambox");
212                match socialgraph::open_social_graph_store_at_path(
213                    &spam_dir,
214                    Some(config.spambox_db_max_bytes),
215                ) {
216                    Ok(store) => Some(SpamboxStore::Persistent(NostrStore::new(store))),
217                    Err(err) => {
218                        warn!(
219                            "Failed to open social graph spambox (falling back to memory): {}",
220                            err
221                        );
222                        Some(SpamboxStore::Memory(MemorySpambox::new(
223                            config.max_query_limit * 2,
224                        )))
225                    }
226                }
227            };
228
229            let recent_size = config.max_query_limit.saturating_mul(2);
230
231            Ok(Self {
232                config,
233                trusted: NostrStore::new(trusted_store),
234                spambox,
235                social_graph,
236                clients: Mutex::new(HashMap::new()),
237                subscriptions: Mutex::new(HashMap::new()),
238                recent_events: Mutex::new(RecentEvents::new(recent_size)),
239                next_client_id: AtomicU64::new(1),
240            })
241        }
242
243        pub fn next_client_id(&self) -> u64 {
244            self.next_client_id.fetch_add(1, Ordering::SeqCst)
245        }
246
247        pub async fn register_client(
248            &self,
249            client_id: u64,
250            sender: mpsc::UnboundedSender<String>,
251            pubkey: Option<String>,
252        ) {
253            let mut clients = self.clients.lock().await;
254            clients.insert(
255                client_id,
256                ClientState {
257                    sender,
258                    pubkey,
259                    quota: ClientQuota::new(),
260                },
261            );
262        }
263
264        pub async fn unregister_client(&self, client_id: u64) {
265            let mut clients = self.clients.lock().await;
266            clients.remove(&client_id);
267            drop(clients);
268            let mut subs = self.subscriptions.lock().await;
269            subs.remove(&client_id);
270        }
271
272        pub async fn handle_client_message(&self, client_id: u64, msg: NostrClientMessage) {
273            match msg {
274                NostrClientMessage::Event(event) => {
275                    self.handle_event(client_id, *event).await;
276                }
277                NostrClientMessage::Req {
278                    subscription_id,
279                    filters,
280                } => {
281                    self.handle_req(client_id, subscription_id, filters).await;
282                }
283                NostrClientMessage::Count {
284                    subscription_id,
285                    filters,
286                } => {
287                    self.handle_count(client_id, subscription_id, filters).await;
288                }
289                NostrClientMessage::Close(subscription_id) => {
290                    self.handle_close(client_id, subscription_id).await;
291                }
292                NostrClientMessage::Auth(event) => {
293                    self.handle_auth(client_id, *event).await;
294                }
295                NostrClientMessage::NegOpen { .. }
296                | NostrClientMessage::NegMsg { .. }
297                | NostrClientMessage::NegClose { .. } => {
298                    self.send_to_client(
299                        client_id,
300                        NostrRelayMessage::notice("negentropy not supported"),
301                    )
302                    .await;
303                }
304            }
305        }
306
307        async fn handle_auth(&self, client_id: u64, event: Event) {
308            let ok = event.verify().is_ok();
309            let message = if ok { "" } else { "invalid auth" };
310            self.send_to_client(client_id, NostrRelayMessage::ok(event.id, ok, message))
311                .await;
312        }
313
314        async fn handle_close(&self, client_id: u64, subscription_id: SubscriptionId) {
315            let mut subs = self.subscriptions.lock().await;
316            if let Some(map) = subs.get_mut(&client_id) {
317                map.remove(&subscription_id);
318            }
319        }
320
321        async fn handle_event(&self, client_id: u64, event: Event) {
322            let ok = event.verify().is_ok();
323            if !ok {
324                self.send_to_client(
325                    client_id,
326                    NostrRelayMessage::ok(event.id, false, "invalid: signature"),
327                )
328                .await;
329                return;
330            }
331
332            let trusted = self.is_trusted_event(client_id, &event).await;
333            if !trusted && !self.allow_spambox_event(client_id).await {
334                self.send_to_client(
335                    client_id,
336                    NostrRelayMessage::ok(event.id, false, "rate limited"),
337                )
338                .await;
339                return;
340            }
341
342            let is_ephemeral = event.kind.is_ephemeral();
343            if trusted {
344                let mut recent = self.recent_events.lock().await;
345                recent.insert(event.clone());
346            }
347            if !is_ephemeral {
348                let stored = if trusted {
349                    self.trusted.ingest(&event).is_ok()
350                } else {
351                    match self.spambox.as_ref() {
352                        Some(spambox) => spambox.ingest(&event).await,
353                        None => false,
354                    }
355                };
356
357                if !stored {
358                    let message = if trusted {
359                        "store failed"
360                    } else {
361                        "spambox full"
362                    };
363                    self.send_to_client(client_id, NostrRelayMessage::ok(event.id, false, message))
364                        .await;
365                    return;
366                }
367            }
368
369            let message = if trusted { "" } else { "spambox" };
370            self.send_to_client(client_id, NostrRelayMessage::ok(event.id, true, message))
371                .await;
372
373            if trusted {
374                self.broadcast_event(&event).await;
375            }
376        }
377
378        async fn handle_req(
379            &self,
380            client_id: u64,
381            subscription_id: SubscriptionId,
382            mut filters: Vec<NostrFilter>,
383        ) {
384            if !self.allow_req(client_id).await {
385                self.send_to_client(
386                    client_id,
387                    NostrRelayMessage::closed(subscription_id, "rate limited"),
388                )
389                .await;
390                return;
391            }
392
393            if filters.len() > self.config.max_filters_per_sub {
394                filters.truncate(self.config.max_filters_per_sub);
395            }
396
397            {
398                let mut subs = self.subscriptions.lock().await;
399                let entry = subs.entry(client_id).or_default();
400                if !entry.contains_key(&subscription_id)
401                    && entry.len() >= self.config.max_subs_per_client
402                {
403                    self.send_to_client(
404                        client_id,
405                        NostrRelayMessage::closed(subscription_id, "too many subscriptions"),
406                    )
407                    .await;
408                    return;
409                }
410                entry.insert(subscription_id.clone(), filters.clone());
411            }
412
413            let mut seen: HashSet<EventId> = HashSet::new();
414            for filter in &filters {
415                let limit = filter
416                    .limit
417                    .unwrap_or(self.config.max_query_limit)
418                    .min(self.config.max_query_limit);
419                if limit == 0 {
420                    continue;
421                }
422
423                let recent = {
424                    let cache = self.recent_events.lock().await;
425                    cache.matching(filter)
426                };
427                for event in recent {
428                    if seen.insert(event.id) {
429                        self.send_to_client(
430                            client_id,
431                            NostrRelayMessage::event(subscription_id.clone(), event),
432                        )
433                        .await;
434                    }
435                }
436
437                for event in self.trusted.query(filter, limit) {
438                    if seen.insert(event.id) {
439                        self.send_to_client(
440                            client_id,
441                            NostrRelayMessage::event(subscription_id.clone(), event),
442                        )
443                        .await;
444                    }
445                }
446            }
447
448            self.send_to_client(client_id, NostrRelayMessage::eose(subscription_id))
449                .await;
450        }
451
452        async fn handle_count(
453            &self,
454            client_id: u64,
455            subscription_id: SubscriptionId,
456            filters: Vec<NostrFilter>,
457        ) {
458            if !self.allow_req(client_id).await {
459                self.send_to_client(
460                    client_id,
461                    NostrRelayMessage::closed(subscription_id, "rate limited"),
462                )
463                .await;
464                return;
465            }
466
467            let mut seen: HashSet<EventId> = HashSet::new();
468            for filter in &filters {
469                let limit = filter
470                    .limit
471                    .unwrap_or(self.config.max_query_limit)
472                    .min(self.config.max_query_limit);
473                if limit == 0 {
474                    continue;
475                }
476                let recent = {
477                    let cache = self.recent_events.lock().await;
478                    cache.matching(filter)
479                };
480                for event in recent {
481                    seen.insert(event.id);
482                }
483                for event in self.trusted.query(filter, limit) {
484                    seen.insert(event.id);
485                }
486            }
487
488            self.send_to_client(
489                client_id,
490                NostrRelayMessage::count(subscription_id, seen.len()),
491            )
492            .await;
493        }
494
495        async fn is_trusted_event(&self, client_id: u64, event: &Event) -> bool {
496            if let Some(ref social_graph) = self.social_graph {
497                return social_graph.check_write_access(&event.pubkey.to_hex());
498            }
499            let client_pubkey = {
500                let clients = self.clients.lock().await;
501                clients
502                    .get(&client_id)
503                    .and_then(|state| state.pubkey.clone())
504            };
505            if let Some(pubkey) = client_pubkey {
506                return pubkey == event.pubkey.to_hex();
507            }
508            true
509        }
510
511        async fn allow_spambox_event(&self, client_id: u64) -> bool {
512            let mut clients = self.clients.lock().await;
513            let Some(state) = clients.get_mut(&client_id) else {
514                return false;
515            };
516            state
517                .quota
518                .allow_spambox_event(self.config.spambox_max_events_per_min)
519        }
520
521        async fn allow_req(&self, client_id: u64) -> bool {
522            let mut clients = self.clients.lock().await;
523            let Some(state) = clients.get_mut(&client_id) else {
524                return false;
525            };
526            state.quota.allow_req(self.config.spambox_max_reqs_per_min)
527        }
528
529        async fn broadcast_event(&self, event: &Event) {
530            let subscriptions = self.subscriptions.lock().await;
531            let mut deliveries: Vec<(u64, SubscriptionId)> = Vec::new();
532            for (client_id, subs) in subscriptions.iter() {
533                for (sub_id, filters) in subs.iter() {
534                    if filters.iter().any(|f| f.match_event(event)) {
535                        deliveries.push((*client_id, sub_id.clone()));
536                    }
537                }
538            }
539            drop(subscriptions);
540
541            for (client_id, sub_id) in deliveries {
542                self.send_to_client(client_id, NostrRelayMessage::event(sub_id, event.clone()))
543                    .await;
544            }
545        }
546
547        async fn send_to_client(&self, client_id: u64, msg: NostrRelayMessage) {
548            let sender = {
549                let clients = self.clients.lock().await;
550                clients.get(&client_id).map(|state| state.sender.clone())
551            };
552            if let Some(tx) = sender {
553                let _ = tx.send(msg.as_json());
554            }
555        }
556    }
557}
558
559pub use imp::NostrRelay;
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use anyhow::Result;
565    use nostr::{EventBuilder, Filter, JsonUtil, Keys, Kind, RelayMessage, SubscriptionId};
566    use std::collections::HashSet;
567    use tempfile::TempDir;
568    use tokio::time::{timeout, Duration};
569
570    async fn recv_relay_message(rx: &mut mpsc::UnboundedReceiver<String>) -> Result<RelayMessage> {
571        let msg = timeout(Duration::from_secs(1), rx.recv())
572            .await?
573            .ok_or_else(|| anyhow::anyhow!("channel closed"))?;
574        Ok(RelayMessage::from_json(msg)?)
575    }
576
577    #[tokio::test]
578    async fn relay_stores_and_serves_events() -> Result<()> {
579        let tmp = TempDir::new()?;
580        let graph_store = {
581            let _guard = crate::socialgraph::test_lock();
582            crate::socialgraph::open_social_graph_store_with_mapsize(
583                tmp.path(),
584                Some(128 * 1024 * 1024),
585            )?
586        };
587        let keys = Keys::generate();
588        let mut allowed = HashSet::new();
589        allowed.insert(keys.public_key().to_hex());
590        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
591
592        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
593            Arc::clone(&backend),
594            0,
595            allowed,
596        ));
597
598        let relay_config = NostrRelayConfig {
599            spambox_db_max_bytes: 0,
600            ..Default::default()
601        };
602        let relay = NostrRelay::new(
603            Arc::clone(&backend),
604            tmp.path().to_path_buf(),
605            Some(access),
606            relay_config,
607        )?;
608
609        let (tx, mut rx) = mpsc::unbounded_channel();
610        relay.register_client(1, tx, None).await;
611
612        let event = EventBuilder::new(Kind::TextNote, "hello", []).to_event(&keys)?;
613        relay
614            .handle_client_message(1, NostrClientMessage::event(event.clone()))
615            .await;
616
617        match recv_relay_message(&mut rx).await? {
618            RelayMessage::Ok { status, .. } => assert!(status),
619            other => anyhow::bail!("expected OK, got {:?}", other),
620        }
621
622        tokio::time::sleep(Duration::from_millis(50)).await;
623
624        let sub_id = SubscriptionId::new("sub-1");
625        let filter = Filter::new()
626            .authors(vec![event.pubkey])
627            .kinds(vec![event.kind]);
628        let mut got_event = false;
629        for _ in 0..3 {
630            relay
631                .handle_client_message(
632                    1,
633                    NostrClientMessage::req(sub_id.clone(), vec![filter.clone()]),
634                )
635                .await;
636
637            match recv_relay_message(&mut rx).await? {
638                RelayMessage::Event {
639                    subscription_id,
640                    event: ev,
641                } => {
642                    assert_eq!(subscription_id, sub_id);
643                    assert_eq!(ev.id, event.id);
644                    got_event = true;
645                    break;
646                }
647                RelayMessage::EndOfStoredEvents(id) => {
648                    assert_eq!(id, sub_id);
649                    tokio::time::sleep(Duration::from_millis(100)).await;
650                }
651                other => anyhow::bail!("expected EVENT/EOSE, got {:?}", other),
652            }
653        }
654
655        if !got_event {
656            anyhow::bail!("event not available in time");
657        }
658
659        match recv_relay_message(&mut rx).await? {
660            RelayMessage::EndOfStoredEvents(id) => assert_eq!(id, sub_id),
661            other => anyhow::bail!("expected EOSE, got {:?}", other),
662        }
663
664        Ok(())
665    }
666
667    #[tokio::test]
668    async fn relay_spambox_does_not_serve_untrusted_events() -> Result<()> {
669        let tmp = TempDir::new()?;
670        let graph_store = {
671            let _guard = crate::socialgraph::test_lock();
672            crate::socialgraph::open_social_graph_store_with_mapsize(
673                tmp.path(),
674                Some(128 * 1024 * 1024),
675            )?
676        };
677
678        crate::socialgraph::set_social_graph_root(&graph_store, &[1u8; 32]);
679        std::thread::sleep(std::time::Duration::from_millis(100));
680        let backend: Arc<dyn crate::socialgraph::SocialGraphBackend> = graph_store.clone();
681
682        let access = Arc::new(crate::socialgraph::SocialGraphAccessControl::new(
683            Arc::clone(&backend),
684            0,
685            HashSet::new(),
686        ));
687
688        let relay_config = NostrRelayConfig {
689            spambox_db_max_bytes: 0,
690            ..Default::default()
691        };
692        let relay = NostrRelay::new(
693            Arc::clone(&backend),
694            tmp.path().to_path_buf(),
695            Some(access),
696            relay_config,
697        )?;
698
699        let (tx, mut rx) = mpsc::unbounded_channel();
700        relay.register_client(2, tx, None).await;
701
702        let keys = Keys::generate();
703        let event = EventBuilder::new(Kind::TextNote, "spam", []).to_event(&keys)?;
704        relay
705            .handle_client_message(2, NostrClientMessage::event(event.clone()))
706            .await;
707
708        match recv_relay_message(&mut rx).await? {
709            RelayMessage::Ok { status, .. } => assert!(status),
710            other => anyhow::bail!("expected OK, got {:?}", other),
711        }
712
713        tokio::time::sleep(Duration::from_millis(50)).await;
714
715        let sub_id = SubscriptionId::new("sub-2");
716        let filter = Filter::new()
717            .authors(vec![event.pubkey])
718            .kinds(vec![event.kind]);
719        relay
720            .handle_client_message(2, NostrClientMessage::req(sub_id.clone(), vec![filter]))
721            .await;
722
723        match recv_relay_message(&mut rx).await? {
724            RelayMessage::EndOfStoredEvents(id) => assert_eq!(id, sub_id),
725            other => anyhow::bail!("expected EOSE only, got {:?}", other),
726        }
727
728        Ok(())
729    }
730}