use std::{
collections::{BTreeMap, btree_map::Entry},
sync::{Arc, RwLock},
};
use iroh_base::EndpointId;
use n0_future::{
boxed::BoxStream,
stream::{self, StreamExt},
time::SystemTime,
};
use super::{AddressLookup, EndpointData, EndpointInfo, Error, Item};
#[derive(Debug, Clone)]
pub struct MemoryLookup {
endpoints: Arc<RwLock<BTreeMap<EndpointId, StoredEndpointInfo>>>,
provenance: &'static str,
}
impl Default for MemoryLookup {
fn default() -> Self {
Self {
endpoints: Default::default(),
provenance: Self::PROVENANCE,
}
}
}
#[derive(Debug)]
struct StoredEndpointInfo {
data: EndpointData,
last_updated: SystemTime,
}
impl MemoryLookup {
pub const PROVENANCE: &'static str = "memory_lookup";
pub fn new() -> Self {
Self::default()
}
pub fn with_provenance(provenance: &'static str) -> Self {
Self {
endpoints: Default::default(),
provenance,
}
}
pub fn from_endpoint_info(infos: impl IntoIterator<Item = impl Into<EndpointInfo>>) -> Self {
let res = Self::default();
for info in infos {
res.add_endpoint_info(info);
}
res
}
pub fn set_endpoint_info(
&self,
endpoint_info: impl Into<EndpointInfo>,
) -> Option<EndpointData> {
let last_updated = SystemTime::now();
let EndpointInfo { endpoint_id, data } = endpoint_info.into();
let mut guard = self.endpoints.write().expect("poisoned");
let previous = guard.insert(endpoint_id, StoredEndpointInfo { data, last_updated });
previous.map(|x| x.data)
}
pub fn add_endpoint_info(&self, endpoint_info: impl Into<EndpointInfo>) {
let last_updated = SystemTime::now();
let EndpointInfo { endpoint_id, data } = endpoint_info.into();
let mut guard = self.endpoints.write().expect("poisoned");
match guard.entry(endpoint_id) {
Entry::Occupied(mut entry) => {
let existing = entry.get_mut();
existing.data.add_addrs(data.addrs().cloned());
existing.data.set_user_data(data.user_data().cloned());
existing.last_updated = last_updated;
}
Entry::Vacant(entry) => {
entry.insert(StoredEndpointInfo { data, last_updated });
}
}
}
pub fn get_endpoint_info(&self, endpoint_id: EndpointId) -> Option<EndpointInfo> {
let guard = self.endpoints.read().expect("poisoned");
let info = guard.get(&endpoint_id)?;
Some(EndpointInfo::from_parts(endpoint_id, info.data.clone()))
}
pub fn remove_endpoint_info(&self, endpoint_id: EndpointId) -> Option<EndpointInfo> {
let mut guard = self.endpoints.write().expect("poisoned");
let info = guard.remove(&endpoint_id)?;
Some(EndpointInfo::from_parts(endpoint_id, info.data))
}
}
impl AddressLookup for MemoryLookup {
fn publish(&self, _data: &EndpointData) {}
fn resolve(&self, endpoint_id: EndpointId) -> Option<BoxStream<Result<super::Item, Error>>> {
let guard = self.endpoints.read().expect("poisoned");
let info = guard.get(&endpoint_id);
match info {
Some(endpoint_info) => {
let last_updated = endpoint_info
.last_updated
.duration_since(SystemTime::UNIX_EPOCH)
.expect("time drift")
.as_micros() as u64;
let item = Item::new(
EndpointInfo::from_parts(endpoint_id, endpoint_info.data.clone()),
self.provenance,
Some(last_updated),
);
Some(stream::iter(Some(Ok(item))).boxed())
}
None => None,
}
}
}
#[cfg(all(test, with_crypto_provider))]
mod tests {
use iroh_base::{EndpointAddr, SecretKey, TransportAddr};
use n0_error::{Result, StackResultExt};
use super::*;
use crate::{Endpoint, endpoint::presets};
#[tokio::test]
async fn test_basic() -> Result {
let address_lookup = MemoryLookup::new();
let _ep = Endpoint::builder(presets::Minimal)
.address_lookup(address_lookup.clone())
.bind()
.await?;
let key = SecretKey::from_bytes(&[0u8; 32]);
let addr = EndpointAddr::from_parts(
key.public(),
[TransportAddr::Relay("https://example.com".parse()?)],
);
let user_data = Some("foobar".parse().unwrap());
let endpoint_info = EndpointInfo::from(addr.clone()).with_user_data(user_data.clone());
address_lookup.add_endpoint_info(endpoint_info.clone());
let back = address_lookup
.get_endpoint_info(key.public())
.context("no addr")?;
assert_eq!(back, endpoint_info);
assert_eq!(back.user_data(), user_data.as_ref());
assert_eq!(back.into_endpoint_addr(), addr);
let removed = address_lookup
.remove_endpoint_info(key.public())
.context("nothing removed")?;
assert_eq!(removed, endpoint_info);
let res = address_lookup.get_endpoint_info(key.public());
assert!(res.is_none());
Ok(())
}
#[tokio::test]
async fn test_provenance() -> Result {
let address_lookup = MemoryLookup::with_provenance("foo");
let key = SecretKey::from_bytes(&[0u8; 32]);
let addr = EndpointAddr::from_parts(
key.public(),
[TransportAddr::Relay("https://example.com".parse()?)],
);
address_lookup.add_endpoint_info(addr);
let mut stream = address_lookup.resolve(key.public()).unwrap();
let item = stream.next().await.unwrap()?;
assert_eq!(item.provenance(), "foo");
assert_eq!(
item.relay_urls().next(),
Some(&("https://example.com".parse()?))
);
Ok(())
}
}