1use super::*;
2
3pub struct TcpWarpClient {
4 bind_address: IpAddr,
5 tunnel_address: SocketAddr,
6}
7
8pub type TcpWarpClientResult = HashMap<Uuid, TcpWarpConnection>;
9
10impl TcpWarpClient {
11 pub fn new(bind_address: IpAddr, tunnel_address: SocketAddr) -> Self {
12 Self {
13 bind_address,
14 tunnel_address,
15 }
16 }
17
18 pub async fn connect(
19 &self,
20 addresses: Vec<TcpWarpPortConnection>,
21 ) -> Result<(TcpWarpClientResult, Arc<Vec<TcpWarpPortConnection>>), Box<dyn Error>> {
22 self.connect_with(HashMap::new(), Arc::new(addresses)).await
23 }
24
25 pub async fn connect_loop(
26 &self,
27 retry_delay: Duration,
28 keep_connections: bool,
29 mut addresses: Arc<Vec<TcpWarpPortConnection>>,
30 ) -> Result<(), Box<dyn Error>> {
31 let mut connections = HashMap::new();
32
33 while let Ok((data, addrs)) = self.connect_with(connections, addresses).await {
34 connections = if keep_connections {
35 data
36 } else {
37 HashMap::new()
38 };
39 addresses = addrs;
40 warn!("retrying in {:?}", retry_delay);
41 delay_for(retry_delay).await;
42 }
43
44 Ok(())
45 }
46
47 async fn connect_with(
48 &self,
49 mut connections: TcpWarpClientResult,
50 addresses: Arc<Vec<TcpWarpPortConnection>>,
51 ) -> Result<(TcpWarpClientResult, Arc<Vec<TcpWarpPortConnection>>), Box<dyn Error>> {
52 let stream = match TcpStream::connect(&self.tunnel_address).await {
53 Ok(stream) => stream,
54 Err(err) => {
55 error!("cannot connect to tunnel: {}", err);
56 return Ok((connections, addresses));
57 }
58 };
59 let (mut wtransport, mut rtransport) = Framed::new(stream, TcpWarpProto).split();
60
61 let (mut sender, mut receiver) = channel(100);
62
63 let forward_task = async move {
64 debug!("in receiver task");
65
66 let mut listeners = vec![];
67
68 while let Some(message) = receiver.next().await {
69 debug!("just received a message connect: {:?}", message);
70 let message = match message {
71 TcpWarpMessage::Connect {
72 connection_id,
73 connection,
74 sender,
75 connected_sender,
76 } => {
77 debug!("adding connection: {}", connection_id);
78 connections.insert(
79 connection_id.clone(),
80 TcpWarpConnection {
81 sender,
82 connected_sender: Some(connected_sender),
83 },
84 );
85 TcpWarpMessage::HostConnect {
86 connection_id,
87 host: connection.host,
88 port: connection.port,
89 }
90 }
91 TcpWarpMessage::Listener(abort_handler) => {
92 listeners.push(abort_handler);
93 continue;
94 }
95 TcpWarpMessage::Disconnect => {
96 debug!("stopping lesteners...");
97 for listener in listeners {
98 listener.abort();
99 }
100 debug!("stopped listeners");
101 break;
102 }
103 TcpWarpMessage::DisconnectHost { ref connection_id } => {
104 if let Some(mut connection) = connections.remove(connection_id) {
105 if let Err(err) = connection.sender.send(message).await {
106 error!("cannot send to channel: {}", err);
107 }
108 } else {
109 error!("connection not found: {}", connection_id);
110 }
111 debug!("connections in pool: {}", connections.len());
112 continue;
113 }
114 TcpWarpMessage::ConnectFailure { ref connection_id } => {
115 if let Some(mut connection) = connections.remove(connection_id) {
116 if let Some(connection_sender) = connection.connected_sender.take() {
117 if let Err(err) = connection_sender.send(Err(io::Error::new(
118 io::ErrorKind::Other,
119 "disonnect propagated",
120 ))) {
121 error!("cannot send to oneshot channel: {:?}", err);
122 }
123 }
124 if let Err(err) = connection.sender.send(message).await {
125 error!("cannot send to channel: {}", err);
126 }
127 } else {
128 error!("connection not found: {}", connection_id);
129 }
130 debug!("connections in pool: {}", connections.len());
131 continue;
132 }
133 TcpWarpMessage::Connected { ref connection_id } => {
134 if let Some(connection) = connections.get_mut(&connection_id) {
135 debug!("start connected loop: {}", connection_id);
136 if let Some(connection_sender) = connection.connected_sender.take() {
137 if let Err(err) = connection_sender.send(Ok(())) {
138 error!("cannot send to oneshot channel: {:?}", err);
139 }
140 }
141 } else {
142 error!("connection not found: {}", connection_id);
143 }
144 continue;
145 }
146 TcpWarpMessage::BytesHost {
147 connection_id,
148 data,
149 } => {
150 if let Some(connection) = connections.get_mut(&connection_id) {
151 debug!(
152 "forward message to host port of connection: {}",
153 connection_id
154 );
155 if let Err(err) = connection
156 .sender
157 .send(TcpWarpMessage::BytesServer { data })
158 .await
159 {
160 error!("cannot send to channel: {}", err);
161 }
162 } else {
163 error!("connection not found: {}", connection_id);
164 }
165 continue;
166 }
167 regular_message => regular_message,
168 };
169 debug!("sending message {:?} from client to tunnel server", message);
170 wtransport.send(message).await?;
171 }
172
173 debug!("no more messages, closing forward task");
174
175 wtransport.close().await?;
176 receiver.close();
177
178 Ok::<TcpWarpClientResult, io::Error>(connections)
179 };
180
181 let bind_address = self.bind_address;
182
183 let _addresses = addresses.clone();
184 let processing_task = async move {
185 while let Some(Ok(message)) = rtransport.next().await {
186 process_host_to_client_message(
187 message,
188 sender.clone(),
189 addresses.clone(),
190 bind_address,
191 )
192 .await?;
193 }
194
195 debug!("processing task for host to client finished");
196
197 if let Err(err) = sender.send(TcpWarpMessage::Disconnect).await {
198 error!("could not send disconnect message {}", err);
199 }
200
201 Ok::<(), io::Error>(())
202 };
203
204 let (connections, _) = try_join!(forward_task, processing_task)?;
205
206 Ok((connections, _addresses))
207 }
208}
209
210async fn process_host_to_client_message(
212 message: TcpWarpMessage,
213 mut sender: Sender<TcpWarpMessage>,
214 addresses: Arc<Vec<TcpWarpPortConnection>>,
215 bind_address: IpAddr,
216) -> Result<(), io::Error> {
217 debug!("{} host to client: {:?}", bind_address, message);
218
219 match message {
220 TcpWarpMessage::AddPorts(_) => {
221 for address in addresses.iter().cloned() {
222 let bind_address =
223 SocketAddr::new(bind_address, address.client_port.unwrap_or(address.port));
224 let sender_ = sender.clone();
225
226 let mut listener = match TcpListener::bind(bind_address).await {
227 Ok(listener) => listener,
228 Err(err) => {
229 error!("could not start listen {}: {}", bind_address, err);
230 return Err(err);
231 }
232 };
233
234 debug!("listen: {:?}", bind_address);
235
236 let abortable_feature = async move {
237 let mut incoming = listener.incoming();
238
239 while let Some(Ok(stream)) = incoming.next().await {
240 let sender__ = sender_.clone();
241
242 let _address = address.clone();
243 spawn(async move {
244 if let Err(e) = process(stream, sender__, _address).await {
245 error!("failed to process connection; error = {}", e);
246 }
247 });
248 }
249
250 debug!("done listen: {:?}", bind_address);
251
252 Ok::<(), io::Error>(())
253 };
254 let (abortable_listener, abort_handler) = abortable(abortable_feature);
255 if let Err(err) = sender.send(TcpWarpMessage::Listener(abort_handler)).await {
256 error!("cannot send message Listener to forward channel: {}", err);
257 }
258 spawn(abortable_listener);
259 }
260 }
261 TcpWarpMessage::BytesHost { .. } => {
262 if let Err(err) = sender.send(message).await {
263 error!("cannot send message BytesHost to forward channel: {}", err);
264 }
265 }
266 TcpWarpMessage::Connected { .. } => {
267 if let Err(err) = sender.send(message).await {
268 error!("cannot send message Connected to forward channel: {}", err);
269 }
270 }
271 TcpWarpMessage::DisconnectHost { .. } => {
272 if let Err(err) = sender.send(message).await {
273 error!(
274 "cannot send message DisconnectHost to forward channel: {}",
275 err
276 );
277 }
278 }
279 TcpWarpMessage::ConnectFailure { .. } => {
280 if let Err(err) = sender.send(message).await {
281 error!(
282 "cannot send message ConnectFailure to forward channel: {}",
283 err
284 );
285 }
286 }
287 other_message => warn!("unsupported message: {:?}", other_message),
288 }
289 Ok(())
290}
291
292async fn process(
293 stream: TcpStream,
294 mut host_sender: Sender<TcpWarpMessage>,
295 address: TcpWarpPortConnection,
296) -> Result<(), Box<dyn Error>> {
297 let connection_id = Uuid::new_v4();
298
299 debug!("new connection: {}", connection_id);
300
301 let (mut wtransport, mut rtransport) =
302 Framed::new(stream, TcpWarpProtoClient { connection_id }).split();
303
304 let (client_sender, mut client_receiver) = channel(100);
305
306 let forward_task = async move {
307 debug!("in receiver task");
308 while let Some(message) = client_receiver.next().await {
309 debug!(
310 "{} just received a message process: {:?}",
311 connection_id, message
312 );
313 match message {
314 TcpWarpMessage::ConnectFailure { .. } => break,
315 TcpWarpMessage::DisconnectHost { .. } => break,
316 TcpWarpMessage::BytesServer { data } => wtransport.send(data).await?,
317 _ => (),
318 }
319 }
320
321 debug!("{} no more messages, closing forward task", connection_id);
322 debug!(
323 "{} closing write channel to client side port",
324 connection_id
325 );
326 wtransport.close().await?;
327 client_receiver.close();
328 debug!("{} write channel to client side port closed", connection_id);
329
330 Ok::<(), io::Error>(())
331 };
332
333 let (connected_sender, connected_receiver) = oneshot::channel();
334
335 host_sender
336 .send(TcpWarpMessage::Connect {
337 connection_id,
338 connection: address,
339 sender: client_sender,
340 connected_sender,
341 })
342 .await?;
343
344 let processing_task = async move {
345 match connected_receiver.await {
346 Err(err) => {
347 error!("{} connection error: {}", connection_id, err);
348 return Ok(());
349 }
350 Ok(Err(err)) => {
351 error!("{} connection error: {}", connection_id, err);
352 return Ok(());
353 }
354 _ => (),
355 }
356
357 while let Some(Ok(message)) = rtransport.next().await {
358 if let Err(err) = host_sender.send(message).await {
359 error!("{} {}", connection_id, err);
360 }
361 }
362
363 debug!(
364 "{} processing task for incoming connection finished",
365 connection_id
366 );
367
368 let message = TcpWarpMessage::DisconnectClient { connection_id };
369 debug!("{} sending disconnect message {:?}", connection_id, message);
370 if let Err(err) = host_sender.send(message).await {
371 error!("{} {}", connection_id, err);
372 }
373 debug!("{} done processing", connection_id);
374
375 Ok::<(), io::Error>(())
376 };
377
378 try_join!(forward_task, processing_task)?;
379
380 debug!("{} full complete process", connection_id);
381
382 Ok(())
383}