use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
TcpStream,
tcp::{OwnedReadHalf, OwnedWriteHalf},
},
};
use crate::{
apperror::AppError,
channel::{Channel, ChannelReadHalf, ChannelWriteHalf},
model::{connect_info::ConnectInfo, packet::Packet},
rules::Rule,
};
pub struct TcpChannel {
channel: Channel,
}
impl TcpChannel {
pub fn new(channel: Channel) -> TcpChannel {
TcpChannel { channel }
}
fn parse_connect_info(&self, packet: Packet) -> Result<ConnectInfo, AppError> {
let connect: ConnectInfo = ConnectInfo::from_json_buffer(&packet.data)
.map_err(|_| AppError::new("failed to deserialize connect packet"))?;
log::debug!("{:?}", connect);
Ok(connect)
}
async fn transfer_to_tcp(
mut channel_read: ChannelReadHalf,
mut tcp: OwnedWriteHalf,
) -> Result<(), AppError> {
loop {
match channel_read.recv().await {
Some(packet) => {
tcp.write(&packet)
.await
.map_err(|_| AppError::new("tcp write failed"))?;
tcp.flush()
.await
.map_err(|_| AppError::new("tcp flush failed"))?;
}
None => {
log::trace!("closing tcp");
tcp.shutdown()
.await
.map_err(|_| AppError::new("tcp shutdown failed"))?;
return Ok(());
}
}
}
}
async fn transfer_from_tcp(
channel_write: ChannelWriteHalf,
mut tcp: OwnedReadHalf,
) -> Result<(), AppError> {
let mut buf = vec![0; 8192];
loop {
match tcp.read(&mut buf).await {
Ok(size) => {
if size == 0 {
channel_write.close().await?;
return Ok(());
}
channel_write.send(&buf[0..size]).await?;
}
Err(_) => {
channel_write.close().await?;
return Ok(());
}
}
}
}
async fn transfer(self, tcp: TcpStream) -> Result<(), AppError> {
let (tcp_read, tcp_write) = tcp.into_split();
let (channel_read, channel_write) = self.channel.split();
tokio::spawn(async move { TcpChannel::transfer_from_tcp(channel_write, tcp_read).await });
tokio::spawn(async move { TcpChannel::transfer_to_tcp(channel_read, tcp_write).await });
Ok(())
}
pub fn spawn(self, packet: Packet, rule: &Rule) {
let connect_info = match self.parse_connect_info(packet) {
Ok(connect_info) => connect_info,
Err(e) => {
log::warn!("handle connect failed {:}", e);
return;
}
};
let target = rule.evaluate(connect_info.port, connect_info.host);
tokio::spawn(async move {
let Some(target) = target else {
log::warn!("disabled forwarding for {:}", connect_info.port);
match self.channel.close().await {
Ok(_) => {}
Err(e) => log::warn!("close failed: {:}", e),
}
return;
};
let target = format!(
"{}:{}",
target.address.unwrap_or("127.0.0.1".to_string()),
target.port
);
log::info!("forward to {:}", target);
let Ok(tcp) = TcpStream::connect(target).await else {
log::warn!("failed to connect to {:}", connect_info.port);
match self.channel.close().await {
Ok(_) => {}
Err(e) => log::warn!("close failed: {:}", e),
}
return;
};
match self.transfer(tcp).await {
Ok(_) => {}
Err(e) => log::warn!("transfer failed: {:}", e),
}
});
}
}