Skip to main content

hashtree_network/
multicast.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use nostr_sdk::nostr::{
4    ClientMessage, Event, Filter, JsonUtil, Keys, RelayMessage, SingleLetterTag, SubscriptionId,
5};
6use socket2::{Domain, Protocol, Socket, Type};
7use std::collections::{HashMap, HashSet};
8use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::net::UdpSocket;
12use tokio::sync::{mpsc, watch, Mutex};
13use tokio::time::Sleep;
14use tracing::{debug, warn};
15
16use crate::local_bus::LocalNostrBus;
17use crate::relay_bridge::SharedMeshEventStore;
18use crate::root_events::{
19    build_root_filter, is_hashtree_labeled_event, pick_latest_event, root_event_from_peer,
20    PeerRootEvent, HASHTREE_KIND, HASHTREE_LABEL,
21};
22
23#[derive(Debug, Clone)]
24pub struct MulticastConfig {
25    pub enabled: bool,
26    pub group: String,
27    pub port: u16,
28    pub max_peers: usize,
29    pub announce_interval_ms: u64,
30}
31
32#[async_trait]
33impl LocalNostrBus for MulticastNostrBus {
34    fn source_name(&self) -> &'static str {
35        "multicast"
36    }
37
38    async fn broadcast_event(&self, event: &Event) -> Result<()> {
39        MulticastNostrBus::broadcast_event(self, event).await
40    }
41
42    async fn query_root(
43        &self,
44        owner_pubkey: &str,
45        tree_name: &str,
46        timeout: Duration,
47    ) -> Option<PeerRootEvent> {
48        MulticastNostrBus::query_root(self, owner_pubkey, tree_name, timeout).await
49    }
50}
51
52impl MulticastConfig {
53    pub fn is_enabled(&self) -> bool {
54        self.enabled && self.max_peers > 0
55    }
56}
57
58impl Default for MulticastConfig {
59    fn default() -> Self {
60        Self {
61            enabled: false,
62            group: "239.255.42.98".to_string(),
63            port: 48555,
64            max_peers: 0,
65            announce_interval_ms: 2_000,
66        }
67    }
68}
69
70pub struct MulticastNostrBus {
71    config: MulticastConfig,
72    keys: Keys,
73    relay: SharedMeshEventStore,
74    socket: Arc<UdpSocket>,
75    target_addr: SocketAddr,
76    pending_queries: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<RelayMessage<'static>>>>>,
77    announced_event_ids: Arc<Mutex<HashSet<String>>>,
78}
79
80const QUERY_SETTLE_GRACE_MS: u64 = 150;
81
82impl MulticastNostrBus {
83    pub async fn bind(
84        config: MulticastConfig,
85        keys: Keys,
86        relay: SharedMeshEventStore,
87    ) -> Result<Arc<Self>> {
88        let group: Ipv4Addr = config
89            .group
90            .parse()
91            .with_context(|| format!("invalid multicast group {}", config.group))?;
92        let std_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
93        std_socket.set_reuse_address(true)?;
94        std_socket.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.port).into())?;
95        std_socket.set_multicast_loop_v4(true)?;
96        std_socket.join_multicast_v4(&group, &Ipv4Addr::UNSPECIFIED)?;
97        std_socket.set_nonblocking(true)?;
98
99        let socket = UdpSocket::from_std(std_socket.into())?;
100        let target_addr = SocketAddr::V4(SocketAddrV4::new(group, config.port));
101
102        Ok(Arc::new(Self {
103            config,
104            keys,
105            relay,
106            socket: Arc::new(socket),
107            target_addr,
108            pending_queries: Arc::new(Mutex::new(HashMap::new())),
109            announced_event_ids: Arc::new(Mutex::new(HashSet::new())),
110        }))
111    }
112
113    pub async fn run(
114        self: Arc<Self>,
115        mut shutdown_rx: watch::Receiver<bool>,
116        signaling_tx: mpsc::Sender<(String, Event)>,
117    ) -> Result<()> {
118        let mut announce_ticker = tokio::time::interval(Duration::from_millis(
119            self.config.announce_interval_ms.max(1),
120        ));
121        let mut buf = vec![0u8; 64 * 1024];
122
123        loop {
124            tokio::select! {
125                _ = shutdown_rx.changed() => {
126                    if *shutdown_rx.borrow() {
127                        break;
128                    }
129                }
130                _ = announce_ticker.tick() => {
131                    if let Err(err) = self.broadcast_known_root_updates().await {
132                        debug!("multicast root announcement failed: {}", err);
133                    }
134                }
135                recv = self.socket.recv_from(&mut buf) => {
136                    let (len, _src) = match recv {
137                        Ok(value) => value,
138                        Err(err) => {
139                            warn!("multicast receive failed: {}", err);
140                            continue;
141                        }
142                    };
143                    let text = match std::str::from_utf8(&buf[..len]) {
144                        Ok(text) => text,
145                        Err(err) => {
146                            debug!("ignoring non-utf8 multicast datagram: {}", err);
147                            continue;
148                        }
149                    };
150                    self.handle_datagram(text, &signaling_tx).await;
151                }
152            }
153        }
154
155        Ok(())
156    }
157
158    pub async fn broadcast_event(&self, event: &Event) -> Result<()> {
159        let payload = event.as_json();
160        let copies = if event.kind.is_ephemeral() { 3 } else { 1 };
161        for _ in 0..copies {
162            self.socket
163                .send_to(payload.as_bytes(), self.target_addr)
164                .await?;
165        }
166        Ok(())
167    }
168
169    pub async fn query_root(
170        &self,
171        owner_pubkey: &str,
172        tree_name: &str,
173        timeout: Duration,
174    ) -> Option<PeerRootEvent> {
175        let filter = build_root_filter(owner_pubkey, tree_name)?;
176        let subscription_id = format!("multicast-root-{}", rand::random::<u64>());
177        let request =
178            ClientMessage::req(SubscriptionId::new(subscription_id.clone()), vec![filter]);
179        let (tx, mut rx) = mpsc::unbounded_channel();
180        self.pending_queries
181            .lock()
182            .await
183            .insert(subscription_id.clone(), tx);
184
185        if self
186            .socket
187            .send_to(request.as_json().as_bytes(), self.target_addr)
188            .await
189            .is_err()
190        {
191            self.pending_queries.lock().await.remove(&subscription_id);
192            return None;
193        }
194
195        let mut events = Vec::new();
196        let deadline = tokio::time::sleep(timeout);
197        tokio::pin!(deadline);
198        let mut settle_deadline: Option<std::pin::Pin<Box<Sleep>>> = None;
199
200        loop {
201            tokio::select! {
202                _ = &mut deadline => break,
203                _ = async {
204                    if let Some(deadline) = &mut settle_deadline {
205                        deadline.as_mut().await;
206                    }
207                }, if settle_deadline.is_some() => break,
208                maybe_msg = rx.recv() => {
209                    let Some(msg) = maybe_msg else {
210                        break;
211                    };
212                    match msg {
213                        RelayMessage::Event { subscription_id: sid, event }
214                            if sid.to_string() == subscription_id =>
215                        {
216                            events.push(event.into_owned());
217                            settle_deadline = Some(Box::pin(tokio::time::sleep(Duration::from_millis(
218                                QUERY_SETTLE_GRACE_MS,
219                            ))));
220                        }
221                        RelayMessage::EndOfStoredEvents(sid)
222                            if sid.to_string() == subscription_id
223                                && !events.is_empty()
224                                && settle_deadline.is_none() =>
225                        {
226                            settle_deadline = Some(Box::pin(tokio::time::sleep(Duration::from_millis(
227                                QUERY_SETTLE_GRACE_MS,
228                            ))));
229                        }
230                        _ => {}
231                    }
232                }
233            }
234        }
235
236        self.pending_queries.lock().await.remove(&subscription_id);
237
238        let latest = pick_latest_event(events.iter())?;
239        root_event_from_peer(latest, self.source_name(), tree_name)
240    }
241
242    async fn handle_datagram(&self, text: &str, signaling_tx: &mpsc::Sender<(String, Event)>) {
243        if let Ok(event) = Event::from_json(text) {
244            if event.pubkey == self.keys.public_key() {
245                return;
246            }
247
248            if event.kind.is_ephemeral() {
249                let _ = signaling_tx.send(("multicast".to_string(), event)).await;
250                return;
251            }
252
253            if event.kind == nostr_sdk::nostr::Kind::Custom(HASHTREE_KIND)
254                && is_hashtree_labeled_event(&event)
255                && event.verify().is_ok()
256            {
257                let _ = self.relay.ingest_trusted_event(event).await;
258            }
259            return;
260        }
261
262        if let Ok(msg) = ClientMessage::from_json(text) {
263            if let ClientMessage::Req {
264                subscription_id,
265                filters,
266            } = msg
267            {
268                for filter in filters {
269                    let limit = filter.limit.unwrap_or(50).min(50);
270                    for event in self.relay.query_events(&filter, limit).await {
271                        let relay_msg =
272                            RelayMessage::event(subscription_id.clone().into_owned(), event);
273                        let _ = self
274                            .socket
275                            .send_to(relay_msg.as_json().as_bytes(), self.target_addr)
276                            .await;
277                    }
278                }
279                let eose = RelayMessage::eose(subscription_id.into_owned());
280                let _ = self
281                    .socket
282                    .send_to(eose.as_json().as_bytes(), self.target_addr)
283                    .await;
284            }
285            return;
286        }
287
288        if let Ok(msg) = RelayMessage::from_json(text) {
289            match &msg {
290                RelayMessage::Event {
291                    subscription_id,
292                    event,
293                } => {
294                    if event.kind == nostr_sdk::nostr::Kind::Custom(HASHTREE_KIND)
295                        && is_hashtree_labeled_event(event)
296                        && event.verify().is_ok()
297                    {
298                        let _ = self.relay.ingest_trusted_event((**event).clone()).await;
299                    }
300                    let tx = self
301                        .pending_queries
302                        .lock()
303                        .await
304                        .get(&subscription_id.to_string())
305                        .cloned();
306                    if let Some(tx) = tx {
307                        let _ = tx.send(RelayMessage::event(
308                            subscription_id.clone().into_owned(),
309                            event.clone().into_owned(),
310                        ));
311                    }
312                }
313                RelayMessage::EndOfStoredEvents(subscription_id) => {
314                    let tx = self
315                        .pending_queries
316                        .lock()
317                        .await
318                        .get(&subscription_id.to_string())
319                        .cloned();
320                    if let Some(tx) = tx {
321                        let _ = tx.send(RelayMessage::eose(subscription_id.clone().into_owned()));
322                    }
323                }
324                _ => {}
325            }
326        }
327    }
328
329    async fn broadcast_known_root_updates(&self) -> Result<()> {
330        let filter = Filter::new()
331            .kind(nostr_sdk::nostr::Kind::Custom(HASHTREE_KIND))
332            .author(self.keys.public_key())
333            .custom_tag(
334                SingleLetterTag::lowercase(nostr_sdk::nostr::Alphabet::L),
335                HASHTREE_LABEL.to_string(),
336            )
337            .limit(256);
338        let events = self.relay.query_events(&filter, 256).await;
339        let mut announced = self.announced_event_ids.lock().await;
340        for event in events {
341            let event_id = event.id.to_hex();
342            if announced.insert(event_id) {
343                self.broadcast_event(&event).await?;
344            }
345        }
346        Ok(())
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::relay_bridge::MeshEventStore;
354    use anyhow::Result;
355    use nostr_sdk::nostr::{Alphabet, EventBuilder, Kind, Tag, TagKind};
356    use std::time::{SystemTime, UNIX_EPOCH};
357
358    const HASHTREE_LABEL: &str = "hashtree";
359
360    #[derive(Default)]
361    struct TestEventStore {
362        events: Mutex<Vec<Event>>,
363    }
364
365    #[async_trait]
366    impl MeshEventStore for TestEventStore {
367        async fn ingest_trusted_event(&self, event: Event) -> Result<()> {
368            self.events.lock().await.push(event);
369            Ok(())
370        }
371
372        async fn query_events(&self, filter: &Filter, limit: usize) -> Vec<Event> {
373            self.events
374                .lock()
375                .await
376                .iter()
377                .filter(|event| filter.match_event(event, Default::default()))
378                .take(limit)
379                .cloned()
380                .collect()
381        }
382    }
383
384    fn unique_multicast_port() -> u16 {
385        let nanos = SystemTime::now()
386            .duration_since(UNIX_EPOCH)
387            .unwrap_or_default()
388            .subsec_nanos();
389        40000 + (nanos % 2000) as u16
390    }
391
392    fn build_root_event(keys: &Keys, tree_name: &str, hash_hex: &str) -> Event {
393        EventBuilder::new(Kind::Custom(HASHTREE_KIND), "")
394            .tags([
395                Tag::identifier(tree_name.to_string()),
396                Tag::custom(
397                    TagKind::SingleLetter(SingleLetterTag::lowercase(Alphabet::L)),
398                    vec![HASHTREE_LABEL.to_string()],
399                ),
400                Tag::custom(TagKind::Custom("hash".into()), vec![hash_hex.to_string()]),
401            ])
402            .sign_with_keys(keys)
403            .expect("root event")
404    }
405
406    #[tokio::test]
407    async fn query_root_ignores_early_eose_until_grace_period_expires() -> Result<()> {
408        let keys = Keys::generate();
409        let owner_keys = Keys::generate();
410        let relay = Arc::new(TestEventStore::default()) as SharedMeshEventStore;
411        let bus = MulticastNostrBus::bind(
412            MulticastConfig {
413                enabled: true,
414                group: "239.255.43.10".to_string(),
415                port: unique_multicast_port(),
416                max_peers: 4,
417                announce_interval_ms: 60_000,
418            },
419            keys,
420            relay,
421        )
422        .await?;
423
424        let tree_name = "eose-race";
425        let hash_hex = "ef".repeat(32);
426        let event = build_root_event(&owner_keys, tree_name, &hash_hex);
427
428        let query_bus = Arc::clone(&bus);
429        let query = tokio::spawn(async move {
430            query_bus
431                .query_root(
432                    &owner_keys.public_key().to_hex(),
433                    tree_name,
434                    Duration::from_millis(500),
435                )
436                .await
437        });
438
439        let subscription_id = tokio::time::timeout(Duration::from_secs(1), async {
440            loop {
441                if let Some(subscription_id) =
442                    bus.pending_queries.lock().await.keys().next().cloned()
443                {
444                    break subscription_id;
445                }
446                tokio::time::sleep(Duration::from_millis(10)).await;
447            }
448        })
449        .await
450        .expect("query registered pending subscription");
451
452        let (signal_tx, _signal_rx) = mpsc::channel(1);
453        bus.handle_datagram(
454            &RelayMessage::eose(SubscriptionId::new(subscription_id.clone())).as_json(),
455            &signal_tx,
456        )
457        .await;
458        bus.handle_datagram(
459            &RelayMessage::event(SubscriptionId::new(subscription_id), event.clone()).as_json(),
460            &signal_tx,
461        )
462        .await;
463
464        let resolved = query.await.expect("query task completed");
465        let resolved = resolved.expect("query returned root event after early eose");
466        assert_eq!(resolved.hash, hash_hex);
467        assert_eq!(resolved.event_id, event.id.to_hex());
468        Ok(())
469    }
470}