mod framed;
use std::{
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use airio_core::{DialOpts, ListenerId, TransportError, TransportEvent, utils::RwStreamSink};
pub use async_tungstenite::tungstenite::protocol::WebSocketConfig;
use async_tungstenite::{
accept_async_with_config, client_async_with_config,
tungstenite::{self, http::Uri},
};
pub use framed::BytesWebSocketStream;
use futures::{FutureExt, TryFutureExt};
pub struct Transport {
config: Option<WebSocketConfig>,
inner: airio_tcp::Transport,
}
impl Transport {
pub fn new(config: Option<WebSocketConfig>, inner: airio_tcp::Transport) -> Self {
Self { config, inner }
}
}
impl airio_core::Transport for Transport {
type Output = RwStreamSink<BytesWebSocketStream<airio_tcp::TcpStream>>;
type Error = tungstenite::Error;
type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
fn listen_on(
&mut self,
id: ListenerId,
addr: SocketAddr,
) -> Result<(), TransportError<Self::Error>> {
self.inner
.listen_on(id, addr)
.map_err(|e| e.map(|r| r.into()))
}
fn remove_listener(&mut self, id: ListenerId) -> bool {
self.inner.remove_listener(id)
}
fn dial(
&mut self,
addr: SocketAddr,
opts: DialOpts,
) -> Result<Self::Dial, TransportError<Self::Error>> {
let dial_fut = self
.inner
.dial(addr, opts)
.map_err(|e| e.map(|r| r.into()))?;
let config = self.config.clone();
let request = Uri::builder()
.scheme("ws")
.authority(addr.to_string())
.path_and_query("/")
.build()
.map_err(tungstenite::Error::from)
.map_err(TransportError::Other)?;
Ok(dial_fut
.map_err(tungstenite::Error::from)
.and_then(move |stream| client_async_with_config(request, stream, config))
.map_ok(|(s, _)| BytesWebSocketStream::new(s))
.map_ok(RwStreamSink::new)
.boxed())
}
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
let config = self.config.clone();
Pin::new(&mut self.inner).poll(cx).map(|r| {
r.map_upgrade(|u| {
u.map_err(tungstenite::Error::from)
.and_then(move |s| accept_async_with_config(s, config))
.map_ok(BytesWebSocketStream::new)
.map_ok(RwStreamSink::new)
.boxed()
})
.map_err(tungstenite::Error::from)
})
}
}
fn into_io_error(error: tungstenite::Error) -> io::Error {
match error {
tungstenite::Error::Io(e) => e,
e => io::Error::new(io::ErrorKind::Other, e),
}
}