use crate::arg::ClientArgs;
use crate::tls::{MaybeTlsStream, tls_connect};
use http::header::HeaderValue;
use penguin_mux::{Dupe, PROTOCOL_VERSION};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::{client::IntoClientRequest, handshake::client::Request};
use tokio_tungstenite::{WebSocketStream, client_async};
use tracing::{debug, warn};
#[tracing::instrument(skip_all, fields(server = %args.server.0), level = "debug")]
async fn handshake_inner(
args: &ClientArgs,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, super::Error> {
let is_tls = args
.server
.scheme()
.expect("URL scheme should be present (this is a bug)")
.as_str()
== "wss";
let host = args
.server
.0
.host()
.expect("URL host should be present (this is a bug)");
let host = crate::parse_remote::remove_brackets(host);
let port = args
.server
.0
.port_u16()
.expect("URL port should be present (this is a bug)");
let mut tls_server_name = host;
let mut req: Request = args.server.0.dupe().into_client_request()?;
let req_headers = req.headers_mut();
req_headers.insert(
"sec-websocket-protocol",
HeaderValue::from_static(PROTOCOL_VERSION),
);
if let Some(ref ws_psk) = args.ws_psk {
req_headers.insert("x-penguin-psk", ws_psk.dupe());
}
if let Some(ref hostname) = args.hostname {
req_headers.insert("host", hostname.dupe());
tls_server_name = hostname.to_str().map_err(super::Error::InvalidDomainName)?;
}
if let Some(tls_sni) = args.tls_server_name.as_deref() {
tls_server_name = tls_sni;
}
for header in &args.header {
req_headers.insert(&header.name, header.value.dupe());
}
let stream = if is_tls {
MaybeTlsStream::Tls(
tls_connect(
host,
port,
tls_server_name,
args.tls_cert.as_deref(),
args.tls_key.as_deref(),
args.tls_ca.as_deref(),
args.tls_skip_verify,
)
.await?,
)
} else {
warn!("Using insecure WebSocket connection");
MaybeTlsStream::Plain(
TcpStream::connect((host, port))
.await
.map_err(super::Error::TcpConnect)?,
)
};
let (ws_stream, _response) = client_async(req, stream).await?;
debug!("WebSocket handshake succeeded");
Ok(ws_stream)
}
pub async fn handshake(
args: &ClientArgs,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, super::Error> {
tokio::select! {
result = handshake_inner(args) => result,
() = args.handshake_timeout.sleep() => Err(super::Error::HandshakeTimeout),
Ok(()) = tokio::signal::ctrl_c() => Err(super::Error::HandshakeCancelled),
}
}