Skip to main content

hashtree_cli/webrtc/
multicast.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use nostr::{ClientMessage, Event, Filter, JsonUtil, Keys, RelayMessage};
4use socket2::{Domain, Protocol, Socket, Type};
5use std::collections::{HashMap, HashSet};
6use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::net::UdpSocket;
10use tokio::sync::{mpsc, watch, Mutex};
11use tracing::{debug, warn};
12
13use super::root_events::{
14    build_root_filter, is_hashtree_labeled_event, pick_latest_event, root_event_from_peer,
15    PeerRootEvent, HASHTREE_KIND,
16};
17use super::LocalNostrBus;
18use crate::nostr_relay::NostrRelay;
19
20#[derive(Debug, Clone)]
21pub struct MulticastConfig {
22    pub enabled: bool,
23    pub group: String,
24    pub port: u16,
25    pub max_peers: usize,
26    pub announce_interval_ms: u64,
27}
28
29#[async_trait]
30impl LocalNostrBus for MulticastNostrBus {
31    fn source_name(&self) -> &'static str {
32        "multicast"
33    }
34
35    async fn broadcast_event(&self, event: &Event) -> Result<()> {
36        MulticastNostrBus::broadcast_event(self, event).await
37    }
38
39    async fn query_root(
40        &self,
41        owner_pubkey: &str,
42        tree_name: &str,
43        timeout: Duration,
44    ) -> Option<PeerRootEvent> {
45        MulticastNostrBus::query_root(self, owner_pubkey, tree_name, timeout).await
46    }
47}
48
49impl MulticastConfig {
50    pub fn is_enabled(&self) -> bool {
51        self.enabled && self.max_peers > 0
52    }
53}
54
55impl Default for MulticastConfig {
56    fn default() -> Self {
57        Self {
58            enabled: false,
59            group: "239.255.42.98".to_string(),
60            port: 48555,
61            max_peers: 0,
62            announce_interval_ms: 2_000,
63        }
64    }
65}
66
67pub struct MulticastNostrBus {
68    config: MulticastConfig,
69    keys: Keys,
70    relay: Arc<NostrRelay>,
71    socket: Arc<UdpSocket>,
72    target_addr: SocketAddr,
73    pending_queries: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<RelayMessage>>>>,
74    announced_event_ids: Arc<Mutex<HashSet<String>>>,
75}
76
77impl MulticastNostrBus {
78    pub async fn bind(
79        config: MulticastConfig,
80        keys: Keys,
81        relay: Arc<NostrRelay>,
82    ) -> Result<Arc<Self>> {
83        let group: Ipv4Addr = config
84            .group
85            .parse()
86            .with_context(|| format!("invalid multicast group {}", config.group))?;
87        let std_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
88        std_socket.set_reuse_address(true)?;
89        #[cfg(unix)]
90        std_socket.set_reuse_port(true)?;
91        std_socket.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.port).into())?;
92        std_socket.set_multicast_loop_v4(true)?;
93        std_socket.join_multicast_v4(&group, &Ipv4Addr::UNSPECIFIED)?;
94        std_socket.set_nonblocking(true)?;
95
96        let socket = UdpSocket::from_std(std_socket.into())?;
97        let target_addr = SocketAddr::V4(SocketAddrV4::new(group, config.port));
98
99        Ok(Arc::new(Self {
100            config,
101            keys,
102            relay,
103            socket: Arc::new(socket),
104            target_addr,
105            pending_queries: Arc::new(Mutex::new(HashMap::new())),
106            announced_event_ids: Arc::new(Mutex::new(HashSet::new())),
107        }))
108    }
109
110    pub async fn run(
111        self: Arc<Self>,
112        mut shutdown_rx: watch::Receiver<bool>,
113        signaling_tx: mpsc::Sender<(String, Event)>,
114    ) -> Result<()> {
115        let mut announce_ticker = tokio::time::interval(Duration::from_millis(
116            self.config.announce_interval_ms.max(1),
117        ));
118        let mut buf = vec![0u8; 64 * 1024];
119
120        loop {
121            tokio::select! {
122                _ = shutdown_rx.changed() => {
123                    if *shutdown_rx.borrow() {
124                        break;
125                    }
126                }
127                _ = announce_ticker.tick() => {
128                    if let Err(err) = self.broadcast_known_root_updates().await {
129                        debug!("multicast root announcement failed: {}", err);
130                    }
131                }
132                recv = self.socket.recv_from(&mut buf) => {
133                    let (len, _src) = match recv {
134                        Ok(value) => value,
135                        Err(err) => {
136                            warn!("multicast receive failed: {}", err);
137                            continue;
138                        }
139                    };
140                    let text = match std::str::from_utf8(&buf[..len]) {
141                        Ok(text) => text,
142                        Err(err) => {
143                            debug!("ignoring non-utf8 multicast datagram: {}", err);
144                            continue;
145                        }
146                    };
147                    self.handle_datagram(text, &signaling_tx).await;
148                }
149            }
150        }
151
152        Ok(())
153    }
154
155    pub async fn broadcast_event(&self, event: &Event) -> Result<()> {
156        let payload = event.as_json();
157        let copies = if event.kind.is_ephemeral() { 3 } else { 1 };
158        for _ in 0..copies {
159            self.socket
160                .send_to(payload.as_bytes(), self.target_addr)
161                .await?;
162        }
163        Ok(())
164    }
165
166    pub async fn query_root(
167        &self,
168        owner_pubkey: &str,
169        tree_name: &str,
170        timeout: Duration,
171    ) -> Option<PeerRootEvent> {
172        let filter = build_root_filter(owner_pubkey, tree_name)?;
173        let subscription_id = format!("multicast-root-{}", rand::random::<u64>());
174        let request = ClientMessage::req(
175            nostr::SubscriptionId::new(subscription_id.clone()),
176            vec![filter],
177        );
178        let (tx, mut rx) = mpsc::unbounded_channel();
179        self.pending_queries
180            .lock()
181            .await
182            .insert(subscription_id.clone(), tx);
183
184        if self
185            .socket
186            .send_to(request.as_json().as_bytes(), self.target_addr)
187            .await
188            .is_err()
189        {
190            self.pending_queries.lock().await.remove(&subscription_id);
191            return None;
192        }
193
194        let mut events = Vec::new();
195        let deadline = tokio::time::sleep(timeout);
196        tokio::pin!(deadline);
197
198        loop {
199            tokio::select! {
200                _ = &mut deadline => break,
201                maybe_msg = rx.recv() => {
202                    let Some(msg) = maybe_msg else {
203                        break;
204                    };
205                    match msg {
206                        RelayMessage::Event { subscription_id: sid, event }
207                            if sid.to_string() == subscription_id =>
208                        {
209                            events.push(*event);
210                        }
211                        RelayMessage::EndOfStoredEvents(sid) if sid.to_string() == subscription_id => {
212                            break;
213                        }
214                        _ => {}
215                    }
216                }
217            }
218        }
219
220        self.pending_queries.lock().await.remove(&subscription_id);
221
222        let latest = pick_latest_event(events.iter())?;
223        root_event_from_peer(latest, self.source_name(), tree_name)
224    }
225
226    async fn handle_datagram(&self, text: &str, signaling_tx: &mpsc::Sender<(String, Event)>) {
227        if let Ok(event) = Event::from_json(text) {
228            if event.pubkey == self.keys.public_key() {
229                return;
230            }
231
232            if event.kind.is_ephemeral() {
233                let _ = signaling_tx.send(("multicast".to_string(), event)).await;
234                return;
235            }
236
237            if event.kind == nostr::Kind::Custom(HASHTREE_KIND)
238                && is_hashtree_labeled_event(&event)
239                && event.verify().is_ok()
240            {
241                let _ = self.relay.ingest_trusted_event(event).await;
242            }
243            return;
244        }
245
246        if let Ok(msg) = ClientMessage::from_json(text) {
247            if let ClientMessage::Req {
248                subscription_id,
249                filters,
250            } = msg
251            {
252                for filter in filters {
253                    let limit = filter.limit.unwrap_or(50).min(50);
254                    for event in self.relay.query_events(&filter, limit).await {
255                        let relay_msg = RelayMessage::event(subscription_id.clone(), event);
256                        let _ = self
257                            .socket
258                            .send_to(relay_msg.as_json().as_bytes(), self.target_addr)
259                            .await;
260                    }
261                }
262                let eose = RelayMessage::eose(subscription_id);
263                let _ = self
264                    .socket
265                    .send_to(eose.as_json().as_bytes(), self.target_addr)
266                    .await;
267            }
268            return;
269        }
270
271        if let Ok(msg) = RelayMessage::from_json(text) {
272            match &msg {
273                RelayMessage::Event {
274                    subscription_id,
275                    event,
276                } => {
277                    if event.kind == nostr::Kind::Custom(HASHTREE_KIND)
278                        && is_hashtree_labeled_event(event)
279                        && event.verify().is_ok()
280                    {
281                        let _ = self.relay.ingest_trusted_event((**event).clone()).await;
282                    }
283                    let tx = self
284                        .pending_queries
285                        .lock()
286                        .await
287                        .get(&subscription_id.to_string())
288                        .cloned();
289                    if let Some(tx) = tx {
290                        let _ = tx.send(msg);
291                    }
292                }
293                RelayMessage::EndOfStoredEvents(subscription_id) => {
294                    let tx = self
295                        .pending_queries
296                        .lock()
297                        .await
298                        .get(&subscription_id.to_string())
299                        .cloned();
300                    if let Some(tx) = tx {
301                        let _ = tx.send(msg);
302                    }
303                }
304                _ => {}
305            }
306        }
307    }
308
309    async fn broadcast_known_root_updates(&self) -> Result<()> {
310        let filter = Filter::new()
311            .kind(nostr::Kind::Custom(HASHTREE_KIND))
312            .author(self.keys.public_key())
313            .custom_tag(
314                nostr::SingleLetterTag::lowercase(nostr::Alphabet::L),
315                vec![super::root_events::HASHTREE_LABEL.to_string()],
316            )
317            .limit(256);
318        let events = self.relay.query_events(&filter, 256).await;
319        let mut announced = self.announced_event_ids.lock().await;
320        for event in events {
321            let event_id = event.id.to_hex();
322            if announced.insert(event_id) {
323                self.broadcast_event(&event).await?;
324            }
325        }
326        Ok(())
327    }
328}