use tokio_with_wasm::alias as tokio;
#[derive(thiserror::Error, Debug)]
pub enum EnsureError {
#[error("URL normalization error")]
Normalize(#[from] url::ParseError),
#[error("failed to connect: {0}")]
Connect(#[from] crate::relay::ConnectError),
}
#[derive(Debug)]
pub struct Network {
relays: std::sync::Arc<tokio::sync::Mutex<std::collections::HashMap<String, crate::Relay>>>,
}
pub struct PublishResult {
pub error: Option<String>,
pub relay_url: String,
}
#[derive(Debug)]
pub enum Occurrence {
Event(Box<crate::Event>, url::Url),
EOSE,
Close,
}
impl Default for Network {
fn default() -> Self {
Self::new()
}
}
impl Network {
pub fn new() -> Self {
Self {
relays: std::sync::Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
}
}
pub async fn get_relay(&self, url: &str) -> Option<crate::Relay> {
let normalized_url = crate::normalize_url(url).ok()?;
self.relays
.lock()
.await
.get(normalized_url.as_str())
.cloned()
}
pub async fn ensure_relay<S>(&self, url: S) -> Result<crate::Relay, EnsureError>
where
S: AsRef<str>,
{
let normalized_url = crate::normalize_url(url.as_ref())?;
let mut relay_map = self.relays.lock().await;
if let Some(relay) = relay_map.get(normalized_url.as_str()) {
return Ok(relay.clone());
}
let (on_close, handle_close) = tokio::sync::oneshot::channel::<String>();
let nm_ = normalized_url.clone();
let relay_map_on_close = self.relays.clone();
tokio::spawn(async move {
match handle_close.await {
Ok(reason) => {
log::info!("[{}] relay connection closed: {}", nm_.as_str(), reason);
relay_map_on_close.lock().await.remove(nm_.as_str());
}
Err(err) => {
log::info!(
"got an error from the handle_close oneshot for {}: {}",
nm_.as_str(),
err
);
}
}
});
match crate::Relay::connect(normalized_url.to_owned(), Some(on_close)).await {
Ok(relay) => {
relay_map.insert(normalized_url.to_string(), relay.clone());
Ok(relay)
}
Err(err) => Err(EnsureError::Connect(err)),
}
}
pub async fn publish_many<S, T>(
&mut self,
urls: T,
event: crate::Event,
) -> tokio::sync::mpsc::UnboundedReceiver<PublishResult>
where
S: AsRef<str>,
T: AsRef<[S]>,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
for url in urls.as_ref() {
let tx = tx.clone();
let event = event.clone();
let pool = self.clone();
if let Ok(url) = crate::normalize_url(url.as_ref()) {
tokio::spawn(async move {
let result = match pool.ensure_relay(&url).await {
Ok(relay) => match relay.publish(event).await {
Ok(_) => PublishResult {
error: None,
relay_url: url.to_string(),
},
Err(err) => PublishResult {
error: Some(err.to_string()),
relay_url: url.to_string(),
},
},
Err(err) => PublishResult {
error: Some(err.to_string()),
relay_url: url.to_string(),
},
};
let _ = tx.send(result);
});
}
}
rx
}
pub async fn query<S, T>(
&self,
urls: T,
filter: crate::Filter,
subscription_options: crate::SubscriptionOptions,
) -> Vec<crate::Event>
where
S: AsRef<str>,
T: AsRef<[S]>,
{
let mut events = Vec::with_capacity(filter.limit.unwrap_or(500) * urls.as_ref().len() / 2);
let mut occurrences = self.subscribe(urls, filter, subscription_options).await;
while let Some(occ) = occurrences.recv().await {
match occ {
Occurrence::Event(event, _) => {
events.push(*event);
}
_ => {
break;
}
}
}
glidesort::sort_by_key(&mut events, |event| u32::MAX - event.created_at.0);
events
}
pub async fn subscribe<S, T>(
&self,
urls: T,
filter: crate::Filter,
subscription_options: crate::SubscriptionOptions,
) -> tokio::sync::mpsc::Receiver<Occurrence>
where
S: AsRef<str>,
T: AsRef<[S]>,
{
let (tx, rx) = tokio::sync::mpsc::channel(256);
let skip_ids = std::sync::Arc::new(dashmap::DashSet::new());
let eose_counter =
std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(urls.as_ref().len()));
let closed_counter =
std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(urls.as_ref().len()));
let eosed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
for url in urls.as_ref() {
let filter = filter.clone();
let pool = self.clone();
let opts = crate::SubscriptionOptions {
skip_ids: Some(skip_ids.clone()),
..subscription_options.clone()
};
let tx = tx.clone();
let eose_counter = eose_counter.clone();
let closed_counter = closed_counter.clone();
let eosed = eosed.clone();
let url_owned = url.as_ref().to_owned();
tokio::spawn(async move {
if let Ok(relay) = pool.ensure_relay(&url_owned).await {
let mut sub = relay.subscribe(filter, opts).await;
while let Some(occ) = sub.recv().await {
match occ {
crate::relay::Occurrence::Event(event) => {
if tx
.send(Occurrence::Event(Box::new(event), relay.url.clone()))
.await
.is_err()
{
return;
}
}
crate::relay::Occurrence::Eose => {
if eose_counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
&& !eosed.swap(true, std::sync::atomic::Ordering::SeqCst)
&& tx.send(Occurrence::EOSE).await.is_err()
{
return;
}
}
crate::relay::Occurrence::Close(_) => break,
}
}
}
if eose_counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) == 1
&& !eosed.swap(true, std::sync::atomic::Ordering::SeqCst)
&& tx.send(Occurrence::EOSE).await.is_err()
{
return;
}
if closed_counter.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) == 1
&& tx.send(Occurrence::Close).await.is_err()
{
}
});
}
drop(tx);
rx
}
}
impl Clone for Network {
fn clone(&self) -> Self {
Self {
relays: self.relays.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::*;
use std::collections::HashSet;
#[tokio::test]
async fn test_pool_subscribe_multiple() {
let pool = Network::new();
let urls = vec![
"wss://nos.lol".to_string(),
"wss://nostr.wine".to_string(),
"wss://nostr.mom".to_string(),
"wss://relay.damus.io".to_string(),
"wss://relay.primal.net".to_string(),
];
let filter = Filter {
kinds: Some(vec![Kind(1), Kind(1111)]),
limit: Some(5),
..Default::default()
};
let events = pool
.query(
urls,
filter,
SubscriptionOptions {
label: Some("test".to_string()),
..Default::default()
},
)
.await;
assert!(events.len() > 10, "was {}", events.len()); assert!(events.len() < 25, "was {}", events.len());
let mut ids = HashSet::new();
for event in events.iter() {
assert!(ids.insert(event.id));
}
}
#[tokio::test]
async fn test_pool_ensure() {
let pool = Network::new();
let relay1 = pool.ensure_relay("wss://nos.lol").await.unwrap();
let relay2 = pool.ensure_relay("wss://nos.lol").await.unwrap();
assert!(std::ptr::eq(
relay1.sub_sender_map.as_ref() as *const _,
relay2.sub_sender_map.as_ref() as *const _
));
}
}