1use nostro2::NostrRelayEvent;
2use nostro2_cache::Cache;
3use quetzalcoatl::broadcast;
4use quetzalcoatl::capacity::Capacity;
5use quetzalcoatl::mpsc::{Consumer, Producer, RingBuffer};
6use std::net::TcpStream;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use tungstenite::stream::MaybeTlsStream;
10use tungstenite::{connect, Message, WebSocket};
11
12#[derive(Debug, Clone)]
14pub enum PoolMessage {
15 RelayEvent {
17 relay_url: String,
19 event: NostrRelayEvent,
21 },
22 ConnectionClosed {
24 relay_url: String,
25 error: Option<String>,
26 },
27}
28
29#[derive(Clone)]
34pub struct PoolSender {
35 producer: broadcast::Producer<String>,
36}
37
38impl PoolSender {
39 pub fn send<T: Into<nostro2::NostrClientEvent>>(&self, msg: T) -> Result<(), String> {
44 let client_event: nostro2::NostrClientEvent = msg.into();
45 let json = serde_json::to_string(&client_event).map_err(|e| e.to_string())?;
46 self.producer.push(json)
47 }
48
49 pub fn send_raw(&self, json: String) -> Result<(), String> {
53 self.producer.push(json)
54 }
55}
56
57pub struct RelayConnection {
62 relay_url: String,
63 thread_handle: Option<std::thread::JoinHandle<()>>,
64 shutdown: Arc<AtomicBool>,
65}
66
67impl RelayConnection {
68 pub fn spawn(
73 relay_url: String,
74 mut producer: Producer<PoolMessage>,
75 outbound: broadcast::Consumer<String>,
76 shutdown: Arc<AtomicBool>,
77 ) -> Self {
78 let url = relay_url.clone();
79 let shutdown_clone = Arc::clone(&shutdown);
80 let thread_handle = std::thread::spawn(move || {
81 match Self::run_connection(&url, &mut producer, outbound, &shutdown_clone) {
82 Ok(()) => {
83 let _ = producer.push(PoolMessage::ConnectionClosed {
84 relay_url: url.clone(),
85 error: None,
86 });
87 }
88 Err(e) => {
89 let _ = producer.push(PoolMessage::ConnectionClosed {
90 relay_url: url.clone(),
91 error: Some(e.to_string()),
92 });
93 }
94 }
95 });
96
97 Self {
98 relay_url,
99 thread_handle: Some(thread_handle),
100 shutdown,
101 }
102 }
103
104 pub fn is_finished(&self) -> bool {
106 self.thread_handle
107 .as_ref()
108 .is_some_and(|h| h.is_finished())
109 }
110
111 pub fn request_shutdown(&self) {
116 self.shutdown.store(true, Ordering::Relaxed);
117 }
118
119 fn shutdown_and_join(&mut self) {
121 self.shutdown.store(true, Ordering::Relaxed);
122 if let Some(handle) = self.thread_handle.take() {
123 let _ = handle.join();
124 }
125 }
126
127 fn run_connection(
134 url: &str,
135 producer: &mut Producer<PoolMessage>,
136 mut outbound: broadcast::Consumer<String>,
137 shutdown: &AtomicBool,
138 ) -> Result<(), Box<dyn std::error::Error>> {
139 let _ = rustls::crypto::ring::default_provider().install_default();
141
142 let (mut socket, _response) = connect(url)?;
143
144 let subscription = nostro2::NostrSubscription {
146 kinds: vec![1].into(),
147 limit: Some(1000),
148 ..Default::default()
149 };
150
151 let client_event: nostro2::NostrClientEvent = subscription.into();
153 let subscription_json = serde_json::to_string(&client_event)?;
154 socket.send(Message::Text(subscription_json.into()))?;
155
156 set_nonblocking(&socket, true)?;
158
159 loop {
160 if shutdown.load(Ordering::Relaxed) {
162 let _ = socket.send(Message::Close(None));
163 break;
164 }
165
166 let mut had_work = false;
167
168 match socket.read() {
170 Ok(Message::Text(text)) => {
171 if let Ok(event) = text.parse::<NostrRelayEvent>() {
172 let mut pool_msg = PoolMessage::RelayEvent {
173 relay_url: url.to_string(),
174 event,
175 };
176 loop {
177 match producer.push(pool_msg) {
178 Ok(()) => break,
179 Err(returned) => {
180 pool_msg = returned;
181 std::hint::spin_loop();
182 }
183 }
184 }
185 }
186 had_work = true;
187 }
188 Ok(Message::Close(_)) => break,
189 Ok(Message::Ping(data)) => {
190 let _ = socket.send(Message::Pong(data));
193 had_work = true;
194 }
195 Ok(_) => {
196 had_work = true;
197 }
198 Err(tungstenite::Error::Io(ref e))
199 if e.kind() == std::io::ErrorKind::WouldBlock =>
200 {
201 }
203 Err(e) => return Err(e.into()),
204 }
205
206 while let Some(json) = outbound.pop() {
208 match socket.send(Message::Text(json.into())) {
209 Ok(()) => {
210 had_work = true;
211 }
212 Err(tungstenite::Error::Io(ref e))
213 if e.kind() == std::io::ErrorKind::WouldBlock =>
214 {
215 had_work = true;
219 break;
220 }
221 Err(e) => return Err(e.into()),
222 }
223 }
224
225 if !had_work {
227 std::thread::sleep(std::time::Duration::from_millis(1));
228 }
229 }
230
231 Ok(())
232 }
233
234 pub fn relay_url(&self) -> &str {
236 &self.relay_url
237 }
238}
239
240impl Drop for RelayConnection {
241 fn drop(&mut self) {
242 self.shutdown.store(true, Ordering::Relaxed);
243 if let Some(handle) = self.thread_handle.take() {
244 let _ = handle.join();
245 }
246 }
247}
248
249fn set_nonblocking(
251 socket: &WebSocket<MaybeTlsStream<TcpStream>>,
252 nonblocking: bool,
253) -> std::io::Result<()> {
254 match socket.get_ref() {
255 MaybeTlsStream::Plain(tcp) => tcp.set_nonblocking(nonblocking),
256 MaybeTlsStream::Rustls(tls) => tls.get_ref().set_nonblocking(nonblocking),
257 _ => Ok(()),
258 }
259}
260
261pub struct PoolConsumer {
263 consumer: Consumer<PoolMessage>,
264 dedup_cache: Cache,
265}
266
267impl PoolConsumer {
268 pub fn new(consumer: Consumer<PoolMessage>, cache_size: usize) -> Self {
270 Self {
271 consumer,
272 dedup_cache: Cache::new(cache_size),
273 }
274 }
275
276 pub fn try_recv(&mut self) -> Option<PoolMessage> {
281 loop {
282 match self.consumer.pop()? {
283 PoolMessage::RelayEvent {
284 relay_url,
285 event: NostrRelayEvent::NewNote(tag, sub_id, note),
286 } => {
287 if let Some(ref event_id) = note.id {
289 if self.dedup_cache.insert(event_id.clone()) {
290 return Some(PoolMessage::RelayEvent {
292 relay_url,
293 event: NostrRelayEvent::NewNote(tag, sub_id, note),
294 });
295 }
296 continue;
298 }
299 return Some(PoolMessage::RelayEvent {
301 relay_url,
302 event: NostrRelayEvent::NewNote(tag, sub_id, note),
303 });
304 }
305 other => {
306 return Some(other);
308 }
309 }
310 }
311 }
312
313 pub fn recv(&mut self) -> PoolMessage {
318 loop {
319 if let Some(msg) = self.try_recv() {
320 return msg;
321 }
322 std::hint::spin_loop();
323 }
324 }
325}
326
327pub struct RelayPool {
332 connections: Vec<RelayConnection>,
333 consumer: PoolConsumer,
334 sender: PoolSender,
335 broadcast_consumer: broadcast::Consumer<String>,
336 mpsc_producer: Producer<PoolMessage>,
337}
338
339impl RelayPool {
340 pub fn new(
348 ring_capacity: usize,
349 cache_size: usize,
350 broadcast_capacity: usize,
351 max_relays: usize,
352 ) -> Self {
353 let (mpsc_producer, mpsc_consumer) =
354 RingBuffer::new(Capacity::at_least(ring_capacity)).split();
355 let (bc_producer, bc_consumer) =
357 broadcast::RingBuffer::new(Capacity::at_least(broadcast_capacity), max_relays + 1)
358 .split();
359 Self {
360 connections: Vec::new(),
361 consumer: PoolConsumer::new(mpsc_consumer, cache_size),
362 sender: PoolSender {
363 producer: bc_producer,
364 },
365 broadcast_consumer: bc_consumer,
366 mpsc_producer,
367 }
368 }
369
370 pub fn add_relay(&mut self, relay_url: String) {
377 self.cleanup();
378 let shutdown = Arc::new(AtomicBool::new(false));
379 let bc_consumer = self.broadcast_consumer.clone();
380 let mpsc_producer = self.mpsc_producer.clone();
381 let connection =
382 RelayConnection::spawn(relay_url, mpsc_producer, bc_consumer, shutdown);
383 self.connections.push(connection);
384 }
385
386 pub fn remove_relay(&mut self, relay_url: &str) -> bool {
393 if let Some(pos) = self
394 .connections
395 .iter()
396 .position(|c| c.relay_url == relay_url)
397 {
398 let mut conn = self.connections.swap_remove(pos);
399 conn.shutdown_and_join();
400 true
401 } else {
402 false
403 }
404 }
405
406 pub fn cleanup(&mut self) {
412 self.connections.retain_mut(|conn| {
413 if conn.is_finished() {
414 if let Some(handle) = conn.thread_handle.take() {
415 let _ = handle.join();
416 }
417 false
418 } else {
419 true
420 }
421 });
422 }
423
424 pub fn sender(&self) -> PoolSender {
428 self.sender.clone()
429 }
430
431 pub fn recv(&mut self) -> PoolMessage {
433 self.consumer.recv()
434 }
435
436 pub fn try_recv(&mut self) -> Option<PoolMessage> {
438 self.consumer.try_recv()
439 }
440
441 pub fn connection_count(&self) -> usize {
443 self.connections.len()
444 }
445
446 pub fn active_connection_count(&self) -> usize {
448 self.connections.iter().filter(|c| !c.is_finished()).count()
449 }
450
451 pub fn relay_urls(&self) -> Vec<&str> {
453 self.connections.iter().map(|c| c.relay_url.as_str()).collect()
454 }
455
456 pub fn active_relay_urls(&self) -> Vec<&str> {
458 self.connections
459 .iter()
460 .filter(|c| !c.is_finished())
461 .map(|c| c.relay_url.as_str())
462 .collect()
463 }
464}
465
466impl Drop for RelayPool {
467 fn drop(&mut self) {
468 for conn in &self.connections {
470 conn.request_shutdown();
471 }
472 for conn in &mut self.connections {
474 if let Some(handle) = conn.thread_handle.take() {
475 let _ = handle.join();
476 }
477 }
478 }
479}
480
481pub fn create_pool(ring_capacity: usize, cache_size: usize) -> (PoolConsumer, Producer<PoolMessage>) {
487 let (producer, consumer) = RingBuffer::new(Capacity::at_least(ring_capacity)).split();
488 (PoolConsumer::new(consumer, cache_size), producer)
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_pool_creation() {
497 let pool = RelayPool::new(1024, 10_000, 64, 8);
498 assert_eq!(pool.connection_count(), 0);
499 }
500
501 #[test]
502 fn test_create_pool_helper() {
503 let (_consumer, _producer) = create_pool(1024, 10_000);
504 }
506
507 #[test]
508 fn test_pool_sender_clone_and_broadcast() {
509 let (bc_producer, mut c1) =
510 broadcast::RingBuffer::<String>::new(Capacity::exact(16), 4).split();
511 let mut c2 = c1.clone();
512
513 let sender = PoolSender {
514 producer: bc_producer,
515 };
516 let sender2 = sender.clone();
517
518 sender.send_raw("hello".to_string()).unwrap();
520 sender2.send_raw("world".to_string()).unwrap();
521
522 assert_eq!(c1.pop(), Some("hello".to_string()));
524 assert_eq!(c1.pop(), Some("world".to_string()));
525 assert_eq!(c2.pop(), Some("hello".to_string()));
526 assert_eq!(c2.pop(), Some("world".to_string()));
527 }
528
529 #[test]
530 fn test_pool_sender_via_relay_pool() {
531 let pool = RelayPool::new(1024, 10_000, 64, 8);
532 let sender = pool.sender();
533 let sender2 = pool.sender();
534
535 assert!(!sender.producer.is_full());
537 assert!(!sender2.producer.is_full());
538 }
539
540 #[test]
541 fn test_shutdown_flag_stops_thread() {
542 let shutdown = Arc::new(AtomicBool::new(false));
543 let shutdown_clone = Arc::clone(&shutdown);
544 let handle = std::thread::spawn(move || {
545 while !shutdown_clone.load(Ordering::Relaxed) {
546 std::thread::sleep(std::time::Duration::from_millis(1));
547 }
548 });
549 assert!(!handle.is_finished());
550 shutdown.store(true, Ordering::Relaxed);
551 handle.join().unwrap();
552 }
553
554 #[test]
555 fn test_cleanup_removes_dead_connections() {
556 let mut pool = RelayPool::new(1024, 10_000, 64, 8);
558 pool.add_relay("ws://127.0.0.1:1".to_string());
559 assert_eq!(pool.connection_count(), 1);
560
561 std::thread::sleep(std::time::Duration::from_millis(500));
563
564 pool.cleanup();
565 assert_eq!(pool.connection_count(), 0);
566 }
567
568 #[test]
569 fn test_remove_relay() {
570 let mut pool = RelayPool::new(1024, 10_000, 64, 8);
571 pool.add_relay("ws://127.0.0.1:1".to_string());
572 assert_eq!(pool.connection_count(), 1);
573
574 assert!(pool.remove_relay("ws://127.0.0.1:1"));
575 assert_eq!(pool.connection_count(), 0);
576
577 assert!(!pool.remove_relay("ws://127.0.0.1:2"));
579 }
580
581 #[test]
582 fn test_active_connection_count() {
583 let mut pool = RelayPool::new(1024, 10_000, 64, 8);
584 pool.add_relay("ws://127.0.0.1:1".to_string());
586 pool.add_relay("ws://127.0.0.1:2".to_string());
587 assert_eq!(pool.connection_count(), 2);
588
589 std::thread::sleep(std::time::Duration::from_millis(500));
591
592 assert_eq!(pool.connection_count(), 2);
594 assert_eq!(pool.active_connection_count(), 0);
595
596 pool.cleanup();
598 assert_eq!(pool.connection_count(), 0);
599 }
600
601 #[test]
602 fn test_relay_urls() {
603 let mut pool = RelayPool::new(1024, 10_000, 64, 8);
604 pool.add_relay("ws://127.0.0.1:1".to_string());
605 pool.add_relay("ws://127.0.0.1:2".to_string());
606
607 let urls = pool.relay_urls();
608 assert_eq!(urls.len(), 2);
609 assert!(urls.contains(&"ws://127.0.0.1:1"));
610 assert!(urls.contains(&"ws://127.0.0.1:2"));
611 }
612
613 #[test]
614 fn test_pool_drop_joins_threads() {
615 let mut pool = RelayPool::new(1024, 10_000, 64, 8);
616 pool.add_relay("ws://127.0.0.1:1".to_string());
617 pool.add_relay("ws://127.0.0.1:2".to_string());
618 drop(pool);
620 }
621
622 #[test]
623 fn test_add_after_remove_reuses_slots() {
624 let mut pool = RelayPool::new(1024, 10_000, 64, 2);
626 pool.add_relay("ws://127.0.0.1:1".to_string());
627 pool.add_relay("ws://127.0.0.1:2".to_string());
628
629 pool.remove_relay("ws://127.0.0.1:1");
631 assert_eq!(pool.connection_count(), 1);
632
633 pool.add_relay("ws://127.0.0.1:3".to_string());
635 assert!(pool.relay_urls().contains(&"ws://127.0.0.1:3"));
636 }
637}