use crate::websocket::{
tls::TlsConfig, wss_info::WssInfo, BaseStream, SocketMap, TlsConnectResult, TlsMidHandshake,
TlsSrvMidHandshake, TlsStream, WsConnectResult, WsMidHandshake, WsSrvAcceptResult,
WsSrvMidHandshake, WsStream, WssConnectResult, WssMidHandshake, WssSrvAcceptResult,
WssSrvMidHandshake, WssStream,
};
use holochain_tracing::prelude::*;
use holochain_tracing_macros::newrelic_autotrace;
use lib3h::transport::error::{TransportError, TransportResult};
use lib3h_protocol::{uri::Lib3hUri, DidWork};
use lib3h_zombie_actor::GhostMutex;
use std::{
io::{Read, Write},
sync::Arc,
};
use lazy_static::lazy_static;
use url::Url;
use url2::prelude::*;
pub const DEFAULT_HEARTBEAT_MS: usize = 2000;
pub const DEFAULT_HEARTBEAT_WAIT_MS: usize = 5000;
#[derive(Debug)]
pub enum WebsocketStreamState<T: Read + Write + std::fmt::Debug> {
None,
Connecting(BaseStream<T>),
#[allow(dead_code)]
ConnectingSrv(BaseStream<T>),
TlsMidHandshake(TlsMidHandshake<T>),
TlsSrvMidHandshake(TlsSrvMidHandshake<T>),
TlsReady(TlsStream<T>),
TlsSrvReady(TlsStream<T>),
WsMidHandshake(WsMidHandshake<T>),
WsSrvMidHandshake(WsSrvMidHandshake<T>),
WssMidHandshake(WssMidHandshake<T>),
WssSrvMidHandshake(WssSrvMidHandshake<T>),
ReadyWs(Box<WsStream<T>>),
ReadyWss(Box<WssStream<T>>),
}
#[derive(PartialEq)]
pub enum ConnectionStatus {
None,
Initializing,
Ready,
}
#[derive(Debug, PartialEq, Clone)]
pub enum StreamEvent {
ErrorOccured(Url, TransportError),
ConnectResult(Url, String),
IncomingConnectionEstablished(Url),
ReceivedData(Url, Vec<u8>),
ConnectionClosed(Url),
}
pub type StreamFactory<T> = fn(uri: &str) -> TransportResult<T>;
lazy_static! {
static ref TRANSPORT_COUNT: Arc<GhostMutex<u64>> = Arc::new(GhostMutex::new(0));
}
pub type Acceptor<T> = Box<dyn FnMut() -> TransportResult<WssInfo<T>> + 'static + Send + Sync>;
pub type Bind<T> =
Box<dyn FnMut(&Url) -> TransportResult<(Url2, Acceptor<T>)> + 'static + Send + Sync>;
pub struct StreamManager<T: Read + Write + std::fmt::Debug> {
tls_config: TlsConfig,
stream_factory: StreamFactory<T>,
stream_sockets: SocketMap<T>,
event_queue: Vec<StreamEvent>,
bind: Bind<T>,
acceptor: TransportResult<Acceptor<T>>,
}
#[newrelic_autotrace(SIM2H)]
impl<T: Read + Write + std::fmt::Debug> StreamManager<T> {
pub fn new(stream_factory: StreamFactory<T>, bind: Bind<T>, tls_config: TlsConfig) -> Self {
StreamManager {
tls_config,
stream_factory,
stream_sockets: std::collections::HashMap::new(),
event_queue: Vec::new(),
bind,
acceptor: Err(TransportError::new("acceptor not initialized".into())),
}
}
pub fn connect(&mut self, uri: &Url) -> TransportResult<()> {
let host_port = format!(
"{}:{}",
uri.host_str()
.ok_or_else(|| TransportError::new("bad connect host".into()))?,
uri.port()
.ok_or_else(|| TransportError::new("bad connect port".into()))?,
);
let socket = (self.stream_factory)(&host_port)?;
let info = WssInfo::client(uri.clone(), socket);
self.stream_sockets.insert(uri.clone().into(), info);
Ok(())
}
#[allow(dead_code)]
pub fn close(&mut self, uri: &Url) -> TransportResult<()> {
if let Some(mut info) = self.stream_sockets.remove(uri) {
info.close()?;
}
Ok(())
}
#[allow(dead_code)]
pub fn close_all(&mut self) -> TransportResult<()> {
let mut errors: Vec<TransportError> = Vec::new();
while !self.stream_sockets.is_empty() {
let key = self
.stream_sockets
.keys()
.next()
.expect("should not be None")
.clone();
if let Some(mut info) = self.stream_sockets.remove(&key) {
if let Err(e) = info.close() {
errors.push(e);
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.into())
}
}
pub fn process(&mut self) -> TransportResult<(DidWork, Vec<StreamEvent>)> {
let mut did_work = false;
if self.priv_process_stream_sockets()? {
did_work = true
}
Ok((did_work, self.event_queue.drain(..).collect()))
}
pub fn send(&mut self, url: &Url, payload: &[u8]) -> TransportResult<()> {
let mut info = self
.stream_sockets
.get_mut(url)
.ok_or_else(|| format!("No socket found for URL: {}", url.to_string()))?;
let mut ws_stream =
std::mem::replace(&mut info.stateful_socket, WebsocketStreamState::None);
let mut send_result = match &mut ws_stream {
WebsocketStreamState::ReadyWs(socket) => {
socket.write_message(tungstenite::Message::Binary(payload.to_vec()))
}
WebsocketStreamState::ReadyWss(socket) => {
socket.write_message(tungstenite::Message::Binary(payload.to_vec()))
}
_ => Err(tungstenite::Error::Io(std::io::Error::from(
std::io::ErrorKind::NotConnected,
))),
};
if let Err(tungstenite::Error::Io(ref e)) = send_result {
if let std::io::ErrorKind::WouldBlock = e.kind() {
send_result = Ok(())
}
}
info.stateful_socket = ws_stream;
send_result.map_err(|error_string| {
TransportError::from(error_string)
})
}
pub fn bind(&mut self, url: &Url) -> TransportResult<Url> {
let (url, acceptor) = (self.bind)(&url.clone())?;
self.acceptor = Ok(acceptor);
Ok(url.into())
}
pub fn connection_status(&self, url: &Url) -> ConnectionStatus {
self.stream_sockets
.get(url)
.map(|info| match info.stateful_socket {
WebsocketStreamState::ReadyWs(_) | WebsocketStreamState::ReadyWss(_) => {
ConnectionStatus::Ready
}
_ => ConnectionStatus::Initializing,
})
.unwrap_or(ConnectionStatus::None)
}
fn priv_process_accept(&mut self) -> DidWork {
match &mut self.acceptor {
Err(_err) => {
false
}
Ok(acceptor) => (acceptor)()
.map(move |wss_info| {
let _insert_result = self
.stream_sockets
.insert(wss_info.url.clone().into(), wss_info);
true
})
.unwrap_or_else(|err| {
if !err.is_ignorable() {
panic!("Error when attempting to accept connections: {:?}", err);
}
false
}),
}
}
fn priv_process_stream_sockets(&mut self) -> TransportResult<DidWork> {
let mut did_work = false;
did_work |= self.priv_process_accept();
let sockets: Vec<(Lib3hUri, WssInfo<T>)> = self.stream_sockets.drain().collect();
for (id, mut info) in sockets {
if let Err(e) = self.priv_process_socket(&mut did_work, &mut info) {
self.event_queue
.push(StreamEvent::ErrorOccured(info.url.clone(), e));
}
if let WebsocketStreamState::None = info.stateful_socket {
self.event_queue
.push(StreamEvent::ConnectionClosed(info.url));
continue;
}
if info.last_msg.elapsed().as_millis() as usize > DEFAULT_HEARTBEAT_MS {
if let WebsocketStreamState::ReadyWss(socket) = &mut info.stateful_socket {
if let Err(e) = socket.write_message(tungstenite::Message::Ping(vec![])) {
error!("Transport error trying to send ping over stream: {:?}. Dropping stream...", e);
continue;
}
}
if let WebsocketStreamState::ReadyWs(socket) = &mut info.stateful_socket {
if let Err(e) = socket.write_message(tungstenite::Message::Ping(vec![])) {
error!("Transport error trying to send ping over stream: {:?}. Dropping stream...", e);
continue;
}
}
} else if info.last_msg.elapsed().as_millis() as usize > DEFAULT_HEARTBEAT_WAIT_MS {
self.event_queue
.push(StreamEvent::ConnectionClosed(info.url));
info.stateful_socket = WebsocketStreamState::None;
continue;
}
self.stream_sockets.insert(id, info);
}
Ok(did_work)
}
fn priv_process_socket(
&mut self,
did_work: &mut bool,
info: &mut WssInfo<T>,
) -> TransportResult<()> {
let socket = std::mem::replace(&mut info.stateful_socket, WebsocketStreamState::None);
trace!("transport_wss: socket={:?}", socket);
match socket {
WebsocketStreamState::None => {
Ok(())
}
WebsocketStreamState::Connecting(socket) => {
info.last_msg = std::time::Instant::now();
*did_work = true;
match &self.tls_config {
TlsConfig::Unencrypted => {
info.stateful_socket = self.priv_ws_handshake(
&info.url,
&info.request_id,
tungstenite::client(info.url.clone(), socket),
)?;
}
_ => {
let connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.build()
.expect("failed to build TlsConnector");
info.stateful_socket =
self.priv_tls_handshake(connector.connect(info.url.as_str(), socket))?;
}
}
Ok(())
}
#[allow(clippy::match_ref_pats)]
WebsocketStreamState::ConnectingSrv(socket) => {
info.last_msg = std::time::Instant::now();
*did_work = true;
if let &TlsConfig::Unencrypted = &self.tls_config {
info.stateful_socket =
self.priv_ws_srv_handshake(&info.url, tungstenite::accept(socket))?;
return Ok(());
}
let ident = self.tls_config.get_identity()?;
let acceptor = native_tls::TlsAcceptor::builder(ident)
.build()
.expect("failed to build TlsAcceptor");
info.stateful_socket = self.priv_tls_srv_handshake(acceptor.accept(socket))?;
Ok(())
}
WebsocketStreamState::TlsMidHandshake(socket) => {
info.stateful_socket = self.priv_tls_handshake(socket.handshake())?;
Ok(())
}
WebsocketStreamState::TlsSrvMidHandshake(socket) => {
info.stateful_socket = self.priv_tls_srv_handshake(socket.handshake())?;
Ok(())
}
WebsocketStreamState::TlsReady(socket) => {
info.last_msg = std::time::Instant::now();
*did_work = true;
info.stateful_socket = self.priv_wss_handshake(
&info.url,
&info.request_id,
tungstenite::client(info.url.clone(), socket),
)?;
Ok(())
}
WebsocketStreamState::TlsSrvReady(socket) => {
info.last_msg = std::time::Instant::now();
*did_work = true;
info.stateful_socket =
self.priv_wss_srv_handshake(&info.url, tungstenite::accept(socket))?;
Ok(())
}
WebsocketStreamState::WsMidHandshake(socket) => {
info.stateful_socket =
self.priv_ws_handshake(&info.url, &info.request_id, socket.handshake())?;
Ok(())
}
WebsocketStreamState::WsSrvMidHandshake(socket) => {
info.stateful_socket = self.priv_ws_srv_handshake(&info.url, socket.handshake())?;
Ok(())
}
WebsocketStreamState::WssMidHandshake(socket) => {
info.stateful_socket =
self.priv_wss_handshake(&info.url, &info.request_id, socket.handshake())?;
Ok(())
}
WebsocketStreamState::WssSrvMidHandshake(socket) => {
info.stateful_socket =
self.priv_wss_srv_handshake(&info.url, socket.handshake())?;
Ok(())
}
WebsocketStreamState::ReadyWs(mut socket) => {
match socket.read_message() {
Err(tungstenite::error::Error::Io(e)) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
info.stateful_socket = WebsocketStreamState::ReadyWs(socket);
return Ok(());
}
Err(e.into())
}
Err(tungstenite::error::Error::ConnectionClosed) => {
error!("Connection unexpectedly closed");
Ok(())
}
Err(e) => Err(e.into()),
Ok(msg) => {
info.last_msg = std::time::Instant::now();
*did_work = true;
let qmsg = match msg {
tungstenite::Message::Text(s) => Some(s.into_bytes()),
tungstenite::Message::Binary(b) => Some(b),
_ => None,
};
if let Some(msg) = qmsg {
self.event_queue
.push(StreamEvent::ReceivedData(info.url.clone(), msg));
}
info.stateful_socket = WebsocketStreamState::ReadyWs(socket);
Ok(())
}
}
}
WebsocketStreamState::ReadyWss(mut socket) => {
match socket.read_message() {
Err(tungstenite::error::Error::Io(e)) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
info.stateful_socket = WebsocketStreamState::ReadyWss(socket);
return Ok(());
}
Err(e.into())
}
Err(tungstenite::error::Error::ConnectionClosed) => {
error!("Connection unexpectedly closed");
Ok(())
}
Err(e) => Err(e.into()),
Ok(msg) => {
info.last_msg = std::time::Instant::now();
*did_work = true;
let qmsg = match msg {
tungstenite::Message::Text(s) => Some(s.into_bytes()),
tungstenite::Message::Binary(b) => Some(b),
_ => None,
};
if let Some(msg) = qmsg {
self.event_queue
.push(StreamEvent::ReceivedData(info.url.clone(), msg));
}
info.stateful_socket = WebsocketStreamState::ReadyWss(socket);
Ok(())
}
}
}
}
}
fn priv_tls_handshake(
&mut self,
res: TlsConnectResult<T>,
) -> TransportResult<WebsocketStreamState<T>> {
match res {
Err(native_tls::HandshakeError::WouldBlock(socket)) => {
Ok(WebsocketStreamState::TlsMidHandshake(socket))
}
Err(e) => Err(e.into()),
Ok(socket) => Ok(WebsocketStreamState::TlsReady(socket)),
}
}
fn priv_tls_srv_handshake(
&mut self,
res: TlsConnectResult<T>,
) -> TransportResult<WebsocketStreamState<T>> {
trace!("[t] processing tls connect result: {:?}", res);
match res {
Err(native_tls::HandshakeError::WouldBlock(socket)) => {
Ok(WebsocketStreamState::TlsSrvMidHandshake(socket))
}
Err(e) => Err(e.into()),
Ok(socket) => Ok(WebsocketStreamState::TlsSrvReady(socket)),
}
}
fn priv_ws_handshake(
&mut self,
url: &Url,
request_id: &str,
res: WsConnectResult<T>,
) -> TransportResult<WebsocketStreamState<T>> {
match res {
Err(tungstenite::HandshakeError::Interrupted(socket)) => {
Ok(WebsocketStreamState::WsMidHandshake(socket))
}
Err(e) => Err(e.into()),
Ok((socket, _response)) => {
self.event_queue.push(StreamEvent::ConnectResult(
url.clone(),
request_id.to_string(),
));
Ok(WebsocketStreamState::ReadyWs(Box::new(socket)))
}
}
}
fn priv_wss_handshake(
&mut self,
url: &Url,
request_id: &str,
res: WssConnectResult<T>,
) -> TransportResult<WebsocketStreamState<T>> {
match res {
Err(tungstenite::HandshakeError::Interrupted(socket)) => {
Ok(WebsocketStreamState::WssMidHandshake(socket))
}
Err(e) => Err(e.into()),
Ok((socket, _response)) => {
self.event_queue.push(StreamEvent::ConnectResult(
url.clone(),
request_id.to_string(),
));
Ok(WebsocketStreamState::ReadyWss(Box::new(socket)))
}
}
}
fn priv_ws_srv_handshake(
&mut self,
url: &Url,
res: WsSrvAcceptResult<T>,
) -> TransportResult<WebsocketStreamState<T>> {
match res {
Err(tungstenite::HandshakeError::Interrupted(socket)) => {
Ok(WebsocketStreamState::WsSrvMidHandshake(socket))
}
Err(e) => Err(e.into()),
Ok(socket) => {
self.event_queue
.push(StreamEvent::IncomingConnectionEstablished(url.clone()));
Ok(WebsocketStreamState::ReadyWs(Box::new(socket)))
}
}
}
fn priv_wss_srv_handshake(
&mut self,
url: &Url,
res: WssSrvAcceptResult<T>,
) -> TransportResult<WebsocketStreamState<T>> {
match res {
Err(tungstenite::HandshakeError::Interrupted(socket)) => {
Ok(WebsocketStreamState::WssSrvMidHandshake(socket))
}
Err(e) => Err(e.into()),
Ok(socket) => {
self.event_queue
.push(StreamEvent::IncomingConnectionEstablished(url.clone()));
Ok(WebsocketStreamState::ReadyWss(Box::new(socket)))
}
}
}
}