rathole/
server.rs

1use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
2use crate::config_watcher::{ConfigChange, ServerServiceChange};
3use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
4use crate::helper::retry_notify_with_deadline;
5use crate::multi_map::MultiMap;
6use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
7use crate::protocol::{
8    self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
9    HASH_WIDTH_IN_BYTES,
10};
11use crate::transport::{SocketOpts, TcpTransport, Transport};
12use anyhow::{anyhow, bail, Context, Result};
13use backoff::backoff::Backoff;
14use backoff::ExponentialBackoff;
15
16use rand::RngCore;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpListener, TcpStream, UdpSocket};
22use tokio::sync::{broadcast, mpsc, RwLock};
23use tokio::time;
24use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
25
26#[cfg(feature = "noise")]
27use crate::transport::NoiseTransport;
28#[cfg(feature = "tls")]
29use crate::transport::TlsTransport;
30#[cfg(feature = "websocket")]
31use crate::transport::WebsocketTransport;
32
33type ServiceDigest = protocol::Digest; // SHA256 of a service name
34type Nonce = protocol::Digest; // Also called `session_key`
35
36const TCP_POOL_SIZE: usize = 8; // The number of cached connections for TCP servies
37const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
38const CHAN_SIZE: usize = 2048; // The capacity of various chans
39const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake
40
41// The entrypoint of running a server
42pub async fn run_server(
43    config: Config,
44    shutdown_rx: broadcast::Receiver<bool>,
45    update_rx: mpsc::Receiver<ConfigChange>,
46) -> Result<()> {
47    let config = match config.server {
48            Some(config) => config,
49            None => {
50                return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
51            }
52        };
53
54    match config.transport.transport_type {
55        TransportType::Tcp => {
56            let mut server = Server::<TcpTransport>::from(config).await?;
57            server.run(shutdown_rx, update_rx).await?;
58        }
59        TransportType::Tls => {
60            #[cfg(feature = "tls")]
61            {
62                let mut server = Server::<TlsTransport>::from(config).await?;
63                server.run(shutdown_rx, update_rx).await?;
64            }
65            #[cfg(not(feature = "tls"))]
66            crate::helper::feature_not_compile("tls")
67        }
68        TransportType::Noise => {
69            #[cfg(feature = "noise")]
70            {
71                let mut server = Server::<NoiseTransport>::from(config).await?;
72                server.run(shutdown_rx, update_rx).await?;
73            }
74            #[cfg(not(feature = "noise"))]
75            crate::helper::feature_not_compile("noise")
76        }
77        TransportType::Websocket => {
78            #[cfg(feature = "websocket")]
79            {
80                let mut server = Server::<WebsocketTransport>::from(config).await?;
81                server.run(shutdown_rx, update_rx).await?;
82            }
83            #[cfg(not(feature = "websocket"))]
84            crate::helper::feature_not_compile("websocket")
85        }
86    }
87
88    Ok(())
89}
90
91// A hash map of ControlChannelHandles, indexed by ServiceDigest or Nonce
92// See also MultiMap
93type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
94
95// Server holds all states of running a server
96struct Server<T: Transport> {
97    // `[server]` config
98    config: Arc<ServerConfig>,
99
100    // `[server.services]` config, indexed by ServiceDigest
101    services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
102    // Collection of contorl channels
103    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
104    // Wrapper around the transport layer
105    transport: Arc<T>,
106}
107
108// Generate a hash map of services which is indexed by ServiceDigest
109fn generate_service_hashmap(
110    server_config: &ServerConfig,
111) -> HashMap<ServiceDigest, ServerServiceConfig> {
112    let mut ret = HashMap::new();
113    for u in &server_config.services {
114        ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
115    }
116    ret
117}
118
119impl<T: 'static + Transport> Server<T> {
120    // Create a server from `[server]`
121    pub async fn from(config: ServerConfig) -> Result<Server<T>> {
122        let config = Arc::new(config);
123        let services = Arc::new(RwLock::new(generate_service_hashmap(&config)));
124        let control_channels = Arc::new(RwLock::new(ControlChannelMap::new()));
125        let transport = Arc::new(T::new(&config.transport)?);
126        Ok(Server {
127            config,
128            services,
129            control_channels,
130            transport,
131        })
132    }
133
134    // The entry point of Server
135    pub async fn run(
136        &mut self,
137        mut shutdown_rx: broadcast::Receiver<bool>,
138        mut update_rx: mpsc::Receiver<ConfigChange>,
139    ) -> Result<()> {
140        // Listen at `server.bind_addr`
141        let l = self
142            .transport
143            .bind(&self.config.bind_addr)
144            .await
145            .with_context(|| "Failed to listen at `server.bind_addr`")?;
146        info!("Listening at {}", self.config.bind_addr);
147
148        // Retry at least every 100ms
149        let mut backoff = ExponentialBackoff {
150            max_interval: Duration::from_millis(100),
151            max_elapsed_time: None,
152            ..Default::default()
153        };
154
155        // Wait for connections and shutdown signals
156        loop {
157            tokio::select! {
158                // Wait for incoming control and data channels
159                ret = self.transport.accept(&l) => {
160                    match ret {
161                        Err(err) => {
162                            // Detects whether it's an IO error
163                            if let Some(err) = err.downcast_ref::<io::Error>() {
164                                // If it is an IO error, then it's possibly an
165                                // EMFILE. So sleep for a while and retry
166                                // TODO: Only sleep for EMFILE, ENFILE, ENOMEM, ENOBUFS
167                                if let Some(d) = backoff.next_backoff() {
168                                    error!("Failed to accept: {:#}. Retry in {:?}...", err, d);
169                                    time::sleep(d).await;
170                                } else {
171                                    // This branch will never be executed according to the current retry policy
172                                    error!("Too many retries. Aborting...");
173                                    break;
174                                }
175                            }
176                            // If it's not an IO error, then it comes from
177                            // the transport layer, so just ignore it
178                        }
179                        Ok((conn, addr)) => {
180                            backoff.reset();
181
182                            // Do transport handshake with a timeout
183                            match time::timeout(Duration::from_secs(HANDSHAKE_TIMEOUT), self.transport.handshake(conn)).await {
184                                Ok(conn) => {
185                                    match conn.with_context(|| "Failed to do transport handshake") {
186                                        Ok(conn) => {
187                                            let services = self.services.clone();
188                                            let control_channels = self.control_channels.clone();
189                                            let server_config = self.config.clone();
190                                            tokio::spawn(async move {
191                                                if let Err(err) = handle_connection(conn, services, control_channels, server_config).await {
192                                                    error!("{:#}", err);
193                                                }
194                                            }.instrument(info_span!("connection", %addr)));
195                                        }, Err(e) => {
196                                            error!("{:#}", e);
197                                        }
198                                    }
199                                },
200                                Err(e) => {
201                                    error!("Transport handshake timeout: {}", e);
202                                }
203                            }
204                        }
205                    }
206                },
207                // Wait for the shutdown signal
208                _ = shutdown_rx.recv() => {
209                    info!("Shuting down gracefully...");
210                    break;
211                },
212                e = update_rx.recv() => {
213                    if let Some(e) = e {
214                        self.handle_hot_reload(e).await;
215                    }
216                }
217            }
218        }
219
220        info!("Shutdown");
221
222        Ok(())
223    }
224
225    async fn handle_hot_reload(&mut self, e: ConfigChange) {
226        match e {
227            ConfigChange::ServerChange(server_change) => match server_change {
228                ServerServiceChange::Add(cfg) => {
229                    let hash = protocol::digest(cfg.name.as_bytes());
230                    let mut wg = self.services.write().await;
231                    let _ = wg.insert(hash, cfg);
232
233                    let mut wg = self.control_channels.write().await;
234                    let _ = wg.remove1(&hash);
235                }
236                ServerServiceChange::Delete(s) => {
237                    let hash = protocol::digest(s.as_bytes());
238                    let _ = self.services.write().await.remove(&hash);
239
240                    let mut wg = self.control_channels.write().await;
241                    let _ = wg.remove1(&hash);
242                }
243            },
244            ignored => warn!("Ignored {:?} since running as a server", ignored),
245        }
246    }
247}
248
249// Handle connections to `server.bind_addr`
250async fn handle_connection<T: 'static + Transport>(
251    mut conn: T::Stream,
252    services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
253    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
254    server_config: Arc<ServerConfig>,
255) -> Result<()> {
256    // Read hello
257    let hello = read_hello(&mut conn).await?;
258    match hello {
259        ControlChannelHello(_, service_digest) => {
260            do_control_channel_handshake(
261                conn,
262                services,
263                control_channels,
264                service_digest,
265                server_config,
266            )
267            .await?;
268        }
269        DataChannelHello(_, nonce) => {
270            do_data_channel_handshake(conn, control_channels, nonce).await?;
271        }
272    }
273    Ok(())
274}
275
276async fn do_control_channel_handshake<T: 'static + Transport>(
277    mut conn: T::Stream,
278    services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
279    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
280    service_digest: ServiceDigest,
281    server_config: Arc<ServerConfig>,
282) -> Result<()> {
283    info!("Try to handshake a control channel");
284
285    T::hint(&conn, SocketOpts::for_control_channel());
286
287    // Generate a nonce
288    let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
289    rand::thread_rng().fill_bytes(&mut nonce);
290
291    // Send hello
292    let hello_send = Hello::ControlChannelHello(
293        protocol::CURRENT_PROTO_VERSION,
294        nonce.clone().try_into().unwrap(),
295    );
296    conn.write_all(&bincode::serialize(&hello_send).unwrap())
297        .await?;
298    conn.flush().await?;
299
300    // Lookup the service
301    let service_config = match services.read().await.get(&service_digest) {
302        Some(v) => v,
303        None => {
304            conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
305                .await?;
306            bail!("No such a service {}", hex::encode(service_digest));
307        }
308    }
309    .to_owned();
310
311    let service_name = &service_config.name;
312
313    // Calculate the checksum
314    let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
315    concat.append(&mut nonce);
316
317    // Read auth
318    let protocol::Auth(d) = read_auth(&mut conn).await?;
319
320    // Validate
321    let session_key = protocol::digest(&concat);
322    if session_key != d {
323        conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
324            .await?;
325        debug!(
326            "Expect {}, but got {}",
327            hex::encode(session_key),
328            hex::encode(d)
329        );
330        bail!("Service {} failed the authentication", service_name);
331    } else {
332        let mut h = control_channels.write().await;
333
334        // If there's already a control channel for the service, then drop the old one.
335        // Because a control channel doesn't report back when it's dead,
336        // the handle in the map could be stall, dropping the old handle enables
337        // the client to reconnect.
338        if h.remove1(&service_digest).is_some() {
339            warn!(
340                "Dropping previous control channel for service {}",
341                service_name
342            );
343        }
344
345        // Send ack
346        conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
347            .await?;
348        conn.flush().await?;
349
350        info!(service = %service_config.name, "Control channel established");
351        let handle =
352            ControlChannelHandle::new(conn, service_config, server_config.heartbeat_interval);
353
354        // Insert the new handle
355        let _ = h.insert(service_digest, session_key, handle);
356    }
357
358    Ok(())
359}
360
361async fn do_data_channel_handshake<T: 'static + Transport>(
362    conn: T::Stream,
363    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
364    nonce: Nonce,
365) -> Result<()> {
366    debug!("Try to handshake a data channel");
367
368    // Validate
369    let control_channels_guard = control_channels.read().await;
370    match control_channels_guard.get2(&nonce) {
371        Some(handle) => {
372            T::hint(&conn, SocketOpts::from_server_cfg(&handle.service));
373
374            // Send the data channel to the corresponding control channel
375            handle
376                .data_ch_tx
377                .send(conn)
378                .await
379                .with_context(|| "Data channel for a stale control channel")?;
380        }
381        None => {
382            warn!("Data channel has incorrect nonce");
383        }
384    }
385    Ok(())
386}
387
388pub struct ControlChannelHandle<T: Transport> {
389    // Shutdown the control channel by dropping it
390    _shutdown_tx: broadcast::Sender<bool>,
391    data_ch_tx: mpsc::Sender<T::Stream>,
392    service: ServerServiceConfig,
393}
394
395impl<T> ControlChannelHandle<T>
396where
397    T: 'static + Transport,
398{
399    // Create a control channel handle, where the control channel handling task
400    // and the connection pool task are created.
401    #[instrument(name = "handle", skip_all, fields(service = %service.name))]
402    fn new(
403        conn: T::Stream,
404        service: ServerServiceConfig,
405        heartbeat_interval: u64,
406    ) -> ControlChannelHandle<T> {
407        // Create a shutdown channel
408        let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
409
410        // Store data channels
411        let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
412
413        // Store data channel creation requests
414        let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
415
416        // Cache some data channels for later use
417        let pool_size = match service.service_type {
418            ServiceType::Tcp => TCP_POOL_SIZE,
419            ServiceType::Udp => UDP_POOL_SIZE,
420        };
421
422        for _i in 0..pool_size {
423            if let Err(e) = data_ch_req_tx.send(true) {
424                error!("Failed to request data channel {}", e);
425            };
426        }
427
428        let shutdown_rx_clone = shutdown_tx.subscribe();
429        let bind_addr = service.bind_addr.clone();
430        match service.service_type {
431            ServiceType::Tcp => tokio::spawn(
432                async move {
433                    if let Err(e) = run_tcp_connection_pool::<T>(
434                        bind_addr,
435                        data_ch_rx,
436                        data_ch_req_tx,
437                        shutdown_rx_clone,
438                    )
439                    .await
440                    .with_context(|| "Failed to run TCP connection pool")
441                    {
442                        error!("{:#}", e);
443                    }
444                }
445                .instrument(Span::current()),
446            ),
447            ServiceType::Udp => tokio::spawn(
448                async move {
449                    if let Err(e) = run_udp_connection_pool::<T>(
450                        bind_addr,
451                        data_ch_rx,
452                        data_ch_req_tx,
453                        shutdown_rx_clone,
454                    )
455                    .await
456                    .with_context(|| "Failed to run TCP connection pool")
457                    {
458                        error!("{:#}", e);
459                    }
460                }
461                .instrument(Span::current()),
462            ),
463        };
464
465        // Create the control channel
466        let ch = ControlChannel::<T> {
467            conn,
468            shutdown_rx,
469            data_ch_req_rx,
470            heartbeat_interval,
471        };
472
473        // Run the control channel
474        tokio::spawn(
475            async move {
476                if let Err(err) = ch.run().await {
477                    error!("{:#}", err);
478                }
479            }
480            .instrument(Span::current()),
481        );
482
483        ControlChannelHandle {
484            _shutdown_tx: shutdown_tx,
485            data_ch_tx,
486            service,
487        }
488    }
489}
490
491// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
492struct ControlChannel<T: Transport> {
493    conn: T::Stream,                               // The connection of control channel
494    shutdown_rx: broadcast::Receiver<bool>,        // Receives the shutdown signal
495    data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
496    heartbeat_interval: u64,                       // Application-layer heartbeat interval in secs
497}
498
499impl<T: Transport> ControlChannel<T> {
500    async fn write_and_flush(&mut self, data: &[u8]) -> Result<()> {
501        self.conn
502            .write_all(data)
503            .await
504            .with_context(|| "Failed to write control cmds")?;
505        self.conn
506            .flush()
507            .await
508            .with_context(|| "Failed to flush control cmds")?;
509        Ok(())
510    }
511    // Run a control channel
512    #[instrument(skip_all)]
513    async fn run(mut self) -> Result<()> {
514        let create_ch_cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
515        let heartbeat = bincode::serialize(&ControlChannelCmd::HeartBeat).unwrap();
516
517        // Wait for data channel requests and the shutdown signal
518        loop {
519            tokio::select! {
520                val = self.data_ch_req_rx.recv() => {
521                    match val {
522                        Some(_) => {
523                            if let Err(e) = self.write_and_flush(&create_ch_cmd).await {
524                                error!("{:#}", e);
525                                break;
526                            }
527                        }
528                        None => {
529                            break;
530                        }
531                    }
532                },
533                _ = time::sleep(Duration::from_secs(self.heartbeat_interval)), if self.heartbeat_interval != 0 => {
534                            if let Err(e) = self.write_and_flush(&heartbeat).await {
535                                error!("{:#}", e);
536                                break;
537                            }
538                }
539                // Wait for the shutdown signal
540                _ = self.shutdown_rx.recv() => {
541                    break;
542                }
543            }
544        }
545
546        info!("Control channel shutdown");
547
548        Ok(())
549    }
550}
551
552fn tcp_listen_and_send(
553    addr: String,
554    data_ch_req_tx: mpsc::UnboundedSender<bool>,
555    mut shutdown_rx: broadcast::Receiver<bool>,
556) -> mpsc::Receiver<TcpStream> {
557    let (tx, rx) = mpsc::channel(CHAN_SIZE);
558
559    tokio::spawn(async move {
560        let l = retry_notify_with_deadline(listen_backoff(),  || async {
561            Ok(TcpListener::bind(&addr).await?)
562        }, |e, duration| {
563            error!("{:#}. Retry in {:?}", e, duration);
564        }, &mut shutdown_rx).await
565        .with_context(|| "Failed to listen for the service");
566
567        let l: TcpListener = match l {
568            Ok(v) => v,
569            Err(e) => {
570                error!("{:#}", e);
571                return;
572            }
573        };
574
575        info!("Listening at {}", &addr);
576
577        // Retry at least every 1s
578        let mut backoff = ExponentialBackoff {
579            max_interval: Duration::from_secs(1),
580            max_elapsed_time: None,
581            ..Default::default()
582        };
583
584        // Wait for visitors and the shutdown signal
585        loop {
586            tokio::select! {
587                val = l.accept() => {
588                    match val {
589                        Err(e) => {
590                            // `l` is a TCP listener so this must be a IO error
591                            // Possibly a EMFILE. So sleep for a while
592                            error!("{}. Sleep for a while", e);
593                            if let Some(d) = backoff.next_backoff() {
594                                time::sleep(d).await;
595                            } else {
596                                // This branch will never be reached for current backoff policy
597                                error!("Too many retries. Aborting...");
598                                break;
599                            }
600                        }
601                        Ok((incoming, addr)) => {
602                            // For every visitor, request to create a data channel
603                            if data_ch_req_tx.send(true).with_context(|| "Failed to send data chan create request").is_err() {
604                                // An error indicates the control channel is broken
605                                // So break the loop
606                                break;
607                            }
608
609                            backoff.reset();
610
611                            debug!("New visitor from {}", addr);
612
613                            // Send the visitor to the connection pool
614                            let _ = tx.send(incoming).await;
615                        }
616                    }
617                },
618                _ = shutdown_rx.recv() => {
619                    break;
620                }
621            }
622        }
623
624        info!("TCPListener shutdown");
625    }.instrument(Span::current()));
626
627    rx
628}
629
630#[instrument(skip_all)]
631async fn run_tcp_connection_pool<T: Transport>(
632    bind_addr: String,
633    mut data_ch_rx: mpsc::Receiver<T::Stream>,
634    data_ch_req_tx: mpsc::UnboundedSender<bool>,
635    shutdown_rx: broadcast::Receiver<bool>,
636) -> Result<()> {
637    let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx.clone(), shutdown_rx);
638    let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
639
640    'pool: while let Some(mut visitor) = visitor_rx.recv().await {
641        loop {
642            if let Some(mut ch) = data_ch_rx.recv().await {
643                if ch.write_all(&cmd).await.is_ok() {
644                    tokio::spawn(async move {
645                        let _ = copy_bidirectional(&mut ch, &mut visitor).await;
646                    });
647                    break;
648                } else {
649                    // Current data channel is broken. Request for a new one
650                    if data_ch_req_tx.send(true).is_err() {
651                        break 'pool;
652                    }
653                }
654            } else {
655                break 'pool;
656            }
657        }
658    }
659
660    info!("Shutdown");
661    Ok(())
662}
663
664#[instrument(skip_all)]
665async fn run_udp_connection_pool<T: Transport>(
666    bind_addr: String,
667    mut data_ch_rx: mpsc::Receiver<T::Stream>,
668    _data_ch_req_tx: mpsc::UnboundedSender<bool>,
669    mut shutdown_rx: broadcast::Receiver<bool>,
670) -> Result<()> {
671    // TODO: Load balance
672
673    let l = retry_notify_with_deadline(
674        listen_backoff(),
675        || async { Ok(UdpSocket::bind(&bind_addr).await?) },
676        |e, duration| {
677            warn!("{:#}. Retry in {:?}", e, duration);
678        },
679        &mut shutdown_rx,
680    )
681    .await
682    .with_context(|| "Failed to listen for the service")?;
683
684    info!("Listening at {}", &bind_addr);
685
686    let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
687
688    // Receive one data channel
689    let mut conn = data_ch_rx
690        .recv()
691        .await
692        .ok_or_else(|| anyhow!("No available data channels"))?;
693    conn.write_all(&cmd).await?;
694
695    let mut buf = [0u8; UDP_BUFFER_SIZE];
696    loop {
697        tokio::select! {
698            // Forward inbound traffic to the client
699            val = l.recv_from(&mut buf) => {
700                let (n, from) = val?;
701                UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
702            },
703
704            // Forward outbound traffic from the client to the visitor
705            hdr_len = conn.read_u8() => {
706                let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
707                l.send_to(&t.data, t.from).await?;
708            }
709
710            _ = shutdown_rx.recv() => {
711                break;
712            }
713        }
714    }
715
716    debug!("UDP pool dropped");
717
718    Ok(())
719}