Skip to main content

hashtree_cli/webrtc/
multicast.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use nostr::{ClientMessage, Event, Filter, JsonUtil, Keys, RelayMessage};
4use socket2::{Domain, Protocol, Socket, Type};
5use std::collections::{HashMap, HashSet};
6use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::net::UdpSocket;
10use tokio::sync::{mpsc, watch, Mutex};
11use tokio::time::Sleep;
12use tracing::{debug, warn};
13
14use super::root_events::{
15    build_root_filter, is_hashtree_labeled_event, pick_latest_event, root_event_from_peer,
16    PeerRootEvent, HASHTREE_KIND,
17};
18use super::LocalNostrBus;
19use crate::nostr_relay::NostrRelay;
20
21#[derive(Debug, Clone)]
22pub struct MulticastConfig {
23    pub enabled: bool,
24    pub group: String,
25    pub port: u16,
26    pub max_peers: usize,
27    pub announce_interval_ms: u64,
28}
29
30#[async_trait]
31impl LocalNostrBus for MulticastNostrBus {
32    fn source_name(&self) -> &'static str {
33        "multicast"
34    }
35
36    async fn broadcast_event(&self, event: &Event) -> Result<()> {
37        MulticastNostrBus::broadcast_event(self, event).await
38    }
39
40    async fn query_root(
41        &self,
42        owner_pubkey: &str,
43        tree_name: &str,
44        timeout: Duration,
45    ) -> Option<PeerRootEvent> {
46        MulticastNostrBus::query_root(self, owner_pubkey, tree_name, timeout).await
47    }
48}
49
50impl MulticastConfig {
51    pub fn is_enabled(&self) -> bool {
52        self.enabled && self.max_peers > 0
53    }
54}
55
56impl Default for MulticastConfig {
57    fn default() -> Self {
58        Self {
59            enabled: false,
60            group: "239.255.42.98".to_string(),
61            port: 48555,
62            max_peers: 0,
63            announce_interval_ms: 2_000,
64        }
65    }
66}
67
68pub struct MulticastNostrBus {
69    config: MulticastConfig,
70    keys: Keys,
71    relay: Arc<NostrRelay>,
72    socket: Arc<UdpSocket>,
73    target_addr: SocketAddr,
74    pending_queries: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<RelayMessage>>>>,
75    announced_event_ids: Arc<Mutex<HashSet<String>>>,
76}
77
78const QUERY_SETTLE_GRACE_MS: u64 = 150;
79
80impl MulticastNostrBus {
81    pub async fn bind(
82        config: MulticastConfig,
83        keys: Keys,
84        relay: Arc<NostrRelay>,
85    ) -> Result<Arc<Self>> {
86        let group: Ipv4Addr = config
87            .group
88            .parse()
89            .with_context(|| format!("invalid multicast group {}", config.group))?;
90        let std_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
91        std_socket.set_reuse_address(true)?;
92        #[cfg(unix)]
93        std_socket.set_reuse_port(true)?;
94        std_socket.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.port).into())?;
95        std_socket.set_multicast_loop_v4(true)?;
96        std_socket.join_multicast_v4(&group, &Ipv4Addr::UNSPECIFIED)?;
97        std_socket.set_nonblocking(true)?;
98
99        let socket = UdpSocket::from_std(std_socket.into())?;
100        let target_addr = SocketAddr::V4(SocketAddrV4::new(group, config.port));
101
102        Ok(Arc::new(Self {
103            config,
104            keys,
105            relay,
106            socket: Arc::new(socket),
107            target_addr,
108            pending_queries: Arc::new(Mutex::new(HashMap::new())),
109            announced_event_ids: Arc::new(Mutex::new(HashSet::new())),
110        }))
111    }
112
113    pub async fn run(
114        self: Arc<Self>,
115        mut shutdown_rx: watch::Receiver<bool>,
116        signaling_tx: mpsc::Sender<(String, Event)>,
117    ) -> Result<()> {
118        let mut announce_ticker = tokio::time::interval(Duration::from_millis(
119            self.config.announce_interval_ms.max(1),
120        ));
121        let mut buf = vec![0u8; 64 * 1024];
122
123        loop {
124            tokio::select! {
125                _ = shutdown_rx.changed() => {
126                    if *shutdown_rx.borrow() {
127                        break;
128                    }
129                }
130                _ = announce_ticker.tick() => {
131                    if let Err(err) = self.broadcast_known_root_updates().await {
132                        debug!("multicast root announcement failed: {}", err);
133                    }
134                }
135                recv = self.socket.recv_from(&mut buf) => {
136                    let (len, _src) = match recv {
137                        Ok(value) => value,
138                        Err(err) => {
139                            warn!("multicast receive failed: {}", err);
140                            continue;
141                        }
142                    };
143                    let text = match std::str::from_utf8(&buf[..len]) {
144                        Ok(text) => text,
145                        Err(err) => {
146                            debug!("ignoring non-utf8 multicast datagram: {}", err);
147                            continue;
148                        }
149                    };
150                    self.handle_datagram(text, &signaling_tx).await;
151                }
152            }
153        }
154
155        Ok(())
156    }
157
158    pub async fn broadcast_event(&self, event: &Event) -> Result<()> {
159        let payload = event.as_json();
160        let copies = if event.kind.is_ephemeral() { 3 } else { 1 };
161        for _ in 0..copies {
162            self.socket
163                .send_to(payload.as_bytes(), self.target_addr)
164                .await?;
165        }
166        Ok(())
167    }
168
169    pub async fn query_root(
170        &self,
171        owner_pubkey: &str,
172        tree_name: &str,
173        timeout: Duration,
174    ) -> Option<PeerRootEvent> {
175        let filter = build_root_filter(owner_pubkey, tree_name)?;
176        let subscription_id = format!("multicast-root-{}", rand::random::<u64>());
177        let request = ClientMessage::req(
178            nostr::SubscriptionId::new(subscription_id.clone()),
179            vec![filter],
180        );
181        let (tx, mut rx) = mpsc::unbounded_channel();
182        self.pending_queries
183            .lock()
184            .await
185            .insert(subscription_id.clone(), tx);
186
187        if self
188            .socket
189            .send_to(request.as_json().as_bytes(), self.target_addr)
190            .await
191            .is_err()
192        {
193            self.pending_queries.lock().await.remove(&subscription_id);
194            return None;
195        }
196
197        let mut events = Vec::new();
198        let deadline = tokio::time::sleep(timeout);
199        tokio::pin!(deadline);
200        let mut settle_deadline: Option<std::pin::Pin<Box<Sleep>>> = None;
201
202        loop {
203            tokio::select! {
204                _ = &mut deadline => break,
205                _ = async {
206                    if let Some(deadline) = &mut settle_deadline {
207                        deadline.as_mut().await;
208                    }
209                }, if settle_deadline.is_some() => break,
210                maybe_msg = rx.recv() => {
211                    let Some(msg) = maybe_msg else {
212                        break;
213                    };
214                    match msg {
215                        RelayMessage::Event { subscription_id: sid, event }
216                            if sid.to_string() == subscription_id =>
217                        {
218                            events.push(*event);
219                            settle_deadline = Some(Box::pin(tokio::time::sleep(Duration::from_millis(
220                                QUERY_SETTLE_GRACE_MS,
221                            ))));
222                        }
223                        RelayMessage::EndOfStoredEvents(sid) if sid.to_string() == subscription_id => {
224                            if !events.is_empty() && settle_deadline.is_none() {
225                                settle_deadline = Some(Box::pin(tokio::time::sleep(Duration::from_millis(
226                                    QUERY_SETTLE_GRACE_MS,
227                                ))));
228                            }
229                        }
230                        _ => {}
231                    }
232                }
233            }
234        }
235
236        self.pending_queries.lock().await.remove(&subscription_id);
237
238        let latest = pick_latest_event(events.iter())?;
239        root_event_from_peer(latest, self.source_name(), tree_name)
240    }
241
242    async fn handle_datagram(&self, text: &str, signaling_tx: &mpsc::Sender<(String, Event)>) {
243        if let Ok(event) = Event::from_json(text) {
244            if event.pubkey == self.keys.public_key() {
245                return;
246            }
247
248            if event.kind.is_ephemeral() {
249                let _ = signaling_tx.send(("multicast".to_string(), event)).await;
250                return;
251            }
252
253            if event.kind == nostr::Kind::Custom(HASHTREE_KIND)
254                && is_hashtree_labeled_event(&event)
255                && event.verify().is_ok()
256            {
257                let _ = self.relay.ingest_trusted_event(event).await;
258            }
259            return;
260        }
261
262        if let Ok(msg) = ClientMessage::from_json(text) {
263            if let ClientMessage::Req {
264                subscription_id,
265                filters,
266            } = msg
267            {
268                for filter in filters {
269                    let limit = filter.limit.unwrap_or(50).min(50);
270                    for event in self.relay.query_events(&filter, limit).await {
271                        let relay_msg = RelayMessage::event(subscription_id.clone(), event);
272                        let _ = self
273                            .socket
274                            .send_to(relay_msg.as_json().as_bytes(), self.target_addr)
275                            .await;
276                    }
277                }
278                let eose = RelayMessage::eose(subscription_id);
279                let _ = self
280                    .socket
281                    .send_to(eose.as_json().as_bytes(), self.target_addr)
282                    .await;
283            }
284            return;
285        }
286
287        if let Ok(msg) = RelayMessage::from_json(text) {
288            match &msg {
289                RelayMessage::Event {
290                    subscription_id,
291                    event,
292                } => {
293                    if event.kind == nostr::Kind::Custom(HASHTREE_KIND)
294                        && is_hashtree_labeled_event(event)
295                        && event.verify().is_ok()
296                    {
297                        let _ = self.relay.ingest_trusted_event((**event).clone()).await;
298                    }
299                    let tx = self
300                        .pending_queries
301                        .lock()
302                        .await
303                        .get(&subscription_id.to_string())
304                        .cloned();
305                    if let Some(tx) = tx {
306                        let _ = tx.send(msg);
307                    }
308                }
309                RelayMessage::EndOfStoredEvents(subscription_id) => {
310                    let tx = self
311                        .pending_queries
312                        .lock()
313                        .await
314                        .get(&subscription_id.to_string())
315                        .cloned();
316                    if let Some(tx) = tx {
317                        let _ = tx.send(msg);
318                    }
319                }
320                _ => {}
321            }
322        }
323    }
324
325    async fn broadcast_known_root_updates(&self) -> Result<()> {
326        let filter = Filter::new()
327            .kind(nostr::Kind::Custom(HASHTREE_KIND))
328            .author(self.keys.public_key())
329            .custom_tag(
330                nostr::SingleLetterTag::lowercase(nostr::Alphabet::L),
331                vec![super::root_events::HASHTREE_LABEL.to_string()],
332            )
333            .limit(256);
334        let events = self.relay.query_events(&filter, 256).await;
335        let mut announced = self.announced_event_ids.lock().await;
336        for event in events {
337            let event_id = event.id.to_hex();
338            if announced.insert(event_id) {
339                self.broadcast_event(&event).await?;
340            }
341        }
342        Ok(())
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::nostr_relay::{NostrRelay, NostrRelayConfig};
350    use crate::socialgraph;
351    use nostr::{Alphabet, EventBuilder, Kind, SingleLetterTag, Tag, TagKind};
352    use std::time::{SystemTime, UNIX_EPOCH};
353    use tempfile::TempDir;
354
355    const HASHTREE_LABEL: &str = "hashtree";
356
357    fn unique_multicast_port() -> u16 {
358        let nanos = SystemTime::now()
359            .duration_since(UNIX_EPOCH)
360            .unwrap_or_default()
361            .subsec_nanos();
362        40000 + (nanos % 2000) as u16
363    }
364
365    fn build_root_event(keys: &Keys, tree_name: &str, hash_hex: &str) -> Event {
366        EventBuilder::new(
367            Kind::Custom(HASHTREE_KIND),
368            "",
369            [
370                Tag::identifier(tree_name.to_string()),
371                Tag::custom(
372                    TagKind::SingleLetter(SingleLetterTag::lowercase(Alphabet::L)),
373                    vec![HASHTREE_LABEL.to_string()],
374                ),
375                Tag::custom(TagKind::Custom("hash".into()), vec![hash_hex.to_string()]),
376            ],
377        )
378        .to_event(keys)
379        .expect("root event")
380    }
381
382    async fn make_relay(dir: &TempDir, allowed_pubkey: String) -> Result<Arc<NostrRelay>> {
383        let graph_store =
384            socialgraph::open_social_graph_store_with_mapsize(dir.path(), Some(128 * 1024 * 1024))?;
385        let backend: Arc<dyn socialgraph::SocialGraphBackend> = graph_store.clone();
386        let mut allowed = HashSet::new();
387        allowed.insert(allowed_pubkey.clone());
388        let access = Arc::new(socialgraph::SocialGraphAccessControl::new(
389            Arc::clone(&backend),
390            0,
391            allowed,
392        ));
393
394        Ok(Arc::new(NostrRelay::new(
395            backend,
396            dir.path().to_path_buf(),
397            HashSet::from([allowed_pubkey]),
398            Some(access),
399            NostrRelayConfig {
400                spambox_db_max_bytes: 0,
401                ..Default::default()
402            },
403        )?))
404    }
405
406    #[tokio::test]
407    async fn query_root_ignores_early_eose_until_grace_period_expires() -> Result<()> {
408        let keys = Keys::generate();
409        let owner_keys = Keys::generate();
410        let dir = TempDir::new()?;
411        let relay = make_relay(&dir, keys.public_key().to_hex()).await?;
412        let bus = MulticastNostrBus::bind(
413            MulticastConfig {
414                enabled: true,
415                group: "239.255.43.10".to_string(),
416                port: unique_multicast_port(),
417                max_peers: 4,
418                announce_interval_ms: 60_000,
419            },
420            keys,
421            relay,
422        )
423        .await?;
424
425        let tree_name = "eose-race";
426        let hash_hex = "ef".repeat(32);
427        let event = build_root_event(&owner_keys, tree_name, &hash_hex);
428
429        let query_bus = Arc::clone(&bus);
430        let query = tokio::spawn(async move {
431            query_bus
432                .query_root(
433                    &owner_keys.public_key().to_hex(),
434                    tree_name,
435                    Duration::from_millis(500),
436                )
437                .await
438        });
439
440        let subscription_id = tokio::time::timeout(Duration::from_secs(1), async {
441            loop {
442                if let Some(subscription_id) =
443                    bus.pending_queries.lock().await.keys().next().cloned()
444                {
445                    break subscription_id;
446                }
447                tokio::time::sleep(Duration::from_millis(10)).await;
448            }
449        })
450        .await
451        .expect("query registered pending subscription");
452
453        let (signal_tx, _signal_rx) = mpsc::channel(1);
454        bus.handle_datagram(
455            &RelayMessage::eose(nostr::SubscriptionId::new(subscription_id.clone())).as_json(),
456            &signal_tx,
457        )
458        .await;
459        bus.handle_datagram(
460            &RelayMessage::event(nostr::SubscriptionId::new(subscription_id), event.clone())
461                .as_json(),
462            &signal_tx,
463        )
464        .await;
465
466        let resolved = query.await.expect("query task completed");
467        let resolved = resolved.expect("query returned root event after early eose");
468        assert_eq!(resolved.hash, hash_hex);
469        assert_eq!(resolved.event_id, event.id.to_hex());
470        Ok(())
471    }
472}