1use futures_util::{SinkExt, StreamExt};
2use std::time::Duration;
3
4#[derive(Debug, Clone)]
9pub struct ReconnectConfig {
10 pub max_retries: u32,
12 pub initial_delay: Duration,
14 pub max_delay: Duration,
16 pub backoff_multiplier: f64,
18}
19
20impl Default for ReconnectConfig {
21 fn default() -> Self {
22 Self {
23 max_retries: 0, initial_delay: Duration::from_secs(1),
25 max_delay: Duration::from_secs(60),
26 backoff_multiplier: 2.0,
27 }
28 }
29}
30
31impl ReconnectConfig {
32 #[must_use]
34 pub const fn disabled() -> Self {
35 Self {
36 max_retries: 0,
37 initial_delay: Duration::from_secs(0),
38 max_delay: Duration::from_secs(0),
39 backoff_multiplier: 0.0,
40 }
41 }
42
43 #[must_use]
45 pub const fn is_enabled(&self) -> bool {
46 self.max_delay.as_secs() > 0
47 }
48
49 #[must_use]
51 pub fn next_delay(&self, attempt: u32) -> Duration {
52 if !self.is_enabled() {
53 return Duration::from_secs(0);
54 }
55
56 let delay_secs = self.initial_delay.as_secs_f64()
57 * self.backoff_multiplier.powf(f64::from(attempt));
58 Duration::from_secs_f64(delay_secs.min(self.max_delay.as_secs_f64()))
59 }
60}
61
62#[derive(Clone)]
63pub struct NostrRelay {
64 receiver: std::sync::Arc<
66 tokio::sync::RwLock<
67 tokio::sync::mpsc::UnboundedReceiver<tokio_tungstenite::tungstenite::Utf8Bytes>,
68 >,
69 >,
70 sender: tokio::sync::mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Utf8Bytes>,
72 #[allow(dead_code)]
74 url: std::sync::Arc<String>,
75 #[allow(dead_code)]
77 reconnect_config: std::sync::Arc<ReconnectConfig>,
78}
79impl NostrRelay {
80 pub async fn new(url: &str) -> Result<Self, crate::errors::NostrRelayError> {
89 Self::with_reconnect(url, ReconnectConfig::default()).await
90 }
91
92 pub async fn with_reconnect(
117 url: &str,
118 reconnect_config: ReconnectConfig,
119 ) -> Result<Self, crate::errors::NostrRelayError> {
120 let (incoming_tx, incoming_rx) =
122 tokio::sync::mpsc::unbounded_channel::<tokio_tungstenite::tungstenite::Utf8Bytes>();
123 let (outgoing_tx, outgoing_rx) =
124 tokio::sync::mpsc::unbounded_channel::<tokio_tungstenite::tungstenite::Utf8Bytes>();
125
126 let url = url.to_string();
127 let url_arc = std::sync::Arc::new(url.clone());
128 let reconnect_config_arc = std::sync::Arc::new(reconnect_config.clone());
129
130 let initial_connection = Self::connect(&url).await?;
132 let (sink, stream) = futures_util::StreamExt::split(initial_connection);
133
134 tokio::spawn(Self::connection_manager(
136 url,
137 reconnect_config,
138 incoming_tx,
139 outgoing_rx,
140 sink,
141 stream,
142 ));
143
144 Ok(Self {
145 receiver: std::sync::Arc::new(tokio::sync::RwLock::new(incoming_rx)),
146 sender: outgoing_tx,
147 url: url_arc,
148 reconnect_config: reconnect_config_arc,
149 })
150 }
151
152 async fn connect(
154 url: &str,
155 ) -> Result<
156 tokio_tungstenite::WebSocketStream<
157 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
158 >,
159 crate::errors::NostrRelayError,
160 > {
161 let (websocket, _response) = tokio_tungstenite::connect_async_with_config(
162 url,
163 Some(
164 tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
165 .max_write_buffer_size(5 << 20) .max_frame_size(Some(256 << 10)) .max_message_size(Some(5 << 20)) .read_buffer_size(4 << 20) .write_buffer_size(4 << 20), ),
171 false,
172 )
173 .await?;
174 Ok(websocket)
175 }
176
177 #[allow(clippy::too_many_lines)]
179 async fn connection_manager(
180 url: String,
181 config: ReconnectConfig,
182 incoming_tx: tokio::sync::mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Utf8Bytes>,
183 mut outgoing_rx: tokio::sync::mpsc::UnboundedReceiver<
184 tokio_tungstenite::tungstenite::Utf8Bytes,
185 >,
186 initial_sink: futures_util::stream::SplitSink<
187 tokio_tungstenite::WebSocketStream<
188 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
189 >,
190 tokio_tungstenite::tungstenite::Message,
191 >,
192 initial_stream: futures_util::stream::SplitStream<
193 tokio_tungstenite::WebSocketStream<
194 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
195 >,
196 >,
197 ) {
198 let mut attempt = 0;
199 let mut current_sink = initial_sink;
200 let mut current_stream = initial_stream;
201
202 loop {
203 let _result = Self::run_connection(
205 &incoming_tx,
206 &mut outgoing_rx,
207 &mut current_sink,
208 &mut current_stream,
209 )
210 .await;
211
212 if !config.is_enabled() {
214 break;
216 }
217
218 if config.max_retries > 0 && attempt >= config.max_retries {
219 eprintln!("Max reconnection attempts ({}) reached for {}", config.max_retries, url);
221 break;
222 }
223
224 let delay = config.next_delay(attempt);
226 if delay.as_secs() == 0 {
227 break;
228 }
229
230 eprintln!(
231 "Connection to {} lost, reconnecting in {:?} (attempt {})",
232 url,
233 delay,
234 attempt + 1
235 );
236 tokio::time::sleep(delay).await;
237
238 match Self::connect(&url).await {
240 Ok(websocket) => {
241 eprintln!("Successfully reconnected to {url}");
242 let (sink, stream) = futures_util::StreamExt::split(websocket);
243 current_sink = sink;
244 current_stream = stream;
245 attempt = 0; }
247 Err(e) => {
248 eprintln!("Failed to reconnect to {url}: {e}");
249 attempt += 1;
250 }
251 }
252 }
253 }
254
255 async fn run_connection(
257 incoming_tx: &tokio::sync::mpsc::UnboundedSender<tokio_tungstenite::tungstenite::Utf8Bytes>,
258 outgoing_rx: &mut tokio::sync::mpsc::UnboundedReceiver<
259 tokio_tungstenite::tungstenite::Utf8Bytes,
260 >,
261 sink: &mut futures_util::stream::SplitSink<
262 tokio_tungstenite::WebSocketStream<
263 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
264 >,
265 tokio_tungstenite::tungstenite::Message,
266 >,
267 stream: &mut futures_util::stream::SplitStream<
268 tokio_tungstenite::WebSocketStream<
269 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
270 >,
271 >,
272 ) -> Result<(), ()> {
273 loop {
274 tokio::select! {
275 Some(msg) = stream.next() => {
277 match msg {
278 Ok(tokio_tungstenite::tungstenite::Message::Text(text)) => {
279 if incoming_tx.send(text).is_err() {
280 return Err(());
282 }
283 }
284 Ok(tokio_tungstenite::tungstenite::Message::Close(_)) | Err(_) => {
285 return Err(());
287 }
288 _ => {
289 }
291 }
292 }
293 Some(msg) = outgoing_rx.recv() => {
295 if sink
296 .send(tokio_tungstenite::tungstenite::Message::Text(msg))
297 .await
298 .is_err()
299 {
300 return Err(());
302 }
303 }
304 else => {
305 let _ = sink.flush().await;
307 return Err(());
308 }
309 }
310 }
311 }
312 pub fn send<T>(&self, msg: T) -> Result<(), crate::errors::NostrRelayError>
319 where
320 T: Into<nostro2::NostrClientEvent> + Send + Sync,
321 {
322 let msg: nostro2::NostrClientEvent = msg.into();
323 let msg_str = serde_json::to_string(&msg).map_err(crate::errors::NostrRelayError::Serde)?;
325 self.sender
326 .send(msg_str.into())
327 .map_err(|_| crate::errors::NostrRelayError::SendError)?;
328 Ok(())
329 }
330 pub async fn send_all<St, T>(
338 &self,
339 mut stream: St,
340 ) -> Result<(), crate::errors::NostrRelayError>
341 where
342 T: Into<nostro2::NostrClientEvent> + Send + Sync + std::fmt::Debug,
343 St: futures_util::Stream<Item = T> + Unpin + Sized,
344 {
345 while let Some(msg) = stream.next().await {
346 let msg: nostro2::NostrClientEvent = msg.into();
347 let msg_str =
348 serde_json::to_string(&msg).map_err(crate::errors::NostrRelayError::Serde)?;
349 self.sender
350 .send(msg_str.into())
351 .map_err(|_| crate::errors::NostrRelayError::SendError)?;
352 }
353 Ok(())
354 }
355
356 pub async fn recv(&self) -> Option<nostro2::NostrRelayEvent> {
363 let msg_text = self.receiver.write().await.recv().await?;
364 msg_text
366 .parse()
367 .ok()
368 .or(Some(nostro2::NostrRelayEvent::Ping))
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 #[tokio::test]
376 async fn test_relay() {
377 let time = std::time::Instant::now();
378 println!("Connecting to relay...");
379 let relay = NostrRelay::new("wss://relay.illuminodes.com")
380 .await
381 .unwrap();
382 let subscription = nostro2::NostrSubscription {
383 kinds: vec![20001].into(),
384 limit: 5000.into(),
385 ..Default::default()
386 };
387 relay.send(subscription).unwrap();
388 println!("Connected in {:?}", time.elapsed());
389 while let Some(msg) = relay.recv().await {
390 println!("{msg:?}",);
391 if let nostro2::NostrRelayEvent::EndOfSubscription(_, _) = msg {
392 break;
393 }
394 }
395 println!("Done in {:?}", time.elapsed());
396 }
397}