use std::{
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{self, Poll},
};
use anyhow::{anyhow, bail, Result};
use conn::Conn;
use iroh_base::{RelayUrl, SecretKey};
use n0_future::{
split::{split, SplitSink, SplitStream},
Sink, Stream,
};
#[cfg(any(test, feature = "test-utils"))]
use tracing::warn;
use tracing::{debug, event, trace, Level};
use url::Url;
pub use self::conn::{ConnSendError, ReceivedMessage, SendMessage};
#[cfg(not(wasm_browser))]
use crate::dns::DnsResolver;
use crate::{
http::{Protocol, RELAY_PATH},
KeyCache,
};
pub(crate) mod conn;
#[cfg(not(wasm_browser))]
pub(crate) mod streams;
#[cfg(not(wasm_browser))]
mod tls;
#[cfg(not(wasm_browser))]
mod util;
#[derive(derive_more::Debug, Clone)]
pub struct ClientBuilder {
#[debug("address family selector callback")]
address_family_selector: Option<Arc<dyn Fn() -> bool + Send + Sync>>,
is_prober: bool,
url: RelayUrl,
protocol: Protocol,
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_cert_verify: bool,
proxy_url: Option<Url>,
secret_key: SecretKey,
#[cfg(not(wasm_browser))]
dns_resolver: DnsResolver,
key_cache: KeyCache,
}
impl ClientBuilder {
pub fn new(
url: impl Into<RelayUrl>,
secret_key: SecretKey,
#[cfg(not(wasm_browser))] dns_resolver: DnsResolver,
) -> Self {
ClientBuilder {
address_family_selector: None,
is_prober: false,
url: url.into(),
protocol: Protocol::default(),
#[cfg(any(test, feature = "test-utils"))]
insecure_skip_cert_verify: false,
proxy_url: None,
secret_key,
#[cfg(not(wasm_browser))]
dns_resolver,
key_cache: KeyCache::new(128),
}
}
pub fn protocol(mut self, protocol: Protocol) -> Self {
self.protocol = protocol;
self
}
pub fn address_family_selector<S>(mut self, selector: S) -> Self
where
S: Fn() -> bool + Send + Sync + 'static,
{
self.address_family_selector = Some(Arc::new(selector));
self
}
pub fn is_prober(mut self, is: bool) -> Self {
self.is_prober = is;
self
}
#[cfg(any(test, feature = "test-utils"))]
pub fn insecure_skip_cert_verify(mut self, skip: bool) -> Self {
self.insecure_skip_cert_verify = skip;
self
}
pub fn proxy_url(mut self, url: Url) -> Self {
self.proxy_url.replace(url);
self
}
pub fn key_cache_capacity(mut self, capacity: usize) -> Self {
self.key_cache = KeyCache::new(capacity);
self
}
pub async fn connect(&self) -> Result<Client> {
let (conn, local_addr) = match self.protocol {
#[cfg(wasm_browser)]
Protocol::Websocket => {
let conn = self.connect_ws().await?;
let local_addr = None;
(conn, local_addr)
}
#[cfg(not(wasm_browser))]
Protocol::Websocket => {
let (conn, local_addr) = self.connect_ws().await?;
(conn, Some(local_addr))
}
#[cfg(not(wasm_browser))]
Protocol::Relay => {
let (conn, local_addr) = self.connect_relay().await?;
(conn, Some(local_addr))
}
#[cfg(wasm_browser)]
Protocol::Relay => {
bail!("Can only connect to relay using websockets in browsers.");
}
};
event!(
target: "events.net.relay.connected",
Level::DEBUG,
url = %self.url,
protocol = ?self.protocol,
);
trace!("connect done");
Ok(Client { conn, local_addr })
}
#[cfg(wasm_browser)]
async fn connect_ws(&self) -> Result<Conn> {
let mut dial_url = (*self.url).clone();
dial_url.set_path(RELAY_PATH);
dial_url
.set_scheme(match self.url.scheme() {
"http" => "ws",
"ws" => "ws",
_ => "wss",
})
.map_err(|()| anyhow!("Invalid URL"))?;
debug!(%dial_url, "Dialing relay by websocket");
let (_, ws_stream) = ws_stream_wasm::WsMeta::connect(dial_url.as_str(), None).await?;
let conn =
Conn::new_ws_browser(ws_stream, self.key_cache.clone(), &self.secret_key).await?;
Ok(conn)
}
}
#[derive(Debug)]
pub struct Client {
conn: Conn,
local_addr: Option<SocketAddr>,
}
impl Client {
pub fn split(self) -> (ClientStream, ClientSink) {
let (sink, stream) = split(self.conn);
(
ClientStream {
stream,
local_addr: self.local_addr,
},
ClientSink { sink },
)
}
}
impl Stream for Client {
type Item = Result<ReceivedMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.conn).poll_next(cx)
}
}
impl Sink<SendMessage> for Client {
type Error = ConnSendError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
<Conn as Sink<SendMessage>>::poll_ready(Pin::new(&mut self.conn), cx)
}
fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> {
Pin::new(&mut self.conn).start_send(item)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
<Conn as Sink<SendMessage>>::poll_flush(Pin::new(&mut self.conn), cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
<Conn as Sink<SendMessage>>::poll_close(Pin::new(&mut self.conn), cx)
}
}
#[derive(Debug)]
pub struct ClientSink {
sink: SplitSink<Conn, SendMessage>,
}
impl Sink<SendMessage> for ClientSink {
type Error = ConnSendError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> {
Pin::new(&mut self.sink).start_send(item)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_flush(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_close(cx)
}
}
#[derive(Debug)]
pub struct ClientStream {
stream: SplitStream<Conn>,
local_addr: Option<SocketAddr>,
}
impl ClientStream {
pub fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr
}
}
impl Stream for ClientStream {
type Item = Result<ReceivedMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
#[cfg(any(test, feature = "test-utils"))]
pub fn make_dangerous_client_config() -> rustls::ClientConfig {
warn!(
"Insecure config: SSL certificates from relay servers will be trusted without verification"
);
rustls::client::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_protocol_versions(&[&rustls::version::TLS13])
.expect("protocols supported by ring")
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerifier))
.with_no_client_auth()
}
#[cfg(any(test, feature = "test-utils"))]
#[derive(Debug)]
struct NoCertVerifier;
#[cfg(any(test, feature = "test-utils"))]
impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer,
_intermediates: &[rustls::pki_types::CertificateDer],
_server_name: &rustls::pki_types::ServerName,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}