use std::{
fmt,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use async_wsocket::{ConnectionMode, Message, WebSocket};
use futures::{Sink, SinkExt, TryStreamExt, stream::StreamExt};
use nostr::{Url, util::BoxedFuture};
use nostr_relay_pool::transport::{
error::TransportError,
websocket::{WebSocketSink, WebSocketStream, WebSocketTransport},
};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::protocol::Message as TungsteniteMessage;
const HAPPY_EYEBALLS_DELAY: Duration = Duration::from_millis(250);
#[derive(Debug, Clone, Copy, Default)]
pub struct HappyEyeballsTransport;
impl WebSocketTransport for HappyEyeballsTransport {
fn support_ping(&self) -> bool {
true
}
fn connect<'a>(
&'a self,
url: &'a Url,
mode: &'a ConnectionMode,
timeout: Duration,
) -> BoxedFuture<'a, Result<(WebSocketSink, WebSocketStream), TransportError>> {
Box::pin(async move {
match mode {
ConnectionMode::Direct => connect_happy_eyeballs(url, timeout).await,
_ => connect_default(url, mode, timeout).await,
}
})
}
}
async fn connect_default(
url: &Url,
mode: &ConnectionMode,
timeout: Duration,
) -> Result<(WebSocketSink, WebSocketStream), TransportError> {
let socket: WebSocket = WebSocket::connect(url, mode, timeout)
.await
.map_err(TransportError::backend)?;
let (tx, rx) = socket.split();
let sink: WebSocketSink = Box::new(DefaultSinkAdapter(tx)) as WebSocketSink;
let stream: WebSocketStream = Box::pin(rx.map_err(TransportError::backend)) as WebSocketStream;
Ok((sink, stream))
}
async fn connect_happy_eyeballs(
url: &Url,
timeout: Duration,
) -> Result<(WebSocketSink, WebSocketStream), TransportError> {
tokio::time::timeout(timeout, connect_happy_eyeballs_inner(url))
.await
.map_err(|_| {
TransportError::backend(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"connection timed out",
))
})?
}
async fn connect_happy_eyeballs_inner(
url: &Url,
) -> Result<(WebSocketSink, WebSocketStream), TransportError> {
let host = url
.host_str()
.ok_or_else(|| TransportError::backend(IoError("missing host in URL")))?;
let default_port = match url.scheme() {
"wss" => 443,
"ws" => 80,
_ => 80,
};
let port = url.port().unwrap_or(default_port);
let addr_str = format!("{host}:{port}");
let addrs: Vec<SocketAddr> = tokio::net::lookup_host(&addr_str)
.await
.map_err(TransportError::backend)?
.collect();
if addrs.is_empty() {
return Err(TransportError::backend(IoError(
"DNS resolution returned no addresses",
)));
}
let mut ipv6_addrs: Vec<SocketAddr> = Vec::new();
let mut ipv4_addrs: Vec<SocketAddr> = Vec::new();
for addr in addrs {
if addr.is_ipv6() {
ipv6_addrs.push(addr);
} else {
ipv4_addrs.push(addr);
}
}
let tcp_stream = happy_eyeballs_tcp(&ipv6_addrs, &ipv4_addrs).await?;
let request_uri = url.as_str().to_string();
let (ws_stream, _response) = tokio_tungstenite::client_async_tls(&request_uri, tcp_stream)
.await
.map_err(TransportError::backend)?;
let (native_sink, native_stream) = ws_stream.split();
let sink: WebSocketSink = Box::new(NativeSinkAdapter(native_sink)) as WebSocketSink;
let stream: WebSocketStream = Box::pin(native_stream.map(|result| {
result
.map(tungstenite_to_async_wsocket)
.map_err(TransportError::backend)
})) as WebSocketStream;
Ok((sink, stream))
}
async fn happy_eyeballs_tcp(
ipv6_addrs: &[SocketAddr],
ipv4_addrs: &[SocketAddr],
) -> Result<TcpStream, TransportError> {
match (ipv6_addrs.is_empty(), ipv4_addrs.is_empty()) {
(true, false) => try_connect_addrs(ipv4_addrs).await,
(false, true) => try_connect_addrs(ipv6_addrs).await,
(false, false) => {
let ipv6_fut = try_connect_addrs(ipv6_addrs);
tokio::pin!(ipv6_fut);
let delay = tokio::time::sleep(HAPPY_EYEBALLS_DELAY);
tokio::pin!(delay);
tokio::select! {
biased;
result = &mut ipv6_fut => {
if let Ok(stream) = result {
return Ok(stream);
}
try_connect_addrs(ipv4_addrs).await
}
_ = &mut delay => {
let ipv4_fut = try_connect_addrs(ipv4_addrs);
tokio::pin!(ipv4_fut);
tokio::select! {
biased;
result = &mut ipv6_fut => {
match result {
Ok(stream) => Ok(stream),
Err(_) => ipv4_fut.await,
}
}
result = &mut ipv4_fut => {
match result {
Ok(stream) => Ok(stream),
Err(_) => ipv6_fut.await,
}
}
}
}
}
}
(true, true) => Err(TransportError::backend(IoError(
"no addresses to connect to",
))),
}
}
async fn try_connect_addrs(addrs: &[SocketAddr]) -> Result<TcpStream, TransportError> {
let mut last_err = None;
for addr in addrs {
match TcpStream::connect(addr).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
}
Err(TransportError::backend(last_err.unwrap_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
"no addresses to connect to",
)
})))
}
fn tungstenite_to_async_wsocket(msg: TungsteniteMessage) -> Message {
match msg {
TungsteniteMessage::Text(text) => Message::Text(text.to_string()),
TungsteniteMessage::Binary(data) => Message::Binary(data.to_vec()),
TungsteniteMessage::Ping(data) => Message::Ping(data.to_vec()),
TungsteniteMessage::Pong(data) => Message::Pong(data.to_vec()),
TungsteniteMessage::Close(frame) => {
Message::Close(frame.map(|f| async_wsocket::message::CloseFrame {
code: u16::from(f.code),
reason: f.reason.to_string(),
}))
}
TungsteniteMessage::Frame(_) => unreachable!(),
}
}
struct NativeSinkAdapter<S>(S);
impl<S> Sink<Message> for NativeSinkAdapter<S>
where
S: Sink<TungsteniteMessage, Error = tokio_tungstenite::tungstenite::Error> + Unpin,
{
type Error = TransportError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_ready(cx)
.map_err(TransportError::backend)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
let native_msg: TungsteniteMessage = item.into();
Pin::new(&mut self.0)
.start_send(native_msg)
.map_err(TransportError::backend)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_flush(cx)
.map_err(TransportError::backend)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_close(cx)
.map_err(TransportError::backend)
}
}
struct DefaultSinkAdapter(futures::stream::SplitSink<WebSocket, Message>);
impl Sink<Message> for DefaultSinkAdapter {
type Error = TransportError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_ready_unpin(cx)
.map_err(TransportError::backend)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.0)
.start_send_unpin(item)
.map_err(TransportError::backend)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_flush_unpin(cx)
.map_err(TransportError::backend)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_close_unpin(cx)
.map_err(TransportError::backend)
}
}
#[derive(Debug)]
struct IoError(&'static str);
impl fmt::Display for IoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.0)
}
}
impl std::error::Error for IoError {}