use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
use crate::frame::{read_length_prefixed, write_length_prefixed};
const RECV_BUF: usize = 65_536;
pub struct HttpTunnelTransport {
stream: Arc<Mutex<Option<TcpStream>>>,
}
impl HttpTunnelTransport {
#[must_use]
pub fn from_tcp(stream: TcpStream) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
}
}
pub async fn connect_via_proxy(
proxy: SocketAddr,
target_authority: &str,
) -> crate::error::Result<Self> {
let mut stream = TcpStream::connect(proxy)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let req = format!(
"CONNECT {target_authority} HTTP/1.1\r\nHost: {target_authority}\r\nProxy-Connection: keep-alive\r\n\r\n"
);
stream
.write_all(req.as_bytes())
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let mut buf = Vec::with_capacity(2048);
let mut tmp = [0u8; 256];
loop {
let n = stream.read(&mut tmp).await.map_err(|e| {
SrxError::Transport(TransportError::ConnectionFailed(e.to_string()))
})?;
if n == 0 {
return Err(SrxError::Transport(TransportError::ChannelClosed));
}
buf.extend_from_slice(&tmp[..n]);
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
if buf.len() > 16 * 1024 {
return Err(SrxError::Transport(TransportError::ConnectionFailed(
"CONNECT response headers too large".into(),
)));
}
}
let head_end = buf
.windows(4)
.position(|w| w == b"\r\n\r\n")
.expect("checked")
+ 4;
let head = std::str::from_utf8(&buf[..head_end])
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
let status = head.lines().next().unwrap_or("");
if !status.starts_with("HTTP/1.") || !status.contains("200") {
return Err(SrxError::Transport(TransportError::ConnectionFailed(
format!("CONNECT failed: {status}"),
)));
}
Ok(Self {
stream: Arc::new(Mutex::new(Some(stream))),
})
}
pub async fn send_framed(&self, payload: &[u8]) -> crate::error::Result<()> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
write_length_prefixed(stream, payload).await
}
pub async fn recv_framed(&self) -> crate::error::Result<Bytes> {
let mut guard = self.stream.lock().await;
let stream = guard
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let v = read_length_prefixed(stream).await?;
Ok(Bytes::from(v))
}
}
#[async_trait]
impl Transport for HttpTunnelTransport {
fn kind(&self) -> TransportKind {
TransportKind::Http2
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
let mut g = self.stream.lock().await;
let s = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
s.write_all(&data)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
s.flush()
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
let mut g = self.stream.lock().await;
let s = g
.as_mut()
.ok_or(SrxError::Transport(TransportError::ChannelClosed))?;
let mut buf = vec![0u8; RECV_BUF];
let n = s
.read(&mut buf)
.await
.map_err(|e| SrxError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
if n == 0 {
return Err(SrxError::Transport(TransportError::ChannelClosed));
}
buf.truncate(n);
Ok(Bytes::from(buf))
}
async fn is_healthy(&self) -> bool {
self.stream.lock().await.is_some()
}
async fn close(&self) -> crate::error::Result<()> {
let mut g = self.stream.lock().await;
if let Some(mut s) = g.take() {
let _ = s.shutdown().await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn connect_proxy_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = listener.local_addr().unwrap();
let target = "127.0.0.1:9";
let serve = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 2048];
let mut total = 0usize;
loop {
let n = stream.read(&mut buf[total..]).await.unwrap();
if n == 0 {
panic!("eof");
}
total += n;
if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
stream
.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
.await
.unwrap();
let mut out = [0u8; 64];
let n = stream.read(&mut out).await.unwrap();
assert_eq!(&out[..n], b"ping");
stream.write_all(b"pong").await.unwrap();
});
let t = HttpTunnelTransport::connect_via_proxy(proxy_addr, target)
.await
.unwrap();
t.send(Bytes::from_static(b"ping")).await.unwrap();
let r = t.recv().await.unwrap();
assert_eq!(r.as_ref(), b"pong");
t.close().await.unwrap();
serve.await.unwrap();
}
#[tokio::test]
async fn framed_roundtrip_over_http_tunnel() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = listener.local_addr().unwrap();
let target = "127.0.0.1:9";
let payload = b"framed-payload-srx";
let serve = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 2048];
let mut total = 0usize;
loop {
let n = stream.read(&mut buf[total..]).await.unwrap();
if n == 0 {
panic!("eof");
}
total += n;
if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
stream
.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
.await
.unwrap();
let t = HttpTunnelTransport::from_tcp(stream);
let got = t.recv_framed().await.unwrap();
assert_eq!(got.as_ref(), payload);
t.send_framed(b"ack").await.unwrap();
});
let t = HttpTunnelTransport::connect_via_proxy(proxy_addr, target)
.await
.unwrap();
t.send_framed(payload).await.unwrap();
let reply = t.recv_framed().await.unwrap();
assert_eq!(reply.as_ref(), b"ack");
t.close().await.unwrap();
serve.await.unwrap();
}
}