Skip to main content

nostro2_relay/
relay.rs

1use futures_util::{SinkExt, StreamExt};
2use std::time::Duration;
3
4/// Configuration for automatic reconnection with exponential backoff
5///
6/// When a relay connection drops, it will automatically attempt to reconnect
7/// using an exponential backoff strategy.
8#[derive(Debug, Clone)]
9pub struct ReconnectConfig {
10    /// Maximum number of reconnection attempts (0 = infinite)
11    pub max_retries: u32,
12    /// Initial delay before first reconnection attempt
13    pub initial_delay: Duration,
14    /// Maximum delay between reconnection attempts
15    pub max_delay: Duration,
16    /// Multiplier for exponential backoff (e.g., 2.0 doubles the delay each time)
17    pub backoff_multiplier: f64,
18}
19
20impl Default for ReconnectConfig {
21    fn default() -> Self {
22        Self {
23            max_retries: 0, // Infinite retries by default
24            initial_delay: Duration::from_secs(1),
25            max_delay: Duration::from_secs(60),
26            backoff_multiplier: 2.0,
27        }
28    }
29}
30
31impl ReconnectConfig {
32    /// Create a config with no automatic reconnection
33    #[must_use]
34    pub const fn disabled() -> Self {
35        Self {
36            max_retries: 0,
37            initial_delay: Duration::from_secs(0),
38            max_delay: Duration::from_secs(0),
39            backoff_multiplier: 0.0,
40        }
41    }
42
43    /// Check if reconnection is enabled
44    #[must_use]
45    pub const fn is_enabled(&self) -> bool {
46        self.max_delay.as_secs() > 0
47    }
48
49    /// Calculate the next delay using exponential backoff
50    #[must_use]
51    pub fn next_delay(&self, attempt: u32) -> Duration {
52        if !self.is_enabled() {
53            return Duration::from_secs(0);
54        }
55
56        let delay_secs = self.initial_delay.as_secs_f64()
57            * self.backoff_multiplier.powf(f64::from(attempt));
58        Duration::from_secs_f64(delay_secs.min(self.max_delay.as_secs_f64()))
59    }
60}
61
62#[derive(Clone)]
63pub struct NostrRelay {
64    /// Channel for receiving raw messages from the reader task
65    receiver: std::sync::Arc<
66        tokio::sync::RwLock<
67            tokio::sync::mpsc::UnboundedReceiver<tokio_tungstenite::tungstenite::Utf8Bytes>,
68        >,
69    >,
70    /// Channel for sending messages to the writer task
71    sender: tokio::sync::mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Utf8Bytes>,
72    /// URL of the relay for reconnection
73    #[allow(dead_code)]
74    url: std::sync::Arc<String>,
75    /// Reconnection configuration
76    #[allow(dead_code)]
77    reconnect_config: std::sync::Arc<ReconnectConfig>,
78}
79impl NostrRelay {
80    /// Creates a new relay connection with default reconnection settings.
81    ///
82    /// By default, the relay will automatically reconnect with exponential backoff
83    /// if the connection drops.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if the initial connection fails.
88    pub async fn new(url: &str) -> Result<Self, crate::errors::NostrRelayError> {
89        Self::with_reconnect(url, ReconnectConfig::default()).await
90    }
91
92    /// Creates a new relay connection with custom reconnection configuration.
93    ///
94    /// # Examples
95    ///
96    /// ```no_run
97    /// use nostro2_relay::{NostrRelay, ReconnectConfig};
98    /// use std::time::Duration;
99    ///
100    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
101    /// // Custom reconnection with max 10 retries
102    /// let config = ReconnectConfig {
103    ///     max_retries: 10,
104    ///     initial_delay: Duration::from_secs(1),
105    ///     max_delay: Duration::from_secs(30),
106    ///     backoff_multiplier: 2.0,
107    /// };
108    /// let relay = NostrRelay::with_reconnect("wss://relay.example.com", config).await?;
109    /// # Ok(())
110    /// # }
111    /// ```
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if the initial connection fails.
116    pub async fn with_reconnect(
117        url: &str,
118        reconnect_config: ReconnectConfig,
119    ) -> Result<Self, crate::errors::NostrRelayError> {
120        // Create persistent channels for communication
121        let (incoming_tx, incoming_rx) =
122            tokio::sync::mpsc::unbounded_channel::<tokio_tungstenite::tungstenite::Utf8Bytes>();
123        let (outgoing_tx, outgoing_rx) =
124            tokio::sync::mpsc::unbounded_channel::<tokio_tungstenite::tungstenite::Utf8Bytes>();
125
126        let url = url.to_string();
127        let url_arc = std::sync::Arc::new(url.clone());
128        let reconnect_config_arc = std::sync::Arc::new(reconnect_config.clone());
129
130        // Try initial connection
131        let initial_connection = Self::connect(&url).await?;
132        let (sink, stream) = futures_util::StreamExt::split(initial_connection);
133
134        // Spawn connection manager task
135        tokio::spawn(Self::connection_manager(
136            url,
137            reconnect_config,
138            incoming_tx,
139            outgoing_rx,
140            sink,
141            stream,
142        ));
143
144        Ok(Self {
145            receiver: std::sync::Arc::new(tokio::sync::RwLock::new(incoming_rx)),
146            sender: outgoing_tx,
147            url: url_arc,
148            reconnect_config: reconnect_config_arc,
149        })
150    }
151
152    /// Establishes a WebSocket connection to the relay
153    async fn connect(
154        url: &str,
155    ) -> Result<
156        tokio_tungstenite::WebSocketStream<
157            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
158        >,
159        crate::errors::NostrRelayError,
160    > {
161        let (websocket, _response) = tokio_tungstenite::connect_async_with_config(
162            url,
163            Some(
164                tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
165                    .max_write_buffer_size(5 << 20) // 5 MiB
166                    .max_frame_size(Some(256 << 10)) // 256 KiB
167                    .max_message_size(Some(5 << 20)) // 5 MiB
168                    .read_buffer_size(4 << 20) // 4 MiB
169                    .write_buffer_size(4 << 20), // 4 MiB
170            ),
171            false,
172        )
173        .await?;
174        Ok(websocket)
175    }
176
177    /// Manages the connection lifecycle with automatic reconnection
178    #[allow(clippy::too_many_lines)]
179    async fn connection_manager(
180        url: String,
181        config: ReconnectConfig,
182        incoming_tx: tokio::sync::mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Utf8Bytes>,
183        mut outgoing_rx: tokio::sync::mpsc::UnboundedReceiver<
184            tokio_tungstenite::tungstenite::Utf8Bytes,
185        >,
186        initial_sink: futures_util::stream::SplitSink<
187            tokio_tungstenite::WebSocketStream<
188                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
189            >,
190            tokio_tungstenite::tungstenite::Message,
191        >,
192        initial_stream: futures_util::stream::SplitStream<
193            tokio_tungstenite::WebSocketStream<
194                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
195            >,
196        >,
197    ) {
198        let mut attempt = 0;
199        let mut current_sink = initial_sink;
200        let mut current_stream = initial_stream;
201
202        loop {
203            // Run the connection until it fails
204            let _result = Self::run_connection(
205                &incoming_tx,
206                &mut outgoing_rx,
207                &mut current_sink,
208                &mut current_stream,
209            )
210            .await;
211
212            // Check if we should reconnect
213            if !config.is_enabled() {
214                // Reconnection disabled, exit
215                break;
216            }
217
218            if config.max_retries > 0 && attempt >= config.max_retries {
219                // Max retries reached
220                eprintln!("Max reconnection attempts ({}) reached for {}", config.max_retries, url);
221                break;
222            }
223
224            // Calculate backoff delay
225            let delay = config.next_delay(attempt);
226            if delay.as_secs() == 0 {
227                break;
228            }
229
230            eprintln!(
231                "Connection to {} lost, reconnecting in {:?} (attempt {})",
232                url,
233                delay,
234                attempt + 1
235            );
236            tokio::time::sleep(delay).await;
237
238            // Attempt to reconnect
239            match Self::connect(&url).await {
240                Ok(websocket) => {
241                    eprintln!("Successfully reconnected to {url}");
242                    let (sink, stream) = futures_util::StreamExt::split(websocket);
243                    current_sink = sink;
244                    current_stream = stream;
245                    attempt = 0; // Reset attempt counter on successful reconnection
246                }
247                Err(e) => {
248                    eprintln!("Failed to reconnect to {url}: {e}");
249                    attempt += 1;
250                }
251            }
252        }
253    }
254
255    /// Runs the connection, handling read/write operations
256    async fn run_connection(
257        incoming_tx: &tokio::sync::mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Utf8Bytes>,
258        outgoing_rx: &mut tokio::sync::mpsc::UnboundedReceiver<
259            tokio_tungstenite::tungstenite::Utf8Bytes,
260        >,
261        sink: &mut futures_util::stream::SplitSink<
262            tokio_tungstenite::WebSocketStream<
263                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
264            >,
265            tokio_tungstenite::tungstenite::Message,
266        >,
267        stream: &mut futures_util::stream::SplitStream<
268            tokio_tungstenite::WebSocketStream<
269                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
270            >,
271        >,
272    ) -> Result<(), ()> {
273        loop {
274            tokio::select! {
275                // Handle incoming messages from WebSocket
276                Some(msg) = stream.next() => {
277                    match msg {
278                        Ok(tokio_tungstenite::tungstenite::Message::Text(text)) => {
279                            if incoming_tx.send(text).is_err() {
280                                // Receiver dropped, exit
281                                return Err(());
282                            }
283                        }
284                        Ok(tokio_tungstenite::tungstenite::Message::Close(_)) | Err(_) => {
285                            // Connection closed or error
286                            return Err(());
287                        }
288                        _ => {
289                            // Ignore other message types (binary, ping, pong)
290                        }
291                    }
292                }
293                // Handle outgoing messages to WebSocket
294                Some(msg) = outgoing_rx.recv() => {
295                    if sink
296                        .send(tokio_tungstenite::tungstenite::Message::Text(msg))
297                        .await
298                        .is_err()
299                    {
300                        // Error writing to sink
301                        return Err(());
302                    }
303                }
304                else => {
305                    // Both channels closed
306                    let _ = sink.flush().await;
307                    return Err(());
308                }
309            }
310        }
311    }
312    /// Sends a message to the relay.
313    /// Message must implement `Into<NostrClientEvent>`.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the message fails to send.
318    pub fn send<T>(&self, msg: T) -> Result<(), crate::errors::NostrRelayError>
319    where
320        T: Into<nostro2::NostrClientEvent> + Send + Sync,
321    {
322        let msg: nostro2::NostrClientEvent = msg.into();
323        // Pre-serialize JSON before sending to writer task
324        let msg_str = serde_json::to_string(&msg).map_err(crate::errors::NostrRelayError::Serde)?;
325        self.sender
326            .send(msg_str.into())
327            .map_err(|_| crate::errors::NostrRelayError::SendError)?;
328        Ok(())
329    }
330    /// Sends multiple messages to the relay.
331    /// Messages are pre-serialized and sent through the writer task.
332    /// Message must implement `Into<NostrClientEvent>`.
333    ///
334    /// # Errors
335    ///
336    /// Returns an error if any message fails to send.
337    pub async fn send_all<St, T>(
338        &self,
339        mut stream: St,
340    ) -> Result<(), crate::errors::NostrRelayError>
341    where
342        T: Into<nostro2::NostrClientEvent> + Send + Sync + std::fmt::Debug,
343        St: futures_util::Stream<Item = T> + Unpin + Sized,
344    {
345        while let Some(msg) = stream.next().await {
346            let msg: nostro2::NostrClientEvent = msg.into();
347            let msg_str =
348                serde_json::to_string(&msg).map_err(crate::errors::NostrRelayError::Serde)?;
349            self.sender
350                .send(msg_str.into())
351                .map_err(|_| crate::errors::NostrRelayError::SendError)?;
352        }
353        Ok(())
354    }
355
356    /// Receives a message from the relay.
357    /// Pulls raw text from the reader task's channel and parses it.
358    ///
359    /// # Errors
360    ///
361    /// Returns None if the stream is closed or the message fails to parse.
362    pub async fn recv(&self) -> Option<nostro2::NostrRelayEvent> {
363        let msg_text = self.receiver.write().await.recv().await?;
364        // Parse raw string to NostrRelayEvent
365        msg_text
366            .parse()
367            .ok()
368            .or(Some(nostro2::NostrRelayEvent::Ping))
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    #[tokio::test]
376    async fn test_relay() {
377        let time = std::time::Instant::now();
378        println!("Connecting to relay...");
379        let relay = NostrRelay::new("wss://relay.illuminodes.com")
380            .await
381            .unwrap();
382        let subscription = nostro2::NostrSubscription {
383            kinds: vec![20001].into(),
384            limit: 5000.into(),
385            ..Default::default()
386        };
387        relay.send(subscription).unwrap();
388        println!("Connected in {:?}", time.elapsed());
389        while let Some(msg) = relay.recv().await {
390            println!("{msg:?}",);
391            if let nostro2::NostrRelayEvent::EndOfSubscription(_, _) = msg {
392                break;
393            }
394        }
395        println!("Done in {:?}", time.elapsed());
396    }
397}