Skip to main content

nostro2_ring_relay/
lib.rs

1use nostro2::NostrRelayEvent;
2use nostro2_cache::Cache;
3use quetzalcoatl::broadcast;
4use quetzalcoatl::capacity::Capacity;
5use quetzalcoatl::mpsc::{Consumer, Producer, RingBuffer};
6use std::net::TcpStream;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use tungstenite::stream::MaybeTlsStream;
10use tungstenite::{connect, Message, WebSocket};
11
12/// Messages that flow through the ring buffer from relay threads to consumer
13#[derive(Debug, Clone)]
14pub enum PoolMessage {
15    /// Event received from a relay
16    RelayEvent {
17        /// URL of the relay that sent this event
18        relay_url: String,
19        /// The actual relay event
20        event: NostrRelayEvent,
21    },
22    /// Connection error or closed
23    ConnectionClosed {
24        relay_url: String,
25        error: Option<String>,
26    },
27}
28
29/// Sender handle for broadcasting messages to all connected relays.
30///
31/// Clone this to send from multiple threads — the broadcast Producer uses CAS
32/// internally, so concurrent pushes are lock-free.
33#[derive(Clone)]
34pub struct PoolSender {
35    producer: broadcast::Producer<String>,
36}
37
38impl PoolSender {
39    /// Send a `NostrClientEvent` to all connected relays.
40    ///
41    /// Serializes to JSON once; each relay thread sends the pre-serialized string.
42    /// Returns `Err` if the broadcast ring is full (all relay threads behind).
43    pub fn send<T: Into<nostro2::NostrClientEvent>>(&self, msg: T) -> Result<(), String> {
44        let client_event: nostro2::NostrClientEvent = msg.into();
45        let json = serde_json::to_string(&client_event).map_err(|e| e.to_string())?;
46        self.producer.push(json)
47    }
48
49    /// Send a raw pre-serialized JSON string to all relays.
50    ///
51    /// Use this when you've already serialized the message.
52    pub fn send_raw(&self, json: String) -> Result<(), String> {
53        self.producer.push(json)
54    }
55}
56
57/// Handle to a relay WebSocket connection running in its own thread.
58///
59/// Each connection runs in a dedicated OS thread with non-blocking I/O.
60/// The thread can be signaled to shut down via an atomic flag.
61pub struct RelayConnection {
62    relay_url: String,
63    thread_handle: Option<std::thread::JoinHandle<()>>,
64    shutdown: Arc<AtomicBool>,
65}
66
67impl RelayConnection {
68    /// Spawn a new thread that connects to a relay with bidirectional messaging.
69    ///
70    /// The thread reads inbound events into the MPSC ring buffer and sends
71    /// outbound messages from the broadcast consumer to the WebSocket.
72    pub fn spawn(
73        relay_url: String,
74        mut producer: Producer<PoolMessage>,
75        outbound: broadcast::Consumer<String>,
76        shutdown: Arc<AtomicBool>,
77    ) -> Self {
78        let url = relay_url.clone();
79        let shutdown_clone = Arc::clone(&shutdown);
80        let thread_handle = std::thread::spawn(move || {
81            match Self::run_connection(&url, &mut producer, outbound, &shutdown_clone) {
82                Ok(()) => {
83                    let _ = producer.push(PoolMessage::ConnectionClosed {
84                        relay_url: url.clone(),
85                        error: None,
86                    });
87                }
88                Err(e) => {
89                    let _ = producer.push(PoolMessage::ConnectionClosed {
90                        relay_url: url.clone(),
91                        error: Some(e.to_string()),
92                    });
93                }
94            }
95        });
96
97        Self {
98            relay_url,
99            thread_handle: Some(thread_handle),
100            shutdown,
101        }
102    }
103
104    /// Returns `true` if the connection thread has exited.
105    pub fn is_finished(&self) -> bool {
106        self.thread_handle
107            .as_ref()
108            .is_some_and(|h| h.is_finished())
109    }
110
111    /// Signal the connection thread to shut down gracefully.
112    ///
113    /// The thread will send a WebSocket Close frame and exit within one
114    /// poll cycle (~1ms). Does not block.
115    pub fn request_shutdown(&self) {
116        self.shutdown.store(true, Ordering::Relaxed);
117    }
118
119    /// Signal shutdown and block until the thread exits.
120    fn shutdown_and_join(&mut self) {
121        self.shutdown.store(true, Ordering::Relaxed);
122        if let Some(handle) = self.thread_handle.take() {
123            let _ = handle.join();
124        }
125    }
126
127    /// Main connection loop — non-blocking, multiplexed read/write.
128    ///
129    /// 1. Connects and performs WebSocket handshake (blocking)
130    /// 2. Sends the initial subscription (blocking)
131    /// 3. Switches to non-blocking mode
132    /// 4. Loops: check shutdown → try read inbound → drain outbound → sleep if idle
133    fn run_connection(
134        url: &str,
135        producer: &mut Producer<PoolMessage>,
136        mut outbound: broadcast::Consumer<String>,
137        shutdown: &AtomicBool,
138    ) -> Result<(), Box<dyn std::error::Error>> {
139        // Install default crypto provider for this thread (required for rustls 0.23+)
140        let _ = rustls::crypto::ring::default_provider().install_default();
141
142        let (mut socket, _response) = connect(url)?;
143
144        // Subscribe to kind 1 events (text notes) with limit 1000
145        let subscription = nostro2::NostrSubscription {
146            kinds: vec![1].into(),
147            limit: Some(1000),
148            ..Default::default()
149        };
150
151        // Convert to NostrClientEvent and send (still blocking at this point)
152        let client_event: nostro2::NostrClientEvent = subscription.into();
153        let subscription_json = serde_json::to_string(&client_event)?;
154        socket.send(Message::Text(subscription_json.into()))?;
155
156        // Switch to non-blocking for the multiplexed loop
157        set_nonblocking(&socket, true)?;
158
159        loop {
160            // Check shutdown signal (Relaxed — no need for immediate visibility)
161            if shutdown.load(Ordering::Relaxed) {
162                let _ = socket.send(Message::Close(None));
163                break;
164            }
165
166            let mut had_work = false;
167
168            // 1. Try reading inbound (returns WouldBlock instantly if empty)
169            match socket.read() {
170                Ok(Message::Text(text)) => {
171                    if let Ok(event) = text.parse::<NostrRelayEvent>() {
172                        let mut pool_msg = PoolMessage::RelayEvent {
173                            relay_url: url.to_string(),
174                            event,
175                        };
176                        loop {
177                            match producer.push(pool_msg) {
178                                Ok(()) => break,
179                                Err(returned) => {
180                                    pool_msg = returned;
181                                    std::hint::spin_loop();
182                                }
183                            }
184                        }
185                    }
186                    had_work = true;
187                }
188                Ok(Message::Close(_)) => break,
189                Ok(Message::Ping(data)) => {
190                    // Pong may WouldBlock — data is buffered internally by tungstenite
191                    // and will flush on the next successful I/O operation
192                    let _ = socket.send(Message::Pong(data));
193                    had_work = true;
194                }
195                Ok(_) => {
196                    had_work = true;
197                }
198                Err(tungstenite::Error::Io(ref e))
199                    if e.kind() == std::io::ErrorKind::WouldBlock =>
200                {
201                    // No data available — fall through to check outbound
202                }
203                Err(e) => return Err(e.into()),
204            }
205
206            // 2. Drain outbound broadcast messages
207            while let Some(json) = outbound.pop() {
208                match socket.send(Message::Text(json.into())) {
209                    Ok(()) => {
210                        had_work = true;
211                    }
212                    Err(tungstenite::Error::Io(ref e))
213                        if e.kind() == std::io::ErrorKind::WouldBlock =>
214                    {
215                        // Write buffer full — frame is in tungstenite's internal buffer,
216                        // will flush on next successful I/O. Stop draining to avoid
217                        // growing the buffer unboundedly.
218                        had_work = true;
219                        break;
220                    }
221                    Err(e) => return Err(e.into()),
222                }
223            }
224
225            // 3. Avoid burning CPU when idle
226            if !had_work {
227                std::thread::sleep(std::time::Duration::from_millis(1));
228            }
229        }
230
231        Ok(())
232    }
233
234    /// Get the relay URL
235    pub fn relay_url(&self) -> &str {
236        &self.relay_url
237    }
238}
239
240impl Drop for RelayConnection {
241    fn drop(&mut self) {
242        self.shutdown.store(true, Ordering::Relaxed);
243        if let Some(handle) = self.thread_handle.take() {
244            let _ = handle.join();
245        }
246    }
247}
248
249/// Set non-blocking mode on the underlying TCP stream through tungstenite's layers.
250fn set_nonblocking(
251    socket: &WebSocket<MaybeTlsStream<TcpStream>>,
252    nonblocking: bool,
253) -> std::io::Result<()> {
254    match socket.get_ref() {
255        MaybeTlsStream::Plain(tcp) => tcp.set_nonblocking(nonblocking),
256        MaybeTlsStream::Rustls(tls) => tls.get_ref().set_nonblocking(nonblocking),
257        _ => Ok(()),
258    }
259}
260
261/// Consumer side of the pool - reads events from all relays in a single thread
262pub struct PoolConsumer {
263    consumer: Consumer<PoolMessage>,
264    dedup_cache: Cache,
265}
266
267impl PoolConsumer {
268    /// Create a new pool consumer with deduplication cache
269    pub fn new(consumer: Consumer<PoolMessage>, cache_size: usize) -> Self {
270        Self {
271            consumer,
272            dedup_cache: Cache::new(cache_size),
273        }
274    }
275
276    /// Receive the next message from any relay (non-blocking)
277    ///
278    /// Returns `Some(message)` if available and not a duplicate, `None` if ring buffer is empty
279    /// Automatically deduplicates NewNote events based on event ID
280    pub fn try_recv(&mut self) -> Option<PoolMessage> {
281        loop {
282            match self.consumer.pop()? {
283                PoolMessage::RelayEvent {
284                    relay_url,
285                    event: NostrRelayEvent::NewNote(tag, sub_id, note),
286                } => {
287                    // Check for duplicate event ID
288                    if let Some(ref event_id) = note.id {
289                        if self.dedup_cache.insert(event_id.clone()) {
290                            // New event, return it
291                            return Some(PoolMessage::RelayEvent {
292                                relay_url,
293                                event: NostrRelayEvent::NewNote(tag, sub_id, note),
294                            });
295                        }
296                        // Duplicate, continue to next message
297                        continue;
298                    }
299                    // No ID, pass through
300                    return Some(PoolMessage::RelayEvent {
301                        relay_url,
302                        event: NostrRelayEvent::NewNote(tag, sub_id, note),
303                    });
304                }
305                other => {
306                    // Pass through non-NewNote messages
307                    return Some(other);
308                }
309            }
310        }
311    }
312
313    /// Blocking receive - spins until a message is available
314    ///
315    /// This is the main event loop for the consumer thread
316    /// Automatically deduplicates NewNote events based on event ID
317    pub fn recv(&mut self) -> PoolMessage {
318        loop {
319            if let Some(msg) = self.try_recv() {
320                return msg;
321            }
322            std::hint::spin_loop();
323        }
324    }
325}
326
327/// The relay pool — manages multiple WebSocket connections with bidirectional messaging.
328///
329/// Inbound events flow through an MPSC ring buffer with deduplication.
330/// Outbound messages are broadcast to all relay threads via a lock-free broadcast ring.
331pub struct RelayPool {
332    connections: Vec<RelayConnection>,
333    consumer: PoolConsumer,
334    sender: PoolSender,
335    broadcast_consumer: broadcast::Consumer<String>,
336    mpsc_producer: Producer<PoolMessage>,
337}
338
339impl RelayPool {
340    /// Create a new relay pool with bidirectional messaging.
341    ///
342    /// # Arguments
343    /// * `ring_capacity` - MPSC ring buffer size for inbound event throughput
344    /// * `cache_size` - Deduplication cache size (e.g. 10,000)
345    /// * `broadcast_capacity` - Broadcast ring buffer size for outbound messages
346    /// * `max_relays` - Maximum number of relay connections (broadcast consumer slots)
347    pub fn new(
348        ring_capacity: usize,
349        cache_size: usize,
350        broadcast_capacity: usize,
351        max_relays: usize,
352    ) -> Self {
353        let (mpsc_producer, mpsc_consumer) =
354            RingBuffer::new(Capacity::at_least(ring_capacity)).split();
355        // +1 because split() creates the template consumer that we clone per relay
356        let (bc_producer, bc_consumer) =
357            broadcast::RingBuffer::new(Capacity::at_least(broadcast_capacity), max_relays + 1)
358                .split();
359        Self {
360            connections: Vec::new(),
361            consumer: PoolConsumer::new(mpsc_consumer, cache_size),
362            sender: PoolSender {
363                producer: bc_producer,
364            },
365            broadcast_consumer: bc_consumer,
366            mpsc_producer,
367        }
368    }
369
370    /// Add a relay connection to the pool.
371    ///
372    /// Spawns a new thread that connects to the relay, reads inbound events,
373    /// and sends outbound messages from the broadcast ring.
374    ///
375    /// Automatically cleans up dead connections first to free broadcast slots.
376    pub fn add_relay(&mut self, relay_url: String) {
377        self.cleanup();
378        let shutdown = Arc::new(AtomicBool::new(false));
379        let bc_consumer = self.broadcast_consumer.clone();
380        let mpsc_producer = self.mpsc_producer.clone();
381        let connection =
382            RelayConnection::spawn(relay_url, mpsc_producer, bc_consumer, shutdown);
383        self.connections.push(connection);
384    }
385
386    /// Remove a relay from the pool by URL.
387    ///
388    /// Signals the relay thread to shut down and blocks until it exits (~1-2ms).
389    /// The broadcast consumer slot is freed immediately.
390    ///
391    /// Returns `true` if the relay was found and removed.
392    pub fn remove_relay(&mut self, relay_url: &str) -> bool {
393        if let Some(pos) = self
394            .connections
395            .iter()
396            .position(|c| c.relay_url == relay_url)
397        {
398            let mut conn = self.connections.swap_remove(pos);
399            conn.shutdown_and_join();
400            true
401        } else {
402            false
403        }
404    }
405
406    /// Remove dead connections from the pool.
407    ///
408    /// Joins finished threads and frees their broadcast consumer slots.
409    /// Called automatically by [`add_relay`], but can be called explicitly
410    /// to update [`connection_count`].
411    pub fn cleanup(&mut self) {
412        self.connections.retain_mut(|conn| {
413            if conn.is_finished() {
414                if let Some(handle) = conn.thread_handle.take() {
415                    let _ = handle.join();
416                }
417                false
418            } else {
419                true
420            }
421        });
422    }
423
424    /// Get a cloneable sender handle for broadcasting to all relays.
425    ///
426    /// Multiple threads can hold a `PoolSender` and send concurrently.
427    pub fn sender(&self) -> PoolSender {
428        self.sender.clone()
429    }
430
431    /// Receive the next event from any relay (blocking)
432    pub fn recv(&mut self) -> PoolMessage {
433        self.consumer.recv()
434    }
435
436    /// Receive the next event from any relay (non-blocking)
437    pub fn try_recv(&mut self) -> Option<PoolMessage> {
438        self.consumer.try_recv()
439    }
440
441    /// Get the total number of connections (including dead ones not yet cleaned up).
442    pub fn connection_count(&self) -> usize {
443        self.connections.len()
444    }
445
446    /// Get the number of connections whose threads are still running.
447    pub fn active_connection_count(&self) -> usize {
448        self.connections.iter().filter(|c| !c.is_finished()).count()
449    }
450
451    /// Get the relay URLs of all connections in the pool.
452    pub fn relay_urls(&self) -> Vec<&str> {
453        self.connections.iter().map(|c| c.relay_url.as_str()).collect()
454    }
455
456    /// Get the relay URLs of only active (thread still running) connections.
457    pub fn active_relay_urls(&self) -> Vec<&str> {
458        self.connections
459            .iter()
460            .filter(|c| !c.is_finished())
461            .map(|c| c.relay_url.as_str())
462            .collect()
463    }
464}
465
466impl Drop for RelayPool {
467    fn drop(&mut self) {
468        // Phase 1: Signal all threads to shut down (they exit in parallel)
469        for conn in &self.connections {
470            conn.request_shutdown();
471        }
472        // Phase 2: Join all threads
473        for conn in &mut self.connections {
474            if let Some(handle) = conn.thread_handle.take() {
475                let _ = handle.join();
476            }
477        }
478    }
479}
480
481/// Helper function to create a new pool and producer for spawning connections
482///
483/// # Arguments
484/// * `ring_capacity` - Ring buffer size for event throughput
485/// * `cache_size` - Deduplication cache size
486pub fn create_pool(ring_capacity: usize, cache_size: usize) -> (PoolConsumer, Producer<PoolMessage>) {
487    let (producer, consumer) = RingBuffer::new(Capacity::at_least(ring_capacity)).split();
488    (PoolConsumer::new(consumer, cache_size), producer)
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_pool_creation() {
497        let pool = RelayPool::new(1024, 10_000, 64, 8);
498        assert_eq!(pool.connection_count(), 0);
499    }
500
501    #[test]
502    fn test_create_pool_helper() {
503        let (_consumer, _producer) = create_pool(1024, 10_000);
504        // Just testing that it compiles and runs
505    }
506
507    #[test]
508    fn test_pool_sender_clone_and_broadcast() {
509        let (bc_producer, mut c1) =
510            broadcast::RingBuffer::<String>::new(Capacity::exact(16), 4).split();
511        let mut c2 = c1.clone();
512
513        let sender = PoolSender {
514            producer: bc_producer,
515        };
516        let sender2 = sender.clone();
517
518        // Send from two different senders
519        sender.send_raw("hello".to_string()).unwrap();
520        sender2.send_raw("world".to_string()).unwrap();
521
522        // Both consumers see both messages
523        assert_eq!(c1.pop(), Some("hello".to_string()));
524        assert_eq!(c1.pop(), Some("world".to_string()));
525        assert_eq!(c2.pop(), Some("hello".to_string()));
526        assert_eq!(c2.pop(), Some("world".to_string()));
527    }
528
529    #[test]
530    fn test_pool_sender_via_relay_pool() {
531        let pool = RelayPool::new(1024, 10_000, 64, 8);
532        let sender = pool.sender();
533        let sender2 = pool.sender();
534
535        // Both senders are valid clones
536        assert!(!sender.producer.is_full());
537        assert!(!sender2.producer.is_full());
538    }
539
540    #[test]
541    fn test_shutdown_flag_stops_thread() {
542        let shutdown = Arc::new(AtomicBool::new(false));
543        let shutdown_clone = Arc::clone(&shutdown);
544        let handle = std::thread::spawn(move || {
545            while !shutdown_clone.load(Ordering::Relaxed) {
546                std::thread::sleep(std::time::Duration::from_millis(1));
547            }
548        });
549        assert!(!handle.is_finished());
550        shutdown.store(true, Ordering::Relaxed);
551        handle.join().unwrap();
552    }
553
554    #[test]
555    fn test_cleanup_removes_dead_connections() {
556        // Connect to an invalid address — thread will fail fast
557        let mut pool = RelayPool::new(1024, 10_000, 64, 8);
558        pool.add_relay("ws://127.0.0.1:1".to_string());
559        assert_eq!(pool.connection_count(), 1);
560
561        // Wait for the thread to fail and exit
562        std::thread::sleep(std::time::Duration::from_millis(500));
563
564        pool.cleanup();
565        assert_eq!(pool.connection_count(), 0);
566    }
567
568    #[test]
569    fn test_remove_relay() {
570        let mut pool = RelayPool::new(1024, 10_000, 64, 8);
571        pool.add_relay("ws://127.0.0.1:1".to_string());
572        assert_eq!(pool.connection_count(), 1);
573
574        assert!(pool.remove_relay("ws://127.0.0.1:1"));
575        assert_eq!(pool.connection_count(), 0);
576
577        // Removing a non-existent relay returns false
578        assert!(!pool.remove_relay("ws://127.0.0.1:2"));
579    }
580
581    #[test]
582    fn test_active_connection_count() {
583        let mut pool = RelayPool::new(1024, 10_000, 64, 8);
584        // Invalid address — thread will die quickly
585        pool.add_relay("ws://127.0.0.1:1".to_string());
586        pool.add_relay("ws://127.0.0.1:2".to_string());
587        assert_eq!(pool.connection_count(), 2);
588
589        // Wait for threads to fail
590        std::thread::sleep(std::time::Duration::from_millis(500));
591
592        // connection_count still 2 (stale), active_connection_count is 0
593        assert_eq!(pool.connection_count(), 2);
594        assert_eq!(pool.active_connection_count(), 0);
595
596        // cleanup brings connection_count in sync
597        pool.cleanup();
598        assert_eq!(pool.connection_count(), 0);
599    }
600
601    #[test]
602    fn test_relay_urls() {
603        let mut pool = RelayPool::new(1024, 10_000, 64, 8);
604        pool.add_relay("ws://127.0.0.1:1".to_string());
605        pool.add_relay("ws://127.0.0.1:2".to_string());
606
607        let urls = pool.relay_urls();
608        assert_eq!(urls.len(), 2);
609        assert!(urls.contains(&"ws://127.0.0.1:1"));
610        assert!(urls.contains(&"ws://127.0.0.1:2"));
611    }
612
613    #[test]
614    fn test_pool_drop_joins_threads() {
615        let mut pool = RelayPool::new(1024, 10_000, 64, 8);
616        pool.add_relay("ws://127.0.0.1:1".to_string());
617        pool.add_relay("ws://127.0.0.1:2".to_string());
618        // Drop should signal shutdown and join — no panic
619        drop(pool);
620    }
621
622    #[test]
623    fn test_add_after_remove_reuses_slots() {
624        // max_relays=2 means only 2 broadcast consumer slots available
625        let mut pool = RelayPool::new(1024, 10_000, 64, 2);
626        pool.add_relay("ws://127.0.0.1:1".to_string());
627        pool.add_relay("ws://127.0.0.1:2".to_string());
628
629        // Remove one — frees a broadcast consumer slot via blocking join
630        pool.remove_relay("ws://127.0.0.1:1");
631        assert_eq!(pool.connection_count(), 1);
632
633        // Add a new relay — should reuse the freed slot without panic
634        pool.add_relay("ws://127.0.0.1:3".to_string());
635        assert!(pool.relay_urls().contains(&"ws://127.0.0.1:3"));
636    }
637}