tcpwarp/
server.rs

1use super::*;
2
3pub struct TcpWarpServer {
4    listen_address: SocketAddr,
5    connect_address: IpAddr,
6}
7
8impl TcpWarpServer {
9    pub fn new(listen_address: SocketAddr, connect_address: IpAddr) -> Self {
10        Self {
11            listen_address,
12            connect_address,
13        }
14    }
15
16    pub async fn listen(&self) -> Result<(), Box<dyn Error>> {
17        let mut listener = TcpListener::bind(&self.listen_address).await?;
18        let mut incoming = listener.incoming();
19        let connect_address = self.connect_address;
20
21        while let Some(Ok(stream)) = incoming.next().await {
22            spawn(async move {
23                if let Err(e) = process(stream, connect_address).await {
24                    println!("failed to process connection; error = {}", e);
25                }
26            });
27        }
28        Ok(())
29    }
30}
31
32async fn process(stream: TcpStream, connect_address: IpAddr) -> Result<(), Box<dyn Error>> {
33    let mut transport = Framed::new(stream, TcpWarpProto);
34
35    transport.send(TcpWarpMessage::AddPorts(vec![])).await?;
36
37    let (mut wtransport, mut rtransport) = transport.split();
38
39    let (sender, mut receiver) = channel(100);
40
41    let mut connections = HashMap::new();
42
43    let forward_task = async move {
44        debug!("in receiver task process");
45        while let Some(message) = receiver.next().await {
46            debug!("received in fw message: {:?}", message);
47            let message = match message {
48                TcpWarpMessage::ConnectForward {
49                    connection_id,
50                    sender,
51                    connected_sender,
52                } => {
53                    debug!("adding connection: {}", connection_id);
54                    if let Err(err) = connected_sender.send(Ok(())) {
55                        error!("connected sender errored: {:?}", err);
56                    }
57                    connections.insert(connection_id.clone(), sender.clone());
58                    TcpWarpMessage::Connected { connection_id }
59                }
60                TcpWarpMessage::DisconnectClient { ref connection_id } => {
61                    debug!(
62                        "{} client connection disconnected, handle server disconnect",
63                        connection_id
64                    );
65                    if let Some(mut sender) = connections.remove(connection_id) {
66                        if let Err(err) = sender.send(message).await {
67                            error!("cannot send to channel: {}", err);
68                        }
69                    } else {
70                        error!("connection not found: {}", connection_id);
71                    }
72                    debug!("connections in pool: {}", connections.len());
73                    continue;
74                }
75                TcpWarpMessage::BytesClient {
76                    connection_id,
77                    data,
78                } => {
79                    if let Some(sender) = connections.get_mut(&connection_id) {
80                        debug!(
81                            "forward message to host port of connection: {}",
82                            connection_id
83                        );
84                        if let Err(err) = sender.send(TcpWarpMessage::BytesServer { data }).await {
85                            error!("cannot send to channel: {}", err);
86                        };
87                    } else {
88                        error!("connection not found: {}", connection_id);
89                    }
90                    continue;
91                }
92                regular_message => regular_message,
93            };
94            debug!("sending message {:?} from server to tunnel client", message);
95            wtransport.send(message).await?
96        }
97
98        debug!("no more messages, closing forward to tunnel client task");
99        wtransport.close().await?;
100        receiver.close();
101
102        Ok::<(), io::Error>(())
103    };
104
105    let processing_task = async move {
106        while let Some(Ok(message)) = rtransport.next().await {
107            debug!("server received from tunnel client {:?}", message);
108            if let Err(err) =
109                process_client_to_host_message(message, sender.clone(), connect_address).await
110            {
111                error!("error in processing: {}", err);
112            }
113        }
114
115        debug!("processing task for client to host tunnel finished");
116
117        Ok::<(), io::Error>(())
118    };
119
120    let (_, _) = try_join!(forward_task, processing_task)?;
121
122    debug!("finished process of tunnel connection");
123
124    Ok(())
125}
126
127async fn process_client_to_host_message(
128    message: TcpWarpMessage,
129    mut client_sender: Sender<TcpWarpMessage>,
130    connect_address: IpAddr,
131) -> Result<(), io::Error> {
132    match message {
133        TcpWarpMessage::HostConnect {
134            connection_id,
135            host,
136            port,
137        } => {
138            let client_sender_ = client_sender.clone();
139            spawn(async move {
140                let connect_address = connect_address.to_string();
141                let socket_address = format!(
142                    "{}:{}",
143                    host.unwrap_or_else(|| connect_address.to_string()),
144                    port
145                );
146                debug!("host connection to {}", socket_address);
147                if let Err(err) =
148                    process_host_connection(client_sender_, connection_id, socket_address).await
149                {
150                    error!(
151                        "failed connection {} {}: {}",
152                        connect_address, connection_id, err
153                    );
154                }
155            });
156        }
157        TcpWarpMessage::DisconnectClient { .. } => {
158            if let Err(err) = client_sender.send(message).await {
159                error!(
160                    "cannot send message DisconnectClient to forward channel: {}",
161                    err
162                );
163            }
164        }
165        TcpWarpMessage::BytesClient { .. } => {
166            if let Err(err) = client_sender.send(message).await {
167                error!(
168                    "cannot send message BytesClient to forward channel: {}",
169                    err
170                );
171            }
172        }
173        other_message => warn!("unsupported message: {:?}", other_message),
174    }
175    Ok(())
176}
177
178async fn process_host_connection<S: ToSocketAddrs>(
179    mut client_sender: Sender<TcpWarpMessage>,
180    connection_id: Uuid,
181    socket_address: S,
182) -> Result<(), Box<dyn Error>> {
183    debug!("{} new connection", connection_id);
184
185    let stream = match TcpStream::connect(socket_address).await {
186        Ok(stream) => stream,
187        Err(err) => {
188            client_sender
189                .send(TcpWarpMessage::ConnectFailure { connection_id })
190                .await?;
191            return Err(err.into());
192        }
193    };
194
195    let (mut wtransport, mut rtransport) =
196        Framed::new(stream, TcpWarpProtoHost { connection_id }).split();
197
198    let (host_sender, mut host_receiver) = channel(100);
199
200    let forward_task = async move {
201        debug!("{} in receiver task process_host_connection", connection_id);
202
203        while let Some(message) = host_receiver.next().await {
204            debug!("{} just received a message: {:?}", connection_id, message);
205            match message {
206                TcpWarpMessage::DisconnectClient { .. } => break,
207                TcpWarpMessage::BytesServer { data } => wtransport.send(data).await?,
208                _ => (),
209            }
210        }
211
212        debug!(
213            "{} no more messages, closing process host forward task",
214            connection_id
215        );
216        wtransport.close().await?;
217        host_receiver.close();
218        debug!("{} closed write transport", connection_id);
219
220        Ok::<(), io::Error>(())
221    };
222
223    let (connected_sender, connected_receiver) = oneshot::channel();
224
225    client_sender
226        .send(TcpWarpMessage::ConnectForward {
227            connection_id,
228            sender: host_sender,
229            connected_sender,
230        })
231        .await?;
232
233    debug!("{} sended connect to client", connection_id);
234
235    let mut client_sender_ = client_sender.clone();
236
237    let processing_task = async move {
238        if let Err(err) = connected_receiver.await {
239            error!("{} connection error: {}", connection_id, err);
240        }
241        while let Some(Ok(message)) = rtransport.next().await {
242            if let Err(err) = client_sender_.send(message).await {
243                error!("{} {}", connection_id, err);
244            }
245        }
246
247        let message = TcpWarpMessage::DisconnectHost { connection_id };
248
249        debug!("{} sending disconnect host message", connection_id);
250
251        if let Err(err) = client_sender_.send(message).await {
252            error!("{} err: {}", connection_id, err);
253        }
254
255        debug!("{} host connection processing task done", connection_id);
256
257        Ok::<(), io::Error>(())
258    };
259
260    try_join!(forward_task, processing_task)?;
261
262    debug!("{} disconnect, processing task done", connection_id);
263
264    Ok(())
265}