use std::{
any::Any,
future::{Future, IntoFuture},
net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6},
pin::Pin,
sync::Arc,
task::Poll,
time::Duration,
};
use anyhow::{anyhow, bail, Context, Result};
use derive_more::Debug;
use futures_lite::{Stream, StreamExt};
use pin_project::pin_project;
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
use tracing::{debug, instrument, trace, warn};
use url::Url;
use crate::{
discovery::{
dns::DnsDiscovery, pkarr::PkarrPublisher, ConcurrentDiscovery, Discovery, DiscoveryTask,
},
dns::{default_resolver, DnsResolver},
key::{PublicKey, SecretKey},
magicsock::{self, Handle, QuicMappedAddr},
relay::{force_staging_infra, RelayMode, RelayUrl},
tls, NodeId,
};
mod rtt_actor;
pub use bytes::Bytes;
pub use iroh_base::node_addr::{AddrInfo, NodeAddr};
pub use quinn::{
AcceptBi, AcceptUni, AckFrequencyConfig, ApplicationClose, Chunk, ClosedStream, Connection,
ConnectionClose, ConnectionError, ConnectionStats, MtuDiscoveryConfig, OpenBi, OpenUni,
ReadDatagram, ReadError, ReadExactError, ReadToEndError, RecvStream, ResetError, RetryError,
SendDatagramError, SendStream, ServerConfig, StoppedError, StreamId, TransportConfig, VarInt,
WeakConnectionHandle, WriteError, ZeroRttAccepted,
};
pub use quinn_proto::{
congestion::{Controller, ControllerFactory},
crypto::{
AeadKey, CryptoError, ExportKeyingMaterialError, HandshakeTokenKey,
ServerConfig as CryptoServerConfig, UnsupportedVersion,
},
FrameStats, PathStats, TransportError, TransportErrorCode, UdpStats, Written,
};
use self::rtt_actor::RttMessage;
pub use super::magicsock::{
ConnectionType, ConnectionTypeStream, ControlMsg, DirectAddr, DirectAddrInfo, DirectAddrType,
DirectAddrsStream, RemoteInfo, Source,
};
const DISCOVERY_WAIT_PERIOD: Duration = Duration::from_millis(500);
type DiscoveryBuilder = Box<dyn FnOnce(&SecretKey) -> Option<Box<dyn Discovery>> + Send + Sync>;
#[derive(Debug)]
pub struct Builder {
secret_key: Option<SecretKey>,
relay_mode: RelayMode,
alpn_protocols: Vec<Vec<u8>>,
transport_config: Option<quinn::TransportConfig>,
keylog: bool,
#[debug(skip)]
discovery: Vec<DiscoveryBuilder>,
proxy_url: Option<Url>,
node_map: Option<Vec<NodeAddr>>,
dns_resolver: Option<DnsResolver>,
#[cfg(any(test, feature = "test-utils"))]
#[cfg_attr(iroh_docsrs, doc(cfg(any(test, feature = "test-utils"))))]
insecure_skip_relay_cert_verify: bool,
addr_v4: Option<SocketAddrV4>,
addr_v6: Option<SocketAddrV6>,
}
impl Default for Builder {
fn default() -> Self {
Self {
secret_key: Default::default(),
relay_mode: default_relay_mode(),
alpn_protocols: Default::default(),
transport_config: Default::default(),
keylog: Default::default(),
discovery: Default::default(),
proxy_url: None,
node_map: None,
dns_resolver: None,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: false,
addr_v4: None,
addr_v6: None,
}
}
}
impl Builder {
pub async fn bind(self) -> Result<Endpoint> {
let relay_map = self.relay_mode.relay_map();
let secret_key = self.secret_key.unwrap_or_else(SecretKey::generate);
let static_config = StaticConfig {
transport_config: Arc::new(self.transport_config.unwrap_or_default()),
keylog: self.keylog,
secret_key: secret_key.clone(),
};
let dns_resolver = self
.dns_resolver
.unwrap_or_else(|| default_resolver().clone());
let discovery = self
.discovery
.into_iter()
.filter_map(|f| f(&secret_key))
.collect::<Vec<_>>();
let discovery: Option<Box<dyn Discovery>> = match discovery.len() {
0 => None,
1 => Some(discovery.into_iter().next().unwrap()),
_ => Some(Box::new(ConcurrentDiscovery::from_services(discovery))),
};
let msock_opts = magicsock::Options {
addr_v4: self.addr_v4,
addr_v6: self.addr_v6,
secret_key,
relay_map,
node_map: self.node_map,
discovery,
proxy_url: self.proxy_url,
dns_resolver,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_relay_cert_verify: self.insecure_skip_relay_cert_verify,
};
Endpoint::bind(static_config, msock_opts, self.alpn_protocols).await
}
pub fn bind_addr_v4(mut self, addr: SocketAddrV4) -> Self {
self.addr_v4.replace(addr);
self
}
pub fn bind_addr_v6(mut self, addr: SocketAddrV6) -> Self {
self.addr_v6.replace(addr);
self
}
pub fn secret_key(mut self, secret_key: SecretKey) -> Self {
self.secret_key = Some(secret_key);
self
}
pub fn alpns(mut self, alpn_protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols = alpn_protocols;
self
}
pub fn relay_mode(mut self, relay_mode: RelayMode) -> Self {
self.relay_mode = relay_mode;
self
}
pub fn clear_discovery(mut self) -> Self {
self.discovery.clear();
self
}
pub fn discovery(mut self, discovery: Box<dyn Discovery>) -> Self {
self.discovery.clear();
self.discovery.push(Box::new(move |_| Some(discovery)));
self
}
pub fn add_discovery<F, D>(mut self, discovery: F) -> Self
where
F: FnOnce(&SecretKey) -> Option<D> + Send + Sync + 'static,
D: Discovery + 'static,
{
let discovery: DiscoveryBuilder =
Box::new(move |secret_key| discovery(secret_key).map(|x| Box::new(x) as _));
self.discovery.push(discovery);
self
}
pub fn discovery_n0(mut self) -> Self {
self.discovery.push(Box::new(|secret_key| {
Some(Box::new(PkarrPublisher::n0_dns(secret_key.clone())))
}));
self.discovery
.push(Box::new(|_| Some(Box::new(DnsDiscovery::n0_dns()))));
self
}
#[cfg(feature = "discovery-pkarr-dht")]
pub fn discovery_dht(mut self) -> Self {
use crate::discovery::pkarr::dht::DhtDiscovery;
self.discovery.push(Box::new(|secret_key| {
Some(Box::new(
DhtDiscovery::builder()
.secret_key(secret_key.clone())
.build()
.unwrap(),
))
}));
self
}
#[cfg(feature = "discovery-local-network")]
pub fn discovery_local_network(mut self) -> Self {
use crate::discovery::local_swarm_discovery::LocalSwarmDiscovery;
self.discovery.push(Box::new(|secret_key| {
LocalSwarmDiscovery::new(secret_key.public())
.map(|x| Box::new(x) as _)
.ok()
}));
self
}
pub fn known_nodes(mut self, nodes: Vec<NodeAddr>) -> Self {
self.node_map = Some(nodes);
self
}
pub fn transport_config(mut self, transport_config: quinn::TransportConfig) -> Self {
self.transport_config = Some(transport_config);
self
}
pub fn dns_resolver(mut self, dns_resolver: DnsResolver) -> Self {
self.dns_resolver = Some(dns_resolver);
self
}
pub fn proxy_url(mut self, url: Url) -> Self {
self.proxy_url.replace(url);
self
}
pub fn proxy_from_env(mut self) -> Self {
self.proxy_url = proxy_url_from_env();
self
}
pub fn keylog(mut self, keylog: bool) -> Self {
self.keylog = keylog;
self
}
#[cfg(any(test, feature = "test-utils"))]
#[cfg_attr(iroh_docsrs, doc(cfg(any(test, feature = "test-utils"))))]
pub fn insecure_skip_relay_cert_verify(mut self, skip_verify: bool) -> Self {
self.insecure_skip_relay_cert_verify = skip_verify;
self
}
}
#[derive(Debug)]
struct StaticConfig {
secret_key: SecretKey,
transport_config: Arc<quinn::TransportConfig>,
keylog: bool,
}
impl StaticConfig {
fn create_server_config(&self, alpn_protocols: Vec<Vec<u8>>) -> Result<ServerConfig> {
let server_config = make_server_config(
&self.secret_key,
alpn_protocols,
self.transport_config.clone(),
self.keylog,
)?;
Ok(server_config)
}
}
pub fn make_server_config(
secret_key: &SecretKey,
alpn_protocols: Vec<Vec<u8>>,
transport_config: Arc<TransportConfig>,
keylog: bool,
) -> Result<ServerConfig> {
let quic_server_config = tls::make_server_config(secret_key, alpn_protocols, keylog)?;
let mut server_config = ServerConfig::with_crypto(Arc::new(quic_server_config));
server_config.transport_config(transport_config);
Ok(server_config)
}
#[derive(Clone, Debug)]
pub struct Endpoint {
msock: Handle,
endpoint: quinn::Endpoint,
rtt_actor: Arc<rtt_actor::RttHandle>,
cancel_token: CancellationToken,
static_config: Arc<StaticConfig>,
}
impl Endpoint {
pub fn builder() -> Builder {
Builder::default()
}
#[instrument("ep", skip_all, fields(me = %static_config.secret_key.public().fmt_short()))]
async fn bind(
static_config: StaticConfig,
msock_opts: magicsock::Options,
initial_alpns: Vec<Vec<u8>>,
) -> Result<Self> {
let msock = magicsock::MagicSock::spawn(msock_opts).await?;
trace!("created magicsock");
let server_config = static_config.create_server_config(initial_alpns)?;
let mut endpoint_config = quinn::EndpointConfig::default();
endpoint_config.grease_quic_bit(false);
let endpoint = quinn::Endpoint::new_with_abstract_socket(
endpoint_config,
Some(server_config),
Arc::new(msock.clone()),
Arc::new(quinn::TokioRuntime),
)?;
trace!("created quinn endpoint");
debug!(version = env!("CARGO_PKG_VERSION"), "iroh Endpoint created");
Ok(Self {
msock,
endpoint,
rtt_actor: Arc::new(rtt_actor::RttHandle::new()),
cancel_token: CancellationToken::new(),
static_config: Arc::new(static_config),
})
}
pub fn set_alpns(&self, alpns: Vec<Vec<u8>>) -> Result<()> {
let server_config = self.static_config.create_server_config(alpns)?;
self.endpoint.set_server_config(Some(server_config));
Ok(())
}
#[instrument(skip_all, fields(me = %self.node_id().fmt_short(), alpn = ?String::from_utf8_lossy(alpn)))]
pub async fn connect(&self, node_addr: impl Into<NodeAddr>, alpn: &[u8]) -> Result<Connection> {
let node_addr = node_addr.into();
tracing::Span::current().record("remote", node_addr.node_id.fmt_short());
if node_addr.node_id == self.node_id() {
bail!(
"Connecting to ourself is not supported ({} is the node id of this node)",
node_addr.node_id.fmt_short()
);
}
if !node_addr.info.is_empty() {
self.add_node_addr(node_addr.clone())?;
}
let NodeAddr { node_id, info } = node_addr.clone();
let (addr, discovery) = self
.get_mapping_addr_and_maybe_start_discovery(node_addr)
.await
.with_context(|| {
format!(
"No addressing information for NodeId({}), unable to connect",
node_id.fmt_short()
)
})?;
debug!(
"connecting to {}: (via {} - {:?})",
node_id, addr, info.direct_addresses
);
let conn = self.connect_quinn(node_id, alpn, addr).await;
if let Some(discovery) = discovery {
discovery.cancel();
}
conn
}
#[deprecated(
since = "0.27.0",
note = "Please use `connect` directly with a NodeId. This fn will be removed in 0.28.0."
)]
pub async fn connect_by_node_id(&self, node_id: NodeId, alpn: &[u8]) -> Result<Connection> {
let addr = NodeAddr::new(node_id);
self.connect(addr, alpn).await
}
#[instrument(
skip_all,
fields(remote_node = node_id.fmt_short(), alpn = %String::from_utf8_lossy(alpn))
)]
async fn connect_quinn(
&self,
node_id: NodeId,
alpn: &[u8],
addr: QuicMappedAddr,
) -> Result<Connection> {
debug!("Attempting connection...");
let client_config = {
let alpn_protocols = vec![alpn.to_vec()];
let quic_client_config = tls::make_client_config(
&self.static_config.secret_key,
Some(node_id),
alpn_protocols,
self.static_config.keylog,
)?;
let mut client_config = quinn::ClientConfig::new(Arc::new(quic_client_config));
let mut transport_config = quinn::TransportConfig::default();
transport_config.keep_alive_interval(Some(Duration::from_secs(1)));
client_config.transport_config(Arc::new(transport_config));
client_config
};
let connect = self
.endpoint
.connect_with(client_config, addr.0, "localhost")?;
let connection = connect
.await
.context("failed connecting to remote endpoint")?;
let rtt_msg = RttMessage::NewConnection {
connection: connection.weak_handle(),
conn_type_changes: self.conn_type_stream(node_id)?,
node_id,
};
if let Err(err) = self.rtt_actor.msg_tx.send(rtt_msg).await {
warn!("rtt-actor not reachable: {err:#}");
}
debug!("Connection established");
Ok(connection)
}
pub fn accept(&self) -> Accept<'_> {
Accept {
inner: self.endpoint.accept(),
ep: self.clone(),
}
}
pub fn add_node_addr(&self, node_addr: NodeAddr) -> Result<()> {
self.add_node_addr_inner(node_addr, magicsock::Source::App)
}
pub fn add_node_addr_with_source(
&self,
node_addr: NodeAddr,
source: &'static str,
) -> Result<()> {
self.add_node_addr_inner(
node_addr,
magicsock::Source::NamedApp {
name: source.into(),
},
)
}
fn add_node_addr_inner(&self, node_addr: NodeAddr, source: magicsock::Source) -> Result<()> {
if node_addr.node_id == self.node_id() {
bail!(
"Adding our own address is not supported ({} is the node id of this node)",
node_addr.node_id.fmt_short()
);
}
self.msock.add_node_addr(node_addr, source)
}
pub fn secret_key(&self) -> &SecretKey {
&self.static_config.secret_key
}
pub fn node_id(&self) -> NodeId {
self.static_config.secret_key.public()
}
pub async fn node_addr(&self) -> Result<NodeAddr> {
let addrs = self
.direct_addresses()
.next()
.await
.ok_or(anyhow!("No IP endpoints found"))?;
let relay = self.home_relay();
Ok(NodeAddr::from_parts(
self.node_id(),
relay,
addrs.into_iter().map(|x| x.addr),
))
}
pub fn home_relay(&self) -> Option<RelayUrl> {
self.msock.my_relay()
}
pub fn watch_home_relay(&self) -> impl Stream<Item = RelayUrl> {
self.msock.watch_home_relay()
}
pub fn direct_addresses(&self) -> DirectAddrsStream {
self.msock.direct_addresses()
}
pub fn bound_sockets(&self) -> (SocketAddr, Option<SocketAddr>) {
self.msock.local_addr()
}
pub fn remote_info(&self, node_id: NodeId) -> Option<RemoteInfo> {
self.msock.remote_info(node_id)
}
pub fn remote_info_iter(&self) -> impl Iterator<Item = RemoteInfo> {
self.msock.list_remote_infos().into_iter()
}
pub fn conn_type_stream(&self, node_id: NodeId) -> Result<ConnectionTypeStream> {
self.msock.conn_type_stream(node_id)
}
pub fn dns_resolver(&self) -> &DnsResolver {
self.msock.dns_resolver()
}
pub fn discovery(&self) -> Option<&dyn Discovery> {
self.msock.discovery()
}
pub async fn network_change(&self) {
self.msock.network_change().await;
}
pub async fn close(self, error_code: VarInt, reason: &[u8]) -> Result<()> {
let Endpoint {
msock,
endpoint,
cancel_token,
..
} = self;
cancel_token.cancel();
tracing::debug!("Closing connections");
endpoint.close(error_code, reason);
endpoint.wait_idle().await;
drop(endpoint);
tracing::debug!("Connections closed");
msock.close().await?;
Ok(())
}
pub(crate) fn cancelled(&self) -> WaitForCancellationFuture<'_> {
self.cancel_token.cancelled()
}
async fn get_mapping_addr_and_maybe_start_discovery(
&self,
node_addr: NodeAddr,
) -> Result<(QuicMappedAddr, Option<DiscoveryTask>)> {
let node_id = node_addr.node_id;
let addr = if self.msock.has_send_address(node_id) {
self.msock.get_mapping_addr(node_id)
} else {
None
};
match addr {
Some(addr) => {
let delay = (!node_addr.info.is_empty()).then_some(DISCOVERY_WAIT_PERIOD);
let discovery = DiscoveryTask::maybe_start_after_delay(self, node_id, delay)
.ok()
.flatten();
Ok((addr, discovery))
}
None => {
let mut discovery = DiscoveryTask::start(self.clone(), node_id)
.context("Discovery service required due to missing addressing information")?;
discovery
.first_arrived()
.await
.context("Discovery service failed")?;
if let Some(addr) = self.msock.get_mapping_addr(node_id) {
Ok((addr, Some(discovery)))
} else {
bail!("Discovery did not find addressing information");
}
}
}
}
#[cfg(test)]
pub(crate) fn magic_sock(&self) -> Handle {
self.msock.clone()
}
#[cfg(test)]
pub(crate) fn endpoint(&self) -> &quinn::Endpoint {
&self.endpoint
}
}
#[derive(Debug)]
#[pin_project]
pub struct Accept<'a> {
#[pin]
#[debug("quinn::Accept")]
inner: quinn::Accept<'a>,
ep: Endpoint,
}
impl<'a> Future for Accept<'a> {
type Output = Option<Incoming>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(inner)) => Poll::Ready(Some(Incoming {
inner,
ep: this.ep.clone(),
})),
}
}
}
#[derive(Debug)]
pub struct Incoming {
inner: quinn::Incoming,
ep: Endpoint,
}
impl Incoming {
pub fn accept(self) -> Result<Connecting, ConnectionError> {
self.inner.accept().map(|conn| Connecting {
inner: conn,
ep: self.ep,
})
}
pub fn accept_with(
self,
server_config: Arc<ServerConfig>,
) -> Result<Connecting, ConnectionError> {
self.inner
.accept_with(server_config)
.map(|conn| Connecting {
inner: conn,
ep: self.ep,
})
}
pub fn refuse(self) {
self.inner.refuse()
}
pub fn retry(self) -> Result<(), RetryError> {
self.inner.retry()
}
pub fn ignore(self) {
self.inner.ignore()
}
pub fn local_ip(&self) -> Option<IpAddr> {
self.inner.local_ip()
}
pub fn remote_address(&self) -> SocketAddr {
self.inner.remote_address()
}
pub fn remote_address_validated(&self) -> bool {
self.inner.remote_address_validated()
}
}
impl IntoFuture for Incoming {
type Output = Result<Connection, ConnectionError>;
type IntoFuture = IncomingFuture;
fn into_future(self) -> Self::IntoFuture {
IncomingFuture {
inner: self.inner.into_future(),
ep: self.ep,
}
}
}
#[derive(Debug)]
#[pin_project]
pub struct IncomingFuture {
#[pin]
inner: quinn::IncomingFuture,
ep: Endpoint,
}
impl Future for IncomingFuture {
type Output = Result<quinn::Connection, ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Ready(Ok(conn)) => {
try_send_rtt_msg(&conn, this.ep);
Poll::Ready(Ok(conn))
}
}
}
}
#[derive(Debug)]
#[pin_project]
pub struct Connecting {
#[pin]
inner: quinn::Connecting,
ep: Endpoint,
}
impl Connecting {
pub fn into_0rtt(self) -> Result<(Connection, ZeroRttAccepted), Self> {
match self.inner.into_0rtt() {
Ok((conn, zrtt_accepted)) => {
try_send_rtt_msg(&conn, &self.ep);
Ok((conn, zrtt_accepted))
}
Err(inner) => Err(Self { inner, ep: self.ep }),
}
}
pub async fn handshake_data(&mut self) -> Result<Box<dyn Any>, ConnectionError> {
self.inner.handshake_data().await
}
pub fn local_ip(&self) -> Option<IpAddr> {
self.inner.local_ip()
}
pub fn remote_address(&self) -> SocketAddr {
self.inner.remote_address()
}
pub async fn alpn(&mut self) -> Result<Vec<u8>> {
let data = self.handshake_data().await?;
match data.downcast::<quinn::crypto::rustls::HandshakeData>() {
Ok(data) => match data.protocol {
Some(protocol) => Ok(protocol),
None => bail!("no ALPN protocol available"),
},
Err(_) => bail!("unknown handshake type"),
}
}
}
impl Future for Connecting {
type Output = Result<Connection, ConnectionError>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Ready(Ok(conn)) => {
try_send_rtt_msg(&conn, this.ep);
Poll::Ready(Ok(conn))
}
}
}
}
pub fn get_remote_node_id(connection: &Connection) -> Result<PublicKey> {
let data = connection.peer_identity();
match data {
None => bail!("no peer certificate found"),
Some(data) => match data.downcast::<Vec<rustls::pki_types::CertificateDer>>() {
Ok(certs) => {
if certs.len() != 1 {
bail!(
"expected a single peer certificate, but {} found",
certs.len()
);
}
let cert = tls::certificate::parse(&certs[0])?;
Ok(cert.peer_id())
}
Err(_) => bail!("invalid peer certificate"),
},
}
}
fn try_send_rtt_msg(conn: &Connection, magic_ep: &Endpoint) {
let Ok(peer_id) = get_remote_node_id(conn) else {
warn!(?conn, "failed to get remote node id");
return;
};
let Ok(conn_type_changes) = magic_ep.conn_type_stream(peer_id) else {
warn!(?conn, "failed to create conn_type_stream");
return;
};
let rtt_msg = RttMessage::NewConnection {
connection: conn.weak_handle(),
conn_type_changes,
node_id: peer_id,
};
if let Err(err) = magic_ep.rtt_actor.msg_tx.try_send(rtt_msg) {
warn!(?conn, "rtt-actor not reachable: {err:#}");
}
}
fn proxy_url_from_env() -> Option<Url> {
if let Some(url) = std::env::var("HTTP_PROXY")
.ok()
.and_then(|s| s.parse::<Url>().ok())
{
if is_cgi() {
warn!("HTTP_PROXY environment variable ignored in CGI");
} else {
return Some(url);
}
}
if let Some(url) = std::env::var("http_proxy")
.ok()
.and_then(|s| s.parse::<Url>().ok())
{
return Some(url);
}
if let Some(url) = std::env::var("HTTPS_PROXY")
.ok()
.and_then(|s| s.parse::<Url>().ok())
{
return Some(url);
}
if let Some(url) = std::env::var("https_proxy")
.ok()
.and_then(|s| s.parse::<Url>().ok())
{
return Some(url);
}
None
}
pub fn default_relay_mode() -> RelayMode {
match force_staging_infra() {
true => RelayMode::Staging,
false => RelayMode::Default,
}
}
fn is_cgi() -> bool {
std::env::var_os("REQUEST_METHOD").is_some()
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use iroh_test::CallOnDrop;
use rand::SeedableRng;
use tracing::{error_span, info, info_span, Instrument};
use super::*;
use crate::test_utils::{run_relay_server, run_relay_server_with};
const TEST_ALPN: &[u8] = b"n0/iroh/test";
#[test]
fn test_addr_info_debug() {
let info = AddrInfo {
relay_url: Some("https://relay.example.com".parse().unwrap()),
direct_addresses: vec![SocketAddr::from(([1, 2, 3, 4], 1234))]
.into_iter()
.collect(),
};
assert_eq!(
format!("{:?}", info),
r#"AddrInfo { relay_url: Some(RelayUrl("https://relay.example.com./")), direct_addresses: {1.2.3.4:1234} }"#
);
}
#[tokio::test]
async fn test_connect_self() {
let _guard = iroh_test::logging::setup();
let ep = Endpoint::builder()
.alpns(vec![TEST_ALPN.to_vec()])
.bind()
.await
.unwrap();
let my_addr = ep.node_addr().await.unwrap();
let res = ep.connect(my_addr.clone(), TEST_ALPN).await;
assert!(res.is_err());
let err = res.err().unwrap();
assert!(err.to_string().starts_with("Connecting to ourself"));
let res = ep.add_node_addr(my_addr);
assert!(res.is_err());
let err = res.err().unwrap();
assert!(err.to_string().starts_with("Adding our own address"));
}
#[tokio::test]
async fn endpoint_connect_close() {
let _guard = iroh_test::logging::setup();
let (relay_map, relay_url, _guard) = run_relay_server().await.unwrap();
let server_secret_key = SecretKey::generate();
let server_peer_id = server_secret_key.public();
let server = {
let relay_map = relay_map.clone();
tokio::spawn(
async move {
let ep = Endpoint::builder()
.secret_key(server_secret_key)
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map))
.insecure_skip_relay_cert_verify(true)
.bind()
.await
.unwrap();
info!("accepting connection");
let incoming = ep.accept().await.unwrap();
let conn = incoming.await.unwrap();
let mut stream = conn.accept_uni().await.unwrap();
let mut buf = [0u8; 5];
stream.read_exact(&mut buf).await.unwrap();
info!("Accepted 1 stream, received {buf:?}. Closing now.");
conn.close(7u8.into(), b"bye");
let res = conn.accept_uni().await;
assert_eq!(res.unwrap_err(), quinn::ConnectionError::LocallyClosed);
let res = stream.read_to_end(10).await;
assert_eq!(
res.unwrap_err(),
quinn::ReadToEndError::Read(quinn::ReadError::ConnectionLost(
quinn::ConnectionError::LocallyClosed
))
);
info!("server test completed");
}
.instrument(info_span!("test-server")),
)
};
let client = tokio::spawn(
async move {
let ep = Endpoint::builder()
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map))
.insecure_skip_relay_cert_verify(true)
.bind()
.await
.unwrap();
info!("client connecting");
let node_addr = NodeAddr::new(server_peer_id).with_relay_url(relay_url);
let conn = ep.connect(node_addr, TEST_ALPN).await.unwrap();
let mut stream = conn.open_uni().await.unwrap();
stream.write_all(b"hello").await.unwrap();
info!("waiting for closed");
let err = conn.closed().await;
let expected_err =
quinn::ConnectionError::ApplicationClosed(quinn::ApplicationClose {
error_code: 7u8.into(),
reason: b"bye".to_vec().into(),
});
assert_eq!(err, expected_err);
info!("opening new - expect it to fail");
let res = conn.open_uni().await;
assert_eq!(res.unwrap_err(), expected_err);
info!("client test completed");
}
.instrument(info_span!("test-client")),
);
let (server, client) = tokio::time::timeout(
Duration::from_secs(30),
futures_lite::future::zip(server, client),
)
.await
.expect("timeout");
server.unwrap();
client.unwrap();
}
#[tokio::test]
async fn restore_peers() {
let _guard = iroh_test::logging::setup();
let secret_key = SecretKey::generate();
async fn new_endpoint(secret_key: SecretKey, nodes: Option<Vec<NodeAddr>>) -> Endpoint {
let mut transport_config = quinn::TransportConfig::default();
transport_config.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
let mut builder = Endpoint::builder()
.secret_key(secret_key.clone())
.transport_config(transport_config);
if let Some(nodes) = nodes {
builder = builder.known_nodes(nodes);
}
builder
.alpns(vec![TEST_ALPN.to_vec()])
.bind()
.await
.unwrap()
}
let peer_id = SecretKey::generate().public();
let direct_addr: SocketAddr =
(std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), 8758u16).into();
let node_addr = NodeAddr::new(peer_id).with_direct_addresses([direct_addr]);
info!("setting up first endpoint");
let endpoint = new_endpoint(secret_key.clone(), None).await;
assert_eq!(endpoint.remote_info_iter().count(), 0);
endpoint.add_node_addr(node_addr.clone()).unwrap();
let node_addrs: Vec<NodeAddr> = endpoint.remote_info_iter().map(Into::into).collect();
assert_eq!(node_addrs.len(), 1);
assert_eq!(node_addrs[0], node_addr);
info!("closing endpoint");
endpoint.close(0u32.into(), b"done").await.unwrap();
info!("restarting endpoint");
let endpoint = new_endpoint(secret_key, Some(node_addrs)).await;
let RemoteInfo { mut addrs, .. } = endpoint.remote_info(peer_id).unwrap();
let conn_addr = addrs.pop().unwrap().addr;
assert_eq!(conn_addr, direct_addr);
}
#[tokio::test]
async fn endpoint_relay_connect_loop() {
let _logging_guard = iroh_test::logging::setup();
let start = Instant::now();
let n_clients = 5;
let n_chunks_per_client = 2;
let chunk_size = 10;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let (relay_map, relay_url, _relay_guard) = run_relay_server().await.unwrap();
let server_secret_key = SecretKey::generate_with_rng(&mut rng);
let server_node_id = server_secret_key.public();
let server = {
let relay_map = relay_map.clone();
tokio::spawn(
async move {
let ep = Endpoint::builder()
.insecure_skip_relay_cert_verify(true)
.secret_key(server_secret_key)
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map))
.bind()
.await
.unwrap();
let eps = ep.bound_sockets();
info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "server bound");
for i in 0..n_clients {
let now = Instant::now();
println!("[server] round {}", i + 1);
let incoming = ep.accept().await.unwrap();
let conn = incoming.await.unwrap();
let peer_id = get_remote_node_id(&conn).unwrap();
info!(%i, peer = %peer_id.fmt_short(), "accepted connection");
let (mut send, mut recv) = conn.accept_bi().await.unwrap();
let mut buf = vec![0u8; chunk_size];
for _i in 0..n_chunks_per_client {
recv.read_exact(&mut buf).await.unwrap();
send.write_all(&buf).await.unwrap();
}
send.finish().unwrap();
send.stopped().await.unwrap();
recv.read_to_end(0).await.unwrap();
info!(%i, peer = %peer_id.fmt_short(), "finished");
println!("[server] round {} done in {:?}", i + 1, now.elapsed());
}
}
.instrument(error_span!("server")),
)
};
let abort_handle = server.abort_handle();
let _server_guard = CallOnDrop::new(move || {
abort_handle.abort();
});
for i in 0..n_clients {
let now = Instant::now();
println!("[client] round {}", i + 1);
let relay_map = relay_map.clone();
let client_secret_key = SecretKey::generate_with_rng(&mut rng);
let relay_url = relay_url.clone();
async {
info!("client binding");
let ep = Endpoint::builder()
.alpns(vec![TEST_ALPN.to_vec()])
.insecure_skip_relay_cert_verify(true)
.relay_mode(RelayMode::Custom(relay_map))
.secret_key(client_secret_key)
.bind()
.await
.unwrap();
let eps = ep.bound_sockets();
info!(me = %ep.node_id().fmt_short(), ipv4=%eps.0, ipv6=?eps.1, "client bound");
let node_addr = NodeAddr::new(server_node_id).with_relay_url(relay_url);
info!(to = ?node_addr, "client connecting");
let conn = ep.connect(node_addr, TEST_ALPN).await.unwrap();
info!("client connected");
let (mut send, mut recv) = conn.open_bi().await.unwrap();
for i in 0..n_chunks_per_client {
let mut buf = vec![i; chunk_size];
send.write_all(&buf).await.unwrap();
recv.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, vec![i; chunk_size]);
}
send.finish().unwrap();
send.stopped().await.unwrap();
recv.read_to_end(0).await.unwrap();
info!("client finished");
ep.close(0u32.into(), &[]).await.unwrap();
info!("client closed");
}
.instrument(error_span!("client", %i))
.await;
println!("[client] round {} done in {:?}", i + 1, now.elapsed());
}
server.await.unwrap();
if start.elapsed() > Duration::from_secs(15) {
panic!("Test too slow, something went wrong");
}
}
#[tokio::test]
async fn endpoint_bidi_send_recv() {
let _logging_guard = iroh_test::logging::setup();
let ep1 = Endpoint::builder()
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Disabled)
.bind()
.await
.unwrap();
let ep2 = Endpoint::builder()
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Disabled)
.bind()
.await
.unwrap();
let ep1_nodeaddr = ep1.node_addr().await.unwrap();
let ep2_nodeaddr = ep2.node_addr().await.unwrap();
ep1.add_node_addr(ep2_nodeaddr.clone()).unwrap();
ep2.add_node_addr(ep1_nodeaddr.clone()).unwrap();
let ep1_nodeid = ep1.node_id();
let ep2_nodeid = ep2.node_id();
eprintln!("node id 1 {ep1_nodeid}");
eprintln!("node id 2 {ep2_nodeid}");
async fn connect_hello(ep: Endpoint, dst: NodeAddr) {
let conn = ep.connect(dst, TEST_ALPN).await.unwrap();
let (mut send, mut recv) = conn.open_bi().await.unwrap();
info!("sending hello");
send.write_all(b"hello").await.unwrap();
send.finish().unwrap();
info!("receiving world");
let m = recv.read_to_end(100).await.unwrap();
assert_eq!(m, b"world");
conn.close(1u8.into(), b"done");
}
async fn accept_world(ep: Endpoint, src: NodeId) {
let incoming = ep.accept().await.unwrap();
let mut iconn = incoming.accept().unwrap();
let alpn = iconn.alpn().await.unwrap();
let conn = iconn.await.unwrap();
let node_id = get_remote_node_id(&conn).unwrap();
assert_eq!(node_id, src);
assert_eq!(alpn, TEST_ALPN);
let (mut send, mut recv) = conn.accept_bi().await.unwrap();
info!("receiving hello");
let m = recv.read_to_end(100).await.unwrap();
assert_eq!(m, b"hello");
info!("sending hello");
send.write_all(b"world").await.unwrap();
send.finish().unwrap();
match conn.closed().await {
ConnectionError::ApplicationClosed(closed) => {
assert_eq!(closed.error_code, 1u8.into());
}
_ => panic!("wrong close error"),
}
}
let p1_accept = tokio::spawn(accept_world(ep1.clone(), ep2_nodeid).instrument(info_span!(
"p1_accept",
ep1 = %ep1.node_id().fmt_short(),
dst = %ep2_nodeid.fmt_short(),
)));
let p2_accept = tokio::spawn(accept_world(ep2.clone(), ep1_nodeid).instrument(info_span!(
"p2_accept",
ep2 = %ep2.node_id().fmt_short(),
dst = %ep1_nodeid.fmt_short(),
)));
let p1_connect = tokio::spawn(connect_hello(ep1.clone(), ep2_nodeaddr).instrument(
info_span!(
"p1_connect",
ep1 = %ep1.node_id().fmt_short(),
dst = %ep2_nodeid.fmt_short(),
),
));
let p2_connect = tokio::spawn(connect_hello(ep2.clone(), ep1_nodeaddr).instrument(
info_span!(
"p2_connect",
ep2 = %ep2.node_id().fmt_short(),
dst = %ep1_nodeid.fmt_short(),
),
));
p1_accept.await.unwrap();
p2_accept.await.unwrap();
p1_connect.await.unwrap();
p2_connect.await.unwrap();
}
#[tokio::test]
async fn endpoint_conn_type_stream() {
const TIMEOUT: Duration = std::time::Duration::from_secs(15);
let _logging_guard = iroh_test::logging::setup();
let (relay_map, _relay_url, _relay_guard) = run_relay_server().await.unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let ep1_secret_key = SecretKey::generate_with_rng(&mut rng);
let ep2_secret_key = SecretKey::generate_with_rng(&mut rng);
let ep1 = Endpoint::builder()
.secret_key(ep1_secret_key)
.insecure_skip_relay_cert_verify(true)
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map.clone()))
.bind()
.await
.unwrap();
let ep2 = Endpoint::builder()
.secret_key(ep2_secret_key)
.insecure_skip_relay_cert_verify(true)
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map))
.bind()
.await
.unwrap();
async fn handle_direct_conn(ep: &Endpoint, node_id: PublicKey) -> Result<()> {
let mut stream = ep.conn_type_stream(node_id)?;
let src = ep.node_id().fmt_short();
let dst = node_id.fmt_short();
while let Some(conn_type) = stream.next().await {
tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type);
if matches!(conn_type, ConnectionType::Direct(_)) {
return Ok(());
}
}
anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`");
}
async fn accept(ep: &Endpoint) -> NodeId {
let incoming = ep.accept().await.unwrap();
let conn = incoming.await.unwrap();
let node_id = get_remote_node_id(&conn).unwrap();
tracing::info!(node_id=%node_id.fmt_short(), "accepted connection");
node_id
}
let ep1_nodeid = ep1.node_id();
let ep2_nodeid = ep2.node_id();
let ep1_nodeaddr = ep1.node_addr().await.unwrap();
tracing::info!(
"node id 1 {ep1_nodeid}, relay URL {:?}",
ep1_nodeaddr.relay_url()
);
tracing::info!("node id 2 {ep2_nodeid}");
let ep1_side = async move {
accept(&ep1).await;
handle_direct_conn(&ep1, ep2_nodeid).await
};
let ep2_side = async move {
ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap();
handle_direct_conn(&ep2, ep1_nodeid).await
};
let res_ep1 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep1_side));
let ep1_abort_handle = res_ep1.abort_handle();
let _ep1_guard = CallOnDrop::new(move || {
ep1_abort_handle.abort();
});
let res_ep2 = tokio::spawn(tokio::time::timeout(TIMEOUT, ep2_side));
let ep2_abort_handle = res_ep2.abort_handle();
let _ep2_guard = CallOnDrop::new(move || {
ep2_abort_handle.abort();
});
let (r1, r2) = tokio::try_join!(res_ep1, res_ep2).unwrap();
r1.expect("ep1 timeout").unwrap();
r2.expect("ep2 timeout").unwrap();
}
#[tokio::test]
async fn test_direct_addresses_no_stun_relay() {
let _guard = iroh_test::logging::setup();
let (relay_map, _, _guard) = run_relay_server_with(None).await.unwrap();
let ep = Endpoint::builder()
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map))
.insecure_skip_relay_cert_verify(true)
.bind()
.await
.unwrap();
tokio::time::timeout(Duration::from_secs(10), ep.direct_addresses().next())
.await
.unwrap()
.unwrap();
}
}