use nostro2::NostrRelayEvent;
use nostro2_cache::Cache;
use quetzalcoatl::broadcast;
use quetzalcoatl::capacity::Capacity;
use quetzalcoatl::mpsc::{Consumer, Producer, RingBuffer};
use std::net::TcpStream;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tungstenite::stream::MaybeTlsStream;
use tungstenite::{connect, Message, WebSocket};
#[derive(Debug, Clone)]
pub enum PoolMessage {
RelayEvent {
relay_url: String,
event: NostrRelayEvent,
},
ConnectionClosed {
relay_url: String,
error: Option<String>,
},
}
#[derive(Clone)]
pub struct PoolSender {
producer: broadcast::Producer<String>,
}
impl PoolSender {
pub fn send<T: Into<nostro2::NostrClientEvent>>(&self, msg: T) -> Result<(), String> {
let client_event: nostro2::NostrClientEvent = msg.into();
let json = serde_json::to_string(&client_event).map_err(|e| e.to_string())?;
self.producer.push(json)
}
pub fn send_raw(&self, json: String) -> Result<(), String> {
self.producer.push(json)
}
}
pub struct RelayConnection {
relay_url: String,
thread_handle: Option<std::thread::JoinHandle<()>>,
shutdown: Arc<AtomicBool>,
}
impl RelayConnection {
pub fn spawn(
relay_url: String,
mut producer: Producer<PoolMessage>,
outbound: broadcast::Consumer<String>,
shutdown: Arc<AtomicBool>,
) -> Self {
let url = relay_url.clone();
let shutdown_clone = Arc::clone(&shutdown);
let thread_handle = std::thread::spawn(move || {
match Self::run_connection(&url, &mut producer, outbound, &shutdown_clone) {
Ok(()) => {
let _ = producer.push(PoolMessage::ConnectionClosed {
relay_url: url.clone(),
error: None,
});
}
Err(e) => {
let _ = producer.push(PoolMessage::ConnectionClosed {
relay_url: url.clone(),
error: Some(e.to_string()),
});
}
}
});
Self {
relay_url,
thread_handle: Some(thread_handle),
shutdown,
}
}
pub fn is_finished(&self) -> bool {
self.thread_handle
.as_ref()
.is_some_and(|h| h.is_finished())
}
pub fn request_shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
fn shutdown_and_join(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
if let Some(handle) = self.thread_handle.take() {
let _ = handle.join();
}
}
fn run_connection(
url: &str,
producer: &mut Producer<PoolMessage>,
mut outbound: broadcast::Consumer<String>,
shutdown: &AtomicBool,
) -> Result<(), Box<dyn std::error::Error>> {
let _ = rustls::crypto::ring::default_provider().install_default();
let (mut socket, _response) = connect(url)?;
let subscription = nostro2::NostrSubscription {
kinds: vec![1].into(),
limit: Some(1000),
..Default::default()
};
let client_event: nostro2::NostrClientEvent = subscription.into();
let subscription_json = serde_json::to_string(&client_event)?;
socket.send(Message::Text(subscription_json.into()))?;
set_nonblocking(&socket, true)?;
loop {
if shutdown.load(Ordering::Relaxed) {
let _ = socket.send(Message::Close(None));
break;
}
let mut had_work = false;
match socket.read() {
Ok(Message::Text(text)) => {
if let Ok(event) = text.parse::<NostrRelayEvent>() {
let mut pool_msg = PoolMessage::RelayEvent {
relay_url: url.to_string(),
event,
};
loop {
match producer.push(pool_msg) {
Ok(()) => break,
Err(returned) => {
pool_msg = returned;
std::hint::spin_loop();
}
}
}
}
had_work = true;
}
Ok(Message::Close(_)) => break,
Ok(Message::Ping(data)) => {
let _ = socket.send(Message::Pong(data));
had_work = true;
}
Ok(_) => {
had_work = true;
}
Err(tungstenite::Error::Io(ref e))
if e.kind() == std::io::ErrorKind::WouldBlock =>
{
}
Err(e) => return Err(e.into()),
}
while let Some(json) = outbound.pop() {
match socket.send(Message::Text(json.into())) {
Ok(()) => {
had_work = true;
}
Err(tungstenite::Error::Io(ref e))
if e.kind() == std::io::ErrorKind::WouldBlock =>
{
had_work = true;
break;
}
Err(e) => return Err(e.into()),
}
}
if !had_work {
std::thread::sleep(std::time::Duration::from_millis(1));
}
}
Ok(())
}
pub fn relay_url(&self) -> &str {
&self.relay_url
}
}
impl Drop for RelayConnection {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
if let Some(handle) = self.thread_handle.take() {
let _ = handle.join();
}
}
}
fn set_nonblocking(
socket: &WebSocket<MaybeTlsStream<TcpStream>>,
nonblocking: bool,
) -> std::io::Result<()> {
match socket.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.set_nonblocking(nonblocking),
MaybeTlsStream::Rustls(tls) => tls.get_ref().set_nonblocking(nonblocking),
_ => Ok(()),
}
}
pub struct PoolConsumer {
consumer: Consumer<PoolMessage>,
dedup_cache: Cache,
}
impl PoolConsumer {
pub fn new(consumer: Consumer<PoolMessage>, cache_size: usize) -> Self {
Self {
consumer,
dedup_cache: Cache::new(cache_size),
}
}
pub fn try_recv(&mut self) -> Option<PoolMessage> {
loop {
match self.consumer.pop()? {
PoolMessage::RelayEvent {
relay_url,
event: NostrRelayEvent::NewNote(tag, sub_id, note),
} => {
if let Some(ref event_id) = note.id {
if self.dedup_cache.insert(event_id.clone()) {
return Some(PoolMessage::RelayEvent {
relay_url,
event: NostrRelayEvent::NewNote(tag, sub_id, note),
});
}
continue;
}
return Some(PoolMessage::RelayEvent {
relay_url,
event: NostrRelayEvent::NewNote(tag, sub_id, note),
});
}
other => {
return Some(other);
}
}
}
}
pub fn recv(&mut self) -> PoolMessage {
loop {
if let Some(msg) = self.try_recv() {
return msg;
}
std::hint::spin_loop();
}
}
}
pub struct RelayPool {
connections: Vec<RelayConnection>,
consumer: PoolConsumer,
sender: PoolSender,
broadcast_consumer: broadcast::Consumer<String>,
mpsc_producer: Producer<PoolMessage>,
}
impl RelayPool {
pub fn new(
ring_capacity: usize,
cache_size: usize,
broadcast_capacity: usize,
max_relays: usize,
) -> Self {
let (mpsc_producer, mpsc_consumer) =
RingBuffer::new(Capacity::at_least(ring_capacity)).split();
let (bc_producer, bc_consumer) =
broadcast::RingBuffer::new(Capacity::at_least(broadcast_capacity), max_relays + 1)
.split();
Self {
connections: Vec::new(),
consumer: PoolConsumer::new(mpsc_consumer, cache_size),
sender: PoolSender {
producer: bc_producer,
},
broadcast_consumer: bc_consumer,
mpsc_producer,
}
}
pub fn add_relay(&mut self, relay_url: String) {
self.cleanup();
let shutdown = Arc::new(AtomicBool::new(false));
let bc_consumer = self.broadcast_consumer.clone();
let mpsc_producer = self.mpsc_producer.clone();
let connection =
RelayConnection::spawn(relay_url, mpsc_producer, bc_consumer, shutdown);
self.connections.push(connection);
}
pub fn remove_relay(&mut self, relay_url: &str) -> bool {
if let Some(pos) = self
.connections
.iter()
.position(|c| c.relay_url == relay_url)
{
let mut conn = self.connections.swap_remove(pos);
conn.shutdown_and_join();
true
} else {
false
}
}
pub fn cleanup(&mut self) {
self.connections.retain_mut(|conn| {
if conn.is_finished() {
if let Some(handle) = conn.thread_handle.take() {
let _ = handle.join();
}
false
} else {
true
}
});
}
pub fn sender(&self) -> PoolSender {
self.sender.clone()
}
pub fn recv(&mut self) -> PoolMessage {
self.consumer.recv()
}
pub fn try_recv(&mut self) -> Option<PoolMessage> {
self.consumer.try_recv()
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn active_connection_count(&self) -> usize {
self.connections.iter().filter(|c| !c.is_finished()).count()
}
pub fn relay_urls(&self) -> Vec<&str> {
self.connections.iter().map(|c| c.relay_url.as_str()).collect()
}
pub fn active_relay_urls(&self) -> Vec<&str> {
self.connections
.iter()
.filter(|c| !c.is_finished())
.map(|c| c.relay_url.as_str())
.collect()
}
}
impl Drop for RelayPool {
fn drop(&mut self) {
for conn in &self.connections {
conn.request_shutdown();
}
for conn in &mut self.connections {
if let Some(handle) = conn.thread_handle.take() {
let _ = handle.join();
}
}
}
}
pub fn create_pool(ring_capacity: usize, cache_size: usize) -> (PoolConsumer, Producer<PoolMessage>) {
let (producer, consumer) = RingBuffer::new(Capacity::at_least(ring_capacity)).split();
(PoolConsumer::new(consumer, cache_size), producer)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_creation() {
let pool = RelayPool::new(1024, 10_000, 64, 8);
assert_eq!(pool.connection_count(), 0);
}
#[test]
fn test_create_pool_helper() {
let (_consumer, _producer) = create_pool(1024, 10_000);
}
#[test]
fn test_pool_sender_clone_and_broadcast() {
let (bc_producer, mut c1) =
broadcast::RingBuffer::<String>::new(Capacity::exact(16), 4).split();
let mut c2 = c1.clone();
let sender = PoolSender {
producer: bc_producer,
};
let sender2 = sender.clone();
sender.send_raw("hello".to_string()).unwrap();
sender2.send_raw("world".to_string()).unwrap();
assert_eq!(c1.pop(), Some("hello".to_string()));
assert_eq!(c1.pop(), Some("world".to_string()));
assert_eq!(c2.pop(), Some("hello".to_string()));
assert_eq!(c2.pop(), Some("world".to_string()));
}
#[test]
fn test_pool_sender_via_relay_pool() {
let pool = RelayPool::new(1024, 10_000, 64, 8);
let sender = pool.sender();
let sender2 = pool.sender();
assert!(!sender.producer.is_full());
assert!(!sender2.producer.is_full());
}
#[test]
fn test_shutdown_flag_stops_thread() {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = Arc::clone(&shutdown);
let handle = std::thread::spawn(move || {
while !shutdown_clone.load(Ordering::Relaxed) {
std::thread::sleep(std::time::Duration::from_millis(1));
}
});
assert!(!handle.is_finished());
shutdown.store(true, Ordering::Relaxed);
handle.join().unwrap();
}
#[test]
fn test_cleanup_removes_dead_connections() {
let mut pool = RelayPool::new(1024, 10_000, 64, 8);
pool.add_relay("ws://127.0.0.1:1".to_string());
assert_eq!(pool.connection_count(), 1);
std::thread::sleep(std::time::Duration::from_millis(500));
pool.cleanup();
assert_eq!(pool.connection_count(), 0);
}
#[test]
fn test_remove_relay() {
let mut pool = RelayPool::new(1024, 10_000, 64, 8);
pool.add_relay("ws://127.0.0.1:1".to_string());
assert_eq!(pool.connection_count(), 1);
assert!(pool.remove_relay("ws://127.0.0.1:1"));
assert_eq!(pool.connection_count(), 0);
assert!(!pool.remove_relay("ws://127.0.0.1:2"));
}
#[test]
fn test_active_connection_count() {
let mut pool = RelayPool::new(1024, 10_000, 64, 8);
pool.add_relay("ws://127.0.0.1:1".to_string());
pool.add_relay("ws://127.0.0.1:2".to_string());
assert_eq!(pool.connection_count(), 2);
std::thread::sleep(std::time::Duration::from_millis(500));
assert_eq!(pool.connection_count(), 2);
assert_eq!(pool.active_connection_count(), 0);
pool.cleanup();
assert_eq!(pool.connection_count(), 0);
}
#[test]
fn test_relay_urls() {
let mut pool = RelayPool::new(1024, 10_000, 64, 8);
pool.add_relay("ws://127.0.0.1:1".to_string());
pool.add_relay("ws://127.0.0.1:2".to_string());
let urls = pool.relay_urls();
assert_eq!(urls.len(), 2);
assert!(urls.contains(&"ws://127.0.0.1:1"));
assert!(urls.contains(&"ws://127.0.0.1:2"));
}
#[test]
fn test_pool_drop_joins_threads() {
let mut pool = RelayPool::new(1024, 10_000, 64, 8);
pool.add_relay("ws://127.0.0.1:1".to_string());
pool.add_relay("ws://127.0.0.1:2".to_string());
drop(pool);
}
#[test]
fn test_add_after_remove_reuses_slots() {
let mut pool = RelayPool::new(1024, 10_000, 64, 2);
pool.add_relay("ws://127.0.0.1:1".to_string());
pool.add_relay("ws://127.0.0.1:2".to_string());
pool.remove_relay("ws://127.0.0.1:1");
assert_eq!(pool.connection_count(), 1);
pool.add_relay("ws://127.0.0.1:3".to_string());
assert!(pool.relay_urls().contains(&"ws://127.0.0.1:3"));
}
}