1use tokio::io;
2use tokio::io::AsyncWriteExt;
3use tokio::net::{TcpListener, TcpStream};
4
5use futures::FutureExt;
6use serde::Deserialize;
7use std::net::SocketAddr;
8
9#[derive(Deserialize)]
10pub struct SocketConfig {
11 pub server_addr: String,
12 pub to_addr: String,
13}
14
15pub enum ValidateType {
16 Normal,
17 Warning,
18 Forbidden,
19}
20
21pub enum CondType {
22 Continue,
23 Stop,
24}
25
26pub async fn new_proxy<F, C>(config: SocketConfig, validate: F, callback: C) -> anyhow::Result<()>
27where
28 F: Fn(&SocketAddr) -> anyhow::Result<ValidateType>,
29 C: Fn(anyhow::Error) -> anyhow::Result<CondType> + Sync + Send + Copy + 'static,
30{
31 let listen_addr = config.server_addr.clone();
32 let to_addr = config.to_addr;
33 let listener = TcpListener::bind(listen_addr).await?;
34
35 while let Ok((inbound, remote_addr)) = listener.accept().await {
36 if let Err(err) = validate(&remote_addr) {
37 error!("validate error:{}", err);
38 } else {
39 debug!("remote_addr:{}", remote_addr.to_string());
40 let transfer = transfer(inbound, to_addr.clone()).map(move |r| {
41 if let Err(e) = r {
42 match callback(e) {
43 Ok(CondType::Stop) => {
44 warn!("system call stop");
45 return;
46 }
47 Err(e) => {
48 error!("Failed to transfer error:{}", e);
49 }
50 _ => {}
51 }
52 }
53 });
54
55 tokio::spawn(transfer);
56 }
57 }
58
59 Ok(())
60}
61
62async fn transfer(mut inbound: TcpStream, to_addr: String) -> anyhow::Result<()> {
63 let mut outbound = TcpStream::connect(to_addr).await?;
64
65 let (mut ri, mut wi) = inbound.split();
66 let (mut ro, mut wo) = outbound.split();
67
68 let client_to_server = async {
69 io::copy(&mut ri, &mut wo).await?;
70 wo.shutdown().await
71 };
72
73 let server_to_client = async {
74 io::copy(&mut ro, &mut wi).await?;
75 wi.shutdown().await
76 };
77
78 tokio::try_join!(client_to_server, server_to_client)?;
79
80 Ok(())
81}