tcpwarp/
client.rs

1use super::*;
2
3pub struct TcpWarpClient {
4    bind_address: IpAddr,
5    tunnel_address: SocketAddr,
6}
7
8pub type TcpWarpClientResult = HashMap<Uuid, TcpWarpConnection>;
9
10impl TcpWarpClient {
11    pub fn new(bind_address: IpAddr, tunnel_address: SocketAddr) -> Self {
12        Self {
13            bind_address,
14            tunnel_address,
15        }
16    }
17
18    pub async fn connect(
19        &self,
20        addresses: Vec<TcpWarpPortConnection>,
21    ) -> Result<(TcpWarpClientResult, Arc<Vec<TcpWarpPortConnection>>), Box<dyn Error>> {
22        self.connect_with(HashMap::new(), Arc::new(addresses)).await
23    }
24
25    pub async fn connect_loop(
26        &self,
27        retry_delay: Duration,
28        keep_connections: bool,
29        mut addresses: Arc<Vec<TcpWarpPortConnection>>,
30    ) -> Result<(), Box<dyn Error>> {
31        let mut connections = HashMap::new();
32
33        while let Ok((data, addrs)) = self.connect_with(connections, addresses).await {
34            connections = if keep_connections {
35                data
36            } else {
37                HashMap::new()
38            };
39            addresses = addrs;
40            warn!("retrying in {:?}", retry_delay);
41            delay_for(retry_delay).await;
42        }
43
44        Ok(())
45    }
46
47    async fn connect_with(
48        &self,
49        mut connections: TcpWarpClientResult,
50        addresses: Arc<Vec<TcpWarpPortConnection>>,
51    ) -> Result<(TcpWarpClientResult, Arc<Vec<TcpWarpPortConnection>>), Box<dyn Error>> {
52        let stream = match TcpStream::connect(&self.tunnel_address).await {
53            Ok(stream) => stream,
54            Err(err) => {
55                error!("cannot connect to tunnel: {}", err);
56                return Ok((connections, addresses));
57            }
58        };
59        let (mut wtransport, mut rtransport) = Framed::new(stream, TcpWarpProto).split();
60
61        let (mut sender, mut receiver) = channel(100);
62
63        let forward_task = async move {
64            debug!("in receiver task");
65
66            let mut listeners = vec![];
67
68            while let Some(message) = receiver.next().await {
69                debug!("just received a message connect: {:?}", message);
70                let message = match message {
71                    TcpWarpMessage::Connect {
72                        connection_id,
73                        connection,
74                        sender,
75                        connected_sender,
76                    } => {
77                        debug!("adding connection: {}", connection_id);
78                        connections.insert(
79                            connection_id.clone(),
80                            TcpWarpConnection {
81                                sender,
82                                connected_sender: Some(connected_sender),
83                            },
84                        );
85                        TcpWarpMessage::HostConnect {
86                            connection_id,
87                            host: connection.host,
88                            port: connection.port,
89                        }
90                    }
91                    TcpWarpMessage::Listener(abort_handler) => {
92                        listeners.push(abort_handler);
93                        continue;
94                    }
95                    TcpWarpMessage::Disconnect => {
96                        debug!("stopping lesteners...");
97                        for listener in listeners {
98                            listener.abort();
99                        }
100                        debug!("stopped listeners");
101                        break;
102                    }
103                    TcpWarpMessage::DisconnectHost { ref connection_id } => {
104                        if let Some(mut connection) = connections.remove(connection_id) {
105                            if let Err(err) = connection.sender.send(message).await {
106                                error!("cannot send to channel: {}", err);
107                            }
108                        } else {
109                            error!("connection not found: {}", connection_id);
110                        }
111                        debug!("connections in pool: {}", connections.len());
112                        continue;
113                    }
114                    TcpWarpMessage::ConnectFailure { ref connection_id } => {
115                        if let Some(mut connection) = connections.remove(connection_id) {
116                            if let Some(connection_sender) = connection.connected_sender.take() {
117                                if let Err(err) = connection_sender.send(Err(io::Error::new(
118                                    io::ErrorKind::Other,
119                                    "disonnect propagated",
120                                ))) {
121                                    error!("cannot send to oneshot channel: {:?}", err);
122                                }
123                            }
124                            if let Err(err) = connection.sender.send(message).await {
125                                error!("cannot send to channel: {}", err);
126                            }
127                        } else {
128                            error!("connection not found: {}", connection_id);
129                        }
130                        debug!("connections in pool: {}", connections.len());
131                        continue;
132                    }
133                    TcpWarpMessage::Connected { ref connection_id } => {
134                        if let Some(connection) = connections.get_mut(&connection_id) {
135                            debug!("start connected loop: {}", connection_id);
136                            if let Some(connection_sender) = connection.connected_sender.take() {
137                                if let Err(err) = connection_sender.send(Ok(())) {
138                                    error!("cannot send to oneshot channel: {:?}", err);
139                                }
140                            }
141                        } else {
142                            error!("connection not found: {}", connection_id);
143                        }
144                        continue;
145                    }
146                    TcpWarpMessage::BytesHost {
147                        connection_id,
148                        data,
149                    } => {
150                        if let Some(connection) = connections.get_mut(&connection_id) {
151                            debug!(
152                                "forward message to host port of connection: {}",
153                                connection_id
154                            );
155                            if let Err(err) = connection
156                                .sender
157                                .send(TcpWarpMessage::BytesServer { data })
158                                .await
159                            {
160                                error!("cannot send to channel: {}", err);
161                            }
162                        } else {
163                            error!("connection not found: {}", connection_id);
164                        }
165                        continue;
166                    }
167                    regular_message => regular_message,
168                };
169                debug!("sending message {:?} from client to tunnel server", message);
170                wtransport.send(message).await?;
171            }
172
173            debug!("no more messages, closing forward task");
174
175            wtransport.close().await?;
176            receiver.close();
177
178            Ok::<TcpWarpClientResult, io::Error>(connections)
179        };
180
181        let bind_address = self.bind_address;
182
183        let _addresses = addresses.clone();
184        let processing_task = async move {
185            while let Some(Ok(message)) = rtransport.next().await {
186                process_host_to_client_message(
187                    message,
188                    sender.clone(),
189                    addresses.clone(),
190                    bind_address,
191                )
192                .await?;
193            }
194
195            debug!("processing task for host to client finished");
196
197            if let Err(err) = sender.send(TcpWarpMessage::Disconnect).await {
198                error!("could not send disconnect message {}", err);
199            }
200
201            Ok::<(), io::Error>(())
202        };
203
204        let (connections, _) = try_join!(forward_task, processing_task)?;
205
206        Ok((connections, _addresses))
207    }
208}
209
210// async fn publish
211async fn process_host_to_client_message(
212    message: TcpWarpMessage,
213    mut sender: Sender<TcpWarpMessage>,
214    addresses: Arc<Vec<TcpWarpPortConnection>>,
215    bind_address: IpAddr,
216) -> Result<(), io::Error> {
217    debug!("{} host to client: {:?}", bind_address, message);
218
219    match message {
220        TcpWarpMessage::AddPorts(_) => {
221            for address in addresses.iter().cloned() {
222                let bind_address =
223                    SocketAddr::new(bind_address, address.client_port.unwrap_or(address.port));
224                let sender_ = sender.clone();
225
226                let mut listener = match TcpListener::bind(bind_address).await {
227                    Ok(listener) => listener,
228                    Err(err) => {
229                        error!("could not start listen {}: {}", bind_address, err);
230                        return Err(err);
231                    }
232                };
233
234                debug!("listen: {:?}", bind_address);
235
236                let abortable_feature = async move {
237                    let mut incoming = listener.incoming();
238
239                    while let Some(Ok(stream)) = incoming.next().await {
240                        let sender__ = sender_.clone();
241
242                        let _address = address.clone();
243                        spawn(async move {
244                            if let Err(e) = process(stream, sender__, _address).await {
245                                error!("failed to process connection; error = {}", e);
246                            }
247                        });
248                    }
249
250                    debug!("done listen: {:?}", bind_address);
251
252                    Ok::<(), io::Error>(())
253                };
254                let (abortable_listener, abort_handler) = abortable(abortable_feature);
255                if let Err(err) = sender.send(TcpWarpMessage::Listener(abort_handler)).await {
256                    error!("cannot send message Listener to forward channel: {}", err);
257                }
258                spawn(abortable_listener);
259            }
260        }
261        TcpWarpMessage::BytesHost { .. } => {
262            if let Err(err) = sender.send(message).await {
263                error!("cannot send message BytesHost to forward channel: {}", err);
264            }
265        }
266        TcpWarpMessage::Connected { .. } => {
267            if let Err(err) = sender.send(message).await {
268                error!("cannot send message Connected to forward channel: {}", err);
269            }
270        }
271        TcpWarpMessage::DisconnectHost { .. } => {
272            if let Err(err) = sender.send(message).await {
273                error!(
274                    "cannot send message DisconnectHost to forward channel: {}",
275                    err
276                );
277            }
278        }
279        TcpWarpMessage::ConnectFailure { .. } => {
280            if let Err(err) = sender.send(message).await {
281                error!(
282                    "cannot send message ConnectFailure to forward channel: {}",
283                    err
284                );
285            }
286        }
287        other_message => warn!("unsupported message: {:?}", other_message),
288    }
289    Ok(())
290}
291
292async fn process(
293    stream: TcpStream,
294    mut host_sender: Sender<TcpWarpMessage>,
295    address: TcpWarpPortConnection,
296) -> Result<(), Box<dyn Error>> {
297    let connection_id = Uuid::new_v4();
298
299    debug!("new connection: {}", connection_id);
300
301    let (mut wtransport, mut rtransport) =
302        Framed::new(stream, TcpWarpProtoClient { connection_id }).split();
303
304    let (client_sender, mut client_receiver) = channel(100);
305
306    let forward_task = async move {
307        debug!("in receiver task");
308        while let Some(message) = client_receiver.next().await {
309            debug!(
310                "{} just received a message process: {:?}",
311                connection_id, message
312            );
313            match message {
314                TcpWarpMessage::ConnectFailure { .. } => break,
315                TcpWarpMessage::DisconnectHost { .. } => break,
316                TcpWarpMessage::BytesServer { data } => wtransport.send(data).await?,
317                _ => (),
318            }
319        }
320
321        debug!("{} no more messages, closing forward task", connection_id);
322        debug!(
323            "{} closing write channel to client side port",
324            connection_id
325        );
326        wtransport.close().await?;
327        client_receiver.close();
328        debug!("{} write channel to client side port closed", connection_id);
329
330        Ok::<(), io::Error>(())
331    };
332
333    let (connected_sender, connected_receiver) = oneshot::channel();
334
335    host_sender
336        .send(TcpWarpMessage::Connect {
337            connection_id,
338            connection: address,
339            sender: client_sender,
340            connected_sender,
341        })
342        .await?;
343
344    let processing_task = async move {
345        match connected_receiver.await {
346            Err(err) => {
347                error!("{} connection error: {}", connection_id, err);
348                return Ok(());
349            }
350            Ok(Err(err)) => {
351                error!("{} connection error: {}", connection_id, err);
352                return Ok(());
353            }
354            _ => (),
355        }
356
357        while let Some(Ok(message)) = rtransport.next().await {
358            if let Err(err) = host_sender.send(message).await {
359                error!("{} {}", connection_id, err);
360            }
361        }
362
363        debug!(
364            "{} processing task for incoming connection finished",
365            connection_id
366        );
367
368        let message = TcpWarpMessage::DisconnectClient { connection_id };
369        debug!("{} sending disconnect message {:?}", connection_id, message);
370        if let Err(err) = host_sender.send(message).await {
371            error!("{} {}", connection_id, err);
372        }
373        debug!("{} done processing", connection_id);
374
375        Ok::<(), io::Error>(())
376    };
377
378    try_join!(forward_task, processing_task)?;
379
380    debug!("{} full complete process", connection_id);
381
382    Ok(())
383}