use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::shared::SharedState;
const SERVER_READ_BUF_SIZE: usize = 16384;
pub fn spawn_tcp_proxy(
handle: &tokio::runtime::Handle,
dst: SocketAddr,
from_smoltcp: mpsc::Receiver<Bytes>,
to_smoltcp: mpsc::Sender<Bytes>,
shared: Arc<SharedState>,
) {
handle.spawn(async move {
if let Err(e) = tcp_proxy_task(dst, from_smoltcp, to_smoltcp, shared).await {
tracing::debug!(dst = %dst, error = %e, "TCP proxy task ended");
}
});
}
async fn tcp_proxy_task(
dst: SocketAddr,
mut from_smoltcp: mpsc::Receiver<Bytes>,
to_smoltcp: mpsc::Sender<Bytes>,
shared: Arc<SharedState>,
) -> io::Result<()> {
let stream = TcpStream::connect(dst).await?;
let (mut server_rx, mut server_tx) = stream.into_split();
let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
loop {
tokio::select! {
data = from_smoltcp.recv() => {
match data {
Some(bytes) => {
if let Err(e) = server_tx.write_all(&bytes).await {
tracing::debug!(dst = %dst, error = %e, "write to server failed");
break;
}
}
None => break,
}
}
result = server_rx.read(&mut server_buf) => {
match result {
Ok(0) => break, Ok(n) => {
let data = Bytes::copy_from_slice(&server_buf[..n]);
if to_smoltcp.send(data).await.is_err() {
break;
}
shared.proxy_wake.wake();
}
Err(e) => {
tracing::debug!(dst = %dst, error = %e, "read from server failed");
break;
}
}
}
}
}
Ok(())
}