use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use tokio::net::TcpStream;
pub const MAX_TUNNEL_BYTES_PER_DIRECTION: u64 = 2 * 1024 * 1024 * 1024;
#[must_use]
pub fn host_addr(uri: &hyper::Uri) -> Option<String> {
uri.authority().map(std::string::ToString::to_string)
}
pub async fn tunnel(upgraded: Upgraded, addrs: Vec<SocketAddr>) -> std::io::Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut server = TcpStream::connect(addrs.as_slice()).await?;
let mut upgraded = TokioIo::new(upgraded);
let (mut up_r, mut up_w) = tokio::io::split(&mut upgraded);
let (mut sv_r, mut sv_w) = server.split();
let to_server = async {
let mut buf = vec![0u8; 16 * 1024];
let mut total: u64 = 0;
loop {
let n = up_r.read(&mut buf).await?;
if n == 0 {
break;
}
total = total.saturating_add(n as u64);
if total > MAX_TUNNEL_BYTES_PER_DIRECTION {
return Err(std::io::Error::other(
"tunnel exceeded byte cap (client→server)",
));
}
sv_w.write_all(&buf[..n]).await?;
}
Ok::<(), std::io::Error>(())
};
let to_client = async {
let mut buf = vec![0u8; 16 * 1024];
let mut total: u64 = 0;
loop {
let n = sv_r.read(&mut buf).await?;
if n == 0 {
break;
}
total = total.saturating_add(n as u64);
if total > MAX_TUNNEL_BYTES_PER_DIRECTION {
return Err(std::io::Error::other(
"tunnel exceeded byte cap (server→client)",
));
}
up_w.write_all(&buf[..n]).await?;
}
Ok::<(), std::io::Error>(())
};
tokio::try_join!(to_server, to_client)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_addr_extracts_authority_with_port() {
let uri: hyper::Uri = "http://example.com:8080/path".parse().unwrap();
assert_eq!(host_addr(&uri).as_deref(), Some("example.com:8080"));
}
#[test]
fn host_addr_extracts_authority_without_explicit_port() {
let uri: hyper::Uri = "https://example.com/path".parse().unwrap();
assert_eq!(host_addr(&uri).as_deref(), Some("example.com"));
}
#[test]
fn host_addr_extracts_ip_authority() {
let uri: hyper::Uri = "http://127.0.0.1:9000/".parse().unwrap();
assert_eq!(host_addr(&uri).as_deref(), Some("127.0.0.1:9000"));
}
#[test]
fn host_addr_relative_uri_returns_none() {
let uri: hyper::Uri = "/path".parse().unwrap();
assert_eq!(host_addr(&uri), None);
}
#[test]
fn max_tunnel_bytes_per_direction_is_2gib() {
assert_eq!(MAX_TUNNEL_BYTES_PER_DIRECTION, 2 * 1024 * 1024 * 1024);
}
#[test]
fn max_tunnel_bytes_per_direction_is_under_u64_max() {
#[allow(clippy::assertions_on_constants)]
{
assert!(MAX_TUNNEL_BYTES_PER_DIRECTION < u64::MAX / 1000);
}
}
}