#![forbid(unsafe_code)]
use std::pin::Pin;
use std::sync::Arc;
use base64::Engine;
use bytes::Bytes;
use http_body_util::Empty;
use hyper::Request;
use hyper::header::{HeaderName, HeaderValue};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
#[allow(unused_imports)]
use uuid::Uuid;
pub const SUBPROTOCOL: &str = "boardwalk-peer/2";
const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
#[derive(Debug, Error)]
pub enum TunnelError {
#[error("websocket upgrade: {0}")]
Upgrade(String),
#[error("invalid url: {0}")]
Url(String),
#[error("h2: {0}")]
H2(#[from] h2::Error),
#[error("io: {0}")]
Io(#[from] std::io::Error),
#[error("hyper: {0}")]
Hyper(#[from] hyper::Error),
#[error("invalid response: {0}")]
Response(String),
}
pub fn ws_accept_key(client_key: &str) -> String {
let mut h = Sha1::new();
h.update(client_key.as_bytes());
h.update(WS_GUID.as_bytes());
base64::engine::general_purpose::STANDARD.encode(h.finalize())
}
pub fn ws_new_key() -> String {
use rand::TryRngCore;
let mut buf = [0u8; 16];
rand::rngs::OsRng.try_fill_bytes(&mut buf).expect("os rng");
base64::engine::general_purpose::STANDARD.encode(buf)
}
pub struct InitiatorReady {
pub upgraded: hyper::upgrade::Upgraded,
pub remote_authority: String,
}
pub async fn dial_initiator(
remote_url: &str,
local_name: &str,
connection_id: Uuid,
) -> Result<InitiatorReady, TunnelError> {
let url = url::Url::parse(remote_url).map_err(|e| TunnelError::Url(format!("{e}")))?;
let scheme = url.scheme();
let host = url
.host_str()
.ok_or_else(|| TunnelError::Url("no host".into()))?;
let port = url.port_or_known_default().unwrap_or(match scheme {
"https" | "wss" => 443,
_ => 80,
});
let tcp = TcpStream::connect((host, port)).await?;
let stream: Pin<Box<dyn AsyncReadWrite + Send>> = match scheme {
"http" | "ws" => Box::pin(tcp),
"https" | "wss" => Box::pin(tls_connect(host, tcp).await?),
other => {
return Err(TunnelError::Url(format!("scheme `{other}` not supported")));
}
};
let path = format!(
"/peers/{}?connectionId={}",
urlencoding::encode(local_name),
connection_id
);
let authority = if let Some(p) = url.port() {
format!("{host}:{p}")
} else {
host.to_string()
};
let key = ws_new_key();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
.handshake::<_, Empty<Bytes>>(io)
.await?;
let conn = conn.with_upgrades();
let conn_task = tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::builder()
.method("POST")
.uri(&path)
.header(hyper::header::HOST, authority.clone())
.header(hyper::header::CONNECTION, "Upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header(HeaderName::from_static("sec-websocket-key"), key.clone())
.header(HeaderName::from_static("sec-websocket-version"), "13")
.header(
HeaderName::from_static("sec-websocket-protocol"),
HeaderValue::from_static(SUBPROTOCOL),
)
.body(Empty::<Bytes>::new())
.map_err(|e| TunnelError::Upgrade(format!("build request: {e}")))?;
let response = sender.send_request(req).await?;
drop(sender);
if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS {
return Err(TunnelError::Response(format!(
"expected 101, got {}",
response.status()
)));
}
let expected_accept = ws_accept_key(&key);
if let Some(got) = response.headers().get("sec-websocket-accept") {
if got.to_str().ok() != Some(expected_accept.as_str()) {
return Err(TunnelError::Upgrade("invalid Sec-WebSocket-Accept".into()));
}
} else {
return Err(TunnelError::Upgrade("missing Sec-WebSocket-Accept".into()));
}
let upgraded = hyper::upgrade::on(response).await?;
drop(conn_task);
Ok(InitiatorReady {
upgraded,
remote_authority: authority,
})
}
pub fn build_upgrade_response(
headers: &http::HeaderMap,
) -> Result<http::Response<()>, TunnelError> {
let key = headers
.get("sec-websocket-key")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| TunnelError::Upgrade("missing Sec-WebSocket-Key".into()))?;
let version = headers
.get("sec-websocket-version")
.and_then(|v| v.to_str().ok());
if version != Some("13") {
return Err(TunnelError::Upgrade(
"missing or wrong Sec-WebSocket-Version".into(),
));
}
let offered = headers
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !offered.split(',').any(|tok| tok.trim() == SUBPROTOCOL) {
return Err(TunnelError::Upgrade(format!(
"client did not offer `{SUBPROTOCOL}` subprotocol; got `{offered}`"
)));
}
let accept = ws_accept_key(key);
let builder = http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-accept", accept)
.header("sec-websocket-protocol", SUBPROTOCOL);
builder
.body(())
.map_err(|e| TunnelError::Upgrade(format!("build response: {e}")))
}
pub use hyper_util::rt::TokioExecutor as H2Executor;
pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite + ?Sized> AsyncReadWrite for T {}
async fn tls_connect(
host: &str,
tcp: TcpStream,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, TunnelError> {
static PROVIDER_INSTALLED: std::sync::Once = std::sync::Once::new();
PROVIDER_INSTALLED.call_once(|| {
let _ = rustls::crypto::CryptoProvider::install_default(
rustls::crypto::aws_lc_rs::default_provider(),
);
});
#[cfg(feature = "dangerous-test-tls")]
let config = {
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(dangerous_test_verifier())
.with_no_client_auth()
};
#[cfg(not(feature = "dangerous-test-tls"))]
let config = {
use rustls_platform_verifier::BuilderVerifierExt;
rustls::ClientConfig::builder()
.with_platform_verifier()
.map_err(|e| TunnelError::Upgrade(format!("rustls platform verifier: {e}")))?
.with_no_client_auth()
};
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
let server_name = rustls_pki_types::ServerName::try_from(host.to_string())
.map_err(|e| TunnelError::Url(format!("invalid TLS server name: {e}")))?;
let tls = connector
.connect(server_name, tcp)
.await
.map_err(TunnelError::Io)?;
Ok(tls)
}
#[cfg(feature = "dangerous-test-tls")]
fn dangerous_test_verifier() -> Arc<dyn rustls::client::danger::ServerCertVerifier> {
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, SignatureScheme};
#[derive(Debug)]
struct Accept;
impl ServerCertVerifier for Accept {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ED25519,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
]
}
}
Arc::new(Accept)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rfc6455_example_accept_key() {
let got = ws_accept_key("dGhlIHNhbXBsZSBub25jZQ==");
assert_eq!(got, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn ws_new_key_is_base64_16_bytes() {
let k = ws_new_key();
let decoded = base64::engine::general_purpose::STANDARD
.decode(&k)
.unwrap();
assert_eq!(decoded.len(), 16);
}
}